Copybara | 854996b | 2021-09-07 19:36:02 +0000 | [diff] [blame] | 1 | # Copyright 2016 The Chromium Authors. All rights reserved. |
| 2 | # Use of this source code is governed by a BSD-style |
| 3 | # license that can be found in the LICENSE file or at |
| 4 | # https://developers.google.com/open-source/licenses/bsd |
| 5 | |
| 6 | """Request rate limiting implementation. |
| 7 | |
| 8 | This is intented to be used for automatic DDoS protection. |
| 9 | |
| 10 | """ |
| 11 | from __future__ import print_function |
| 12 | from __future__ import division |
| 13 | from __future__ import absolute_import |
| 14 | |
| 15 | import datetime |
| 16 | import logging |
| 17 | import os |
| 18 | import settings |
| 19 | import time |
| 20 | |
| 21 | from infra_libs import ts_mon |
| 22 | |
| 23 | from google.appengine.api import memcache |
| 24 | from google.appengine.api.modules import modules |
| 25 | from google.appengine.api import users |
| 26 | |
| 27 | from services import client_config_svc |
| 28 | |
| 29 | |
| 30 | N_MINUTES = 5 |
| 31 | EXPIRE_AFTER_SECS = 60 * 60 |
| 32 | DEFAULT_LIMIT = 60 * N_MINUTES # 300 page requests in 5 minutes is 1 QPS. |
| 33 | DEFAULT_API_QPM = 1000 # For example, chromiumdash uses ~64 per page, 8s each. |
| 34 | |
| 35 | ANON_USER = 'anon' |
| 36 | |
| 37 | COUNTRY_HEADER = 'X-AppEngine-Country' |
| 38 | |
| 39 | COUNTRY_LIMITS = { |
| 40 | # Two-letter country code: max requests per N_MINUTES |
| 41 | # This limit will apply to all requests coming |
| 42 | # from this country. |
| 43 | # To add a country code, see GAE logs and use the |
| 44 | # appropriate code from https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2 |
| 45 | # E.g., 'cn': 300, # Limit to 1 QPS. |
| 46 | } |
| 47 | |
| 48 | # Modules not in this list will not have rate limiting applied by this |
| 49 | # class. |
| 50 | MODULE_ALLOWLIST = ['default', 'api'] |
| 51 | |
| 52 | |
| 53 | def _CacheKeys(request, now_sec): |
| 54 | """ Returns an array of arrays. Each array contains strings with |
| 55 | the same prefix and a timestamp suffix, starting with the most |
| 56 | recent and decrementing by 1 minute each time. |
| 57 | """ |
| 58 | now = datetime.datetime.fromtimestamp(now_sec) |
| 59 | country = request.headers.get(COUNTRY_HEADER, 'ZZ') |
| 60 | ip = request.remote_addr |
| 61 | minute_buckets = [now - datetime.timedelta(minutes=m) for m in |
| 62 | range(N_MINUTES)] |
| 63 | user = users.get_current_user() |
| 64 | user_email = user.email() if user else ANON_USER |
| 65 | |
| 66 | # <IP, country, user_email> to be rendered into each key prefix. |
| 67 | prefixes = [] |
| 68 | |
| 69 | # All logged-in users get a per-user rate limit, regardless of IP and country. |
| 70 | if user: |
| 71 | prefixes.append(['ALL', 'ALL', user.email()]) |
| 72 | else: |
| 73 | # All anon requests get a per-IP ratelimit. |
| 74 | prefixes.append([ip, 'ALL', 'ALL']) |
| 75 | |
| 76 | # All requests from a problematic country get a per-country rate limit, |
| 77 | # regardless of the user (even a non-logged-in one) or IP. |
| 78 | if country in COUNTRY_LIMITS: |
| 79 | prefixes.append(['ALL', country, 'ALL']) |
| 80 | |
| 81 | keysets = [] |
| 82 | for prefix in prefixes: |
| 83 | keysets.append(['ratelimit-%s-%s' % ('-'.join(prefix), |
| 84 | str(minute_bucket.replace(second=0, microsecond=0))) |
| 85 | for minute_bucket in minute_buckets]) |
| 86 | |
| 87 | return keysets, country, ip, user_email |
| 88 | |
| 89 | |
| 90 | def _CreateApiCacheKeys(client_id, client_email, now_sec): |
| 91 | country = os.environ.get('HTTP_X_APPENGINE_COUNTRY') |
| 92 | ip = os.environ.get('REMOTE_ADDR') |
| 93 | now = datetime.datetime.fromtimestamp(now_sec) |
| 94 | minute_buckets = [now - datetime.timedelta(minutes=m) for m in |
| 95 | range(N_MINUTES)] |
| 96 | minute_strs = [str(minute_bucket.replace(second=0, microsecond=0)) |
| 97 | for minute_bucket in minute_buckets] |
| 98 | keys = [] |
| 99 | |
| 100 | if client_id and client_id != 'anonymous': |
| 101 | keys.append(['apiratelimit-%s-%s' % (client_id, minute_str) |
| 102 | for minute_str in minute_strs]) |
| 103 | if client_email: |
| 104 | keys.append(['apiratelimit-%s-%s' % (client_email, minute_str) |
| 105 | for minute_str in minute_strs]) |
| 106 | else: |
| 107 | keys.append(['apiratelimit-%s-%s' % (ip, minute_str) |
| 108 | for minute_str in minute_strs]) |
| 109 | if country in COUNTRY_LIMITS: |
| 110 | keys.append(['apiratelimit-%s-%s' % (country, minute_str) |
| 111 | for minute_str in minute_strs]) |
| 112 | |
| 113 | return keys |
| 114 | |
| 115 | |
| 116 | class RateLimiter(object): |
| 117 | |
| 118 | blocked_requests = ts_mon.CounterMetric( |
| 119 | 'monorail/ratelimiter/blocked_request', |
| 120 | 'Count of requests that exceeded the rate limit and were blocked.', |
| 121 | None) |
| 122 | limit_exceeded = ts_mon.CounterMetric( |
| 123 | 'monorail/ratelimiter/rate_exceeded', |
| 124 | 'Count of requests that exceeded the rate limit.', |
| 125 | None) |
| 126 | cost_thresh_exceeded = ts_mon.CounterMetric( |
| 127 | 'monorail/ratelimiter/cost_thresh_exceeded', |
| 128 | 'Count of requests that were expensive to process', |
| 129 | None) |
| 130 | checks = ts_mon.CounterMetric( |
| 131 | 'monorail/ratelimiter/check', |
| 132 | 'Count of checks done, by fail/success type.', |
| 133 | [ts_mon.StringField('type')]) |
| 134 | |
| 135 | def __init__(self, _cache=memcache, fail_open=True, **_kwargs): |
| 136 | self.fail_open = fail_open |
| 137 | |
| 138 | def CheckStart(self, request, now=None): |
| 139 | if (modules.get_current_module_name() not in MODULE_ALLOWLIST or |
| 140 | users.is_current_user_admin()): |
| 141 | return |
| 142 | logging.info('X-AppEngine-Country: %s' % |
| 143 | request.headers.get(COUNTRY_HEADER, 'ZZ')) |
| 144 | |
| 145 | if now is None: |
| 146 | now = time.time() |
| 147 | |
| 148 | keysets, country, ip, user_email = _CacheKeys(request, now) |
| 149 | # There are either two or three sets of keys in keysets. |
| 150 | # Three if the user's country is in COUNTRY_LIMITS, otherwise two. |
| 151 | self._AuxCheckStart( |
| 152 | keysets, COUNTRY_LIMITS.get(country, DEFAULT_LIMIT), |
| 153 | settings.ratelimiting_enabled, |
| 154 | RateLimitExceeded(country=country, ip=ip, user_email=user_email)) |
| 155 | |
| 156 | def _AuxCheckStart(self, keysets, limit, ratelimiting_enabled, |
| 157 | exception_obj): |
| 158 | for keys in keysets: |
| 159 | count = 0 |
| 160 | try: |
| 161 | counters = memcache.get_multi(keys) |
| 162 | count = sum(counters.values()) |
| 163 | self.checks.increment({'type': 'success'}) |
| 164 | except Exception as e: |
| 165 | logging.error(e) |
| 166 | if not self.fail_open: |
| 167 | self.checks.increment({'type': 'fail_closed'}) |
| 168 | raise exception_obj |
| 169 | self.checks.increment({'type': 'fail_open'}) |
| 170 | |
| 171 | if count > limit: |
| 172 | # Since webapp2 won't let us return a 429 error code |
| 173 | # <http://tools.ietf.org/html/rfc6585#section-4>, we can't |
| 174 | # monitor rate limit exceeded events with our standard tools. |
| 175 | # We return a 400 with a custom error message to the client, |
| 176 | # and this logging is so we can monitor it internally. |
| 177 | logging.info('%s, %d' % (exception_obj.message, count)) |
| 178 | |
| 179 | self.limit_exceeded.increment() |
| 180 | |
| 181 | if ratelimiting_enabled: |
| 182 | self.blocked_requests.increment() |
| 183 | raise exception_obj |
| 184 | |
| 185 | k = keys[0] |
| 186 | # Only update the latest *time* bucket for each prefix (reverse chron). |
| 187 | memcache.add(k, 0, time=EXPIRE_AFTER_SECS) |
| 188 | memcache.incr(k, initial_value=0) |
| 189 | |
| 190 | def CheckEnd(self, request, now, start_time): |
| 191 | """If a request was expensive to process, charge some extra points |
| 192 | against this set of buckets. |
| 193 | We pass in both now and start_time so we can update the buckets |
| 194 | based on keys created from start_time instead of now. |
| 195 | now and start_time are float seconds. |
| 196 | """ |
| 197 | if (modules.get_current_module_name() not in MODULE_ALLOWLIST): |
| 198 | return |
| 199 | |
| 200 | elapsed_ms = int((now - start_time) * 1000) |
| 201 | # Would it kill the python lib maintainers to have timedelta.total_ms()? |
| 202 | penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1 |
| 203 | if penalty >= 1: |
| 204 | # TODO: Look into caching the keys instead of generating them twice |
| 205 | # for every request. Say, return them from CheckStart so they can |
| 206 | # be passed back in here later. |
| 207 | keysets, country, ip, user_email = _CacheKeys(request, start_time) |
| 208 | |
| 209 | self._AuxCheckEnd( |
| 210 | keysets, |
| 211 | 'Rate Limit Cost Threshold Exceeded: %s, %s, %s' % ( |
| 212 | country, ip, user_email), |
| 213 | penalty) |
| 214 | |
| 215 | def _AuxCheckEnd(self, keysets, log_str, penalty): |
| 216 | self.cost_thresh_exceeded.increment() |
| 217 | for keys in keysets: |
| 218 | logging.info(log_str) |
| 219 | |
| 220 | # Only update the latest *time* bucket for each prefix (reverse chron). |
| 221 | k = keys[0] |
| 222 | memcache.add(k, 0, time=EXPIRE_AFTER_SECS) |
| 223 | memcache.incr(k, delta=penalty, initial_value=0) |
| 224 | |
| 225 | |
| 226 | class ApiRateLimiter(RateLimiter): |
| 227 | |
| 228 | blocked_requests = ts_mon.CounterMetric( |
| 229 | 'monorail/apiratelimiter/blocked_request', |
| 230 | 'Count of requests that exceeded the rate limit and were blocked.', |
| 231 | None) |
| 232 | limit_exceeded = ts_mon.CounterMetric( |
| 233 | 'monorail/apiratelimiter/rate_exceeded', |
| 234 | 'Count of requests that exceeded the rate limit.', |
| 235 | None) |
| 236 | cost_thresh_exceeded = ts_mon.CounterMetric( |
| 237 | 'monorail/apiratelimiter/cost_thresh_exceeded', |
| 238 | 'Count of requests that were expensive to process', |
| 239 | None) |
| 240 | checks = ts_mon.CounterMetric( |
| 241 | 'monorail/apiratelimiter/check', |
| 242 | 'Count of checks done, by fail/success type.', |
| 243 | [ts_mon.StringField('type')]) |
| 244 | |
| 245 | #pylint: disable=arguments-differ |
| 246 | def CheckStart(self, client_id, client_email, now=None): |
| 247 | if now is None: |
| 248 | now = time.time() |
| 249 | |
| 250 | keysets = _CreateApiCacheKeys(client_id, client_email, now) |
| 251 | qpm_limit = client_config_svc.GetQPMDict().get( |
| 252 | client_email, DEFAULT_API_QPM) |
| 253 | if qpm_limit < DEFAULT_API_QPM: |
| 254 | qpm_limit = DEFAULT_API_QPM |
| 255 | window_limit = qpm_limit * N_MINUTES |
| 256 | self._AuxCheckStart( |
| 257 | keysets, window_limit, |
| 258 | settings.api_ratelimiting_enabled, |
| 259 | ApiRateLimitExceeded(client_id, client_email)) |
| 260 | |
| 261 | #pylint: disable=arguments-differ |
| 262 | def CheckEnd(self, client_id, client_email, now, start_time): |
| 263 | |
| 264 | elapsed_ms = int((now - start_time) * 1000) |
| 265 | penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1 |
| 266 | |
| 267 | if penalty >= 1: |
| 268 | keysets = _CreateApiCacheKeys(client_id, client_email, start_time) |
| 269 | self._AuxCheckEnd( |
| 270 | keysets, |
| 271 | 'API Rate Limit Cost Threshold Exceeded: %s, %s' % ( |
| 272 | client_id, client_email), |
| 273 | penalty) |
| 274 | |
| 275 | |
| 276 | class RateLimitExceeded(Exception): |
| 277 | def __init__(self, country=None, ip=None, user_email=None, **_kwargs): |
| 278 | self.country = country |
| 279 | self.ip = ip |
| 280 | self.user_email = user_email |
| 281 | message = 'RateLimitExceeded: %s, %s, %s' % ( |
| 282 | self.country, self.ip, self.user_email) |
| 283 | super(RateLimitExceeded, self).__init__(message) |
| 284 | |
| 285 | |
| 286 | class ApiRateLimitExceeded(Exception): |
| 287 | def __init__(self, client_id, client_email): |
| 288 | self.client_id = client_id |
| 289 | self.client_email = client_email |
| 290 | message = 'RateLimitExceeded: %s, %s' % ( |
| 291 | self.client_id, self.client_email) |
| 292 | super(ApiRateLimitExceeded, self).__init__(message) |