Project import generated by Copybara.
GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/tools/ml/component.py b/tools/ml/component.py
new file mode 100755
index 0000000..9b401f3
--- /dev/null
+++ b/tools/ml/component.py
@@ -0,0 +1,137 @@
+#!/usr/bin/env python
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""
+Component classifier command line tools.
+
+Use this command to submit predictions to the model running
+in production.
+
+Note that in order for this command to work, you must be logged into
+gcloud in the project under which you wish to run commands.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import argparse
+import json
+import os
+import re
+import sys
+
+import googleapiclient
+from googleapiclient import discovery
+from googleapiclient import errors
+from google.cloud.storage import client, bucket, blob
+from apiclient.discovery import build
+from oauth2client.client import GoogleCredentials
+
+import ml_helpers
+
+credentials = GoogleCredentials.get_application_default()
+
+# This must be identical with settings.component_features.
+COMPONENT_FEATURES = 5000
+
+MODEL_NAME = 'component_top_words'
+
+
+def Predict(args):
+ ml = googleapiclient.discovery.build('ml', 'v1', credentials=credentials)
+
+ with open(args.content) as f:
+ content = f.read()
+
+ project_ID = 'projects/%s' % args.project
+ full_model_name = '%s/models/%s' % (project_ID, MODEL_NAME)
+ model_request = ml.projects().models().get(name=full_model_name)
+ model_response = model_request.execute()
+
+ version_name = model_response['defaultVersion']['name']
+
+ model_name = 'component_trainer_' + re.search("v_(\d+)",
+ version_name).group(1)
+
+ client_obj = client.Client(project=args.project)
+ bucket_name = '%s-mlengine' % args.project
+ bucket_obj = bucket.Bucket(client_obj, bucket_name)
+
+ instance = ml_helpers.GenerateFeaturesRaw([content],
+ COMPONENT_FEATURES,
+ getTopWords(bucket_name,
+ model_name))
+
+
+ request = ml.projects().predict(name=full_model_name, body={
+ 'instances': [{'inputs': instance['word_features']}]
+ })
+
+ try:
+ response = request.execute()
+
+
+ bucket_obj.blob = blob.Blob('%s/component_index.json'
+ % model_name, bucket_obj)
+ component_index = bucket_obj.blob.download_as_string()
+ component_index_dict = json.loads(component_index)
+
+ return read_indexes(response, component_index_dict)
+
+ except googleapiclient.errors.HttpError, err:
+ print('There was an error. Check the details:')
+ print(err._get_reason())
+
+
+def getTopWords(bucket_name, model_name):
+ storage = discovery.build('storage', 'v1', credentials=credentials)
+ objects = storage.objects()
+
+ request = objects.get_media(bucket=bucket_name,
+ object=model_name + '/topwords.txt')
+ response = request.execute()
+
+ top_list = response.split()
+ top_words = {}
+ for i in range(len(top_list)):
+ top_words[top_list[i]] = i
+
+ return top_words
+
+
+def read_indexes(response, component_index):
+
+ scores = response['predictions'][0]['scores']
+ highest = scores.index(max(scores))
+
+ component_id = component_index[str(highest)]
+
+ return "Most likely component: index %d, component id %d" % (
+ int(highest), int(component_id))
+
+
+def main():
+ if not credentials and 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
+ print(('GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. '
+ 'Exiting.'))
+ sys.exit(1)
+
+ parser = argparse.ArgumentParser(
+ description='Component classifier utilities.')
+ parser.add_argument('--project', '-p', default='monorail-staging')
+
+ parser.add_argument('--content', '-c', required=True,
+ help='A file containing the content.')
+
+ args = parser.parse_args()
+
+ res = Predict(args)
+
+ print(res)
+
+
+if __name__ == '__main__':
+ main()