blob: 84230e83749b4a10dc43263eb484c18d785e174d [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2016 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""Unit tests for RateLimiter.
7"""
8
9from __future__ import division
10from __future__ import print_function
11from __future__ import absolute_import
12
13import unittest
14
15from google.appengine.api import memcache
16from google.appengine.ext import testbed
17
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +020018try:
19 from mox3 import mox
20except ImportError:
21 import mox
Copybara854996b2021-09-07 19:36:02 +000022import os
23import settings
24
25from framework import ratelimiter
26from services import service_manager
27from services import client_config_svc
28from testing import fake
29from testing import testing_helpers
30
31
32class RateLimiterTest(unittest.TestCase):
33 def setUp(self):
34 self.testbed = testbed.Testbed()
35 self.testbed.activate()
36 self.testbed.init_memcache_stub()
37 self.testbed.init_user_stub()
38
39 self.mox = mox.Mox()
40 self.services = service_manager.Services(
41 config=fake.ConfigService(),
42 issue=fake.IssueService(),
43 user=fake.UserService(),
44 project=fake.ProjectService(),
45 )
46 self.project = self.services.project.TestAddProject('proj', project_id=987)
47
48 self.ratelimiter = ratelimiter.RateLimiter()
49 ratelimiter.COUNTRY_LIMITS = {}
50 os.environ['USER_EMAIL'] = ''
51 settings.ratelimiting_enabled = True
52 ratelimiter.DEFAULT_LIMIT = 10
53
54 def tearDown(self):
55 self.testbed.deactivate()
56 self.mox.UnsetStubs()
57 self.mox.ResetAll()
58 # settings.ratelimiting_enabled = True
59
60 def testCheckStart_pass(self):
61 request, _ = testing_helpers.GetRequestObjects(
62 project=self.project)
63 request.headers['X-AppEngine-Country'] = 'US'
64 request.remote_addr = '192.168.1.0'
65 self.ratelimiter.CheckStart(request)
66 # Should not throw an exception.
67
68 def testCheckStart_fail(self):
69 request, _ = testing_helpers.GetRequestObjects(
70 project=self.project)
71 request.headers['X-AppEngine-Country'] = 'US'
72 request.remote_addr = '192.168.1.0'
73 now = 0.0
74 cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
75 values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
76 cachekeys in cachekeysets]
77 for value in values:
78 memcache.add_multi(value)
79 with self.assertRaises(ratelimiter.RateLimitExceeded):
80 self.ratelimiter.CheckStart(request, now)
81
82 def testCheckStart_expiredEntries(self):
83 request, _ = testing_helpers.GetRequestObjects(
84 project=self.project)
85 request.headers['X-AppEngine-Country'] = 'US'
86 request.remote_addr = '192.168.1.0'
87 now = 0.0
88 cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
89 values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
90 cachekeys in cachekeysets]
91 for value in values:
92 memcache.add_multi(value)
93
94 now = now + 2 * ratelimiter.EXPIRE_AFTER_SECS
95 self.ratelimiter.CheckStart(request, now)
96 # Should not throw an exception.
97
98 def testCheckStart_repeatedCalls(self):
99 request, _ = testing_helpers.GetRequestObjects(
100 project=self.project)
101 request.headers['X-AppEngine-Country'] = 'US'
102 request.remote_addr = '192.168.1.0'
103 now = 0.0
104
105 # Call CheckStart once every minute. Should be ok.
106 for _ in range(ratelimiter.N_MINUTES):
107 self.ratelimiter.CheckStart(request, now)
108 now = now + 120.0
109
110 # Call CheckStart more than DEFAULT_LIMIT times in the same minute.
111 with self.assertRaises(ratelimiter.RateLimitExceeded):
112 for _ in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
113 now = now + 0.001
114 self.ratelimiter.CheckStart(request, now)
115
116 def testCheckStart_differentIPs(self):
117 now = 0.0
118
119 ratelimiter.COUNTRY_LIMITS = {}
120 # Exceed DEFAULT_LIMIT calls, but vary remote_addr so different
121 # remote addresses aren't ratelimited together.
122 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
123 request, _ = testing_helpers.GetRequestObjects(
124 project=self.project)
125 request.headers['X-AppEngine-Country'] = 'US'
126 request.remote_addr = '192.168.1.%d' % (m % 16)
127 ratelimiter._CacheKeys(request, now)
128 self.ratelimiter.CheckStart(request, now)
129 now = now + 0.001
130
131 # Exceed the limit, but only for one IP address. The
132 # others should be fine.
133 with self.assertRaises(ratelimiter.RateLimitExceeded):
134 for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
135 request, _ = testing_helpers.GetRequestObjects(
136 project=self.project)
137 request.headers['X-AppEngine-Country'] = 'US'
138 request.remote_addr = '192.168.1.0'
139 ratelimiter._CacheKeys(request, now)
140 self.ratelimiter.CheckStart(request, now)
141 now = now + 0.001
142
143 # Now proceed to make requests for all of the other IP
144 # addresses besides .0.
145 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
146 request, _ = testing_helpers.GetRequestObjects(
147 project=self.project)
148 request.headers['X-AppEngine-Country'] = 'US'
149 # Skip .0 since it's already exceeded the limit.
150 request.remote_addr = '192.168.1.%d' % (m + 1)
151 ratelimiter._CacheKeys(request, now)
152 self.ratelimiter.CheckStart(request, now)
153 now = now + 0.001
154
155 def testCheckStart_sameIPDifferentUserIDs(self):
156 # Behind a NAT, e.g.
157 now = 0.0
158
159 # Exceed DEFAULT_LIMIT calls, but vary user_id so different
160 # users behind the same IP aren't ratelimited together.
161 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
162 request, _ = testing_helpers.GetRequestObjects(
163 project=self.project)
164 request.remote_addr = '192.168.1.0'
165 os.environ['USER_EMAIL'] = '%s@example.com' % m
166 request.headers['X-AppEngine-Country'] = 'US'
167 ratelimiter._CacheKeys(request, now)
168 self.ratelimiter.CheckStart(request, now)
169 now = now + 0.001
170
171 # Exceed the limit, but only for one userID+IP address. The
172 # others should be fine.
173 with self.assertRaises(ratelimiter.RateLimitExceeded):
174 for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
175 request, _ = testing_helpers.GetRequestObjects(
176 project=self.project)
177 request.headers['X-AppEngine-Country'] = 'US'
178 request.remote_addr = '192.168.1.0'
179 os.environ['USER_EMAIL'] = '42@example.com'
180 ratelimiter._CacheKeys(request, now)
181 self.ratelimiter.CheckStart(request, now)
182 now = now + 0.001
183
184 # Now proceed to make requests for other user IDs
185 # besides 42.
186 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
187 request, _ = testing_helpers.GetRequestObjects(
188 project=self.project)
189 request.headers['X-AppEngine-Country'] = 'US'
190 # Skip .0 since it's already exceeded the limit.
191 request.remote_addr = '192.168.1.0'
192 os.environ['USER_EMAIL'] = '%s@example.com' % (43 + m)
193 ratelimiter._CacheKeys(request, now)
194 self.ratelimiter.CheckStart(request, now)
195 now = now + 0.001
196
197 def testCheckStart_ratelimitingDisabled(self):
198 settings.ratelimiting_enabled = False
199 request, _ = testing_helpers.GetRequestObjects(
200 project=self.project)
201 request.headers['X-AppEngine-Country'] = 'US'
202 request.remote_addr = '192.168.1.0'
203 now = 0.0
204
205 # Call CheckStart a lot. Should be ok.
206 for _ in range(ratelimiter.DEFAULT_LIMIT):
207 self.ratelimiter.CheckStart(request, now)
208 now = now + 0.001
209
210 def testCheckStart_perCountryLoggedOutLimit(self):
211 ratelimiter.COUNTRY_LIMITS['US'] = 10
212
213 request, _ = testing_helpers.GetRequestObjects(
214 project=self.project)
215 request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
216 request.remote_addr = '192.168.1.1'
217 now = 0.0
218
219 with self.assertRaises(ratelimiter.RateLimitExceeded):
220 for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
221 self.ratelimiter.CheckStart(request, now)
222 # Vary remote address to make sure the limit covers
223 # the whole country, regardless of IP.
224 request.remote_addr = '192.168.1.%d' % m
225 now = now + 0.001
226
227 # CheckStart for a country that isn't covered by a country-specific limit.
228 request.headers['X-AppEngine-Country'] = 'UK'
229 for m in range(11):
230 self.ratelimiter.CheckStart(request, now)
231 # Vary remote address to make sure the limit covers
232 # the whole country, regardless of IP.
233 request.remote_addr = '192.168.1.%d' % m
234 now = now + 0.001
235
236 # And regular rate limits work per-IP.
237 request.remote_addr = '192.168.1.1'
238 with self.assertRaises(ratelimiter.RateLimitExceeded):
239 for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
240 self.ratelimiter.CheckStart(request, now)
241 # Vary remote address to make sure the limit covers
242 # the whole country, regardless of IP.
243 now = now + 0.001
244
245 def testCheckEnd_SlowRequest(self):
246 """We count one request for each 1000ms."""
247 request, _ = testing_helpers.GetRequestObjects(
248 project=self.project)
249 request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
250 request.remote_addr = '192.168.1.1'
251 start_time = 0.0
252
253 # Send some requests, all under the limit.
254 for _ in range(ratelimiter.DEFAULT_LIMIT-1):
255 start_time = start_time + 0.001
256 self.ratelimiter.CheckStart(request, start_time)
257 now = start_time + 0.010
258 self.ratelimiter.CheckEnd(request, now, start_time)
259
260 # Now issue some more request, this time taking long
261 # enough to get the cost threshold penalty.
262 # Fast forward enough to impact a later bucket than the
263 # previous requests.
264 start_time = now + 120.0
265 self.ratelimiter.CheckStart(request, start_time)
266
267 # Take longer than the threshold to process the request.
268 elapsed_ms = settings.ratelimiting_ms_per_count * 2
269 now = start_time + elapsed_ms / 1000
270
271 # The request finished, taking long enough to count as two.
272 self.ratelimiter.CheckEnd(request, now, start_time)
273
274 with self.assertRaises(ratelimiter.RateLimitExceeded):
275 # One more request after the expensive query should
276 # throw an excpetion.
277 self.ratelimiter.CheckStart(request, start_time)
278
279 def testCheckEnd_FastRequest(self):
280 request, _ = testing_helpers.GetRequestObjects(
281 project=self.project)
282 request.headers[ratelimiter.COUNTRY_HEADER] = 'asdasd'
283 request.remote_addr = '192.168.1.1'
284 start_time = 0.0
285
286 # Send some requests, all under the limit.
287 for _ in range(ratelimiter.DEFAULT_LIMIT):
288 self.ratelimiter.CheckStart(request, start_time)
289 now = start_time + 0.01
290 self.ratelimiter.CheckEnd(request, now, start_time)
291 start_time = now + 0.01
292
293
294class ApiRateLimiterTest(unittest.TestCase):
295
296 def setUp(self):
297 settings.ratelimiting_enabled = True
298 self.testbed = testbed.Testbed()
299 self.testbed.activate()
300 self.testbed.init_memcache_stub()
301
302 self.services = service_manager.Services(
303 config=fake.ConfigService(),
304 issue=fake.IssueService(),
305 user=fake.UserService(),
306 project=fake.ProjectService(),
307 )
308
309 self.client_id = '123456789'
310 self.client_email = 'test@example.com'
311
312 self.ratelimiter = ratelimiter.ApiRateLimiter()
313 settings.api_ratelimiting_enabled = True
314
315 def tearDown(self):
316 self.testbed.deactivate()
317
318 def testCheckStart_Allowed(self):
319 now = 0.0
320 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
321 self.ratelimiter.CheckStart(self.client_id, None, now)
322 self.ratelimiter.CheckStart(None, None, now)
323 self.ratelimiter.CheckStart('anonymous', None, now)
324
325 def testCheckStart_Rejected(self):
326 now = 0.0
327 keysets = ratelimiter._CreateApiCacheKeys(
328 self.client_id, self.client_email, now)
329 values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
330 keyset in keysets]
331 for value in values:
332 memcache.add_multi(value)
333 with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
334 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
335
336 def testCheckStart_Allowed_HigherQPMSpecified(self):
337 """Client goes over the default, but has a higher QPM set."""
338 now = 0.0
339 keysets = ratelimiter._CreateApiCacheKeys(
340 self.client_id, self.client_email, now)
341 qpm_dict = client_config_svc.GetQPMDict()
342 qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM + 10
343 # The client used 1 request more than the default limit in each of the
344 # 5 minutes in our 5 minute sample window, so 5 over to the total.
345 values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
346 keyset in keysets]
347 for value in values:
348 memcache.add_multi(value)
349 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
350 del qpm_dict[self.client_email]
351
352 def testCheckStart_Allowed_LowQPMIgnored(self):
353 """Client specifies a QPM lower than the default and default is used."""
354 now = 0.0
355 keysets = ratelimiter._CreateApiCacheKeys(
356 self.client_id, self.client_email, now)
357 qpm_dict = client_config_svc.GetQPMDict()
358 qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
359 values = [{key: ratelimiter.DEFAULT_API_QPM for key in keyset} for
360 keyset in keysets]
361 for value in values:
362 memcache.add_multi(value)
363 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
364 del qpm_dict[self.client_email]
365
366 def testCheckStart_Rejected_LowQPMIgnored(self):
367 """Client specifies a QPM lower than the default and default is used."""
368 now = 0.0
369 keysets = ratelimiter._CreateApiCacheKeys(
370 self.client_id, self.client_email, now)
371 qpm_dict = client_config_svc.GetQPMDict()
372 qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
373 values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
374 keyset in keysets]
375 for value in values:
376 memcache.add_multi(value)
377 with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
378 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
379 del qpm_dict[self.client_email]
380
381 def testCheckEnd(self):
382 start_time = 0.0
383 keysets = ratelimiter._CreateApiCacheKeys(
384 self.client_id, self.client_email, start_time)
385
386 now = 0.1
387 self.ratelimiter.CheckEnd(
388 self.client_id, self.client_email, now, start_time)
389 counters = memcache.get_multi(keysets[0])
390 count = sum(counters.values())
391 # No extra cost charged
392 self.assertEqual(0, count)
393
394 elapsed_ms = settings.ratelimiting_ms_per_count * 2
395 now = start_time + elapsed_ms / 1000
396 self.ratelimiter.CheckEnd(
397 self.client_id, self.client_email, now, start_time)
398 counters = memcache.get_multi(keysets[0])
399 count = sum(counters.values())
400 # Extra cost charged
401 self.assertEqual(1, count)