| # Copyright 2020 The Chromium Authors. All rights reserved. |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| """Classes to manage cached values. |
| |
| Monorail makes full use of the RAM of GAE frontends to reduce latency |
| and load on the database. |
| |
| Even though these caches do invalidation, there are rare race conditions |
| that can cause a somewhat stale object to be retrieved from memcache and |
| then put into a RAM cache and used by a given GAE instance for some time. |
| So, we only use these caches for operations that can tolerate somewhat |
| stale data. For example, displaying issues in a list or displaying brief |
| info about related issues. We never use the cache to load objects as |
| part of a read-modify-save sequence because that could cause stored data |
| to revert to a previous state. |
| """ |
| from __future__ import print_function |
| from __future__ import division |
| from __future__ import absolute_import |
| |
| import logging |
| import redis |
| |
| from protorpc import protobuf |
| |
| from google.appengine.api import memcache |
| |
| import settings |
| from framework import framework_constants |
| from framework import redis_utils |
| from proto import tracker_pb2 |
| |
| |
| DEFAULT_MAX_SIZE = 10000 |
| |
| |
| class RamCache(object): |
| """An in-RAM cache with distributed invalidation.""" |
| |
| def __init__(self, cache_manager, kind, max_size=None): |
| self.cache_manager = cache_manager |
| self.kind = kind |
| self.cache = {} |
| self.max_size = max_size or DEFAULT_MAX_SIZE |
| cache_manager.RegisterCache(self, kind) |
| |
| def CacheItem(self, key, item): |
| """Store item at key in this cache, discarding a random item if needed.""" |
| if len(self.cache) >= self.max_size: |
| self.cache.popitem() |
| |
| self.cache[key] = item |
| |
| def CacheAll(self, new_item_dict): |
| """Cache all items in the given dict, dropping old items if needed.""" |
| if len(new_item_dict) >= self.max_size: |
| logging.warn('Dumping the entire cache! %s', self.kind) |
| self.cache = {} |
| else: |
| while len(self.cache) + len(new_item_dict) > self.max_size: |
| self.cache.popitem() |
| |
| self.cache.update(new_item_dict) |
| |
| def GetItem(self, key): |
| """Return the cached item if present, otherwise None.""" |
| return self.cache.get(key) |
| |
| def HasItem(self, key): |
| """Return True if there is a value cached at the given key.""" |
| return key in self.cache |
| |
| def GetAll(self, keys): |
| """Look up the given keys. |
| |
| Args: |
| keys: a list of cache keys to look up. |
| |
| Returns: |
| A pair: (hits_dict, misses_list) where hits_dict is a dictionary of |
| all the given keys and the values that were found in the cache, and |
| misses_list is a list of given keys that were not in the cache. |
| """ |
| hits, misses = {}, [] |
| for key in keys: |
| try: |
| hits[key] = self.cache[key] |
| except KeyError: |
| misses.append(key) |
| |
| return hits, misses |
| |
| def LocalInvalidate(self, key): |
| """Drop the given key from this cache, without distributed notification.""" |
| if key in self.cache: |
| logging.info('Locally invalidating %r in kind=%r', key, self.kind) |
| self.cache.pop(key, None) |
| |
| def Invalidate(self, cnxn, key): |
| """Drop key locally, and append it to the Invalidate DB table.""" |
| self.InvalidateKeys(cnxn, [key]) |
| |
| def InvalidateKeys(self, cnxn, keys): |
| """Drop keys locally, and append them to the Invalidate DB table.""" |
| for key in keys: |
| self.LocalInvalidate(key) |
| if self.cache_manager: |
| self.cache_manager.StoreInvalidateRows(cnxn, self.kind, keys) |
| |
| def LocalInvalidateAll(self): |
| """Invalidate all keys locally: just start over with an empty dict.""" |
| logging.info('Locally invalidating all in kind=%r', self.kind) |
| self.cache = {} |
| |
| def InvalidateAll(self, cnxn): |
| """Invalidate all keys in this cache.""" |
| self.LocalInvalidateAll() |
| if self.cache_manager: |
| self.cache_manager.StoreInvalidateAll(cnxn, self.kind) |
| |
| |
| class ShardedRamCache(RamCache): |
| """Specialized version of RamCache that stores values in parts. |
| |
| Instead of the cache keys being simple integers, they are pairs, e.g., |
| (project_id, shard_id). Invalidation will invalidate all shards for |
| a given main key, e.g, invalidating project_id 16 will drop keys |
| (16, 0), (16, 1), (16, 2), ... (16, 9). |
| """ |
| |
| def __init__(self, cache_manager, kind, max_size=None, num_shards=10): |
| super(ShardedRamCache, self).__init__( |
| cache_manager, kind, max_size=max_size) |
| self.num_shards = num_shards |
| |
| def LocalInvalidate(self, key): |
| """Use the specified value to drop entries from the local cache.""" |
| logging.info('About to invalidate shared RAM keys %r', |
| [(key, shard_id) for shard_id in range(self.num_shards) |
| if (key, shard_id) in self.cache]) |
| for shard_id in range(self.num_shards): |
| self.cache.pop((key, shard_id), None) |
| |
| |
| class ValueCentricRamCache(RamCache): |
| """Specialized version of RamCache that stores values in InvalidateTable. |
| |
| This is useful for caches that have non integer keys. |
| """ |
| |
| def LocalInvalidate(self, value): |
| """Use the specified value to drop entries from the local cache.""" |
| keys_to_drop = [] |
| # Loop through and collect all keys with the specified value. |
| for k, v in self.cache.items(): |
| if v == value: |
| keys_to_drop.append(k) |
| for k in keys_to_drop: |
| self.cache.pop(k, None) |
| |
| def InvalidateKeys(self, cnxn, keys): |
| """Drop keys locally, and append their values to the Invalidate DB table.""" |
| # Find values to invalidate. |
| values = [self.cache[key] for key in keys if self.cache.has_key(key)] |
| if len(values) == len(keys): |
| for value in values: |
| self.LocalInvalidate(value) |
| if self.cache_manager: |
| self.cache_manager.StoreInvalidateRows(cnxn, self.kind, values) |
| else: |
| # If a value is not found in the cache then invalidate the whole cache. |
| # This is done to ensure that we are not in an inconsistent state or in a |
| # race condition. |
| self.InvalidateAll(cnxn) |
| |
| |
| class AbstractTwoLevelCache(object): |
| """A class to manage both RAM and secondary-caching layer to retrieve objects. |
| |
| Subclasses must implement the FetchItems() method to get objects from |
| the database when both caches miss. |
| """ |
| |
| # When loading a huge number of issues from the database, do it in chunks |
| # so as to avoid timeouts. |
| _FETCH_BATCH_SIZE = 10000 |
| |
| def __init__( |
| self, |
| cache_manager, |
| kind, |
| prefix, |
| pb_class, |
| max_size=None, |
| use_redis=False, |
| redis_client=None): |
| |
| self.cache = self._MakeCache(cache_manager, kind, max_size=max_size) |
| self.prefix = prefix |
| self.pb_class = pb_class |
| |
| if use_redis: |
| self.redis_client = redis_client or redis_utils.CreateRedisClient() |
| self.use_redis = redis_utils.VerifyRedisConnection( |
| self.redis_client, msg=kind) |
| else: |
| self.redis_client = None |
| self.use_redis = False |
| |
| def _MakeCache(self, cache_manager, kind, max_size=None): |
| """Make the RAM cache and register it with the cache_manager.""" |
| return RamCache(cache_manager, kind, max_size=max_size) |
| |
| def CacheItem(self, key, value): |
| """Add the given key-value pair to RAM and L2 cache.""" |
| self.cache.CacheItem(key, value) |
| self._WriteToCache({key: value}) |
| |
| def HasItem(self, key): |
| """Return True if the given key is in the RAM cache.""" |
| return self.cache.HasItem(key) |
| |
| def GetAnyOnHandItem(self, keys, start=None, end=None): |
| """Try to find one of the specified items in RAM.""" |
| if start is None: |
| start = 0 |
| if end is None: |
| end = len(keys) |
| for i in range(start, end): |
| key = keys[i] |
| if self.cache.HasItem(key): |
| return self.cache.GetItem(key) |
| |
| # Note: We could check L2 here too, but the round-trips to L2 |
| # are kind of slow. And, getting too many hits from L2 actually |
| # fills our RAM cache too quickly and could lead to thrashing. |
| |
| return None |
| |
| def GetAll(self, cnxn, keys, use_cache=True, **kwargs): |
| """Get values for the given keys from RAM, the L2 cache, or the DB. |
| |
| Args: |
| cnxn: connection to the database. |
| keys: list of integer keys to look up. |
| use_cache: set to False to always hit the database. |
| **kwargs: any additional keywords are passed to FetchItems(). |
| |
| Returns: |
| A pair: hits, misses. Where hits is {key: value} and misses is |
| a list of any keys that were not found anywhere. |
| """ |
| if use_cache: |
| result_dict, missed_keys = self.cache.GetAll(keys) |
| else: |
| result_dict, missed_keys = {}, list(keys) |
| |
| if missed_keys: |
| if use_cache: |
| cache_hits, missed_keys = self._ReadFromCache(missed_keys) |
| result_dict.update(cache_hits) |
| self.cache.CacheAll(cache_hits) |
| |
| while missed_keys: |
| missed_batch = missed_keys[:self._FETCH_BATCH_SIZE] |
| missed_keys = missed_keys[self._FETCH_BATCH_SIZE:] |
| retrieved_dict = self.FetchItems(cnxn, missed_batch, **kwargs) |
| result_dict.update(retrieved_dict) |
| if use_cache: |
| self.cache.CacheAll(retrieved_dict) |
| self._WriteToCache(retrieved_dict) |
| |
| still_missing_keys = [key for key in keys if key not in result_dict] |
| return result_dict, still_missing_keys |
| |
| def LocalInvalidateAll(self): |
| self.cache.LocalInvalidateAll() |
| |
| def LocalInvalidate(self, key): |
| self.cache.LocalInvalidate(key) |
| |
| def InvalidateKeys(self, cnxn, keys): |
| """Drop the given keys from both RAM and L2 cache.""" |
| self.cache.InvalidateKeys(cnxn, keys) |
| self._DeleteFromCache(keys) |
| |
| def InvalidateAllKeys(self, cnxn, keys): |
| """Drop the given keys from L2 cache and invalidate all keys in RAM. |
| |
| Useful for avoiding inserting many rows into the Invalidate table when |
| invalidating a large group of keys all at once. Only use when necessary. |
| """ |
| self.cache.InvalidateAll(cnxn) |
| self._DeleteFromCache(keys) |
| |
| def GetAllAlreadyInRam(self, keys): |
| """Look only in RAM to return {key: values}, missed_keys.""" |
| result_dict, missed_keys = self.cache.GetAll(keys) |
| return result_dict, missed_keys |
| |
| def InvalidateAllRamEntries(self, cnxn): |
| """Drop all RAM cache entries. It will refill as needed from L2 cache.""" |
| self.cache.InvalidateAll(cnxn) |
| |
| def FetchItems(self, cnxn, keys, **kwargs): |
| """On RAM and L2 cache miss, hit the database.""" |
| raise NotImplementedError() |
| |
| def _ReadFromCache(self, keys): |
| # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int] |
| """Reads a list of keys from secondary caching service. |
| |
| Redis will be used if Redis is enabled and connection is valid; |
| otherwise, memcache will be used. |
| |
| Args: |
| keys: List of integer keys to look up in L2 cache. |
| |
| Returns: |
| A pair: hits, misses. Where hits is {key: value} and misses is |
| a list of any keys that were not found anywhere. |
| """ |
| if self.use_redis: |
| return self._ReadFromRedis(keys) |
| else: |
| return self._ReadFromMemcache(keys) |
| |
| def _WriteToCache(self, retrieved_dict): |
| # type: (Mapping[int, Any]) -> None |
| """Writes a set of key-value pairs to secondary caching service. |
| |
| Redis will be used if Redis is enabled and connection is valid; |
| otherwise, memcache will be used. |
| |
| Args: |
| retrieved_dict: Dictionary contains pairs of key-values to write to cache. |
| """ |
| if self.use_redis: |
| return self._WriteToRedis(retrieved_dict) |
| else: |
| return self._WriteToMemcache(retrieved_dict) |
| |
| def _DeleteFromCache(self, keys): |
| # type: (Sequence[int]) -> None |
| """Selects which cache to delete from. |
| |
| Redis will be used if Redis is enabled and connection is valid; |
| otherwise, memcache will be used. |
| |
| Args: |
| keys: List of integer keys to delete from cache. |
| """ |
| if self.use_redis: |
| return self._DeleteFromRedis(keys) |
| else: |
| return self._DeleteFromMemcache(keys) |
| |
| def _ReadFromMemcache(self, keys): |
| # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int] |
| """Read the given keys from memcache, return {key: value}, missing_keys.""" |
| cache_hits = {} |
| cached_dict = memcache.get_multi( |
| [self._KeyToStr(key) for key in keys], |
| key_prefix=self.prefix, |
| namespace=settings.memcache_namespace) |
| |
| for key_str, serialized_value in cached_dict.items(): |
| value = self._StrToValue(serialized_value) |
| key = self._StrToKey(key_str) |
| cache_hits[key] = value |
| self.cache.CacheItem(key, value) |
| |
| still_missing_keys = [key for key in keys if key not in cache_hits] |
| return cache_hits, still_missing_keys |
| |
| def _WriteToMemcache(self, retrieved_dict): |
| # type: (Mapping[int, int]) -> None |
| """Write entries for each key-value pair to memcache. Encode PBs.""" |
| strs_to_cache = { |
| self._KeyToStr(key): self._ValueToStr(value) |
| for key, value in retrieved_dict.items()} |
| |
| try: |
| memcache.add_multi( |
| strs_to_cache, |
| key_prefix=self.prefix, |
| time=framework_constants.CACHE_EXPIRATION, |
| namespace=settings.memcache_namespace) |
| except ValueError as identifier: |
| # If memcache does not accept the values, ensure that no stale |
| # values are left, then bail out. |
| logging.error('Got memcache error: %r', identifier) |
| self._DeleteFromMemcache(list(strs_to_cache.keys())) |
| return |
| |
| def _DeleteFromMemcache(self, keys): |
| # type: (Sequence[str]) -> None |
| """Delete key-values from memcache. """ |
| memcache.delete_multi( |
| [self._KeyToStr(key) for key in keys], |
| seconds=5, |
| key_prefix=self.prefix, |
| namespace=settings.memcache_namespace) |
| |
| def _WriteToRedis(self, retrieved_dict): |
| # type: (Mapping[int, Any]) -> None |
| """Write entries for each key-value pair to Redis. Encode PBs. |
| |
| Args: |
| retrieved_dict: Dictionary of key-value pairs to write to Redis. |
| """ |
| try: |
| for key, value in retrieved_dict.items(): |
| redis_key = redis_utils.FormatRedisKey(key, prefix=self.prefix) |
| redis_value = self._ValueToStr(value) |
| |
| self.redis_client.setex( |
| redis_key, framework_constants.CACHE_EXPIRATION, redis_value) |
| except redis.RedisError as identifier: |
| logging.error( |
| 'Redis error occurred during write operation: %s', identifier) |
| self._DeleteFromRedis(list(retrieved_dict.keys())) |
| return |
| logging.info( |
| 'cached batch of %d values in redis %s', len(retrieved_dict), |
| self.prefix) |
| |
| def _ReadFromRedis(self, keys): |
| # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int] |
| """Read the given keys from Redis, return {key: value}, missing keys. |
| |
| Args: |
| keys: List of integer keys to read from Redis. |
| |
| Returns: |
| A pair: hits, misses. Where hits is {key: value} and misses is |
| a list of any keys that were not found anywhere. |
| """ |
| cache_hits = {} |
| missing_keys = [] |
| try: |
| values_list = self.redis_client.mget( |
| [redis_utils.FormatRedisKey(key, prefix=self.prefix) for key in keys]) |
| except redis.RedisError as identifier: |
| logging.error( |
| 'Redis error occurred during read operation: %s', identifier) |
| values_list = [None] * len(keys) |
| |
| for key, serialized_value in zip(keys, values_list): |
| if serialized_value: |
| value = self._StrToValue(serialized_value) |
| cache_hits[key] = value |
| self.cache.CacheItem(key, value) |
| else: |
| missing_keys.append(key) |
| logging.info( |
| 'decoded %d values from redis %s, missing %d', len(cache_hits), |
| self.prefix, len(missing_keys)) |
| return cache_hits, missing_keys |
| |
| def _DeleteFromRedis(self, keys): |
| # type: (Sequence[int]) -> None |
| """Delete key-values from redis. |
| |
| Args: |
| keys: List of integer keys to delete. |
| """ |
| try: |
| self.redis_client.delete( |
| *[ |
| redis_utils.FormatRedisKey(key, prefix=self.prefix) |
| for key in keys |
| ]) |
| except redis.RedisError as identifier: |
| logging.error( |
| 'Redis error occurred during delete operation %s', identifier) |
| |
| def _KeyToStr(self, key): |
| # type: (int) -> str |
| """Convert our int IDs to strings for use as memcache keys.""" |
| return str(key) |
| |
| def _StrToKey(self, key_str): |
| # type: (str) -> int |
| """Convert memcache keys back to the ints that we use as IDs.""" |
| return int(key_str) |
| |
| def _ValueToStr(self, value): |
| # type: (Any) -> str |
| """Serialize an application object so that it can be stored in L2 cache.""" |
| if self.use_redis: |
| return redis_utils.SerializeValue(value, pb_class=self.pb_class) |
| else: |
| if not self.pb_class: |
| return value |
| elif self.pb_class == int: |
| return str(value) |
| else: |
| return protobuf.encode_message(value) |
| |
| def _StrToValue(self, serialized_value): |
| # type: (str) -> Any |
| """Deserialize L2 cache string into an application object.""" |
| if self.use_redis: |
| return redis_utils.DeserializeValue( |
| serialized_value, pb_class=self.pb_class) |
| else: |
| if not self.pb_class: |
| return serialized_value |
| elif self.pb_class == int: |
| return int(serialized_value) |
| else: |
| return protobuf.decode_message(self.pb_class, serialized_value) |