Copybara | 854996b | 2021-09-07 19:36:02 +0000 | [diff] [blame^] | 1 | # Copyright 2018 The Chromium Authors. All rights reserved. |
| 2 | # Use of this source code is governed by a BSD-style |
| 3 | # license that can be found in the LICENSE file or at |
| 4 | # https://developers.google.com/open-source/licenses/bsd |
| 5 | |
| 6 | from __future__ import print_function |
| 7 | from __future__ import division |
| 8 | from __future__ import absolute_import |
| 9 | |
| 10 | import json |
| 11 | import logging |
| 12 | import re |
| 13 | |
| 14 | import settings |
| 15 | import cloudstorage |
| 16 | |
| 17 | from features import generate_dataset |
| 18 | from framework import framework_helpers |
| 19 | from services import ml_helpers |
| 20 | from tracker import tracker_bizobj |
| 21 | |
| 22 | from googleapiclient import discovery |
| 23 | from oauth2client.client import GoogleCredentials |
| 24 | |
| 25 | |
| 26 | MODEL_NAME = 'projects/{}/models/{}'.format( |
| 27 | settings.classifier_project_id, settings.component_model_name) |
| 28 | |
| 29 | |
| 30 | def _GetTopWords(trainer_name): # pragma: no cover |
| 31 | # TODO(carapew): Use memcache to get top words rather than storing as a |
| 32 | # variable. |
| 33 | credentials = GoogleCredentials.get_application_default() |
| 34 | storage = discovery.build('storage', 'v1', credentials=credentials) |
| 35 | request = storage.objects().get_media( |
| 36 | bucket=settings.component_ml_bucket, |
| 37 | object=trainer_name + '/topwords.txt') |
| 38 | response = request.execute() |
| 39 | |
| 40 | # This turns the top words list into a dictionary for faster feature |
| 41 | # generation. |
| 42 | return {word: idx for idx, word in enumerate(response.split())} |
| 43 | |
| 44 | |
| 45 | def _GetComponentsByIndex(trainer_name): |
| 46 | # TODO(carapew): Memcache the index mapping file. |
| 47 | mapping_path = '/%s/%s/component_index.json' % ( |
| 48 | settings.component_ml_bucket, trainer_name) |
| 49 | logging.info('Mapping path full name: %r', mapping_path) |
| 50 | |
| 51 | with cloudstorage.open(mapping_path, 'r') as index_mapping_file: |
| 52 | logging.info('Index component mapping opened') |
| 53 | mapping = index_mapping_file.read() |
| 54 | logging.info(mapping) |
| 55 | return json.loads(mapping) |
| 56 | |
| 57 | |
| 58 | @framework_helpers.retry(3) |
| 59 | def _GetComponentPrediction(ml_engine, instance): |
| 60 | """Predict the component from the default model based on the provided text. |
| 61 | |
| 62 | Args: |
| 63 | ml_engine: An ML Engine instance for making predictions. |
| 64 | instance: The dict object returned from ml_helpers.GenerateFeaturesRaw |
| 65 | containing the features generated from the provided text. |
| 66 | |
| 67 | Returns: |
| 68 | The index of the component with the highest score. ML engine's predict |
| 69 | api returns a dict of the format |
| 70 | {'predictions': [{'classes': ['0', '1', ...], 'scores': [.00234, ...]}]} |
| 71 | where each class has a score at the same index. Classes are sequential, |
| 72 | so the index of the highest score also happens to be the component's |
| 73 | index. |
| 74 | """ |
| 75 | body = {'instances': [{'inputs': instance['word_features']}]} |
| 76 | request = ml_engine.projects().predict(name=MODEL_NAME, body=body) |
| 77 | response = request.execute() |
| 78 | |
| 79 | logging.info('ML Engine API response: %r' % response) |
| 80 | scores = response['predictions'][0]['scores'] |
| 81 | |
| 82 | return scores.index(max(scores)) |
| 83 | |
| 84 | |
| 85 | def PredictComponent(raw_text, config): |
| 86 | """Get the component ID predicted for the given text. |
| 87 | |
| 88 | Args: |
| 89 | raw_text: The raw text for which we want to predict a component. |
| 90 | config: The config of the project. Used to decide if the predicted component |
| 91 | is valid. |
| 92 | |
| 93 | Returns: |
| 94 | The component ID predicted for the provided component, or None if no |
| 95 | component was predicted. |
| 96 | """ |
| 97 | # Set-up ML engine. |
| 98 | ml_engine = ml_helpers.setup_ml_engine() |
| 99 | |
| 100 | # Gets the timestamp number from the folder containing the model's trainer |
| 101 | # in order to get the correct files for mappings and features. |
| 102 | request = ml_engine.projects().models().get(name=MODEL_NAME) |
| 103 | response = request.execute() |
| 104 | |
| 105 | version = re.search(r'v_(\d+)', response['defaultVersion']['name']).group(1) |
| 106 | trainer_name = 'component_trainer_%s' % version |
| 107 | |
| 108 | top_words = _GetTopWords(trainer_name) |
| 109 | components_by_index = _GetComponentsByIndex(trainer_name) |
| 110 | logging.info('Length of top words list: %s', len(top_words)) |
| 111 | |
| 112 | clean_text = generate_dataset.CleanText(raw_text) |
| 113 | instance = ml_helpers.GenerateFeaturesRaw( |
| 114 | [clean_text], settings.component_features, top_words) |
| 115 | |
| 116 | # Get the component id with the highest prediction score. Component ids are |
| 117 | # stored in GCS as strings, but represented in the app as longs. |
| 118 | best_score_index = _GetComponentPrediction(ml_engine, instance) |
| 119 | component_id = components_by_index.get(str(best_score_index)) |
| 120 | if component_id: |
| 121 | component_id = int(component_id) |
| 122 | |
| 123 | # The predicted component id might not exist. |
| 124 | if tracker_bizobj.FindComponentDefByID(component_id, config) is None: |
| 125 | return None |
| 126 | |
| 127 | return component_id |