Copybara | 854996b | 2021-09-07 19:36:02 +0000 | [diff] [blame] | 1 | # Copyright 2020 The Chromium Authors. All rights reserved. |
| 2 | # Use of this source code is governed by a BSD-style license that can be |
| 3 | # found in the LICENSE file. |
| 4 | """Classes to manage cached values. |
| 5 | |
| 6 | Monorail makes full use of the RAM of GAE frontends to reduce latency |
| 7 | and load on the database. |
| 8 | |
| 9 | Even though these caches do invalidation, there are rare race conditions |
| 10 | that can cause a somewhat stale object to be retrieved from memcache and |
| 11 | then put into a RAM cache and used by a given GAE instance for some time. |
| 12 | So, we only use these caches for operations that can tolerate somewhat |
| 13 | stale data. For example, displaying issues in a list or displaying brief |
| 14 | info about related issues. We never use the cache to load objects as |
| 15 | part of a read-modify-save sequence because that could cause stored data |
| 16 | to revert to a previous state. |
| 17 | """ |
| 18 | from __future__ import print_function |
| 19 | from __future__ import division |
| 20 | from __future__ import absolute_import |
| 21 | |
| 22 | import logging |
| 23 | import redis |
| 24 | |
| 25 | from protorpc import protobuf |
| 26 | |
| 27 | from google.appengine.api import memcache |
| 28 | |
| 29 | import settings |
| 30 | from framework import framework_constants |
| 31 | from framework import redis_utils |
| 32 | from proto import tracker_pb2 |
| 33 | |
| 34 | |
| 35 | DEFAULT_MAX_SIZE = 10000 |
| 36 | |
| 37 | |
| 38 | class RamCache(object): |
| 39 | """An in-RAM cache with distributed invalidation.""" |
| 40 | |
| 41 | def __init__(self, cache_manager, kind, max_size=None): |
| 42 | self.cache_manager = cache_manager |
| 43 | self.kind = kind |
| 44 | self.cache = {} |
| 45 | self.max_size = max_size or DEFAULT_MAX_SIZE |
| 46 | cache_manager.RegisterCache(self, kind) |
| 47 | |
| 48 | def CacheItem(self, key, item): |
| 49 | """Store item at key in this cache, discarding a random item if needed.""" |
| 50 | if len(self.cache) >= self.max_size: |
| 51 | self.cache.popitem() |
| 52 | |
| 53 | self.cache[key] = item |
| 54 | |
| 55 | def CacheAll(self, new_item_dict): |
| 56 | """Cache all items in the given dict, dropping old items if needed.""" |
| 57 | if len(new_item_dict) >= self.max_size: |
| 58 | logging.warn('Dumping the entire cache! %s', self.kind) |
| 59 | self.cache = {} |
| 60 | else: |
| 61 | while len(self.cache) + len(new_item_dict) > self.max_size: |
| 62 | self.cache.popitem() |
| 63 | |
| 64 | self.cache.update(new_item_dict) |
| 65 | |
| 66 | def GetItem(self, key): |
| 67 | """Return the cached item if present, otherwise None.""" |
| 68 | return self.cache.get(key) |
| 69 | |
| 70 | def HasItem(self, key): |
| 71 | """Return True if there is a value cached at the given key.""" |
| 72 | return key in self.cache |
| 73 | |
| 74 | def GetAll(self, keys): |
| 75 | """Look up the given keys. |
| 76 | |
| 77 | Args: |
| 78 | keys: a list of cache keys to look up. |
| 79 | |
| 80 | Returns: |
| 81 | A pair: (hits_dict, misses_list) where hits_dict is a dictionary of |
| 82 | all the given keys and the values that were found in the cache, and |
| 83 | misses_list is a list of given keys that were not in the cache. |
| 84 | """ |
| 85 | hits, misses = {}, [] |
| 86 | for key in keys: |
| 87 | try: |
| 88 | hits[key] = self.cache[key] |
| 89 | except KeyError: |
| 90 | misses.append(key) |
| 91 | |
| 92 | return hits, misses |
| 93 | |
| 94 | def LocalInvalidate(self, key): |
| 95 | """Drop the given key from this cache, without distributed notification.""" |
| 96 | if key in self.cache: |
| 97 | logging.info('Locally invalidating %r in kind=%r', key, self.kind) |
| 98 | self.cache.pop(key, None) |
| 99 | |
| 100 | def Invalidate(self, cnxn, key): |
| 101 | """Drop key locally, and append it to the Invalidate DB table.""" |
| 102 | self.InvalidateKeys(cnxn, [key]) |
| 103 | |
| 104 | def InvalidateKeys(self, cnxn, keys): |
| 105 | """Drop keys locally, and append them to the Invalidate DB table.""" |
| 106 | for key in keys: |
| 107 | self.LocalInvalidate(key) |
| 108 | if self.cache_manager: |
| 109 | self.cache_manager.StoreInvalidateRows(cnxn, self.kind, keys) |
| 110 | |
| 111 | def LocalInvalidateAll(self): |
| 112 | """Invalidate all keys locally: just start over with an empty dict.""" |
| 113 | logging.info('Locally invalidating all in kind=%r', self.kind) |
| 114 | self.cache = {} |
| 115 | |
| 116 | def InvalidateAll(self, cnxn): |
| 117 | """Invalidate all keys in this cache.""" |
| 118 | self.LocalInvalidateAll() |
| 119 | if self.cache_manager: |
| 120 | self.cache_manager.StoreInvalidateAll(cnxn, self.kind) |
| 121 | |
| 122 | |
| 123 | class ShardedRamCache(RamCache): |
| 124 | """Specialized version of RamCache that stores values in parts. |
| 125 | |
| 126 | Instead of the cache keys being simple integers, they are pairs, e.g., |
| 127 | (project_id, shard_id). Invalidation will invalidate all shards for |
| 128 | a given main key, e.g, invalidating project_id 16 will drop keys |
| 129 | (16, 0), (16, 1), (16, 2), ... (16, 9). |
| 130 | """ |
| 131 | |
| 132 | def __init__(self, cache_manager, kind, max_size=None, num_shards=10): |
| 133 | super(ShardedRamCache, self).__init__( |
| 134 | cache_manager, kind, max_size=max_size) |
| 135 | self.num_shards = num_shards |
| 136 | |
| 137 | def LocalInvalidate(self, key): |
| 138 | """Use the specified value to drop entries from the local cache.""" |
| 139 | logging.info('About to invalidate shared RAM keys %r', |
| 140 | [(key, shard_id) for shard_id in range(self.num_shards) |
| 141 | if (key, shard_id) in self.cache]) |
| 142 | for shard_id in range(self.num_shards): |
| 143 | self.cache.pop((key, shard_id), None) |
| 144 | |
| 145 | |
| 146 | class ValueCentricRamCache(RamCache): |
| 147 | """Specialized version of RamCache that stores values in InvalidateTable. |
| 148 | |
| 149 | This is useful for caches that have non integer keys. |
| 150 | """ |
| 151 | |
| 152 | def LocalInvalidate(self, value): |
| 153 | """Use the specified value to drop entries from the local cache.""" |
| 154 | keys_to_drop = [] |
| 155 | # Loop through and collect all keys with the specified value. |
| 156 | for k, v in self.cache.items(): |
| 157 | if v == value: |
| 158 | keys_to_drop.append(k) |
| 159 | for k in keys_to_drop: |
| 160 | self.cache.pop(k, None) |
| 161 | |
| 162 | def InvalidateKeys(self, cnxn, keys): |
| 163 | """Drop keys locally, and append their values to the Invalidate DB table.""" |
| 164 | # Find values to invalidate. |
| 165 | values = [self.cache[key] for key in keys if self.cache.has_key(key)] |
| 166 | if len(values) == len(keys): |
| 167 | for value in values: |
| 168 | self.LocalInvalidate(value) |
| 169 | if self.cache_manager: |
| 170 | self.cache_manager.StoreInvalidateRows(cnxn, self.kind, values) |
| 171 | else: |
| 172 | # If a value is not found in the cache then invalidate the whole cache. |
| 173 | # This is done to ensure that we are not in an inconsistent state or in a |
| 174 | # race condition. |
| 175 | self.InvalidateAll(cnxn) |
| 176 | |
| 177 | |
| 178 | class AbstractTwoLevelCache(object): |
| 179 | """A class to manage both RAM and secondary-caching layer to retrieve objects. |
| 180 | |
| 181 | Subclasses must implement the FetchItems() method to get objects from |
| 182 | the database when both caches miss. |
| 183 | """ |
| 184 | |
| 185 | # When loading a huge number of issues from the database, do it in chunks |
| 186 | # so as to avoid timeouts. |
| 187 | _FETCH_BATCH_SIZE = 10000 |
| 188 | |
| 189 | def __init__( |
| 190 | self, |
| 191 | cache_manager, |
| 192 | kind, |
| 193 | prefix, |
| 194 | pb_class, |
| 195 | max_size=None, |
| 196 | use_redis=False, |
| 197 | redis_client=None): |
| 198 | |
| 199 | self.cache = self._MakeCache(cache_manager, kind, max_size=max_size) |
| 200 | self.prefix = prefix |
| 201 | self.pb_class = pb_class |
| 202 | |
| 203 | if use_redis: |
| 204 | self.redis_client = redis_client or redis_utils.CreateRedisClient() |
| 205 | self.use_redis = redis_utils.VerifyRedisConnection( |
| 206 | self.redis_client, msg=kind) |
| 207 | else: |
| 208 | self.redis_client = None |
| 209 | self.use_redis = False |
| 210 | |
| 211 | def _MakeCache(self, cache_manager, kind, max_size=None): |
| 212 | """Make the RAM cache and register it with the cache_manager.""" |
| 213 | return RamCache(cache_manager, kind, max_size=max_size) |
| 214 | |
| 215 | def CacheItem(self, key, value): |
| 216 | """Add the given key-value pair to RAM and L2 cache.""" |
| 217 | self.cache.CacheItem(key, value) |
| 218 | self._WriteToCache({key: value}) |
| 219 | |
| 220 | def HasItem(self, key): |
| 221 | """Return True if the given key is in the RAM cache.""" |
| 222 | return self.cache.HasItem(key) |
| 223 | |
| 224 | def GetAnyOnHandItem(self, keys, start=None, end=None): |
| 225 | """Try to find one of the specified items in RAM.""" |
| 226 | if start is None: |
| 227 | start = 0 |
| 228 | if end is None: |
| 229 | end = len(keys) |
| 230 | for i in range(start, end): |
| 231 | key = keys[i] |
| 232 | if self.cache.HasItem(key): |
| 233 | return self.cache.GetItem(key) |
| 234 | |
| 235 | # Note: We could check L2 here too, but the round-trips to L2 |
| 236 | # are kind of slow. And, getting too many hits from L2 actually |
| 237 | # fills our RAM cache too quickly and could lead to thrashing. |
| 238 | |
| 239 | return None |
| 240 | |
| 241 | def GetAll(self, cnxn, keys, use_cache=True, **kwargs): |
| 242 | """Get values for the given keys from RAM, the L2 cache, or the DB. |
| 243 | |
| 244 | Args: |
| 245 | cnxn: connection to the database. |
| 246 | keys: list of integer keys to look up. |
| 247 | use_cache: set to False to always hit the database. |
| 248 | **kwargs: any additional keywords are passed to FetchItems(). |
| 249 | |
| 250 | Returns: |
| 251 | A pair: hits, misses. Where hits is {key: value} and misses is |
| 252 | a list of any keys that were not found anywhere. |
| 253 | """ |
| 254 | if use_cache: |
| 255 | result_dict, missed_keys = self.cache.GetAll(keys) |
| 256 | else: |
| 257 | result_dict, missed_keys = {}, list(keys) |
| 258 | |
| 259 | if missed_keys: |
| 260 | if use_cache: |
| 261 | cache_hits, missed_keys = self._ReadFromCache(missed_keys) |
| 262 | result_dict.update(cache_hits) |
| 263 | self.cache.CacheAll(cache_hits) |
| 264 | |
| 265 | while missed_keys: |
| 266 | missed_batch = missed_keys[:self._FETCH_BATCH_SIZE] |
| 267 | missed_keys = missed_keys[self._FETCH_BATCH_SIZE:] |
| 268 | retrieved_dict = self.FetchItems(cnxn, missed_batch, **kwargs) |
| 269 | result_dict.update(retrieved_dict) |
| 270 | if use_cache: |
| 271 | self.cache.CacheAll(retrieved_dict) |
| 272 | self._WriteToCache(retrieved_dict) |
| 273 | |
| 274 | still_missing_keys = [key for key in keys if key not in result_dict] |
| 275 | return result_dict, still_missing_keys |
| 276 | |
| 277 | def LocalInvalidateAll(self): |
| 278 | self.cache.LocalInvalidateAll() |
| 279 | |
| 280 | def LocalInvalidate(self, key): |
| 281 | self.cache.LocalInvalidate(key) |
| 282 | |
| 283 | def InvalidateKeys(self, cnxn, keys): |
| 284 | """Drop the given keys from both RAM and L2 cache.""" |
| 285 | self.cache.InvalidateKeys(cnxn, keys) |
| 286 | self._DeleteFromCache(keys) |
| 287 | |
| 288 | def InvalidateAllKeys(self, cnxn, keys): |
| 289 | """Drop the given keys from L2 cache and invalidate all keys in RAM. |
| 290 | |
| 291 | Useful for avoiding inserting many rows into the Invalidate table when |
| 292 | invalidating a large group of keys all at once. Only use when necessary. |
| 293 | """ |
| 294 | self.cache.InvalidateAll(cnxn) |
| 295 | self._DeleteFromCache(keys) |
| 296 | |
| 297 | def GetAllAlreadyInRam(self, keys): |
| 298 | """Look only in RAM to return {key: values}, missed_keys.""" |
| 299 | result_dict, missed_keys = self.cache.GetAll(keys) |
| 300 | return result_dict, missed_keys |
| 301 | |
| 302 | def InvalidateAllRamEntries(self, cnxn): |
| 303 | """Drop all RAM cache entries. It will refill as needed from L2 cache.""" |
| 304 | self.cache.InvalidateAll(cnxn) |
| 305 | |
| 306 | def FetchItems(self, cnxn, keys, **kwargs): |
| 307 | """On RAM and L2 cache miss, hit the database.""" |
| 308 | raise NotImplementedError() |
| 309 | |
| 310 | def _ReadFromCache(self, keys): |
| 311 | # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int] |
| 312 | """Reads a list of keys from secondary caching service. |
| 313 | |
| 314 | Redis will be used if Redis is enabled and connection is valid; |
| 315 | otherwise, memcache will be used. |
| 316 | |
| 317 | Args: |
| 318 | keys: List of integer keys to look up in L2 cache. |
| 319 | |
| 320 | Returns: |
| 321 | A pair: hits, misses. Where hits is {key: value} and misses is |
| 322 | a list of any keys that were not found anywhere. |
| 323 | """ |
| 324 | if self.use_redis: |
| 325 | return self._ReadFromRedis(keys) |
| 326 | else: |
| 327 | return self._ReadFromMemcache(keys) |
| 328 | |
| 329 | def _WriteToCache(self, retrieved_dict): |
| 330 | # type: (Mapping[int, Any]) -> None |
| 331 | """Writes a set of key-value pairs to secondary caching service. |
| 332 | |
| 333 | Redis will be used if Redis is enabled and connection is valid; |
| 334 | otherwise, memcache will be used. |
| 335 | |
| 336 | Args: |
| 337 | retrieved_dict: Dictionary contains pairs of key-values to write to cache. |
| 338 | """ |
| 339 | if self.use_redis: |
| 340 | return self._WriteToRedis(retrieved_dict) |
| 341 | else: |
| 342 | return self._WriteToMemcache(retrieved_dict) |
| 343 | |
| 344 | def _DeleteFromCache(self, keys): |
| 345 | # type: (Sequence[int]) -> None |
| 346 | """Selects which cache to delete from. |
| 347 | |
| 348 | Redis will be used if Redis is enabled and connection is valid; |
| 349 | otherwise, memcache will be used. |
| 350 | |
| 351 | Args: |
| 352 | keys: List of integer keys to delete from cache. |
| 353 | """ |
| 354 | if self.use_redis: |
| 355 | return self._DeleteFromRedis(keys) |
| 356 | else: |
| 357 | return self._DeleteFromMemcache(keys) |
| 358 | |
| 359 | def _ReadFromMemcache(self, keys): |
| 360 | # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int] |
| 361 | """Read the given keys from memcache, return {key: value}, missing_keys.""" |
| 362 | cache_hits = {} |
| 363 | cached_dict = memcache.get_multi( |
| 364 | [self._KeyToStr(key) for key in keys], |
| 365 | key_prefix=self.prefix, |
| 366 | namespace=settings.memcache_namespace) |
| 367 | |
| 368 | for key_str, serialized_value in cached_dict.items(): |
| 369 | value = self._StrToValue(serialized_value) |
| 370 | key = self._StrToKey(key_str) |
| 371 | cache_hits[key] = value |
| 372 | self.cache.CacheItem(key, value) |
| 373 | |
| 374 | still_missing_keys = [key for key in keys if key not in cache_hits] |
| 375 | return cache_hits, still_missing_keys |
| 376 | |
| 377 | def _WriteToMemcache(self, retrieved_dict): |
| 378 | # type: (Mapping[int, int]) -> None |
| 379 | """Write entries for each key-value pair to memcache. Encode PBs.""" |
| 380 | strs_to_cache = { |
| 381 | self._KeyToStr(key): self._ValueToStr(value) |
| 382 | for key, value in retrieved_dict.items()} |
| 383 | |
| 384 | try: |
| 385 | memcache.add_multi( |
| 386 | strs_to_cache, |
| 387 | key_prefix=self.prefix, |
| 388 | time=framework_constants.CACHE_EXPIRATION, |
| 389 | namespace=settings.memcache_namespace) |
| 390 | except ValueError as identifier: |
| 391 | # If memcache does not accept the values, ensure that no stale |
| 392 | # values are left, then bail out. |
| 393 | logging.error('Got memcache error: %r', identifier) |
| 394 | self._DeleteFromMemcache(list(strs_to_cache.keys())) |
| 395 | return |
| 396 | |
| 397 | def _DeleteFromMemcache(self, keys): |
| 398 | # type: (Sequence[str]) -> None |
| 399 | """Delete key-values from memcache. """ |
| 400 | memcache.delete_multi( |
| 401 | [self._KeyToStr(key) for key in keys], |
| 402 | seconds=5, |
| 403 | key_prefix=self.prefix, |
| 404 | namespace=settings.memcache_namespace) |
| 405 | |
| 406 | def _WriteToRedis(self, retrieved_dict): |
| 407 | # type: (Mapping[int, Any]) -> None |
| 408 | """Write entries for each key-value pair to Redis. Encode PBs. |
| 409 | |
| 410 | Args: |
| 411 | retrieved_dict: Dictionary of key-value pairs to write to Redis. |
| 412 | """ |
| 413 | try: |
| 414 | for key, value in retrieved_dict.items(): |
| 415 | redis_key = redis_utils.FormatRedisKey(key, prefix=self.prefix) |
| 416 | redis_value = self._ValueToStr(value) |
| 417 | |
| 418 | self.redis_client.setex( |
| 419 | redis_key, framework_constants.CACHE_EXPIRATION, redis_value) |
| 420 | except redis.RedisError as identifier: |
| 421 | logging.error( |
| 422 | 'Redis error occurred during write operation: %s', identifier) |
| 423 | self._DeleteFromRedis(list(retrieved_dict.keys())) |
| 424 | return |
| 425 | logging.info( |
| 426 | 'cached batch of %d values in redis %s', len(retrieved_dict), |
| 427 | self.prefix) |
| 428 | |
| 429 | def _ReadFromRedis(self, keys): |
| 430 | # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int] |
| 431 | """Read the given keys from Redis, return {key: value}, missing keys. |
| 432 | |
| 433 | Args: |
| 434 | keys: List of integer keys to read from Redis. |
| 435 | |
| 436 | Returns: |
| 437 | A pair: hits, misses. Where hits is {key: value} and misses is |
| 438 | a list of any keys that were not found anywhere. |
| 439 | """ |
| 440 | cache_hits = {} |
| 441 | missing_keys = [] |
| 442 | try: |
| 443 | values_list = self.redis_client.mget( |
| 444 | [redis_utils.FormatRedisKey(key, prefix=self.prefix) for key in keys]) |
| 445 | except redis.RedisError as identifier: |
| 446 | logging.error( |
| 447 | 'Redis error occurred during read operation: %s', identifier) |
| 448 | values_list = [None] * len(keys) |
| 449 | |
| 450 | for key, serialized_value in zip(keys, values_list): |
| 451 | if serialized_value: |
| 452 | value = self._StrToValue(serialized_value) |
| 453 | cache_hits[key] = value |
| 454 | self.cache.CacheItem(key, value) |
| 455 | else: |
| 456 | missing_keys.append(key) |
| 457 | logging.info( |
| 458 | 'decoded %d values from redis %s, missing %d', len(cache_hits), |
| 459 | self.prefix, len(missing_keys)) |
| 460 | return cache_hits, missing_keys |
| 461 | |
| 462 | def _DeleteFromRedis(self, keys): |
| 463 | # type: (Sequence[int]) -> None |
| 464 | """Delete key-values from redis. |
| 465 | |
| 466 | Args: |
| 467 | keys: List of integer keys to delete. |
| 468 | """ |
| 469 | try: |
| 470 | self.redis_client.delete( |
| 471 | *[ |
| 472 | redis_utils.FormatRedisKey(key, prefix=self.prefix) |
| 473 | for key in keys |
| 474 | ]) |
| 475 | except redis.RedisError as identifier: |
| 476 | logging.error( |
| 477 | 'Redis error occurred during delete operation %s', identifier) |
| 478 | |
| 479 | def _KeyToStr(self, key): |
| 480 | # type: (int) -> str |
| 481 | """Convert our int IDs to strings for use as memcache keys.""" |
| 482 | return str(key) |
| 483 | |
| 484 | def _StrToKey(self, key_str): |
| 485 | # type: (str) -> int |
| 486 | """Convert memcache keys back to the ints that we use as IDs.""" |
| 487 | return int(key_str) |
| 488 | |
| 489 | def _ValueToStr(self, value): |
| 490 | # type: (Any) -> str |
| 491 | """Serialize an application object so that it can be stored in L2 cache.""" |
| 492 | if self.use_redis: |
| 493 | return redis_utils.SerializeValue(value, pb_class=self.pb_class) |
| 494 | else: |
| 495 | if not self.pb_class: |
| 496 | return value |
| 497 | elif self.pb_class == int: |
| 498 | return str(value) |
| 499 | else: |
| 500 | return protobuf.encode_message(value) |
| 501 | |
| 502 | def _StrToValue(self, serialized_value): |
| 503 | # type: (str) -> Any |
| 504 | """Deserialize L2 cache string into an application object.""" |
| 505 | if self.use_redis: |
| 506 | return redis_utils.DeserializeValue( |
| 507 | serialized_value, pb_class=self.pb_class) |
| 508 | else: |
| 509 | if not self.pb_class: |
| 510 | return serialized_value |
| 511 | elif self.pb_class == int: |
| 512 | return int(serialized_value) |
| 513 | else: |
| 514 | return protobuf.decode_message(self.pb_class, serialized_value) |