Project import generated by Copybara.

GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/services/client_config_svc.py b/services/client_config_svc.py
new file mode 100644
index 0000000..c0acf03
--- /dev/null
+++ b/services/client_config_svc.py
@@ -0,0 +1,236 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import base64
+import json
+import logging
+import os
+import time
+import urllib
+import webapp2
+
+from google.appengine.api import app_identity
+from google.appengine.api import urlfetch
+from google.appengine.ext import db
+from google.protobuf import text_format
+
+from infra_libs import ts_mon
+
+import settings
+from framework import framework_constants
+from proto import api_clients_config_pb2
+
+
+CONFIG_FILE_PATH = os.path.join(
+    os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
+    'testing', 'api_clients.cfg')
+LUCI_CONFIG_URL = (
+    'https://luci-config.appspot.com/_ah/api/config/v1/config_sets'
+    '/services/monorail-prod/config/api_clients.cfg')
+
+
+client_config_svc = None
+service_account_map = None
+qpm_dict = None
+allowed_origins_set = None
+
+
+class ClientConfig(db.Model):
+  configs = db.TextProperty()
+
+
+# Note: The cron job must have hit the servlet before this will work.
+class LoadApiClientConfigs(webapp2.RequestHandler):
+
+  config_loads = ts_mon.CounterMetric(
+      'monorail/client_config_svc/loads',
+      'Results of fetches from luci-config.',
+      [ts_mon.BooleanField('success'), ts_mon.StringField('type')])
+
+  def get(self):
+    global service_account_map
+    global qpm_dict
+    authorization_token, _ = app_identity.get_access_token(
+      framework_constants.OAUTH_SCOPE)
+    response = urlfetch.fetch(
+      LUCI_CONFIG_URL,
+      method=urlfetch.GET,
+      follow_redirects=False,
+      headers={'Content-Type': 'application/json; charset=UTF-8',
+              'Authorization': 'Bearer ' + authorization_token})
+
+    if response.status_code != 200:
+      logging.error('Invalid response from luci-config: %r', response)
+      self.config_loads.increment({'success': False, 'type': 'luci-cfg-error'})
+      self.abort(500, 'Invalid response from luci-config')
+
+    try:
+      content_text = self._process_response(response)
+    except Exception as e:
+      self.abort(500, str(e))
+
+    logging.info('luci-config content decoded: %r.', content_text)
+    configs = ClientConfig(configs=content_text,
+                            key_name='api_client_configs')
+    configs.put()
+    service_account_map = None
+    qpm_dict = None
+    self.config_loads.increment({'success': True, 'type': 'success'})
+
+  def _process_response(self, response):
+    try:
+      content = json.loads(response.content)
+    except ValueError:
+      logging.error('Response was not JSON: %r', response.content)
+      self.config_loads.increment({'success': False, 'type': 'json-load-error'})
+      raise
+
+    try:
+      config_content = content['content']
+    except KeyError:
+      logging.error('JSON contained no content: %r', content)
+      self.config_loads.increment({'success': False, 'type': 'json-key-error'})
+      raise
+
+    try:
+      content_text = base64.b64decode(config_content)
+    except TypeError:
+      logging.error('Content was not b64: %r', config_content)
+      self.config_loads.increment({'success': False,
+                                   'type': 'b64-decode-error'})
+      raise
+
+    try:
+      cfg = api_clients_config_pb2.ClientCfg()
+      text_format.Merge(content_text, cfg)
+    except:
+      logging.error('Content was not a valid ClientCfg proto: %r', content_text)
+      self.config_loads.increment({'success': False,
+                                   'type': 'proto-load-error'})
+      raise
+
+    return content_text
+
+
+class ClientConfigService(object):
+  """The persistence layer for client config data."""
+
+  # Reload no more than once every 15 minutes.
+  # Different GAE instances can load it at different times,
+  # so clients may get inconsistence responses shortly after allowlisting.
+  EXPIRES_IN = 15 * framework_constants.SECS_PER_MINUTE
+
+  def __init__(self):
+    self.client_configs = None
+    self.load_time = 0
+
+  def GetConfigs(self, use_cache=True, cur_time=None):
+    """Read client configs."""
+
+    cur_time = cur_time or int(time.time())
+    force_load = False
+    if not self.client_configs:
+      force_load = True
+    elif not use_cache:
+      force_load = True
+    elif cur_time - self.load_time > self.EXPIRES_IN:
+      force_load = True
+
+    if force_load:
+      if settings.local_mode or settings.unit_test_mode:
+        self._ReadFromFilesystem()
+      else:
+        self._ReadFromDatastore()
+
+    return self.client_configs
+
+  def _ReadFromFilesystem(self):
+    try:
+      with open(CONFIG_FILE_PATH, 'r') as f:
+        content_text = f.read()
+      logging.info('Read client configs from local file.')
+      cfg = api_clients_config_pb2.ClientCfg()
+      text_format.Merge(content_text, cfg)
+      self.client_configs = cfg
+      self.load_time = int(time.time())
+    except Exception as e:
+      logging.exception('Failed to read client configs: %s', e)
+
+  def _ReadFromDatastore(self):
+    entity = ClientConfig.get_by_key_name('api_client_configs')
+    if entity:
+      cfg = api_clients_config_pb2.ClientCfg()
+      text_format.Merge(entity.configs, cfg)
+      self.client_configs = cfg
+      self.load_time = int(time.time())
+    else:
+      logging.error('Failed to get api client configs from datastore.')
+
+  def GetClientIDEmails(self):
+    """Get client IDs and Emails."""
+    self.GetConfigs(use_cache=True)
+    client_ids = [c.client_id for c in self.client_configs.clients]
+    client_emails = [c.client_email for c in self.client_configs.clients]
+    return client_ids, client_emails
+
+  def GetDisplayNames(self):
+    """Get client display names."""
+    self.GetConfigs(use_cache=True)
+    names_dict = {}
+    for client in self.client_configs.clients:
+      if client.display_name:
+        names_dict[client.client_email] = client.display_name
+    return names_dict
+
+  def GetQPM(self):
+    """Get client qpm limit."""
+    self.GetConfigs(use_cache=True)
+    qpm_map = {}
+    for client in self.client_configs.clients:
+      if client.HasField('qpm_limit'):
+        qpm_map[client.client_email] = client.qpm_limit
+    return qpm_map
+
+  def GetAllowedOriginsSet(self):
+    """Get the set of all allowed origins."""
+    self.GetConfigs(use_cache=True)
+    origins = set()
+    for client in self.client_configs.clients:
+      origins.update(client.allowed_origins)
+    return origins
+
+
+def GetClientConfigSvc():
+  global client_config_svc
+  if client_config_svc is None:
+    client_config_svc = ClientConfigService()
+  return client_config_svc
+
+
+def GetServiceAccountMap():
+  # typ: () -> Mapping[str, str]
+  """Returns only service accounts that have specified display_names."""
+  global service_account_map
+  if service_account_map is None:
+    service_account_map = GetClientConfigSvc().GetDisplayNames()
+  return service_account_map
+
+
+def GetQPMDict():
+  global qpm_dict
+  if qpm_dict is None:
+    qpm_dict = GetClientConfigSvc().GetQPM()
+  return qpm_dict
+
+
+def GetAllowedOriginsSet():
+  global allowed_origins_set
+  if allowed_origins_set is None:
+    allowed_origins_set = GetClientConfigSvc().GetAllowedOriginsSet()
+  return allowed_origins_set