blob: bfc8de76be99c4e6826be648215f272f70645c81 [file] [log] [blame]
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +01001# Copyright 2016 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
Copybara854996b2021-09-07 19:36:02 +00004
5"""Unit tests for RateLimiter.
6"""
7
8from __future__ import division
9from __future__ import print_function
10from __future__ import absolute_import
11
12import unittest
13
14from google.appengine.api import memcache
15from google.appengine.ext import testbed
16
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +020017try:
18 from mox3 import mox
19except ImportError:
20 import mox
Copybara854996b2021-09-07 19:36:02 +000021import os
22import settings
23
24from framework import ratelimiter
25from services import service_manager
26from services import client_config_svc
27from testing import fake
28from testing import testing_helpers
29
30
31class RateLimiterTest(unittest.TestCase):
32 def setUp(self):
33 self.testbed = testbed.Testbed()
34 self.testbed.activate()
35 self.testbed.init_memcache_stub()
36 self.testbed.init_user_stub()
37
38 self.mox = mox.Mox()
39 self.services = service_manager.Services(
40 config=fake.ConfigService(),
41 issue=fake.IssueService(),
42 user=fake.UserService(),
43 project=fake.ProjectService(),
44 )
45 self.project = self.services.project.TestAddProject('proj', project_id=987)
46
47 self.ratelimiter = ratelimiter.RateLimiter()
48 ratelimiter.COUNTRY_LIMITS = {}
49 os.environ['USER_EMAIL'] = ''
50 settings.ratelimiting_enabled = True
51 ratelimiter.DEFAULT_LIMIT = 10
52
53 def tearDown(self):
54 self.testbed.deactivate()
55 self.mox.UnsetStubs()
56 self.mox.ResetAll()
57 # settings.ratelimiting_enabled = True
58
59 def testCheckStart_pass(self):
60 request, _ = testing_helpers.GetRequestObjects(
61 project=self.project)
62 request.headers['X-AppEngine-Country'] = 'US'
63 request.remote_addr = '192.168.1.0'
64 self.ratelimiter.CheckStart(request)
65 # Should not throw an exception.
66
67 def testCheckStart_fail(self):
68 request, _ = testing_helpers.GetRequestObjects(
69 project=self.project)
70 request.headers['X-AppEngine-Country'] = 'US'
71 request.remote_addr = '192.168.1.0'
72 now = 0.0
73 cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
74 values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
75 cachekeys in cachekeysets]
76 for value in values:
77 memcache.add_multi(value)
78 with self.assertRaises(ratelimiter.RateLimitExceeded):
79 self.ratelimiter.CheckStart(request, now)
80
81 def testCheckStart_expiredEntries(self):
82 request, _ = testing_helpers.GetRequestObjects(
83 project=self.project)
84 request.headers['X-AppEngine-Country'] = 'US'
85 request.remote_addr = '192.168.1.0'
86 now = 0.0
87 cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
88 values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
89 cachekeys in cachekeysets]
90 for value in values:
91 memcache.add_multi(value)
92
93 now = now + 2 * ratelimiter.EXPIRE_AFTER_SECS
94 self.ratelimiter.CheckStart(request, now)
95 # Should not throw an exception.
96
97 def testCheckStart_repeatedCalls(self):
98 request, _ = testing_helpers.GetRequestObjects(
99 project=self.project)
100 request.headers['X-AppEngine-Country'] = 'US'
101 request.remote_addr = '192.168.1.0'
102 now = 0.0
103
104 # Call CheckStart once every minute. Should be ok.
105 for _ in range(ratelimiter.N_MINUTES):
106 self.ratelimiter.CheckStart(request, now)
107 now = now + 120.0
108
109 # Call CheckStart more than DEFAULT_LIMIT times in the same minute.
110 with self.assertRaises(ratelimiter.RateLimitExceeded):
111 for _ in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
112 now = now + 0.001
113 self.ratelimiter.CheckStart(request, now)
114
115 def testCheckStart_differentIPs(self):
116 now = 0.0
117
118 ratelimiter.COUNTRY_LIMITS = {}
119 # Exceed DEFAULT_LIMIT calls, but vary remote_addr so different
120 # remote addresses aren't ratelimited together.
121 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
122 request, _ = testing_helpers.GetRequestObjects(
123 project=self.project)
124 request.headers['X-AppEngine-Country'] = 'US'
125 request.remote_addr = '192.168.1.%d' % (m % 16)
126 ratelimiter._CacheKeys(request, now)
127 self.ratelimiter.CheckStart(request, now)
128 now = now + 0.001
129
130 # Exceed the limit, but only for one IP address. The
131 # others should be fine.
132 with self.assertRaises(ratelimiter.RateLimitExceeded):
133 for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
134 request, _ = testing_helpers.GetRequestObjects(
135 project=self.project)
136 request.headers['X-AppEngine-Country'] = 'US'
137 request.remote_addr = '192.168.1.0'
138 ratelimiter._CacheKeys(request, now)
139 self.ratelimiter.CheckStart(request, now)
140 now = now + 0.001
141
142 # Now proceed to make requests for all of the other IP
143 # addresses besides .0.
144 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
145 request, _ = testing_helpers.GetRequestObjects(
146 project=self.project)
147 request.headers['X-AppEngine-Country'] = 'US'
148 # Skip .0 since it's already exceeded the limit.
149 request.remote_addr = '192.168.1.%d' % (m + 1)
150 ratelimiter._CacheKeys(request, now)
151 self.ratelimiter.CheckStart(request, now)
152 now = now + 0.001
153
154 def testCheckStart_sameIPDifferentUserIDs(self):
155 # Behind a NAT, e.g.
156 now = 0.0
157
158 # Exceed DEFAULT_LIMIT calls, but vary user_id so different
159 # users behind the same IP aren't ratelimited together.
160 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
161 request, _ = testing_helpers.GetRequestObjects(
162 project=self.project)
163 request.remote_addr = '192.168.1.0'
164 os.environ['USER_EMAIL'] = '%s@example.com' % m
165 request.headers['X-AppEngine-Country'] = 'US'
166 ratelimiter._CacheKeys(request, now)
167 self.ratelimiter.CheckStart(request, now)
168 now = now + 0.001
169
170 # Exceed the limit, but only for one userID+IP address. The
171 # others should be fine.
172 with self.assertRaises(ratelimiter.RateLimitExceeded):
173 for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
174 request, _ = testing_helpers.GetRequestObjects(
175 project=self.project)
176 request.headers['X-AppEngine-Country'] = 'US'
177 request.remote_addr = '192.168.1.0'
178 os.environ['USER_EMAIL'] = '42@example.com'
179 ratelimiter._CacheKeys(request, now)
180 self.ratelimiter.CheckStart(request, now)
181 now = now + 0.001
182
183 # Now proceed to make requests for other user IDs
184 # besides 42.
185 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
186 request, _ = testing_helpers.GetRequestObjects(
187 project=self.project)
188 request.headers['X-AppEngine-Country'] = 'US'
189 # Skip .0 since it's already exceeded the limit.
190 request.remote_addr = '192.168.1.0'
191 os.environ['USER_EMAIL'] = '%s@example.com' % (43 + m)
192 ratelimiter._CacheKeys(request, now)
193 self.ratelimiter.CheckStart(request, now)
194 now = now + 0.001
195
196 def testCheckStart_ratelimitingDisabled(self):
197 settings.ratelimiting_enabled = False
198 request, _ = testing_helpers.GetRequestObjects(
199 project=self.project)
200 request.headers['X-AppEngine-Country'] = 'US'
201 request.remote_addr = '192.168.1.0'
202 now = 0.0
203
204 # Call CheckStart a lot. Should be ok.
205 for _ in range(ratelimiter.DEFAULT_LIMIT):
206 self.ratelimiter.CheckStart(request, now)
207 now = now + 0.001
208
209 def testCheckStart_perCountryLoggedOutLimit(self):
210 ratelimiter.COUNTRY_LIMITS['US'] = 10
211
212 request, _ = testing_helpers.GetRequestObjects(
213 project=self.project)
214 request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
215 request.remote_addr = '192.168.1.1'
216 now = 0.0
217
218 with self.assertRaises(ratelimiter.RateLimitExceeded):
219 for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
220 self.ratelimiter.CheckStart(request, now)
221 # Vary remote address to make sure the limit covers
222 # the whole country, regardless of IP.
223 request.remote_addr = '192.168.1.%d' % m
224 now = now + 0.001
225
226 # CheckStart for a country that isn't covered by a country-specific limit.
227 request.headers['X-AppEngine-Country'] = 'UK'
228 for m in range(11):
229 self.ratelimiter.CheckStart(request, now)
230 # Vary remote address to make sure the limit covers
231 # the whole country, regardless of IP.
232 request.remote_addr = '192.168.1.%d' % m
233 now = now + 0.001
234
235 # And regular rate limits work per-IP.
236 request.remote_addr = '192.168.1.1'
237 with self.assertRaises(ratelimiter.RateLimitExceeded):
238 for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
239 self.ratelimiter.CheckStart(request, now)
240 # Vary remote address to make sure the limit covers
241 # the whole country, regardless of IP.
242 now = now + 0.001
243
244 def testCheckEnd_SlowRequest(self):
245 """We count one request for each 1000ms."""
246 request, _ = testing_helpers.GetRequestObjects(
247 project=self.project)
248 request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
249 request.remote_addr = '192.168.1.1'
250 start_time = 0.0
251
252 # Send some requests, all under the limit.
253 for _ in range(ratelimiter.DEFAULT_LIMIT-1):
254 start_time = start_time + 0.001
255 self.ratelimiter.CheckStart(request, start_time)
256 now = start_time + 0.010
257 self.ratelimiter.CheckEnd(request, now, start_time)
258
259 # Now issue some more request, this time taking long
260 # enough to get the cost threshold penalty.
261 # Fast forward enough to impact a later bucket than the
262 # previous requests.
263 start_time = now + 120.0
264 self.ratelimiter.CheckStart(request, start_time)
265
266 # Take longer than the threshold to process the request.
267 elapsed_ms = settings.ratelimiting_ms_per_count * 2
268 now = start_time + elapsed_ms / 1000
269
270 # The request finished, taking long enough to count as two.
271 self.ratelimiter.CheckEnd(request, now, start_time)
272
273 with self.assertRaises(ratelimiter.RateLimitExceeded):
274 # One more request after the expensive query should
275 # throw an excpetion.
276 self.ratelimiter.CheckStart(request, start_time)
277
278 def testCheckEnd_FastRequest(self):
279 request, _ = testing_helpers.GetRequestObjects(
280 project=self.project)
281 request.headers[ratelimiter.COUNTRY_HEADER] = 'asdasd'
282 request.remote_addr = '192.168.1.1'
283 start_time = 0.0
284
285 # Send some requests, all under the limit.
286 for _ in range(ratelimiter.DEFAULT_LIMIT):
287 self.ratelimiter.CheckStart(request, start_time)
288 now = start_time + 0.01
289 self.ratelimiter.CheckEnd(request, now, start_time)
290 start_time = now + 0.01
291
292
293class ApiRateLimiterTest(unittest.TestCase):
294
295 def setUp(self):
296 settings.ratelimiting_enabled = True
297 self.testbed = testbed.Testbed()
298 self.testbed.activate()
299 self.testbed.init_memcache_stub()
300
301 self.services = service_manager.Services(
302 config=fake.ConfigService(),
303 issue=fake.IssueService(),
304 user=fake.UserService(),
305 project=fake.ProjectService(),
306 )
307
308 self.client_id = '123456789'
309 self.client_email = 'test@example.com'
310
311 self.ratelimiter = ratelimiter.ApiRateLimiter()
312 settings.api_ratelimiting_enabled = True
313
314 def tearDown(self):
315 self.testbed.deactivate()
316
317 def testCheckStart_Allowed(self):
318 now = 0.0
319 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
320 self.ratelimiter.CheckStart(self.client_id, None, now)
321 self.ratelimiter.CheckStart(None, None, now)
322 self.ratelimiter.CheckStart('anonymous', None, now)
323
324 def testCheckStart_Rejected(self):
325 now = 0.0
326 keysets = ratelimiter._CreateApiCacheKeys(
327 self.client_id, self.client_email, now)
328 values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
329 keyset in keysets]
330 for value in values:
331 memcache.add_multi(value)
332 with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
333 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
334
335 def testCheckStart_Allowed_HigherQPMSpecified(self):
336 """Client goes over the default, but has a higher QPM set."""
337 now = 0.0
338 keysets = ratelimiter._CreateApiCacheKeys(
339 self.client_id, self.client_email, now)
340 qpm_dict = client_config_svc.GetQPMDict()
341 qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM + 10
342 # The client used 1 request more than the default limit in each of the
343 # 5 minutes in our 5 minute sample window, so 5 over to the total.
344 values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
345 keyset in keysets]
346 for value in values:
347 memcache.add_multi(value)
348 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
349 del qpm_dict[self.client_email]
350
351 def testCheckStart_Allowed_LowQPMIgnored(self):
352 """Client specifies a QPM lower than the default and default is used."""
353 now = 0.0
354 keysets = ratelimiter._CreateApiCacheKeys(
355 self.client_id, self.client_email, now)
356 qpm_dict = client_config_svc.GetQPMDict()
357 qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
358 values = [{key: ratelimiter.DEFAULT_API_QPM for key in keyset} for
359 keyset in keysets]
360 for value in values:
361 memcache.add_multi(value)
362 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
363 del qpm_dict[self.client_email]
364
365 def testCheckStart_Rejected_LowQPMIgnored(self):
366 """Client specifies a QPM lower than the default and default is used."""
367 now = 0.0
368 keysets = ratelimiter._CreateApiCacheKeys(
369 self.client_id, self.client_email, now)
370 qpm_dict = client_config_svc.GetQPMDict()
371 qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
372 values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
373 keyset in keysets]
374 for value in values:
375 memcache.add_multi(value)
376 with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
377 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
378 del qpm_dict[self.client_email]
379
380 def testCheckEnd(self):
381 start_time = 0.0
382 keysets = ratelimiter._CreateApiCacheKeys(
383 self.client_id, self.client_email, start_time)
384
385 now = 0.1
386 self.ratelimiter.CheckEnd(
387 self.client_id, self.client_email, now, start_time)
388 counters = memcache.get_multi(keysets[0])
389 count = sum(counters.values())
390 # No extra cost charged
391 self.assertEqual(0, count)
392
393 elapsed_ms = settings.ratelimiting_ms_per_count * 2
394 now = start_time + elapsed_ms / 1000
395 self.ratelimiter.CheckEnd(
396 self.client_id, self.client_email, now, start_time)
397 counters = memcache.get_multi(keysets[0])
398 count = sum(counters.values())
399 # Extra cost charged
400 self.assertEqual(1, count)