blob: afc9d4d233c48302cd151e49d4ba81c92d3232bc [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001#!/usr/bin/env python
2# Copyright 2016 The Chromium Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style
4# license that can be found in the LICENSE file or at
5# https://developers.google.com/open-source/licenses/bsd
6
7"""
8Spam classifier command line tools.
9
10Use this command to submit predictions locally or to the model running
11in production. See tools/spam/README.md for more context on training
12and model operations.
13
14Note that in order for this command to work, you must be logged into
15gcloud in the project under which you wish to run commands.
16"""
17from __future__ import print_function
18from __future__ import division
19from __future__ import absolute_import
20
21import argparse
22import json
23import os
24import re
25import sys
26import googleapiclient
27
28from google.cloud.storage import client, bucket, blob
29import ml_helpers
30from apiclient.discovery import build
31from oauth2client.client import GoogleCredentials
32
33credentials = GoogleCredentials.get_application_default()
34
35# This must be identical with settings.spam_feature_hashes.
36SPAM_FEATURE_HASHES = 500
37
38MODEL_NAME = 'spam_only_words'
39
40
41def Predict(args):
42 ml = googleapiclient.discovery.build('ml', 'v1', credentials=credentials)
43
44 with open(args.summary) as f:
45 summary = f.read()
46 with open(args.content) as f:
47 content = f.read()
48
49 instance = ml_helpers.GenerateFeaturesRaw([summary, content],
50 SPAM_FEATURE_HASHES)
51
52 project_ID = 'projects/%s' % args.project
53 full_model_name = '%s/models/%s' % (project_ID, MODEL_NAME)
54 request = ml.projects().predict(name=full_model_name, body={
55 'instances': [{'inputs': instance['word_hashes']}]
56 })
57
58 try:
59 response = request.execute()
60 print(response)
61 except googleapiclient.errors.HttpError, err:
62 print('There was an error. Check the details:')
63 print(err._get_reason())
64
65
66def LocalPredict(_):
67 print('This will write /tmp/instances.json.')
68 print('Then you can call:')
69 print(('gcloud ml-engine local predict --json-instances /tmp/instances.json'
70 ' --model-dir {model_dir}'))
71
72 summary = raw_input('Summary: ')
73 description = raw_input('Description: ')
74 instance = ml_helpers.GenerateFeaturesRaw([summary, description],
75 SPAM_FEATURE_HASHES)
76
77 with open('/tmp/instances.json', 'w') as f:
78 json.dump({'inputs': instance['word_hashes']}, f)
79
80
81def get_auc(model_name, bucket_obj):
82 bucket_obj.blob = blob.Blob('%s/eval_data.json' % model_name, bucket_obj)
83 data = bucket_obj.blob.download_as_string()
84 data_dict = json.loads(data)
85 return data_dict['auc'], data_dict['auc_precision_recall']
86
87
88def CompareAccuracy(args):
89 client_obj = client.Client(project=args.project)
90 bucket_name = '%s-mlengine' % args.project
91 bucket_obj = bucket.Bucket(client_obj, bucket_name)
92
93 model1_auc, model1_auc_pr = get_auc(args.model1, bucket_obj)
94 print('%s:\nAUC: %f\tAUC Precision/Recall: %f\n'
95 % (args.model1, model1_auc, model1_auc_pr))
96
97 model2_auc, model2_auc_pr = get_auc(args.model2, bucket_obj)
98 print('%s:\nAUC: %f\tAUC Precision/Recall: %f'
99 % (args.model2, model2_auc, model2_auc_pr))
100
101
102def main():
103 if not credentials and 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
104 print(('GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. '
105 'Exiting.'))
106 sys.exit(1)
107
108 parser = argparse.ArgumentParser(description='Spam classifier utilities.')
109 parser.add_argument('--project', '-p', default='monorail-staging')
110
111 project = parser.parse_known_args()
112 subparsers = parser.add_subparsers(dest='command')
113
114 predict = subparsers.add_parser('predict',
115 help='Submit a prediction to the default model in ML Engine.')
116 predict.add_argument('--summary', help='A file containing the summary.')
117 predict.add_argument('--content', help='A file containing the content.')
118
119 subparsers.add_parser('local-predict',
120 help='Create an instance on the local filesystem to use in prediction.')
121
122 ml = googleapiclient.discovery.build('ml', 'v1', credentials=credentials)
123
124 request = ml.projects().models().get(name='projects/%s/models/%s'
125 % (project[0].project, MODEL_NAME))
126 response = request.execute()
127
128 default_version = re.search(
129 '.*(spam_trainer_\d+).*',
130 response['defaultVersion']['deploymentUri']).group(1)
131
132 compare = subparsers.add_parser('compare-accuracy',
133 help='Compare the accuracy of two models.')
134
135 compare.add_argument('--model1',
136 default=default_version,
137 help='The first model to find the auc values of.')
138
139 # TODO(carapew): Make second default the most recently deployed model
140 compare.add_argument('--model2',
141 default='spam_trainer_1513384515'
142 if project[0].project == 'monorail-staging' else
143 'spam_trainer_1522141200',
144 help='The second model to find the auc values of.')
145
146 args = parser.parse_args()
147
148 cmds = {
149 'predict': Predict,
150 'local-predict': LocalPredict,
151 'compare-accuracy': CompareAccuracy,
152 }
153 res = cmds[args.command](args)
154
155 print(json.dumps(res, indent=2))
156
157
158if __name__ == '__main__':
159 main()