blob: c0acf03aea03063bad70eca7723112028cc60a17 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2016 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6from __future__ import print_function
7from __future__ import division
8from __future__ import absolute_import
9
10import base64
11import json
12import logging
13import os
14import time
15import urllib
16import webapp2
17
18from google.appengine.api import app_identity
19from google.appengine.api import urlfetch
20from google.appengine.ext import db
21from google.protobuf import text_format
22
23from infra_libs import ts_mon
24
25import settings
26from framework import framework_constants
27from proto import api_clients_config_pb2
28
29
30CONFIG_FILE_PATH = os.path.join(
31 os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
32 'testing', 'api_clients.cfg')
33LUCI_CONFIG_URL = (
34 'https://luci-config.appspot.com/_ah/api/config/v1/config_sets'
35 '/services/monorail-prod/config/api_clients.cfg')
36
37
38client_config_svc = None
39service_account_map = None
40qpm_dict = None
41allowed_origins_set = None
42
43
44class ClientConfig(db.Model):
45 configs = db.TextProperty()
46
47
48# Note: The cron job must have hit the servlet before this will work.
49class LoadApiClientConfigs(webapp2.RequestHandler):
50
51 config_loads = ts_mon.CounterMetric(
52 'monorail/client_config_svc/loads',
53 'Results of fetches from luci-config.',
54 [ts_mon.BooleanField('success'), ts_mon.StringField('type')])
55
56 def get(self):
57 global service_account_map
58 global qpm_dict
59 authorization_token, _ = app_identity.get_access_token(
60 framework_constants.OAUTH_SCOPE)
61 response = urlfetch.fetch(
62 LUCI_CONFIG_URL,
63 method=urlfetch.GET,
64 follow_redirects=False,
65 headers={'Content-Type': 'application/json; charset=UTF-8',
66 'Authorization': 'Bearer ' + authorization_token})
67
68 if response.status_code != 200:
69 logging.error('Invalid response from luci-config: %r', response)
70 self.config_loads.increment({'success': False, 'type': 'luci-cfg-error'})
71 self.abort(500, 'Invalid response from luci-config')
72
73 try:
74 content_text = self._process_response(response)
75 except Exception as e:
76 self.abort(500, str(e))
77
78 logging.info('luci-config content decoded: %r.', content_text)
79 configs = ClientConfig(configs=content_text,
80 key_name='api_client_configs')
81 configs.put()
82 service_account_map = None
83 qpm_dict = None
84 self.config_loads.increment({'success': True, 'type': 'success'})
85
86 def _process_response(self, response):
87 try:
88 content = json.loads(response.content)
89 except ValueError:
90 logging.error('Response was not JSON: %r', response.content)
91 self.config_loads.increment({'success': False, 'type': 'json-load-error'})
92 raise
93
94 try:
95 config_content = content['content']
96 except KeyError:
97 logging.error('JSON contained no content: %r', content)
98 self.config_loads.increment({'success': False, 'type': 'json-key-error'})
99 raise
100
101 try:
102 content_text = base64.b64decode(config_content)
103 except TypeError:
104 logging.error('Content was not b64: %r', config_content)
105 self.config_loads.increment({'success': False,
106 'type': 'b64-decode-error'})
107 raise
108
109 try:
110 cfg = api_clients_config_pb2.ClientCfg()
111 text_format.Merge(content_text, cfg)
112 except:
113 logging.error('Content was not a valid ClientCfg proto: %r', content_text)
114 self.config_loads.increment({'success': False,
115 'type': 'proto-load-error'})
116 raise
117
118 return content_text
119
120
121class ClientConfigService(object):
122 """The persistence layer for client config data."""
123
124 # Reload no more than once every 15 minutes.
125 # Different GAE instances can load it at different times,
126 # so clients may get inconsistence responses shortly after allowlisting.
127 EXPIRES_IN = 15 * framework_constants.SECS_PER_MINUTE
128
129 def __init__(self):
130 self.client_configs = None
131 self.load_time = 0
132
133 def GetConfigs(self, use_cache=True, cur_time=None):
134 """Read client configs."""
135
136 cur_time = cur_time or int(time.time())
137 force_load = False
138 if not self.client_configs:
139 force_load = True
140 elif not use_cache:
141 force_load = True
142 elif cur_time - self.load_time > self.EXPIRES_IN:
143 force_load = True
144
145 if force_load:
146 if settings.local_mode or settings.unit_test_mode:
147 self._ReadFromFilesystem()
148 else:
149 self._ReadFromDatastore()
150
151 return self.client_configs
152
153 def _ReadFromFilesystem(self):
154 try:
155 with open(CONFIG_FILE_PATH, 'r') as f:
156 content_text = f.read()
157 logging.info('Read client configs from local file.')
158 cfg = api_clients_config_pb2.ClientCfg()
159 text_format.Merge(content_text, cfg)
160 self.client_configs = cfg
161 self.load_time = int(time.time())
162 except Exception as e:
163 logging.exception('Failed to read client configs: %s', e)
164
165 def _ReadFromDatastore(self):
166 entity = ClientConfig.get_by_key_name('api_client_configs')
167 if entity:
168 cfg = api_clients_config_pb2.ClientCfg()
169 text_format.Merge(entity.configs, cfg)
170 self.client_configs = cfg
171 self.load_time = int(time.time())
172 else:
173 logging.error('Failed to get api client configs from datastore.')
174
175 def GetClientIDEmails(self):
176 """Get client IDs and Emails."""
177 self.GetConfigs(use_cache=True)
178 client_ids = [c.client_id for c in self.client_configs.clients]
179 client_emails = [c.client_email for c in self.client_configs.clients]
180 return client_ids, client_emails
181
182 def GetDisplayNames(self):
183 """Get client display names."""
184 self.GetConfigs(use_cache=True)
185 names_dict = {}
186 for client in self.client_configs.clients:
187 if client.display_name:
188 names_dict[client.client_email] = client.display_name
189 return names_dict
190
191 def GetQPM(self):
192 """Get client qpm limit."""
193 self.GetConfigs(use_cache=True)
194 qpm_map = {}
195 for client in self.client_configs.clients:
196 if client.HasField('qpm_limit'):
197 qpm_map[client.client_email] = client.qpm_limit
198 return qpm_map
199
200 def GetAllowedOriginsSet(self):
201 """Get the set of all allowed origins."""
202 self.GetConfigs(use_cache=True)
203 origins = set()
204 for client in self.client_configs.clients:
205 origins.update(client.allowed_origins)
206 return origins
207
208
209def GetClientConfigSvc():
210 global client_config_svc
211 if client_config_svc is None:
212 client_config_svc = ClientConfigService()
213 return client_config_svc
214
215
216def GetServiceAccountMap():
217 # typ: () -> Mapping[str, str]
218 """Returns only service accounts that have specified display_names."""
219 global service_account_map
220 if service_account_map is None:
221 service_account_map = GetClientConfigSvc().GetDisplayNames()
222 return service_account_map
223
224
225def GetQPMDict():
226 global qpm_dict
227 if qpm_dict is None:
228 qpm_dict = GetClientConfigSvc().GetQPM()
229 return qpm_dict
230
231
232def GetAllowedOriginsSet():
233 global allowed_origins_set
234 if allowed_origins_set is None:
235 allowed_origins_set = GetClientConfigSvc().GetAllowedOriginsSet()
236 return allowed_origins_set