blob: aa6c76158198ca499402b577462c7791e27a4afd [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2018 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""Unit tests for component prediction endpoints."""
7from __future__ import print_function
8from __future__ import division
9from __future__ import absolute_import
10
11import json
12import mock
13import sys
14import unittest
15
16from services import service_manager
17from testing import fake
18
19# Mock cloudstorage before it's imported by component_helpers
20sys.modules['cloudstorage'] = mock.Mock()
21from features import component_helpers
22
23
24class FakeMLEngine(object):
25 def __init__(self, test):
26 self.test = test
27 self.expected_features = None
28 self.scores = None
29 self._execute_response = None
30
31 def projects(self):
32 return self
33
34 def models(self):
35 return self
36
37 def predict(self, name, body):
38 self.test.assertEqual(component_helpers.MODEL_NAME, name)
39 self.test.assertEqual(
40 {'instances': [{'inputs': self.expected_features}]}, body)
41 self._execute_response = {'predictions': [{'scores': self.scores}]}
42 return self
43
44 def get(self, name):
45 self.test.assertEqual(component_helpers.MODEL_NAME, name)
46 self._execute_response = {'defaultVersion': {'name': 'v_1234'}}
47 return self
48
49 def execute(self):
50 response = self._execute_response
51 self._execute_response = None
52 return response
53
54
55class ComponentHelpersTest(unittest.TestCase):
56
57 def setUp(self):
58 self.services = service_manager.Services(
59 config=fake.ConfigService(),
60 user=fake.UserService())
61 self.project = fake.Project(project_name='proj')
62
63 self._ml_engine = FakeMLEngine(self)
64 self._top_words = None
65 self._components_by_index = None
66
67 mock.patch(
68 'services.ml_helpers.setup_ml_engine', lambda: self._ml_engine).start()
69 mock.patch(
70 'features.component_helpers._GetTopWords',
71 lambda _: self._top_words).start()
72 mock.patch('cloudstorage.open', self.cloudstorageOpen).start()
73 mock.patch('settings.component_features', 5).start()
74
75 self.addCleanup(mock.patch.stopall)
76
77 def cloudstorageOpen(self, name, mode):
78 """Create a file mock that returns self._components_by_index when read."""
79 open_fn = mock.mock_open(read_data=json.dumps(self._components_by_index))
80 return open_fn(name, mode)
81
82 def testPredict_Normal(self):
83 """Test normal case when predicted component exists."""
84 component_id = self.services.config.CreateComponentDef(
85 cnxn=None, project_id=self.project.project_id, path='Ruta>Baga',
86 docstring='', deprecated=False, admin_ids=[], cc_ids=[], created=None,
87 creator_id=None, label_ids=[])
88 config = self.services.config.GetProjectConfig(
89 None, self.project.project_id)
90
91 self._top_words = {
92 'foo': 0,
93 'bar': 1,
94 'baz': 2}
95 self._components_by_index = {
96 '0': '123',
97 '1': str(component_id),
98 '2': '789'}
99 self._ml_engine.expected_features = [3, 0, 1, 0, 0]
100 self._ml_engine.scores = [5, 10, 3]
101
102 text = 'foo baz foo foo'
103
104 self.assertEqual(
105 component_id, component_helpers.PredictComponent(text, config))
106
107 def testPredict_UnknownComponentIndex(self):
108 """Test case where the prediction is not in components_by_index."""
109 config = self.services.config.GetProjectConfig(
110 None, self.project.project_id)
111
112 self._top_words = {
113 'foo': 0,
114 'bar': 1,
115 'baz': 2}
116 self._components_by_index = {
117 '0': '123',
118 '1': '456',
119 '2': '789'}
120 self._ml_engine.expected_features = [3, 0, 1, 0, 0]
121 self._ml_engine.scores = [5, 10, 3, 1000]
122
123 text = 'foo baz foo foo'
124
125 self.assertIsNone(component_helpers.PredictComponent(text, config))
126
127 def testPredict_InvalidComponentIndex(self):
128 """Test case where the prediction is not a valid component id."""
129 config = self.services.config.GetProjectConfig(
130 None, self.project.project_id)
131
132 self._top_words = {
133 'foo': 0,
134 'bar': 1,
135 'baz': 2}
136 self._components_by_index = {
137 '0': '123',
138 '1': '456',
139 '2': '789'}
140 self._ml_engine.expected_features = [3, 0, 1, 0, 0]
141 self._ml_engine.scores = [5, 10, 3]
142
143 text = 'foo baz foo foo'
144
145 self.assertIsNone(component_helpers.PredictComponent(text, config))