blob: 1392f0ba6bf321a3d4f9f74ffd47130f59b9225d [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2018 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6from __future__ import print_function
7from __future__ import division
8from __future__ import absolute_import
9
10import json
11import logging
12import re
13
14import settings
15import cloudstorage
16
17from features import generate_dataset
18from framework import framework_helpers
19from services import ml_helpers
20from tracker import tracker_bizobj
21
22from googleapiclient import discovery
23from oauth2client.client import GoogleCredentials
24
25
26MODEL_NAME = 'projects/{}/models/{}'.format(
27 settings.classifier_project_id, settings.component_model_name)
28
29
30def _GetTopWords(trainer_name): # pragma: no cover
31 # TODO(carapew): Use memcache to get top words rather than storing as a
32 # variable.
33 credentials = GoogleCredentials.get_application_default()
34 storage = discovery.build('storage', 'v1', credentials=credentials)
35 request = storage.objects().get_media(
36 bucket=settings.component_ml_bucket,
37 object=trainer_name + '/topwords.txt')
38 response = request.execute()
39
40 # This turns the top words list into a dictionary for faster feature
41 # generation.
42 return {word: idx for idx, word in enumerate(response.split())}
43
44
45def _GetComponentsByIndex(trainer_name):
46 # TODO(carapew): Memcache the index mapping file.
47 mapping_path = '/%s/%s/component_index.json' % (
48 settings.component_ml_bucket, trainer_name)
49 logging.info('Mapping path full name: %r', mapping_path)
50
51 with cloudstorage.open(mapping_path, 'r') as index_mapping_file:
52 logging.info('Index component mapping opened')
53 mapping = index_mapping_file.read()
54 logging.info(mapping)
55 return json.loads(mapping)
56
57
58@framework_helpers.retry(3)
59def _GetComponentPrediction(ml_engine, instance):
60 """Predict the component from the default model based on the provided text.
61
62 Args:
63 ml_engine: An ML Engine instance for making predictions.
64 instance: The dict object returned from ml_helpers.GenerateFeaturesRaw
65 containing the features generated from the provided text.
66
67 Returns:
68 The index of the component with the highest score. ML engine's predict
69 api returns a dict of the format
70 {'predictions': [{'classes': ['0', '1', ...], 'scores': [.00234, ...]}]}
71 where each class has a score at the same index. Classes are sequential,
72 so the index of the highest score also happens to be the component's
73 index.
74 """
75 body = {'instances': [{'inputs': instance['word_features']}]}
76 request = ml_engine.projects().predict(name=MODEL_NAME, body=body)
77 response = request.execute()
78
79 logging.info('ML Engine API response: %r' % response)
80 scores = response['predictions'][0]['scores']
81
82 return scores.index(max(scores))
83
84
85def PredictComponent(raw_text, config):
86 """Get the component ID predicted for the given text.
87
88 Args:
89 raw_text: The raw text for which we want to predict a component.
90 config: The config of the project. Used to decide if the predicted component
91 is valid.
92
93 Returns:
94 The component ID predicted for the provided component, or None if no
95 component was predicted.
96 """
97 # Set-up ML engine.
98 ml_engine = ml_helpers.setup_ml_engine()
99
100 # Gets the timestamp number from the folder containing the model's trainer
101 # in order to get the correct files for mappings and features.
102 request = ml_engine.projects().models().get(name=MODEL_NAME)
103 response = request.execute()
104
105 version = re.search(r'v_(\d+)', response['defaultVersion']['name']).group(1)
106 trainer_name = 'component_trainer_%s' % version
107
108 top_words = _GetTopWords(trainer_name)
109 components_by_index = _GetComponentsByIndex(trainer_name)
110 logging.info('Length of top words list: %s', len(top_words))
111
112 clean_text = generate_dataset.CleanText(raw_text)
113 instance = ml_helpers.GenerateFeaturesRaw(
114 [clean_text], settings.component_features, top_words)
115
116 # Get the component id with the highest prediction score. Component ids are
117 # stored in GCS as strings, but represented in the app as longs.
118 best_score_index = _GetComponentPrediction(ml_engine, instance)
119 component_id = components_by_index.get(str(best_score_index))
120 if component_id:
121 component_id = int(component_id)
122
123 # The predicted component id might not exist.
124 if tracker_bizobj.FindComponentDefByID(component_id, config) is None:
125 return None
126
127 return component_id