Project import generated by Copybara.
GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/features/component_helpers.py b/features/component_helpers.py
new file mode 100644
index 0000000..1392f0b
--- /dev/null
+++ b/features/component_helpers.py
@@ -0,0 +1,127 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import json
+import logging
+import re
+
+import settings
+import cloudstorage
+
+from features import generate_dataset
+from framework import framework_helpers
+from services import ml_helpers
+from tracker import tracker_bizobj
+
+from googleapiclient import discovery
+from oauth2client.client import GoogleCredentials
+
+
+MODEL_NAME = 'projects/{}/models/{}'.format(
+ settings.classifier_project_id, settings.component_model_name)
+
+
+def _GetTopWords(trainer_name): # pragma: no cover
+ # TODO(carapew): Use memcache to get top words rather than storing as a
+ # variable.
+ credentials = GoogleCredentials.get_application_default()
+ storage = discovery.build('storage', 'v1', credentials=credentials)
+ request = storage.objects().get_media(
+ bucket=settings.component_ml_bucket,
+ object=trainer_name + '/topwords.txt')
+ response = request.execute()
+
+ # This turns the top words list into a dictionary for faster feature
+ # generation.
+ return {word: idx for idx, word in enumerate(response.split())}
+
+
+def _GetComponentsByIndex(trainer_name):
+ # TODO(carapew): Memcache the index mapping file.
+ mapping_path = '/%s/%s/component_index.json' % (
+ settings.component_ml_bucket, trainer_name)
+ logging.info('Mapping path full name: %r', mapping_path)
+
+ with cloudstorage.open(mapping_path, 'r') as index_mapping_file:
+ logging.info('Index component mapping opened')
+ mapping = index_mapping_file.read()
+ logging.info(mapping)
+ return json.loads(mapping)
+
+
+@framework_helpers.retry(3)
+def _GetComponentPrediction(ml_engine, instance):
+ """Predict the component from the default model based on the provided text.
+
+ Args:
+ ml_engine: An ML Engine instance for making predictions.
+ instance: The dict object returned from ml_helpers.GenerateFeaturesRaw
+ containing the features generated from the provided text.
+
+ Returns:
+ The index of the component with the highest score. ML engine's predict
+ api returns a dict of the format
+ {'predictions': [{'classes': ['0', '1', ...], 'scores': [.00234, ...]}]}
+ where each class has a score at the same index. Classes are sequential,
+ so the index of the highest score also happens to be the component's
+ index.
+ """
+ body = {'instances': [{'inputs': instance['word_features']}]}
+ request = ml_engine.projects().predict(name=MODEL_NAME, body=body)
+ response = request.execute()
+
+ logging.info('ML Engine API response: %r' % response)
+ scores = response['predictions'][0]['scores']
+
+ return scores.index(max(scores))
+
+
+def PredictComponent(raw_text, config):
+ """Get the component ID predicted for the given text.
+
+ Args:
+ raw_text: The raw text for which we want to predict a component.
+ config: The config of the project. Used to decide if the predicted component
+ is valid.
+
+ Returns:
+ The component ID predicted for the provided component, or None if no
+ component was predicted.
+ """
+ # Set-up ML engine.
+ ml_engine = ml_helpers.setup_ml_engine()
+
+ # Gets the timestamp number from the folder containing the model's trainer
+ # in order to get the correct files for mappings and features.
+ request = ml_engine.projects().models().get(name=MODEL_NAME)
+ response = request.execute()
+
+ version = re.search(r'v_(\d+)', response['defaultVersion']['name']).group(1)
+ trainer_name = 'component_trainer_%s' % version
+
+ top_words = _GetTopWords(trainer_name)
+ components_by_index = _GetComponentsByIndex(trainer_name)
+ logging.info('Length of top words list: %s', len(top_words))
+
+ clean_text = generate_dataset.CleanText(raw_text)
+ instance = ml_helpers.GenerateFeaturesRaw(
+ [clean_text], settings.component_features, top_words)
+
+ # Get the component id with the highest prediction score. Component ids are
+ # stored in GCS as strings, but represented in the app as longs.
+ best_score_index = _GetComponentPrediction(ml_engine, instance)
+ component_id = components_by_index.get(str(best_score_index))
+ if component_id:
+ component_id = int(component_id)
+
+ # The predicted component id might not exist.
+ if tracker_bizobj.FindComponentDefByID(component_id, config) is None:
+ return None
+
+ return component_id