blob: 4e1a0b862a080005096d6c5a63145897e1638557 [file] [log] [blame]
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +01001# Copyright 2016 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
Copybara854996b2021-09-07 19:36:02 +00004
5"""Request rate limiting implementation.
6
7This is intented to be used for automatic DDoS protection.
8
9"""
10from __future__ import print_function
11from __future__ import division
12from __future__ import absolute_import
13
14import datetime
15import logging
16import os
17import settings
18import time
19
20from infra_libs import ts_mon
21
22from google.appengine.api import memcache
23from google.appengine.api.modules import modules
24from google.appengine.api import users
25
26from services import client_config_svc
27
28
29N_MINUTES = 5
30EXPIRE_AFTER_SECS = 60 * 60
31DEFAULT_LIMIT = 60 * N_MINUTES # 300 page requests in 5 minutes is 1 QPS.
32DEFAULT_API_QPM = 1000 # For example, chromiumdash uses ~64 per page, 8s each.
33
34ANON_USER = 'anon'
35
36COUNTRY_HEADER = 'X-AppEngine-Country'
37
38COUNTRY_LIMITS = {
39 # Two-letter country code: max requests per N_MINUTES
40 # This limit will apply to all requests coming
41 # from this country.
42 # To add a country code, see GAE logs and use the
43 # appropriate code from https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2
44 # E.g., 'cn': 300, # Limit to 1 QPS.
45}
46
47# Modules not in this list will not have rate limiting applied by this
48# class.
49MODULE_ALLOWLIST = ['default', 'api']
50
51
52def _CacheKeys(request, now_sec):
53 """ Returns an array of arrays. Each array contains strings with
54 the same prefix and a timestamp suffix, starting with the most
55 recent and decrementing by 1 minute each time.
56 """
57 now = datetime.datetime.fromtimestamp(now_sec)
58 country = request.headers.get(COUNTRY_HEADER, 'ZZ')
59 ip = request.remote_addr
60 minute_buckets = [now - datetime.timedelta(minutes=m) for m in
61 range(N_MINUTES)]
62 user = users.get_current_user()
63 user_email = user.email() if user else ANON_USER
64
65 # <IP, country, user_email> to be rendered into each key prefix.
66 prefixes = []
67
68 # All logged-in users get a per-user rate limit, regardless of IP and country.
69 if user:
70 prefixes.append(['ALL', 'ALL', user.email()])
71 else:
72 # All anon requests get a per-IP ratelimit.
73 prefixes.append([ip, 'ALL', 'ALL'])
74
75 # All requests from a problematic country get a per-country rate limit,
76 # regardless of the user (even a non-logged-in one) or IP.
77 if country in COUNTRY_LIMITS:
78 prefixes.append(['ALL', country, 'ALL'])
79
80 keysets = []
81 for prefix in prefixes:
82 keysets.append(['ratelimit-%s-%s' % ('-'.join(prefix),
83 str(minute_bucket.replace(second=0, microsecond=0)))
84 for minute_bucket in minute_buckets])
85
86 return keysets, country, ip, user_email
87
88
89def _CreateApiCacheKeys(client_id, client_email, now_sec):
90 country = os.environ.get('HTTP_X_APPENGINE_COUNTRY')
91 ip = os.environ.get('REMOTE_ADDR')
92 now = datetime.datetime.fromtimestamp(now_sec)
93 minute_buckets = [now - datetime.timedelta(minutes=m) for m in
94 range(N_MINUTES)]
95 minute_strs = [str(minute_bucket.replace(second=0, microsecond=0))
96 for minute_bucket in minute_buckets]
97 keys = []
98
99 if client_id and client_id != 'anonymous':
100 keys.append(['apiratelimit-%s-%s' % (client_id, minute_str)
101 for minute_str in minute_strs])
102 if client_email:
103 keys.append(['apiratelimit-%s-%s' % (client_email, minute_str)
104 for minute_str in minute_strs])
105 else:
106 keys.append(['apiratelimit-%s-%s' % (ip, minute_str)
107 for minute_str in minute_strs])
108 if country in COUNTRY_LIMITS:
109 keys.append(['apiratelimit-%s-%s' % (country, minute_str)
110 for minute_str in minute_strs])
111
112 return keys
113
114
115class RateLimiter(object):
116
117 blocked_requests = ts_mon.CounterMetric(
118 'monorail/ratelimiter/blocked_request',
119 'Count of requests that exceeded the rate limit and were blocked.',
120 None)
121 limit_exceeded = ts_mon.CounterMetric(
122 'monorail/ratelimiter/rate_exceeded',
123 'Count of requests that exceeded the rate limit.',
124 None)
125 cost_thresh_exceeded = ts_mon.CounterMetric(
126 'monorail/ratelimiter/cost_thresh_exceeded',
127 'Count of requests that were expensive to process',
128 None)
129 checks = ts_mon.CounterMetric(
130 'monorail/ratelimiter/check',
131 'Count of checks done, by fail/success type.',
132 [ts_mon.StringField('type')])
133
134 def __init__(self, _cache=memcache, fail_open=True, **_kwargs):
135 self.fail_open = fail_open
136
137 def CheckStart(self, request, now=None):
138 if (modules.get_current_module_name() not in MODULE_ALLOWLIST or
139 users.is_current_user_admin()):
140 return
141 logging.info('X-AppEngine-Country: %s' %
142 request.headers.get(COUNTRY_HEADER, 'ZZ'))
143
144 if now is None:
145 now = time.time()
146
147 keysets, country, ip, user_email = _CacheKeys(request, now)
148 # There are either two or three sets of keys in keysets.
149 # Three if the user's country is in COUNTRY_LIMITS, otherwise two.
150 self._AuxCheckStart(
151 keysets, COUNTRY_LIMITS.get(country, DEFAULT_LIMIT),
152 settings.ratelimiting_enabled,
153 RateLimitExceeded(country=country, ip=ip, user_email=user_email))
154
155 def _AuxCheckStart(self, keysets, limit, ratelimiting_enabled,
156 exception_obj):
157 for keys in keysets:
158 count = 0
159 try:
160 counters = memcache.get_multi(keys)
161 count = sum(counters.values())
162 self.checks.increment({'type': 'success'})
163 except Exception as e:
164 logging.error(e)
165 if not self.fail_open:
166 self.checks.increment({'type': 'fail_closed'})
167 raise exception_obj
168 self.checks.increment({'type': 'fail_open'})
169
170 if count > limit:
171 # Since webapp2 won't let us return a 429 error code
172 # <http://tools.ietf.org/html/rfc6585#section-4>, we can't
173 # monitor rate limit exceeded events with our standard tools.
174 # We return a 400 with a custom error message to the client,
175 # and this logging is so we can monitor it internally.
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +0100176 logging.info('%s, %d' % (str(exception_obj), count))
Copybara854996b2021-09-07 19:36:02 +0000177
178 self.limit_exceeded.increment()
179
180 if ratelimiting_enabled:
181 self.blocked_requests.increment()
182 raise exception_obj
183
184 k = keys[0]
185 # Only update the latest *time* bucket for each prefix (reverse chron).
186 memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
187 memcache.incr(k, initial_value=0)
188
189 def CheckEnd(self, request, now, start_time):
190 """If a request was expensive to process, charge some extra points
191 against this set of buckets.
192 We pass in both now and start_time so we can update the buckets
193 based on keys created from start_time instead of now.
194 now and start_time are float seconds.
195 """
196 if (modules.get_current_module_name() not in MODULE_ALLOWLIST):
197 return
198
199 elapsed_ms = int((now - start_time) * 1000)
200 # Would it kill the python lib maintainers to have timedelta.total_ms()?
201 penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1
202 if penalty >= 1:
203 # TODO: Look into caching the keys instead of generating them twice
204 # for every request. Say, return them from CheckStart so they can
205 # be passed back in here later.
206 keysets, country, ip, user_email = _CacheKeys(request, start_time)
207
208 self._AuxCheckEnd(
209 keysets,
210 'Rate Limit Cost Threshold Exceeded: %s, %s, %s' % (
211 country, ip, user_email),
212 penalty)
213
214 def _AuxCheckEnd(self, keysets, log_str, penalty):
215 self.cost_thresh_exceeded.increment()
216 for keys in keysets:
217 logging.info(log_str)
218
219 # Only update the latest *time* bucket for each prefix (reverse chron).
220 k = keys[0]
221 memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
222 memcache.incr(k, delta=penalty, initial_value=0)
223
224
225class ApiRateLimiter(RateLimiter):
226
227 blocked_requests = ts_mon.CounterMetric(
228 'monorail/apiratelimiter/blocked_request',
229 'Count of requests that exceeded the rate limit and were blocked.',
230 None)
231 limit_exceeded = ts_mon.CounterMetric(
232 'monorail/apiratelimiter/rate_exceeded',
233 'Count of requests that exceeded the rate limit.',
234 None)
235 cost_thresh_exceeded = ts_mon.CounterMetric(
236 'monorail/apiratelimiter/cost_thresh_exceeded',
237 'Count of requests that were expensive to process',
238 None)
239 checks = ts_mon.CounterMetric(
240 'monorail/apiratelimiter/check',
241 'Count of checks done, by fail/success type.',
242 [ts_mon.StringField('type')])
243
244 #pylint: disable=arguments-differ
245 def CheckStart(self, client_id, client_email, now=None):
246 if now is None:
247 now = time.time()
248
249 keysets = _CreateApiCacheKeys(client_id, client_email, now)
250 qpm_limit = client_config_svc.GetQPMDict().get(
251 client_email, DEFAULT_API_QPM)
252 if qpm_limit < DEFAULT_API_QPM:
253 qpm_limit = DEFAULT_API_QPM
254 window_limit = qpm_limit * N_MINUTES
255 self._AuxCheckStart(
256 keysets, window_limit,
257 settings.api_ratelimiting_enabled,
258 ApiRateLimitExceeded(client_id, client_email))
259
260 #pylint: disable=arguments-differ
261 def CheckEnd(self, client_id, client_email, now, start_time):
262
263 elapsed_ms = int((now - start_time) * 1000)
264 penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1
265
266 if penalty >= 1:
267 keysets = _CreateApiCacheKeys(client_id, client_email, start_time)
268 self._AuxCheckEnd(
269 keysets,
270 'API Rate Limit Cost Threshold Exceeded: %s, %s' % (
271 client_id, client_email),
272 penalty)
273
274
275class RateLimitExceeded(Exception):
276 def __init__(self, country=None, ip=None, user_email=None, **_kwargs):
277 self.country = country
278 self.ip = ip
279 self.user_email = user_email
280 message = 'RateLimitExceeded: %s, %s, %s' % (
281 self.country, self.ip, self.user_email)
282 super(RateLimitExceeded, self).__init__(message)
283
284
285class ApiRateLimitExceeded(Exception):
286 def __init__(self, client_id, client_email):
287 self.client_id = client_id
288 self.client_email = client_email
289 message = 'RateLimitExceeded: %s, %s' % (
290 self.client_id, self.client_email)
291 super(ApiRateLimitExceeded, self).__init__(message)