blob: 07702bf268591433d722182953d78d9015da00c4 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2020 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""Classes to manage cached values.
5
6Monorail makes full use of the RAM of GAE frontends to reduce latency
7and load on the database.
8
9Even though these caches do invalidation, there are rare race conditions
10that can cause a somewhat stale object to be retrieved from memcache and
11then put into a RAM cache and used by a given GAE instance for some time.
12So, we only use these caches for operations that can tolerate somewhat
13stale data. For example, displaying issues in a list or displaying brief
14info about related issues. We never use the cache to load objects as
15part of a read-modify-save sequence because that could cause stored data
16to revert to a previous state.
17"""
18from __future__ import print_function
19from __future__ import division
20from __future__ import absolute_import
21
22import logging
23import redis
24
25from protorpc import protobuf
26
27from google.appengine.api import memcache
28
29import settings
30from framework import framework_constants
31from framework import redis_utils
32from proto import tracker_pb2
33
34
35DEFAULT_MAX_SIZE = 10000
36
37
38class RamCache(object):
39 """An in-RAM cache with distributed invalidation."""
40
41 def __init__(self, cache_manager, kind, max_size=None):
42 self.cache_manager = cache_manager
43 self.kind = kind
44 self.cache = {}
45 self.max_size = max_size or DEFAULT_MAX_SIZE
46 cache_manager.RegisterCache(self, kind)
47
48 def CacheItem(self, key, item):
49 """Store item at key in this cache, discarding a random item if needed."""
50 if len(self.cache) >= self.max_size:
51 self.cache.popitem()
52
53 self.cache[key] = item
54
55 def CacheAll(self, new_item_dict):
56 """Cache all items in the given dict, dropping old items if needed."""
57 if len(new_item_dict) >= self.max_size:
58 logging.warn('Dumping the entire cache! %s', self.kind)
59 self.cache = {}
60 else:
61 while len(self.cache) + len(new_item_dict) > self.max_size:
62 self.cache.popitem()
63
64 self.cache.update(new_item_dict)
65
66 def GetItem(self, key):
67 """Return the cached item if present, otherwise None."""
68 return self.cache.get(key)
69
70 def HasItem(self, key):
71 """Return True if there is a value cached at the given key."""
72 return key in self.cache
73
74 def GetAll(self, keys):
75 """Look up the given keys.
76
77 Args:
78 keys: a list of cache keys to look up.
79
80 Returns:
81 A pair: (hits_dict, misses_list) where hits_dict is a dictionary of
82 all the given keys and the values that were found in the cache, and
83 misses_list is a list of given keys that were not in the cache.
84 """
85 hits, misses = {}, []
86 for key in keys:
87 try:
88 hits[key] = self.cache[key]
89 except KeyError:
90 misses.append(key)
91
92 return hits, misses
93
94 def LocalInvalidate(self, key):
95 """Drop the given key from this cache, without distributed notification."""
96 if key in self.cache:
97 logging.info('Locally invalidating %r in kind=%r', key, self.kind)
98 self.cache.pop(key, None)
99
100 def Invalidate(self, cnxn, key):
101 """Drop key locally, and append it to the Invalidate DB table."""
102 self.InvalidateKeys(cnxn, [key])
103
104 def InvalidateKeys(self, cnxn, keys):
105 """Drop keys locally, and append them to the Invalidate DB table."""
106 for key in keys:
107 self.LocalInvalidate(key)
108 if self.cache_manager:
109 self.cache_manager.StoreInvalidateRows(cnxn, self.kind, keys)
110
111 def LocalInvalidateAll(self):
112 """Invalidate all keys locally: just start over with an empty dict."""
113 logging.info('Locally invalidating all in kind=%r', self.kind)
114 self.cache = {}
115
116 def InvalidateAll(self, cnxn):
117 """Invalidate all keys in this cache."""
118 self.LocalInvalidateAll()
119 if self.cache_manager:
120 self.cache_manager.StoreInvalidateAll(cnxn, self.kind)
121
122
123class ShardedRamCache(RamCache):
124 """Specialized version of RamCache that stores values in parts.
125
126 Instead of the cache keys being simple integers, they are pairs, e.g.,
127 (project_id, shard_id). Invalidation will invalidate all shards for
128 a given main key, e.g, invalidating project_id 16 will drop keys
129 (16, 0), (16, 1), (16, 2), ... (16, 9).
130 """
131
132 def __init__(self, cache_manager, kind, max_size=None, num_shards=10):
133 super(ShardedRamCache, self).__init__(
134 cache_manager, kind, max_size=max_size)
135 self.num_shards = num_shards
136
137 def LocalInvalidate(self, key):
138 """Use the specified value to drop entries from the local cache."""
139 logging.info('About to invalidate shared RAM keys %r',
140 [(key, shard_id) for shard_id in range(self.num_shards)
141 if (key, shard_id) in self.cache])
142 for shard_id in range(self.num_shards):
143 self.cache.pop((key, shard_id), None)
144
145
146class ValueCentricRamCache(RamCache):
147 """Specialized version of RamCache that stores values in InvalidateTable.
148
149 This is useful for caches that have non integer keys.
150 """
151
152 def LocalInvalidate(self, value):
153 """Use the specified value to drop entries from the local cache."""
154 keys_to_drop = []
155 # Loop through and collect all keys with the specified value.
156 for k, v in self.cache.items():
157 if v == value:
158 keys_to_drop.append(k)
159 for k in keys_to_drop:
160 self.cache.pop(k, None)
161
162 def InvalidateKeys(self, cnxn, keys):
163 """Drop keys locally, and append their values to the Invalidate DB table."""
164 # Find values to invalidate.
165 values = [self.cache[key] for key in keys if self.cache.has_key(key)]
166 if len(values) == len(keys):
167 for value in values:
168 self.LocalInvalidate(value)
169 if self.cache_manager:
170 self.cache_manager.StoreInvalidateRows(cnxn, self.kind, values)
171 else:
172 # If a value is not found in the cache then invalidate the whole cache.
173 # This is done to ensure that we are not in an inconsistent state or in a
174 # race condition.
175 self.InvalidateAll(cnxn)
176
177
178class AbstractTwoLevelCache(object):
179 """A class to manage both RAM and secondary-caching layer to retrieve objects.
180
181 Subclasses must implement the FetchItems() method to get objects from
182 the database when both caches miss.
183 """
184
185 # When loading a huge number of issues from the database, do it in chunks
186 # so as to avoid timeouts.
187 _FETCH_BATCH_SIZE = 10000
188
189 def __init__(
190 self,
191 cache_manager,
192 kind,
193 prefix,
194 pb_class,
195 max_size=None,
196 use_redis=False,
197 redis_client=None):
198
199 self.cache = self._MakeCache(cache_manager, kind, max_size=max_size)
200 self.prefix = prefix
201 self.pb_class = pb_class
202
203 if use_redis:
204 self.redis_client = redis_client or redis_utils.CreateRedisClient()
205 self.use_redis = redis_utils.VerifyRedisConnection(
206 self.redis_client, msg=kind)
207 else:
208 self.redis_client = None
209 self.use_redis = False
210
211 def _MakeCache(self, cache_manager, kind, max_size=None):
212 """Make the RAM cache and register it with the cache_manager."""
213 return RamCache(cache_manager, kind, max_size=max_size)
214
215 def CacheItem(self, key, value):
216 """Add the given key-value pair to RAM and L2 cache."""
217 self.cache.CacheItem(key, value)
218 self._WriteToCache({key: value})
219
220 def HasItem(self, key):
221 """Return True if the given key is in the RAM cache."""
222 return self.cache.HasItem(key)
223
224 def GetAnyOnHandItem(self, keys, start=None, end=None):
225 """Try to find one of the specified items in RAM."""
226 if start is None:
227 start = 0
228 if end is None:
229 end = len(keys)
230 for i in range(start, end):
231 key = keys[i]
232 if self.cache.HasItem(key):
233 return self.cache.GetItem(key)
234
235 # Note: We could check L2 here too, but the round-trips to L2
236 # are kind of slow. And, getting too many hits from L2 actually
237 # fills our RAM cache too quickly and could lead to thrashing.
238
239 return None
240
241 def GetAll(self, cnxn, keys, use_cache=True, **kwargs):
242 """Get values for the given keys from RAM, the L2 cache, or the DB.
243
244 Args:
245 cnxn: connection to the database.
246 keys: list of integer keys to look up.
247 use_cache: set to False to always hit the database.
248 **kwargs: any additional keywords are passed to FetchItems().
249
250 Returns:
251 A pair: hits, misses. Where hits is {key: value} and misses is
252 a list of any keys that were not found anywhere.
253 """
254 if use_cache:
255 result_dict, missed_keys = self.cache.GetAll(keys)
256 else:
257 result_dict, missed_keys = {}, list(keys)
258
259 if missed_keys:
260 if use_cache:
261 cache_hits, missed_keys = self._ReadFromCache(missed_keys)
262 result_dict.update(cache_hits)
263 self.cache.CacheAll(cache_hits)
264
265 while missed_keys:
266 missed_batch = missed_keys[:self._FETCH_BATCH_SIZE]
267 missed_keys = missed_keys[self._FETCH_BATCH_SIZE:]
268 retrieved_dict = self.FetchItems(cnxn, missed_batch, **kwargs)
269 result_dict.update(retrieved_dict)
270 if use_cache:
271 self.cache.CacheAll(retrieved_dict)
272 self._WriteToCache(retrieved_dict)
273
274 still_missing_keys = [key for key in keys if key not in result_dict]
275 return result_dict, still_missing_keys
276
277 def LocalInvalidateAll(self):
278 self.cache.LocalInvalidateAll()
279
280 def LocalInvalidate(self, key):
281 self.cache.LocalInvalidate(key)
282
283 def InvalidateKeys(self, cnxn, keys):
284 """Drop the given keys from both RAM and L2 cache."""
285 self.cache.InvalidateKeys(cnxn, keys)
286 self._DeleteFromCache(keys)
287
288 def InvalidateAllKeys(self, cnxn, keys):
289 """Drop the given keys from L2 cache and invalidate all keys in RAM.
290
291 Useful for avoiding inserting many rows into the Invalidate table when
292 invalidating a large group of keys all at once. Only use when necessary.
293 """
294 self.cache.InvalidateAll(cnxn)
295 self._DeleteFromCache(keys)
296
297 def GetAllAlreadyInRam(self, keys):
298 """Look only in RAM to return {key: values}, missed_keys."""
299 result_dict, missed_keys = self.cache.GetAll(keys)
300 return result_dict, missed_keys
301
302 def InvalidateAllRamEntries(self, cnxn):
303 """Drop all RAM cache entries. It will refill as needed from L2 cache."""
304 self.cache.InvalidateAll(cnxn)
305
306 def FetchItems(self, cnxn, keys, **kwargs):
307 """On RAM and L2 cache miss, hit the database."""
308 raise NotImplementedError()
309
310 def _ReadFromCache(self, keys):
311 # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int]
312 """Reads a list of keys from secondary caching service.
313
314 Redis will be used if Redis is enabled and connection is valid;
315 otherwise, memcache will be used.
316
317 Args:
318 keys: List of integer keys to look up in L2 cache.
319
320 Returns:
321 A pair: hits, misses. Where hits is {key: value} and misses is
322 a list of any keys that were not found anywhere.
323 """
324 if self.use_redis:
325 return self._ReadFromRedis(keys)
326 else:
327 return self._ReadFromMemcache(keys)
328
329 def _WriteToCache(self, retrieved_dict):
330 # type: (Mapping[int, Any]) -> None
331 """Writes a set of key-value pairs to secondary caching service.
332
333 Redis will be used if Redis is enabled and connection is valid;
334 otherwise, memcache will be used.
335
336 Args:
337 retrieved_dict: Dictionary contains pairs of key-values to write to cache.
338 """
339 if self.use_redis:
340 return self._WriteToRedis(retrieved_dict)
341 else:
342 return self._WriteToMemcache(retrieved_dict)
343
344 def _DeleteFromCache(self, keys):
345 # type: (Sequence[int]) -> None
346 """Selects which cache to delete from.
347
348 Redis will be used if Redis is enabled and connection is valid;
349 otherwise, memcache will be used.
350
351 Args:
352 keys: List of integer keys to delete from cache.
353 """
354 if self.use_redis:
355 return self._DeleteFromRedis(keys)
356 else:
357 return self._DeleteFromMemcache(keys)
358
359 def _ReadFromMemcache(self, keys):
360 # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int]
361 """Read the given keys from memcache, return {key: value}, missing_keys."""
362 cache_hits = {}
363 cached_dict = memcache.get_multi(
364 [self._KeyToStr(key) for key in keys],
365 key_prefix=self.prefix,
366 namespace=settings.memcache_namespace)
367
368 for key_str, serialized_value in cached_dict.items():
369 value = self._StrToValue(serialized_value)
370 key = self._StrToKey(key_str)
371 cache_hits[key] = value
372 self.cache.CacheItem(key, value)
373
374 still_missing_keys = [key for key in keys if key not in cache_hits]
375 return cache_hits, still_missing_keys
376
377 def _WriteToMemcache(self, retrieved_dict):
378 # type: (Mapping[int, int]) -> None
379 """Write entries for each key-value pair to memcache. Encode PBs."""
380 strs_to_cache = {
381 self._KeyToStr(key): self._ValueToStr(value)
382 for key, value in retrieved_dict.items()}
383
384 try:
385 memcache.add_multi(
386 strs_to_cache,
387 key_prefix=self.prefix,
388 time=framework_constants.CACHE_EXPIRATION,
389 namespace=settings.memcache_namespace)
390 except ValueError as identifier:
391 # If memcache does not accept the values, ensure that no stale
392 # values are left, then bail out.
393 logging.error('Got memcache error: %r', identifier)
394 self._DeleteFromMemcache(list(strs_to_cache.keys()))
395 return
396
397 def _DeleteFromMemcache(self, keys):
398 # type: (Sequence[str]) -> None
399 """Delete key-values from memcache. """
400 memcache.delete_multi(
401 [self._KeyToStr(key) for key in keys],
402 seconds=5,
403 key_prefix=self.prefix,
404 namespace=settings.memcache_namespace)
405
406 def _WriteToRedis(self, retrieved_dict):
407 # type: (Mapping[int, Any]) -> None
408 """Write entries for each key-value pair to Redis. Encode PBs.
409
410 Args:
411 retrieved_dict: Dictionary of key-value pairs to write to Redis.
412 """
413 try:
414 for key, value in retrieved_dict.items():
415 redis_key = redis_utils.FormatRedisKey(key, prefix=self.prefix)
416 redis_value = self._ValueToStr(value)
417
418 self.redis_client.setex(
419 redis_key, framework_constants.CACHE_EXPIRATION, redis_value)
420 except redis.RedisError as identifier:
421 logging.error(
422 'Redis error occurred during write operation: %s', identifier)
423 self._DeleteFromRedis(list(retrieved_dict.keys()))
424 return
425 logging.info(
426 'cached batch of %d values in redis %s', len(retrieved_dict),
427 self.prefix)
428
429 def _ReadFromRedis(self, keys):
430 # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int]
431 """Read the given keys from Redis, return {key: value}, missing keys.
432
433 Args:
434 keys: List of integer keys to read from Redis.
435
436 Returns:
437 A pair: hits, misses. Where hits is {key: value} and misses is
438 a list of any keys that were not found anywhere.
439 """
440 cache_hits = {}
441 missing_keys = []
442 try:
443 values_list = self.redis_client.mget(
444 [redis_utils.FormatRedisKey(key, prefix=self.prefix) for key in keys])
445 except redis.RedisError as identifier:
446 logging.error(
447 'Redis error occurred during read operation: %s', identifier)
448 values_list = [None] * len(keys)
449
450 for key, serialized_value in zip(keys, values_list):
451 if serialized_value:
452 value = self._StrToValue(serialized_value)
453 cache_hits[key] = value
454 self.cache.CacheItem(key, value)
455 else:
456 missing_keys.append(key)
457 logging.info(
458 'decoded %d values from redis %s, missing %d', len(cache_hits),
459 self.prefix, len(missing_keys))
460 return cache_hits, missing_keys
461
462 def _DeleteFromRedis(self, keys):
463 # type: (Sequence[int]) -> None
464 """Delete key-values from redis.
465
466 Args:
467 keys: List of integer keys to delete.
468 """
469 try:
470 self.redis_client.delete(
471 *[
472 redis_utils.FormatRedisKey(key, prefix=self.prefix)
473 for key in keys
474 ])
475 except redis.RedisError as identifier:
476 logging.error(
477 'Redis error occurred during delete operation %s', identifier)
478
479 def _KeyToStr(self, key):
480 # type: (int) -> str
481 """Convert our int IDs to strings for use as memcache keys."""
482 return str(key)
483
484 def _StrToKey(self, key_str):
485 # type: (str) -> int
486 """Convert memcache keys back to the ints that we use as IDs."""
487 return int(key_str)
488
489 def _ValueToStr(self, value):
490 # type: (Any) -> str
491 """Serialize an application object so that it can be stored in L2 cache."""
492 if self.use_redis:
493 return redis_utils.SerializeValue(value, pb_class=self.pb_class)
494 else:
495 if not self.pb_class:
496 return value
497 elif self.pb_class == int:
498 return str(value)
499 else:
500 return protobuf.encode_message(value)
501
502 def _StrToValue(self, serialized_value):
503 # type: (str) -> Any
504 """Deserialize L2 cache string into an application object."""
505 if self.use_redis:
506 return redis_utils.DeserializeValue(
507 serialized_value, pb_class=self.pb_class)
508 else:
509 if not self.pb_class:
510 return serialized_value
511 elif self.pb_class == int:
512 return int(serialized_value)
513 else:
514 return protobuf.decode_message(self.pb_class, serialized_value)