Project import generated by Copybara.

GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/features/test/component_helpers_test.py b/features/test/component_helpers_test.py
new file mode 100644
index 0000000..aa6c761
--- /dev/null
+++ b/features/test/component_helpers_test.py
@@ -0,0 +1,145 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for component prediction endpoints."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import json
+import mock
+import sys
+import unittest
+
+from services import service_manager
+from testing import fake
+
+# Mock cloudstorage before it's imported by component_helpers
+sys.modules['cloudstorage'] = mock.Mock()
+from features import component_helpers
+
+
+class FakeMLEngine(object):
+  def __init__(self, test):
+    self.test = test
+    self.expected_features = None
+    self.scores = None
+    self._execute_response = None
+
+  def projects(self):
+    return self
+
+  def models(self):
+    return self
+
+  def predict(self, name, body):
+    self.test.assertEqual(component_helpers.MODEL_NAME, name)
+    self.test.assertEqual(
+        {'instances': [{'inputs': self.expected_features}]}, body)
+    self._execute_response = {'predictions': [{'scores': self.scores}]}
+    return self
+
+  def get(self, name):
+    self.test.assertEqual(component_helpers.MODEL_NAME, name)
+    self._execute_response = {'defaultVersion': {'name': 'v_1234'}}
+    return self
+
+  def execute(self):
+    response = self._execute_response
+    self._execute_response = None
+    return response
+
+
+class ComponentHelpersTest(unittest.TestCase):
+
+  def setUp(self):
+    self.services = service_manager.Services(
+        config=fake.ConfigService(),
+        user=fake.UserService())
+    self.project = fake.Project(project_name='proj')
+
+    self._ml_engine = FakeMLEngine(self)
+    self._top_words = None
+    self._components_by_index = None
+
+    mock.patch(
+        'services.ml_helpers.setup_ml_engine', lambda: self._ml_engine).start()
+    mock.patch(
+        'features.component_helpers._GetTopWords',
+        lambda _: self._top_words).start()
+    mock.patch('cloudstorage.open', self.cloudstorageOpen).start()
+    mock.patch('settings.component_features', 5).start()
+
+    self.addCleanup(mock.patch.stopall)
+
+  def cloudstorageOpen(self, name, mode):
+    """Create a file mock that returns self._components_by_index when read."""
+    open_fn = mock.mock_open(read_data=json.dumps(self._components_by_index))
+    return open_fn(name, mode)
+
+  def testPredict_Normal(self):
+    """Test normal case when predicted component exists."""
+    component_id = self.services.config.CreateComponentDef(
+        cnxn=None, project_id=self.project.project_id, path='Ruta>Baga',
+        docstring='', deprecated=False, admin_ids=[], cc_ids=[], created=None,
+        creator_id=None, label_ids=[])
+    config = self.services.config.GetProjectConfig(
+        None, self.project.project_id)
+
+    self._top_words = {
+        'foo': 0,
+        'bar': 1,
+        'baz': 2}
+    self._components_by_index = {
+        '0': '123',
+        '1': str(component_id),
+        '2': '789'}
+    self._ml_engine.expected_features = [3, 0, 1, 0, 0]
+    self._ml_engine.scores = [5, 10, 3]
+
+    text = 'foo baz foo foo'
+
+    self.assertEqual(
+        component_id, component_helpers.PredictComponent(text, config))
+
+  def testPredict_UnknownComponentIndex(self):
+    """Test case where the prediction is not in components_by_index."""
+    config = self.services.config.GetProjectConfig(
+        None, self.project.project_id)
+
+    self._top_words = {
+        'foo': 0,
+        'bar': 1,
+        'baz': 2}
+    self._components_by_index = {
+        '0': '123',
+        '1': '456',
+        '2': '789'}
+    self._ml_engine.expected_features = [3, 0, 1, 0, 0]
+    self._ml_engine.scores = [5, 10, 3, 1000]
+
+    text = 'foo baz foo foo'
+
+    self.assertIsNone(component_helpers.PredictComponent(text, config))
+
+  def testPredict_InvalidComponentIndex(self):
+    """Test case where the prediction is not a valid component id."""
+    config = self.services.config.GetProjectConfig(
+        None, self.project.project_id)
+
+    self._top_words = {
+        'foo': 0,
+        'bar': 1,
+        'baz': 2}
+    self._components_by_index = {
+        '0': '123',
+        '1': '456',
+        '2': '789'}
+    self._ml_engine.expected_features = [3, 0, 1, 0, 0]
+    self._ml_engine.scores = [5, 10, 3]
+
+    text = 'foo baz foo foo'
+
+    self.assertIsNone(component_helpers.PredictComponent(text, config))