blob: 1392f0ba6bf321a3d4f9f74ffd47130f59b9225d [file] [log] [blame]
# Copyright 2018 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import json
import logging
import re
import settings
import cloudstorage
from features import generate_dataset
from framework import framework_helpers
from services import ml_helpers
from tracker import tracker_bizobj
from googleapiclient import discovery
from oauth2client.client import GoogleCredentials
MODEL_NAME = 'projects/{}/models/{}'.format(
settings.classifier_project_id, settings.component_model_name)
def _GetTopWords(trainer_name): # pragma: no cover
# TODO(carapew): Use memcache to get top words rather than storing as a
# variable.
credentials = GoogleCredentials.get_application_default()
storage = discovery.build('storage', 'v1', credentials=credentials)
request = storage.objects().get_media(
bucket=settings.component_ml_bucket,
object=trainer_name + '/topwords.txt')
response = request.execute()
# This turns the top words list into a dictionary for faster feature
# generation.
return {word: idx for idx, word in enumerate(response.split())}
def _GetComponentsByIndex(trainer_name):
# TODO(carapew): Memcache the index mapping file.
mapping_path = '/%s/%s/component_index.json' % (
settings.component_ml_bucket, trainer_name)
logging.info('Mapping path full name: %r', mapping_path)
with cloudstorage.open(mapping_path, 'r') as index_mapping_file:
logging.info('Index component mapping opened')
mapping = index_mapping_file.read()
logging.info(mapping)
return json.loads(mapping)
@framework_helpers.retry(3)
def _GetComponentPrediction(ml_engine, instance):
"""Predict the component from the default model based on the provided text.
Args:
ml_engine: An ML Engine instance for making predictions.
instance: The dict object returned from ml_helpers.GenerateFeaturesRaw
containing the features generated from the provided text.
Returns:
The index of the component with the highest score. ML engine's predict
api returns a dict of the format
{'predictions': [{'classes': ['0', '1', ...], 'scores': [.00234, ...]}]}
where each class has a score at the same index. Classes are sequential,
so the index of the highest score also happens to be the component's
index.
"""
body = {'instances': [{'inputs': instance['word_features']}]}
request = ml_engine.projects().predict(name=MODEL_NAME, body=body)
response = request.execute()
logging.info('ML Engine API response: %r' % response)
scores = response['predictions'][0]['scores']
return scores.index(max(scores))
def PredictComponent(raw_text, config):
"""Get the component ID predicted for the given text.
Args:
raw_text: The raw text for which we want to predict a component.
config: The config of the project. Used to decide if the predicted component
is valid.
Returns:
The component ID predicted for the provided component, or None if no
component was predicted.
"""
# Set-up ML engine.
ml_engine = ml_helpers.setup_ml_engine()
# Gets the timestamp number from the folder containing the model's trainer
# in order to get the correct files for mappings and features.
request = ml_engine.projects().models().get(name=MODEL_NAME)
response = request.execute()
version = re.search(r'v_(\d+)', response['defaultVersion']['name']).group(1)
trainer_name = 'component_trainer_%s' % version
top_words = _GetTopWords(trainer_name)
components_by_index = _GetComponentsByIndex(trainer_name)
logging.info('Length of top words list: %s', len(top_words))
clean_text = generate_dataset.CleanText(raw_text)
instance = ml_helpers.GenerateFeaturesRaw(
[clean_text], settings.component_features, top_words)
# Get the component id with the highest prediction score. Component ids are
# stored in GCS as strings, but represented in the app as longs.
best_score_index = _GetComponentPrediction(ml_engine, instance)
component_id = components_by_index.get(str(best_score_index))
if component_id:
component_id = int(component_id)
# The predicted component id might not exist.
if tracker_bizobj.FindComponentDefByID(component_id, config) is None:
return None
return component_id