# Copyright 2020 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Classes to manage cached values.

Monorail makes full use of the RAM of GAE frontends to reduce latency
and load on the database.

Even though these caches do invalidation, there are rare race conditions
that can cause a somewhat stale object to be retrieved from memcache and
then put into a RAM cache and used by a given GAE instance for some time.
So, we only use these caches for operations that can tolerate somewhat
stale data.  For example, displaying issues in a list or displaying brief
info about related issues.  We never use the cache to load objects as
part of a read-modify-save sequence because that could cause stored data
to revert to a previous state.
"""
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import logging

from protorpc import protobuf

from google.appengine.api import memcache

import settings
from framework import framework_constants
from framework import logger


DEFAULT_MAX_SIZE = 10000


class RamCache(object):
  """An in-RAM cache with distributed invalidation."""

  def __init__(self, cache_manager, kind, max_size=None):
    self.cache_manager = cache_manager
    self.kind = kind
    self.cache = {}
    self.max_size = max_size or DEFAULT_MAX_SIZE
    cache_manager.RegisterCache(self, kind)

  def CacheItem(self, key, item):
    """Store item at key in this cache, discarding a random item if needed."""
    if len(self.cache) >= self.max_size:
      self.cache.popitem()

    self.cache[key] = item

  def CacheAll(self, new_item_dict):
    """Cache all items in the given dict, dropping old items if needed."""
    if len(new_item_dict) >= self.max_size:
      logging.warn('Dumping the entire cache! %s', self.kind)
      self.cache = {}
    else:
      while len(self.cache) + len(new_item_dict) > self.max_size:
        self.cache.popitem()

    self.cache.update(new_item_dict)

  def GetItem(self, key):
    """Return the cached item if present, otherwise None."""
    return self.cache.get(key)

  def HasItem(self, key):
    """Return True if there is a value cached at the given key."""
    return key in self.cache

  def GetAll(self, keys):
    """Look up the given keys.

    Args:
      keys: a list of cache keys to look up.

    Returns:
      A pair: (hits_dict, misses_list) where hits_dict is a dictionary of
      all the given keys and the values that were found in the cache, and
      misses_list is a list of given keys that were not in the cache.
    """
    hits, misses = {}, []
    for key in keys:
      try:
        hits[key] = self.cache[key]
      except KeyError:
        misses.append(key)

    return hits, misses

  def LocalInvalidate(self, key):
    """Drop the given key from this cache, without distributed notification."""
    if key in self.cache:
      logging.info('Locally invalidating %r in kind=%r', key, self.kind)
    self.cache.pop(key, None)

  def Invalidate(self, cnxn, key):
    """Drop key locally, and append it to the Invalidate DB table."""
    self.InvalidateKeys(cnxn, [key])

  def InvalidateKeys(self, cnxn, keys):
    """Drop keys locally, and append them to the Invalidate DB table."""
    for key in keys:
      self.LocalInvalidate(key)
    if self.cache_manager:
      self.cache_manager.StoreInvalidateRows(cnxn, self.kind, keys)

  def LocalInvalidateAll(self):
    """Invalidate all keys locally: just start over with an empty dict."""
    logging.info('Locally invalidating all in kind=%r', self.kind)
    self.cache = {}

  def InvalidateAll(self, cnxn):
    """Invalidate all keys in this cache."""
    self.LocalInvalidateAll()
    if self.cache_manager:
      self.cache_manager.StoreInvalidateAll(cnxn, self.kind)


class ShardedRamCache(RamCache):
  """Specialized version of RamCache that stores values in parts.

  Instead of the cache keys being simple integers, they are pairs, e.g.,
  (project_id, shard_id).  Invalidation will invalidate all shards for
  a given main key, e.g, invalidating project_id 16 will drop keys
  (16, 0), (16, 1), (16, 2), ... (16, 9).
  """

  def __init__(self, cache_manager, kind, max_size=None, num_shards=10):
    super(ShardedRamCache, self).__init__(
        cache_manager, kind, max_size=max_size)
    self.num_shards = num_shards

  def LocalInvalidate(self, key):
    """Use the specified value to drop entries from the local cache."""
    logging.info('About to invalidate shared RAM keys %r',
                 [(key, shard_id) for shard_id in range(self.num_shards)
                  if (key, shard_id) in self.cache])
    for shard_id in range(self.num_shards):
      self.cache.pop((key, shard_id), None)


class ValueCentricRamCache(RamCache):
  """Specialized version of RamCache that stores values in InvalidateTable.

  This is useful for caches that have non integer keys.
  """

  def LocalInvalidate(self, value):
    """Use the specified value to drop entries from the local cache."""
    keys_to_drop = []
    # Loop through and collect all keys with the specified value.
    for k, v in self.cache.items():
      if v == value:
        keys_to_drop.append(k)
    for k in keys_to_drop:
      self.cache.pop(k, None)

  def InvalidateKeys(self, cnxn, keys):
    """Drop keys locally, and append their values to the Invalidate DB table."""
    # Find values to invalidate.
    values = [self.cache[key] for key in keys if self.cache.has_key(key)]
    if len(values) == len(keys):
      for value in values:
        self.LocalInvalidate(value)
      if self.cache_manager:
        self.cache_manager.StoreInvalidateRows(cnxn, self.kind, values)
    else:
      # If a value is not found in the cache then invalidate the whole cache.
      # This is done to ensure that we are not in an inconsistent state or in a
      # race condition.
      self.InvalidateAll(cnxn)


class AbstractTwoLevelCache(object):
  """A class to manage both RAM and secondary-caching layer to retrieve objects.

  Subclasses must implement the FetchItems() method to get objects from
  the database when both caches miss.
  """

  # When loading a huge number of issues from the database, do it in chunks
  # so as to avoid timeouts.
  _FETCH_BATCH_SIZE = 10000

  def __init__(self, cache_manager, kind, prefix, pb_class, max_size=None):

    self.cache = self._MakeCache(cache_manager, kind, max_size=max_size)
    self.prefix = prefix
    self.pb_class = pb_class

  def _MakeCache(self, cache_manager, kind, max_size=None):
    """Make the RAM cache and register it with the cache_manager."""
    return RamCache(cache_manager, kind, max_size=max_size)

  def CacheItem(self, key, value):
    """Add the given key-value pair to RAM and L2 cache."""
    self.cache.CacheItem(key, value)
    self._WriteToMemcache({key: value})

  def HasItem(self, key):
    """Return True if the given key is in the RAM cache."""
    return self.cache.HasItem(key)

  def GetAnyOnHandItem(self, keys, start=None, end=None):
    """Try to find one of the specified items in RAM."""
    if start is None:
      start = 0
    if end is None:
      end = len(keys)
    for i in range(start, end):
      key = keys[i]
      if self.cache.HasItem(key):
        return self.cache.GetItem(key)

    # Note: We could check L2 here too, but the round-trips to L2
    # are kind of slow. And, getting too many hits from L2 actually
    # fills our RAM cache too quickly and could lead to thrashing.

    return None

  def GetAll(self, cnxn, keys, use_cache=True, **kwargs):
    """Get values for the given keys from RAM, the L2 cache, or the DB.

    Args:
      cnxn: connection to the database.
      keys: list of integer keys to look up.
      use_cache: set to False to always hit the database.
      **kwargs: any additional keywords are passed to FetchItems().

    Returns:
      A pair: hits, misses.  Where hits is {key: value} and misses is
        a list of any keys that were not found anywhere.
    """
    if use_cache:
      result_dict, missed_keys = self.cache.GetAll(keys)
    else:
      result_dict, missed_keys = {}, list(keys)

    if missed_keys:
      if use_cache:
        cache_hits, missed_keys = self._ReadFromMemcache(missed_keys)
        result_dict.update(cache_hits)
        self.cache.CacheAll(cache_hits)

    while missed_keys:
      missed_batch = missed_keys[:self._FETCH_BATCH_SIZE]
      missed_keys = missed_keys[self._FETCH_BATCH_SIZE:]
      retrieved_dict = self.FetchItems(cnxn, missed_batch, **kwargs)
      result_dict.update(retrieved_dict)
      if use_cache:
        self.cache.CacheAll(retrieved_dict)
        self._WriteToMemcache(retrieved_dict)

    still_missing_keys = [key for key in keys if key not in result_dict]
    if still_missing_keys:
      # The keys were not found in the caches or the DB.
      logger.log(
          {
              'log_type': 'database/missing_keys',
              'kind': self.cache.kind,
              'prefix': self.prefix,
              'count': len(still_missing_keys),
              'keys': str(still_missing_keys)
          })
    return result_dict, still_missing_keys

  def LocalInvalidateAll(self):
    self.cache.LocalInvalidateAll()

  def LocalInvalidate(self, key):
    self.cache.LocalInvalidate(key)

  def InvalidateKeys(self, cnxn, keys):
    """Drop the given keys from both RAM and L2 cache."""
    self.cache.InvalidateKeys(cnxn, keys)
    self._DeleteFromMemcache(keys)

  def InvalidateAllKeys(self, cnxn, keys):
    """Drop the given keys from L2 cache and invalidate all keys in RAM.

    Useful for avoiding inserting many rows into the Invalidate table when
    invalidating a large group of keys all at once. Only use when necessary.
    """
    self.cache.InvalidateAll(cnxn)
    self._DeleteFromMemcache(keys)

  def GetAllAlreadyInRam(self, keys):
    """Look only in RAM to return {key: values}, missed_keys."""
    result_dict, missed_keys = self.cache.GetAll(keys)
    return result_dict, missed_keys

  def InvalidateAllRamEntries(self, cnxn):
    """Drop all RAM cache entries. It will refill as needed from L2 cache."""
    self.cache.InvalidateAll(cnxn)

  def FetchItems(self, cnxn, keys, **kwargs):
    """On RAM and L2 cache miss, hit the database."""
    raise NotImplementedError()

  def _ReadFromMemcache(self, keys):
    # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int]
    """Read the given keys from memcache, return {key: value}, missing_keys."""
    cache_hits = {}
    cached_dict = memcache.get_multi(
        [self._KeyToStr(key) for key in keys],
        key_prefix=self.prefix,
        namespace=settings.memcache_namespace)

    for key_str, serialized_value in cached_dict.items():
      value = self._StrToValue(serialized_value)
      key = self._StrToKey(key_str)
      cache_hits[key] = value
      self.cache.CacheItem(key, value)

    still_missing_keys = [key for key in keys if key not in cache_hits]
    return cache_hits, still_missing_keys

  def _WriteToMemcache(self, retrieved_dict):
    # type: (Mapping[int, int]) -> None
    """Write entries for each key-value pair to memcache.  Encode PBs."""
    strs_to_cache = {
        self._KeyToStr(key): self._ValueToStr(value)
        for key, value in retrieved_dict.items()}

    try:
      memcache.add_multi(
          strs_to_cache,
          key_prefix=self.prefix,
          time=framework_constants.CACHE_EXPIRATION,
          namespace=settings.memcache_namespace)
    except ValueError as identifier:
      # If memcache does not accept the values, ensure that no stale
      # values are left, then bail out.
      logging.error('Got memcache error: %r', identifier)
      self._DeleteFromMemcache(list(strs_to_cache.keys()))
      return

  def _DeleteFromMemcache(self, keys):
    # type: (Sequence[str]) -> None
    """Delete key-values from memcache. """
    logger.log(
        {
            'log_type': 'cache/memcache/delete',
            'kind': self.cache.kind,
            'prefix': self.prefix,
            'count': len(keys),
            'keys': str(keys)
        })
    memcache.delete_multi(
        [self._KeyToStr(key) for key in keys],
        seconds=5,
        key_prefix=self.prefix,
        namespace=settings.memcache_namespace)

  def _KeyToStr(self, key):
    # type: (int) -> str
    """Convert our int IDs to strings for use as memcache keys."""
    return str(key)

  def _StrToKey(self, key_str):
    # type: (str) -> int
    """Convert memcache keys back to the ints that we use as IDs."""
    return int(key_str)

  def _ValueToStr(self, value):
    # type: (Any) -> str
    """Serialize an application object so that it can be stored in L2 cache."""
    if not self.pb_class:
      return value
    elif self.pb_class == int:
      return str(value)
    else:
      return protobuf.encode_message(value)

  def _StrToValue(self, serialized_value):
    # type: (str) -> Any
    """Deserialize L2 cache string into an application object."""
    if not self.pb_class:
      return serialized_value
    elif self.pb_class == int:
      return int(serialized_value)
    else:
      return protobuf.decode_message(self.pb_class, serialized_value)
