blob: 9b401f30a657992872e49c3e223c2799edb62129 [file] [log] [blame]
#!/usr/bin/env python
# Copyright 2018 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
"""
Component classifier command line tools.
Use this command to submit predictions to the model running
in production.
Note that in order for this command to work, you must be logged into
gcloud in the project under which you wish to run commands.
"""
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import argparse
import json
import os
import re
import sys
import googleapiclient
from googleapiclient import discovery
from googleapiclient import errors
from google.cloud.storage import client, bucket, blob
from apiclient.discovery import build
from oauth2client.client import GoogleCredentials
import ml_helpers
credentials = GoogleCredentials.get_application_default()
# This must be identical with settings.component_features.
COMPONENT_FEATURES = 5000
MODEL_NAME = 'component_top_words'
def Predict(args):
ml = googleapiclient.discovery.build('ml', 'v1', credentials=credentials)
with open(args.content) as f:
content = f.read()
project_ID = 'projects/%s' % args.project
full_model_name = '%s/models/%s' % (project_ID, MODEL_NAME)
model_request = ml.projects().models().get(name=full_model_name)
model_response = model_request.execute()
version_name = model_response['defaultVersion']['name']
model_name = 'component_trainer_' + re.search("v_(\d+)",
version_name).group(1)
client_obj = client.Client(project=args.project)
bucket_name = '%s-mlengine' % args.project
bucket_obj = bucket.Bucket(client_obj, bucket_name)
instance = ml_helpers.GenerateFeaturesRaw([content],
COMPONENT_FEATURES,
getTopWords(bucket_name,
model_name))
request = ml.projects().predict(name=full_model_name, body={
'instances': [{'inputs': instance['word_features']}]
})
try:
response = request.execute()
bucket_obj.blob = blob.Blob('%s/component_index.json'
% model_name, bucket_obj)
component_index = bucket_obj.blob.download_as_string()
component_index_dict = json.loads(component_index)
return read_indexes(response, component_index_dict)
except googleapiclient.errors.HttpError, err:
print('There was an error. Check the details:')
print(err._get_reason())
def getTopWords(bucket_name, model_name):
storage = discovery.build('storage', 'v1', credentials=credentials)
objects = storage.objects()
request = objects.get_media(bucket=bucket_name,
object=model_name + '/topwords.txt')
response = request.execute()
top_list = response.split()
top_words = {}
for i in range(len(top_list)):
top_words[top_list[i]] = i
return top_words
def read_indexes(response, component_index):
scores = response['predictions'][0]['scores']
highest = scores.index(max(scores))
component_id = component_index[str(highest)]
return "Most likely component: index %d, component id %d" % (
int(highest), int(component_id))
def main():
if not credentials and 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
print(('GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. '
'Exiting.'))
sys.exit(1)
parser = argparse.ArgumentParser(
description='Component classifier utilities.')
parser.add_argument('--project', '-p', default='monorail-staging')
parser.add_argument('--content', '-c', required=True,
help='A file containing the content.')
args = parser.parse_args()
res = Predict(args)
print(res)
if __name__ == '__main__':
main()