# Copyright 2016 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import base64
import binascii
import json
import logging
import os
import time
import flask

from google.appengine.api import app_identity
from google.appengine.api import urlfetch
from google.appengine.ext import db
from google.protobuf import text_format

from infra_libs import ts_mon

import settings
from framework import framework_constants
from mrproto import api_clients_config_pb2


CONFIG_FILE_PATH = os.path.join(
    os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
    'testing', 'api_clients.cfg')
LUCI_CONFIG_URL = (
    'https://config.luci.app/prpc/config.service.v2.Configs/GetConfig')


client_config_svc = None
service_account_map = None
qpm_dict = None
allowed_origins_set = None


class ClientConfig(db.Model):
  configs = db.TextProperty()


_CONFIG_LOADS = ts_mon.CounterMetric(
    'monorail/client_config_svc/loads', 'Results of fetches from luci-config.',
    [ts_mon.BooleanField('success'),
     ts_mon.StringField('type')])

def _process_response(response):
  try:
    utf8_decoded_content = response.content.decode('utf-8')
  except AttributeError:
    logging.error('Response content was not binary: %r', response.content)
    _CONFIG_LOADS.increment({'success': False, 'type': 'json-load-error'})
    raise

  try:
    # Strip the XSSI prefix.
    stripped_content = utf8_decoded_content[len(")]}'"):].strip()
    json_config = json.loads(stripped_content)
  except ValueError:
    logging.error('Response was not JSON: %r', response.content)
    _CONFIG_LOADS.increment({'success': False, 'type': 'json-load-error'})
    raise

  try:
    config_raw_content = json_config['rawContent']
  except KeyError:
    logging.error('JSON missing rawContent: %r', json_config)
    _CONFIG_LOADS.increment({'success': False, 'type': 'json-key-error'})
    raise

  try:
    content_text = base64.b64decode(config_raw_content)
  except binascii.Error:
    logging.error('Content was not b64: %r', config_raw_content)
    _CONFIG_LOADS.increment({'success': False, 'type': 'b64-decode-error'})
    raise

  try:
    cfg = api_clients_config_pb2.ClientCfg()
    text_format.Merge(content_text, cfg)
  except:
    logging.error('Content was not a valid ClientCfg proto: %r', content_text)
    _CONFIG_LOADS.increment({'success': False, 'type': 'proto-load-error'})
    raise

  return content_text


def _CallLuciConfig() -> urlfetch._URLFetchResult:
  authorization_token, _ = app_identity.get_access_token(
      framework_constants.OAUTH_SCOPE)
  response = urlfetch.fetch(
      LUCI_CONFIG_URL,
      method=urlfetch.POST,
      follow_redirects=False,
      headers={
          'Content-Type': 'application/json; charset=utf-8',
          'Authorization': 'Bearer ' + authorization_token,
          'Accept': 'application/json'
      },
      payload=json.dumps(
          {
              'configSet': 'services/monorail-prod',
              'path': 'api_clients.cfg'
          }),
  )
  if response.status_code != 200:
    logging.error('Invalid response from luci-config: %r', response)
    _CONFIG_LOADS.increment({'success': False, 'type': 'luci-cfg-error'})
    flask.abort(500, 'Invalid response from luci-config')
  return response


def GetLoadApiClientConfigs():
  global service_account_map
  global qpm_dict
  response = _CallLuciConfig()

  try:
    config_content_text = _process_response(response)
  except Exception as e:
    flask.abort(500, str(e))

  logging.info('luci-config content decoded: %r.', config_content_text)
  configs = ClientConfig(
      configs=config_content_text, key_name='api_client_configs')
  configs.put()
  service_account_map = None
  qpm_dict = None
  _CONFIG_LOADS.increment({'success': True, 'type': 'success'})

  return ''


class ClientConfigService(object):
  """The persistence layer for client config data."""

  # Reload no more than once every 15 minutes.
  # Different GAE instances can load it at different times,
  # so clients may get inconsistence responses shortly after allowlisting.
  EXPIRES_IN = 15 * framework_constants.SECS_PER_MINUTE

  def __init__(self):
    self.client_configs = None
    self.load_time = 0

  def GetConfigs(self, use_cache=True, cur_time=None):
    """Read client configs."""

    cur_time = cur_time or int(time.time())
    force_load = False
    if not self.client_configs:
      force_load = True
    elif not use_cache:
      force_load = True
    elif cur_time - self.load_time > self.EXPIRES_IN:
      force_load = True

    if force_load:
      if settings.local_mode or settings.unit_test_mode:
        self._ReadFromFilesystem()
      else:
        self._ReadFromDatastore()

    return self.client_configs

  def _ReadFromFilesystem(self):
    try:
      with open(CONFIG_FILE_PATH, 'r') as f:
        content_text = f.read()
      logging.info('Read client configs from local file.')
      cfg = api_clients_config_pb2.ClientCfg()
      text_format.Merge(content_text, cfg)
      self.client_configs = cfg
      self.load_time = int(time.time())
    except Exception as e:
      logging.exception('Failed to read client configs: %s', e)

  def _ReadFromDatastore(self):
    entity = ClientConfig.get_by_key_name('api_client_configs')
    if entity:
      cfg = api_clients_config_pb2.ClientCfg()
      text_format.Merge(entity.configs, cfg)
      self.client_configs = cfg
      self.load_time = int(time.time())
    else:
      logging.error('Failed to get api client configs from datastore.')

  def GetClientIDEmails(self):
    """Get client IDs and Emails."""
    self.GetConfigs(use_cache=True)
    client_ids = [c.client_id for c in self.client_configs.clients]
    client_emails = [c.client_email for c in self.client_configs.clients]
    return client_ids, client_emails

  def GetDisplayNames(self):
    """Get client display names."""
    self.GetConfigs(use_cache=True)
    names_dict = {}
    for client in self.client_configs.clients:
      if client.display_name:
        names_dict[client.client_email] = client.display_name
    return names_dict

  def GetQPM(self):
    """Get client qpm limit."""
    self.GetConfigs(use_cache=True)
    qpm_map = {}
    for client in self.client_configs.clients:
      if client.HasField('qpm_limit'):
        qpm_map[client.client_email] = client.qpm_limit
    return qpm_map

  def GetAllowedOriginsSet(self):
    """Get the set of all allowed origins."""
    self.GetConfigs(use_cache=True)
    origins = set()
    for client in self.client_configs.clients:
      origins.update(client.allowed_origins)
    return origins


def GetClientConfigSvc():
  global client_config_svc
  if client_config_svc is None:
    client_config_svc = ClientConfigService()
  return client_config_svc


def GetServiceAccountMap():
  # typ: () -> Mapping[str, str]
  """Returns only service accounts that have specified display_names."""
  global service_account_map
  if service_account_map is None:
    service_account_map = GetClientConfigSvc().GetDisplayNames()
  return service_account_map


def GetQPMDict():
  global qpm_dict
  if qpm_dict is None:
    qpm_dict = GetClientConfigSvc().GetQPM()
  return qpm_dict


def GetAllowedOriginsSet():
  global allowed_origins_set
  if allowed_origins_set is None:
    allowed_origins_set = GetClientConfigSvc().GetAllowedOriginsSet()
  return allowed_origins_set
