blob: b2bbb253c62204bba9dc3b2cb6853c43eea1eeed [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2016 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""Request rate limiting implementation.
7
8This is intented to be used for automatic DDoS protection.
9
10"""
11from __future__ import print_function
12from __future__ import division
13from __future__ import absolute_import
14
15import datetime
16import logging
17import os
18import settings
19import time
20
21from infra_libs import ts_mon
22
23from google.appengine.api import memcache
24from google.appengine.api.modules import modules
25from google.appengine.api import users
26
27from services import client_config_svc
28
29
30N_MINUTES = 5
31EXPIRE_AFTER_SECS = 60 * 60
32DEFAULT_LIMIT = 60 * N_MINUTES # 300 page requests in 5 minutes is 1 QPS.
33DEFAULT_API_QPM = 1000 # For example, chromiumdash uses ~64 per page, 8s each.
34
35ANON_USER = 'anon'
36
37COUNTRY_HEADER = 'X-AppEngine-Country'
38
39COUNTRY_LIMITS = {
40 # Two-letter country code: max requests per N_MINUTES
41 # This limit will apply to all requests coming
42 # from this country.
43 # To add a country code, see GAE logs and use the
44 # appropriate code from https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2
45 # E.g., 'cn': 300, # Limit to 1 QPS.
46}
47
48# Modules not in this list will not have rate limiting applied by this
49# class.
50MODULE_ALLOWLIST = ['default', 'api']
51
52
53def _CacheKeys(request, now_sec):
54 """ Returns an array of arrays. Each array contains strings with
55 the same prefix and a timestamp suffix, starting with the most
56 recent and decrementing by 1 minute each time.
57 """
58 now = datetime.datetime.fromtimestamp(now_sec)
59 country = request.headers.get(COUNTRY_HEADER, 'ZZ')
60 ip = request.remote_addr
61 minute_buckets = [now - datetime.timedelta(minutes=m) for m in
62 range(N_MINUTES)]
63 user = users.get_current_user()
64 user_email = user.email() if user else ANON_USER
65
66 # <IP, country, user_email> to be rendered into each key prefix.
67 prefixes = []
68
69 # All logged-in users get a per-user rate limit, regardless of IP and country.
70 if user:
71 prefixes.append(['ALL', 'ALL', user.email()])
72 else:
73 # All anon requests get a per-IP ratelimit.
74 prefixes.append([ip, 'ALL', 'ALL'])
75
76 # All requests from a problematic country get a per-country rate limit,
77 # regardless of the user (even a non-logged-in one) or IP.
78 if country in COUNTRY_LIMITS:
79 prefixes.append(['ALL', country, 'ALL'])
80
81 keysets = []
82 for prefix in prefixes:
83 keysets.append(['ratelimit-%s-%s' % ('-'.join(prefix),
84 str(minute_bucket.replace(second=0, microsecond=0)))
85 for minute_bucket in minute_buckets])
86
87 return keysets, country, ip, user_email
88
89
90def _CreateApiCacheKeys(client_id, client_email, now_sec):
91 country = os.environ.get('HTTP_X_APPENGINE_COUNTRY')
92 ip = os.environ.get('REMOTE_ADDR')
93 now = datetime.datetime.fromtimestamp(now_sec)
94 minute_buckets = [now - datetime.timedelta(minutes=m) for m in
95 range(N_MINUTES)]
96 minute_strs = [str(minute_bucket.replace(second=0, microsecond=0))
97 for minute_bucket in minute_buckets]
98 keys = []
99
100 if client_id and client_id != 'anonymous':
101 keys.append(['apiratelimit-%s-%s' % (client_id, minute_str)
102 for minute_str in minute_strs])
103 if client_email:
104 keys.append(['apiratelimit-%s-%s' % (client_email, minute_str)
105 for minute_str in minute_strs])
106 else:
107 keys.append(['apiratelimit-%s-%s' % (ip, minute_str)
108 for minute_str in minute_strs])
109 if country in COUNTRY_LIMITS:
110 keys.append(['apiratelimit-%s-%s' % (country, minute_str)
111 for minute_str in minute_strs])
112
113 return keys
114
115
116class RateLimiter(object):
117
118 blocked_requests = ts_mon.CounterMetric(
119 'monorail/ratelimiter/blocked_request',
120 'Count of requests that exceeded the rate limit and were blocked.',
121 None)
122 limit_exceeded = ts_mon.CounterMetric(
123 'monorail/ratelimiter/rate_exceeded',
124 'Count of requests that exceeded the rate limit.',
125 None)
126 cost_thresh_exceeded = ts_mon.CounterMetric(
127 'monorail/ratelimiter/cost_thresh_exceeded',
128 'Count of requests that were expensive to process',
129 None)
130 checks = ts_mon.CounterMetric(
131 'monorail/ratelimiter/check',
132 'Count of checks done, by fail/success type.',
133 [ts_mon.StringField('type')])
134
135 def __init__(self, _cache=memcache, fail_open=True, **_kwargs):
136 self.fail_open = fail_open
137
138 def CheckStart(self, request, now=None):
139 if (modules.get_current_module_name() not in MODULE_ALLOWLIST or
140 users.is_current_user_admin()):
141 return
142 logging.info('X-AppEngine-Country: %s' %
143 request.headers.get(COUNTRY_HEADER, 'ZZ'))
144
145 if now is None:
146 now = time.time()
147
148 keysets, country, ip, user_email = _CacheKeys(request, now)
149 # There are either two or three sets of keys in keysets.
150 # Three if the user's country is in COUNTRY_LIMITS, otherwise two.
151 self._AuxCheckStart(
152 keysets, COUNTRY_LIMITS.get(country, DEFAULT_LIMIT),
153 settings.ratelimiting_enabled,
154 RateLimitExceeded(country=country, ip=ip, user_email=user_email))
155
156 def _AuxCheckStart(self, keysets, limit, ratelimiting_enabled,
157 exception_obj):
158 for keys in keysets:
159 count = 0
160 try:
161 counters = memcache.get_multi(keys)
162 count = sum(counters.values())
163 self.checks.increment({'type': 'success'})
164 except Exception as e:
165 logging.error(e)
166 if not self.fail_open:
167 self.checks.increment({'type': 'fail_closed'})
168 raise exception_obj
169 self.checks.increment({'type': 'fail_open'})
170
171 if count > limit:
172 # Since webapp2 won't let us return a 429 error code
173 # <http://tools.ietf.org/html/rfc6585#section-4>, we can't
174 # monitor rate limit exceeded events with our standard tools.
175 # We return a 400 with a custom error message to the client,
176 # and this logging is so we can monitor it internally.
177 logging.info('%s, %d' % (exception_obj.message, count))
178
179 self.limit_exceeded.increment()
180
181 if ratelimiting_enabled:
182 self.blocked_requests.increment()
183 raise exception_obj
184
185 k = keys[0]
186 # Only update the latest *time* bucket for each prefix (reverse chron).
187 memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
188 memcache.incr(k, initial_value=0)
189
190 def CheckEnd(self, request, now, start_time):
191 """If a request was expensive to process, charge some extra points
192 against this set of buckets.
193 We pass in both now and start_time so we can update the buckets
194 based on keys created from start_time instead of now.
195 now and start_time are float seconds.
196 """
197 if (modules.get_current_module_name() not in MODULE_ALLOWLIST):
198 return
199
200 elapsed_ms = int((now - start_time) * 1000)
201 # Would it kill the python lib maintainers to have timedelta.total_ms()?
202 penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1
203 if penalty >= 1:
204 # TODO: Look into caching the keys instead of generating them twice
205 # for every request. Say, return them from CheckStart so they can
206 # be passed back in here later.
207 keysets, country, ip, user_email = _CacheKeys(request, start_time)
208
209 self._AuxCheckEnd(
210 keysets,
211 'Rate Limit Cost Threshold Exceeded: %s, %s, %s' % (
212 country, ip, user_email),
213 penalty)
214
215 def _AuxCheckEnd(self, keysets, log_str, penalty):
216 self.cost_thresh_exceeded.increment()
217 for keys in keysets:
218 logging.info(log_str)
219
220 # Only update the latest *time* bucket for each prefix (reverse chron).
221 k = keys[0]
222 memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
223 memcache.incr(k, delta=penalty, initial_value=0)
224
225
226class ApiRateLimiter(RateLimiter):
227
228 blocked_requests = ts_mon.CounterMetric(
229 'monorail/apiratelimiter/blocked_request',
230 'Count of requests that exceeded the rate limit and were blocked.',
231 None)
232 limit_exceeded = ts_mon.CounterMetric(
233 'monorail/apiratelimiter/rate_exceeded',
234 'Count of requests that exceeded the rate limit.',
235 None)
236 cost_thresh_exceeded = ts_mon.CounterMetric(
237 'monorail/apiratelimiter/cost_thresh_exceeded',
238 'Count of requests that were expensive to process',
239 None)
240 checks = ts_mon.CounterMetric(
241 'monorail/apiratelimiter/check',
242 'Count of checks done, by fail/success type.',
243 [ts_mon.StringField('type')])
244
245 #pylint: disable=arguments-differ
246 def CheckStart(self, client_id, client_email, now=None):
247 if now is None:
248 now = time.time()
249
250 keysets = _CreateApiCacheKeys(client_id, client_email, now)
251 qpm_limit = client_config_svc.GetQPMDict().get(
252 client_email, DEFAULT_API_QPM)
253 if qpm_limit < DEFAULT_API_QPM:
254 qpm_limit = DEFAULT_API_QPM
255 window_limit = qpm_limit * N_MINUTES
256 self._AuxCheckStart(
257 keysets, window_limit,
258 settings.api_ratelimiting_enabled,
259 ApiRateLimitExceeded(client_id, client_email))
260
261 #pylint: disable=arguments-differ
262 def CheckEnd(self, client_id, client_email, now, start_time):
263
264 elapsed_ms = int((now - start_time) * 1000)
265 penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1
266
267 if penalty >= 1:
268 keysets = _CreateApiCacheKeys(client_id, client_email, start_time)
269 self._AuxCheckEnd(
270 keysets,
271 'API Rate Limit Cost Threshold Exceeded: %s, %s' % (
272 client_id, client_email),
273 penalty)
274
275
276class RateLimitExceeded(Exception):
277 def __init__(self, country=None, ip=None, user_email=None, **_kwargs):
278 self.country = country
279 self.ip = ip
280 self.user_email = user_email
281 message = 'RateLimitExceeded: %s, %s, %s' % (
282 self.country, self.ip, self.user_email)
283 super(RateLimitExceeded, self).__init__(message)
284
285
286class ApiRateLimitExceeded(Exception):
287 def __init__(self, client_id, client_email):
288 self.client_id = client_id
289 self.client_email = client_email
290 message = 'RateLimitExceeded: %s, %s' % (
291 self.client_id, self.client_email)
292 super(ApiRateLimitExceeded, self).__init__(message)