blob: 9b401f30a657992872e49c3e223c2799edb62129 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001#!/usr/bin/env python
2# Copyright 2018 The Chromium Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style
4# license that can be found in the LICENSE file or at
5# https://developers.google.com/open-source/licenses/bsd
6
7"""
8Component classifier command line tools.
9
10Use this command to submit predictions to the model running
11in production.
12
13Note that in order for this command to work, you must be logged into
14gcloud in the project under which you wish to run commands.
15"""
16from __future__ import print_function
17from __future__ import division
18from __future__ import absolute_import
19
20import argparse
21import json
22import os
23import re
24import sys
25
26import googleapiclient
27from googleapiclient import discovery
28from googleapiclient import errors
29from google.cloud.storage import client, bucket, blob
30from apiclient.discovery import build
31from oauth2client.client import GoogleCredentials
32
33import ml_helpers
34
35credentials = GoogleCredentials.get_application_default()
36
37# This must be identical with settings.component_features.
38COMPONENT_FEATURES = 5000
39
40MODEL_NAME = 'component_top_words'
41
42
43def Predict(args):
44 ml = googleapiclient.discovery.build('ml', 'v1', credentials=credentials)
45
46 with open(args.content) as f:
47 content = f.read()
48
49 project_ID = 'projects/%s' % args.project
50 full_model_name = '%s/models/%s' % (project_ID, MODEL_NAME)
51 model_request = ml.projects().models().get(name=full_model_name)
52 model_response = model_request.execute()
53
54 version_name = model_response['defaultVersion']['name']
55
56 model_name = 'component_trainer_' + re.search("v_(\d+)",
57 version_name).group(1)
58
59 client_obj = client.Client(project=args.project)
60 bucket_name = '%s-mlengine' % args.project
61 bucket_obj = bucket.Bucket(client_obj, bucket_name)
62
63 instance = ml_helpers.GenerateFeaturesRaw([content],
64 COMPONENT_FEATURES,
65 getTopWords(bucket_name,
66 model_name))
67
68
69 request = ml.projects().predict(name=full_model_name, body={
70 'instances': [{'inputs': instance['word_features']}]
71 })
72
73 try:
74 response = request.execute()
75
76
77 bucket_obj.blob = blob.Blob('%s/component_index.json'
78 % model_name, bucket_obj)
79 component_index = bucket_obj.blob.download_as_string()
80 component_index_dict = json.loads(component_index)
81
82 return read_indexes(response, component_index_dict)
83
84 except googleapiclient.errors.HttpError, err:
85 print('There was an error. Check the details:')
86 print(err._get_reason())
87
88
89def getTopWords(bucket_name, model_name):
90 storage = discovery.build('storage', 'v1', credentials=credentials)
91 objects = storage.objects()
92
93 request = objects.get_media(bucket=bucket_name,
94 object=model_name + '/topwords.txt')
95 response = request.execute()
96
97 top_list = response.split()
98 top_words = {}
99 for i in range(len(top_list)):
100 top_words[top_list[i]] = i
101
102 return top_words
103
104
105def read_indexes(response, component_index):
106
107 scores = response['predictions'][0]['scores']
108 highest = scores.index(max(scores))
109
110 component_id = component_index[str(highest)]
111
112 return "Most likely component: index %d, component id %d" % (
113 int(highest), int(component_id))
114
115
116def main():
117 if not credentials and 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
118 print(('GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. '
119 'Exiting.'))
120 sys.exit(1)
121
122 parser = argparse.ArgumentParser(
123 description='Component classifier utilities.')
124 parser.add_argument('--project', '-p', default='monorail-staging')
125
126 parser.add_argument('--content', '-c', required=True,
127 help='A file containing the content.')
128
129 args = parser.parse_args()
130
131 res = Predict(args)
132
133 print(res)
134
135
136if __name__ == '__main__':
137 main()