Copybara | 854996b | 2021-09-07 19:36:02 +0000 | [diff] [blame] | 1 | #!/usr/bin/env python |
| 2 | # Copyright 2016 The Chromium Authors. All rights reserved. |
| 3 | # Use of this source code is governed by a BSD-style |
| 4 | # license that can be found in the LICENSE file or at |
| 5 | # https://developers.google.com/open-source/licenses/bsd |
| 6 | |
| 7 | """ |
| 8 | Spam classifier command line tools. |
| 9 | |
| 10 | Use this command to submit predictions locally or to the model running |
| 11 | in production. See tools/spam/README.md for more context on training |
| 12 | and model operations. |
| 13 | |
| 14 | Note that in order for this command to work, you must be logged into |
| 15 | gcloud in the project under which you wish to run commands. |
| 16 | """ |
| 17 | from __future__ import print_function |
| 18 | from __future__ import division |
| 19 | from __future__ import absolute_import |
| 20 | |
| 21 | import argparse |
| 22 | import json |
| 23 | import os |
| 24 | import re |
| 25 | import sys |
| 26 | import googleapiclient |
| 27 | |
| 28 | from google.cloud.storage import client, bucket, blob |
| 29 | import ml_helpers |
| 30 | from apiclient.discovery import build |
| 31 | from oauth2client.client import GoogleCredentials |
| 32 | |
| 33 | credentials = GoogleCredentials.get_application_default() |
| 34 | |
| 35 | # This must be identical with settings.spam_feature_hashes. |
| 36 | SPAM_FEATURE_HASHES = 500 |
| 37 | |
| 38 | MODEL_NAME = 'spam_only_words' |
| 39 | |
| 40 | |
| 41 | def Predict(args): |
| 42 | ml = googleapiclient.discovery.build('ml', 'v1', credentials=credentials) |
| 43 | |
| 44 | with open(args.summary) as f: |
| 45 | summary = f.read() |
| 46 | with open(args.content) as f: |
| 47 | content = f.read() |
| 48 | |
| 49 | instance = ml_helpers.GenerateFeaturesRaw([summary, content], |
| 50 | SPAM_FEATURE_HASHES) |
| 51 | |
| 52 | project_ID = 'projects/%s' % args.project |
| 53 | full_model_name = '%s/models/%s' % (project_ID, MODEL_NAME) |
| 54 | request = ml.projects().predict(name=full_model_name, body={ |
| 55 | 'instances': [{'inputs': instance['word_hashes']}] |
| 56 | }) |
| 57 | |
| 58 | try: |
| 59 | response = request.execute() |
| 60 | print(response) |
| 61 | except googleapiclient.errors.HttpError, err: |
| 62 | print('There was an error. Check the details:') |
| 63 | print(err._get_reason()) |
| 64 | |
| 65 | |
| 66 | def LocalPredict(_): |
| 67 | print('This will write /tmp/instances.json.') |
| 68 | print('Then you can call:') |
| 69 | print(('gcloud ml-engine local predict --json-instances /tmp/instances.json' |
| 70 | ' --model-dir {model_dir}')) |
| 71 | |
| 72 | summary = raw_input('Summary: ') |
| 73 | description = raw_input('Description: ') |
| 74 | instance = ml_helpers.GenerateFeaturesRaw([summary, description], |
| 75 | SPAM_FEATURE_HASHES) |
| 76 | |
| 77 | with open('/tmp/instances.json', 'w') as f: |
| 78 | json.dump({'inputs': instance['word_hashes']}, f) |
| 79 | |
| 80 | |
| 81 | def get_auc(model_name, bucket_obj): |
| 82 | bucket_obj.blob = blob.Blob('%s/eval_data.json' % model_name, bucket_obj) |
| 83 | data = bucket_obj.blob.download_as_string() |
| 84 | data_dict = json.loads(data) |
| 85 | return data_dict['auc'], data_dict['auc_precision_recall'] |
| 86 | |
| 87 | |
| 88 | def CompareAccuracy(args): |
| 89 | client_obj = client.Client(project=args.project) |
| 90 | bucket_name = '%s-mlengine' % args.project |
| 91 | bucket_obj = bucket.Bucket(client_obj, bucket_name) |
| 92 | |
| 93 | model1_auc, model1_auc_pr = get_auc(args.model1, bucket_obj) |
| 94 | print('%s:\nAUC: %f\tAUC Precision/Recall: %f\n' |
| 95 | % (args.model1, model1_auc, model1_auc_pr)) |
| 96 | |
| 97 | model2_auc, model2_auc_pr = get_auc(args.model2, bucket_obj) |
| 98 | print('%s:\nAUC: %f\tAUC Precision/Recall: %f' |
| 99 | % (args.model2, model2_auc, model2_auc_pr)) |
| 100 | |
| 101 | |
| 102 | def main(): |
| 103 | if not credentials and 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ: |
| 104 | print(('GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. ' |
| 105 | 'Exiting.')) |
| 106 | sys.exit(1) |
| 107 | |
| 108 | parser = argparse.ArgumentParser(description='Spam classifier utilities.') |
| 109 | parser.add_argument('--project', '-p', default='monorail-staging') |
| 110 | |
| 111 | project = parser.parse_known_args() |
| 112 | subparsers = parser.add_subparsers(dest='command') |
| 113 | |
| 114 | predict = subparsers.add_parser('predict', |
| 115 | help='Submit a prediction to the default model in ML Engine.') |
| 116 | predict.add_argument('--summary', help='A file containing the summary.') |
| 117 | predict.add_argument('--content', help='A file containing the content.') |
| 118 | |
| 119 | subparsers.add_parser('local-predict', |
| 120 | help='Create an instance on the local filesystem to use in prediction.') |
| 121 | |
| 122 | ml = googleapiclient.discovery.build('ml', 'v1', credentials=credentials) |
| 123 | |
| 124 | request = ml.projects().models().get(name='projects/%s/models/%s' |
| 125 | % (project[0].project, MODEL_NAME)) |
| 126 | response = request.execute() |
| 127 | |
| 128 | default_version = re.search( |
| 129 | '.*(spam_trainer_\d+).*', |
| 130 | response['defaultVersion']['deploymentUri']).group(1) |
| 131 | |
| 132 | compare = subparsers.add_parser('compare-accuracy', |
| 133 | help='Compare the accuracy of two models.') |
| 134 | |
| 135 | compare.add_argument('--model1', |
| 136 | default=default_version, |
| 137 | help='The first model to find the auc values of.') |
| 138 | |
| 139 | # TODO(carapew): Make second default the most recently deployed model |
| 140 | compare.add_argument('--model2', |
| 141 | default='spam_trainer_1513384515' |
| 142 | if project[0].project == 'monorail-staging' else |
| 143 | 'spam_trainer_1522141200', |
| 144 | help='The second model to find the auc values of.') |
| 145 | |
| 146 | args = parser.parse_args() |
| 147 | |
| 148 | cmds = { |
| 149 | 'predict': Predict, |
| 150 | 'local-predict': LocalPredict, |
| 151 | 'compare-accuracy': CompareAccuracy, |
| 152 | } |
| 153 | res = cmds[args.command](args) |
| 154 | |
| 155 | print(json.dumps(res, indent=2)) |
| 156 | |
| 157 | |
| 158 | if __name__ == '__main__': |
| 159 | main() |