blob: b351f8cbd2118122fb38807e333a3e32cb7c272f [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2016 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""Unit tests for RateLimiter.
7"""
8
9from __future__ import division
10from __future__ import print_function
11from __future__ import absolute_import
12
13import unittest
14
15from google.appengine.api import memcache
16from google.appengine.ext import testbed
17
18import mox
19import os
20import settings
21
22from framework import ratelimiter
23from services import service_manager
24from services import client_config_svc
25from testing import fake
26from testing import testing_helpers
27
28
29class RateLimiterTest(unittest.TestCase):
30 def setUp(self):
31 self.testbed = testbed.Testbed()
32 self.testbed.activate()
33 self.testbed.init_memcache_stub()
34 self.testbed.init_user_stub()
35
36 self.mox = mox.Mox()
37 self.services = service_manager.Services(
38 config=fake.ConfigService(),
39 issue=fake.IssueService(),
40 user=fake.UserService(),
41 project=fake.ProjectService(),
42 )
43 self.project = self.services.project.TestAddProject('proj', project_id=987)
44
45 self.ratelimiter = ratelimiter.RateLimiter()
46 ratelimiter.COUNTRY_LIMITS = {}
47 os.environ['USER_EMAIL'] = ''
48 settings.ratelimiting_enabled = True
49 ratelimiter.DEFAULT_LIMIT = 10
50
51 def tearDown(self):
52 self.testbed.deactivate()
53 self.mox.UnsetStubs()
54 self.mox.ResetAll()
55 # settings.ratelimiting_enabled = True
56
57 def testCheckStart_pass(self):
58 request, _ = testing_helpers.GetRequestObjects(
59 project=self.project)
60 request.headers['X-AppEngine-Country'] = 'US'
61 request.remote_addr = '192.168.1.0'
62 self.ratelimiter.CheckStart(request)
63 # Should not throw an exception.
64
65 def testCheckStart_fail(self):
66 request, _ = testing_helpers.GetRequestObjects(
67 project=self.project)
68 request.headers['X-AppEngine-Country'] = 'US'
69 request.remote_addr = '192.168.1.0'
70 now = 0.0
71 cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
72 values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
73 cachekeys in cachekeysets]
74 for value in values:
75 memcache.add_multi(value)
76 with self.assertRaises(ratelimiter.RateLimitExceeded):
77 self.ratelimiter.CheckStart(request, now)
78
79 def testCheckStart_expiredEntries(self):
80 request, _ = testing_helpers.GetRequestObjects(
81 project=self.project)
82 request.headers['X-AppEngine-Country'] = 'US'
83 request.remote_addr = '192.168.1.0'
84 now = 0.0
85 cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
86 values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
87 cachekeys in cachekeysets]
88 for value in values:
89 memcache.add_multi(value)
90
91 now = now + 2 * ratelimiter.EXPIRE_AFTER_SECS
92 self.ratelimiter.CheckStart(request, now)
93 # Should not throw an exception.
94
95 def testCheckStart_repeatedCalls(self):
96 request, _ = testing_helpers.GetRequestObjects(
97 project=self.project)
98 request.headers['X-AppEngine-Country'] = 'US'
99 request.remote_addr = '192.168.1.0'
100 now = 0.0
101
102 # Call CheckStart once every minute. Should be ok.
103 for _ in range(ratelimiter.N_MINUTES):
104 self.ratelimiter.CheckStart(request, now)
105 now = now + 120.0
106
107 # Call CheckStart more than DEFAULT_LIMIT times in the same minute.
108 with self.assertRaises(ratelimiter.RateLimitExceeded):
109 for _ in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
110 now = now + 0.001
111 self.ratelimiter.CheckStart(request, now)
112
113 def testCheckStart_differentIPs(self):
114 now = 0.0
115
116 ratelimiter.COUNTRY_LIMITS = {}
117 # Exceed DEFAULT_LIMIT calls, but vary remote_addr so different
118 # remote addresses aren't ratelimited together.
119 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
120 request, _ = testing_helpers.GetRequestObjects(
121 project=self.project)
122 request.headers['X-AppEngine-Country'] = 'US'
123 request.remote_addr = '192.168.1.%d' % (m % 16)
124 ratelimiter._CacheKeys(request, now)
125 self.ratelimiter.CheckStart(request, now)
126 now = now + 0.001
127
128 # Exceed the limit, but only for one IP address. The
129 # others should be fine.
130 with self.assertRaises(ratelimiter.RateLimitExceeded):
131 for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
132 request, _ = testing_helpers.GetRequestObjects(
133 project=self.project)
134 request.headers['X-AppEngine-Country'] = 'US'
135 request.remote_addr = '192.168.1.0'
136 ratelimiter._CacheKeys(request, now)
137 self.ratelimiter.CheckStart(request, now)
138 now = now + 0.001
139
140 # Now proceed to make requests for all of the other IP
141 # addresses besides .0.
142 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
143 request, _ = testing_helpers.GetRequestObjects(
144 project=self.project)
145 request.headers['X-AppEngine-Country'] = 'US'
146 # Skip .0 since it's already exceeded the limit.
147 request.remote_addr = '192.168.1.%d' % (m + 1)
148 ratelimiter._CacheKeys(request, now)
149 self.ratelimiter.CheckStart(request, now)
150 now = now + 0.001
151
152 def testCheckStart_sameIPDifferentUserIDs(self):
153 # Behind a NAT, e.g.
154 now = 0.0
155
156 # Exceed DEFAULT_LIMIT calls, but vary user_id so different
157 # users behind the same IP aren't ratelimited together.
158 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
159 request, _ = testing_helpers.GetRequestObjects(
160 project=self.project)
161 request.remote_addr = '192.168.1.0'
162 os.environ['USER_EMAIL'] = '%s@example.com' % m
163 request.headers['X-AppEngine-Country'] = 'US'
164 ratelimiter._CacheKeys(request, now)
165 self.ratelimiter.CheckStart(request, now)
166 now = now + 0.001
167
168 # Exceed the limit, but only for one userID+IP address. The
169 # others should be fine.
170 with self.assertRaises(ratelimiter.RateLimitExceeded):
171 for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
172 request, _ = testing_helpers.GetRequestObjects(
173 project=self.project)
174 request.headers['X-AppEngine-Country'] = 'US'
175 request.remote_addr = '192.168.1.0'
176 os.environ['USER_EMAIL'] = '42@example.com'
177 ratelimiter._CacheKeys(request, now)
178 self.ratelimiter.CheckStart(request, now)
179 now = now + 0.001
180
181 # Now proceed to make requests for other user IDs
182 # besides 42.
183 for m in range(ratelimiter.DEFAULT_LIMIT * 2):
184 request, _ = testing_helpers.GetRequestObjects(
185 project=self.project)
186 request.headers['X-AppEngine-Country'] = 'US'
187 # Skip .0 since it's already exceeded the limit.
188 request.remote_addr = '192.168.1.0'
189 os.environ['USER_EMAIL'] = '%s@example.com' % (43 + m)
190 ratelimiter._CacheKeys(request, now)
191 self.ratelimiter.CheckStart(request, now)
192 now = now + 0.001
193
194 def testCheckStart_ratelimitingDisabled(self):
195 settings.ratelimiting_enabled = False
196 request, _ = testing_helpers.GetRequestObjects(
197 project=self.project)
198 request.headers['X-AppEngine-Country'] = 'US'
199 request.remote_addr = '192.168.1.0'
200 now = 0.0
201
202 # Call CheckStart a lot. Should be ok.
203 for _ in range(ratelimiter.DEFAULT_LIMIT):
204 self.ratelimiter.CheckStart(request, now)
205 now = now + 0.001
206
207 def testCheckStart_perCountryLoggedOutLimit(self):
208 ratelimiter.COUNTRY_LIMITS['US'] = 10
209
210 request, _ = testing_helpers.GetRequestObjects(
211 project=self.project)
212 request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
213 request.remote_addr = '192.168.1.1'
214 now = 0.0
215
216 with self.assertRaises(ratelimiter.RateLimitExceeded):
217 for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
218 self.ratelimiter.CheckStart(request, now)
219 # Vary remote address to make sure the limit covers
220 # the whole country, regardless of IP.
221 request.remote_addr = '192.168.1.%d' % m
222 now = now + 0.001
223
224 # CheckStart for a country that isn't covered by a country-specific limit.
225 request.headers['X-AppEngine-Country'] = 'UK'
226 for m in range(11):
227 self.ratelimiter.CheckStart(request, now)
228 # Vary remote address to make sure the limit covers
229 # the whole country, regardless of IP.
230 request.remote_addr = '192.168.1.%d' % m
231 now = now + 0.001
232
233 # And regular rate limits work per-IP.
234 request.remote_addr = '192.168.1.1'
235 with self.assertRaises(ratelimiter.RateLimitExceeded):
236 for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
237 self.ratelimiter.CheckStart(request, now)
238 # Vary remote address to make sure the limit covers
239 # the whole country, regardless of IP.
240 now = now + 0.001
241
242 def testCheckEnd_SlowRequest(self):
243 """We count one request for each 1000ms."""
244 request, _ = testing_helpers.GetRequestObjects(
245 project=self.project)
246 request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
247 request.remote_addr = '192.168.1.1'
248 start_time = 0.0
249
250 # Send some requests, all under the limit.
251 for _ in range(ratelimiter.DEFAULT_LIMIT-1):
252 start_time = start_time + 0.001
253 self.ratelimiter.CheckStart(request, start_time)
254 now = start_time + 0.010
255 self.ratelimiter.CheckEnd(request, now, start_time)
256
257 # Now issue some more request, this time taking long
258 # enough to get the cost threshold penalty.
259 # Fast forward enough to impact a later bucket than the
260 # previous requests.
261 start_time = now + 120.0
262 self.ratelimiter.CheckStart(request, start_time)
263
264 # Take longer than the threshold to process the request.
265 elapsed_ms = settings.ratelimiting_ms_per_count * 2
266 now = start_time + elapsed_ms / 1000
267
268 # The request finished, taking long enough to count as two.
269 self.ratelimiter.CheckEnd(request, now, start_time)
270
271 with self.assertRaises(ratelimiter.RateLimitExceeded):
272 # One more request after the expensive query should
273 # throw an excpetion.
274 self.ratelimiter.CheckStart(request, start_time)
275
276 def testCheckEnd_FastRequest(self):
277 request, _ = testing_helpers.GetRequestObjects(
278 project=self.project)
279 request.headers[ratelimiter.COUNTRY_HEADER] = 'asdasd'
280 request.remote_addr = '192.168.1.1'
281 start_time = 0.0
282
283 # Send some requests, all under the limit.
284 for _ in range(ratelimiter.DEFAULT_LIMIT):
285 self.ratelimiter.CheckStart(request, start_time)
286 now = start_time + 0.01
287 self.ratelimiter.CheckEnd(request, now, start_time)
288 start_time = now + 0.01
289
290
291class ApiRateLimiterTest(unittest.TestCase):
292
293 def setUp(self):
294 settings.ratelimiting_enabled = True
295 self.testbed = testbed.Testbed()
296 self.testbed.activate()
297 self.testbed.init_memcache_stub()
298
299 self.services = service_manager.Services(
300 config=fake.ConfigService(),
301 issue=fake.IssueService(),
302 user=fake.UserService(),
303 project=fake.ProjectService(),
304 )
305
306 self.client_id = '123456789'
307 self.client_email = 'test@example.com'
308
309 self.ratelimiter = ratelimiter.ApiRateLimiter()
310 settings.api_ratelimiting_enabled = True
311
312 def tearDown(self):
313 self.testbed.deactivate()
314
315 def testCheckStart_Allowed(self):
316 now = 0.0
317 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
318 self.ratelimiter.CheckStart(self.client_id, None, now)
319 self.ratelimiter.CheckStart(None, None, now)
320 self.ratelimiter.CheckStart('anonymous', None, now)
321
322 def testCheckStart_Rejected(self):
323 now = 0.0
324 keysets = ratelimiter._CreateApiCacheKeys(
325 self.client_id, self.client_email, now)
326 values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
327 keyset in keysets]
328 for value in values:
329 memcache.add_multi(value)
330 with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
331 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
332
333 def testCheckStart_Allowed_HigherQPMSpecified(self):
334 """Client goes over the default, but has a higher QPM set."""
335 now = 0.0
336 keysets = ratelimiter._CreateApiCacheKeys(
337 self.client_id, self.client_email, now)
338 qpm_dict = client_config_svc.GetQPMDict()
339 qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM + 10
340 # The client used 1 request more than the default limit in each of the
341 # 5 minutes in our 5 minute sample window, so 5 over to the total.
342 values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
343 keyset in keysets]
344 for value in values:
345 memcache.add_multi(value)
346 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
347 del qpm_dict[self.client_email]
348
349 def testCheckStart_Allowed_LowQPMIgnored(self):
350 """Client specifies a QPM lower than the default and default is used."""
351 now = 0.0
352 keysets = ratelimiter._CreateApiCacheKeys(
353 self.client_id, self.client_email, now)
354 qpm_dict = client_config_svc.GetQPMDict()
355 qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
356 values = [{key: ratelimiter.DEFAULT_API_QPM for key in keyset} for
357 keyset in keysets]
358 for value in values:
359 memcache.add_multi(value)
360 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
361 del qpm_dict[self.client_email]
362
363 def testCheckStart_Rejected_LowQPMIgnored(self):
364 """Client specifies a QPM lower than the default and default is used."""
365 now = 0.0
366 keysets = ratelimiter._CreateApiCacheKeys(
367 self.client_id, self.client_email, now)
368 qpm_dict = client_config_svc.GetQPMDict()
369 qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
370 values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
371 keyset in keysets]
372 for value in values:
373 memcache.add_multi(value)
374 with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
375 self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
376 del qpm_dict[self.client_email]
377
378 def testCheckEnd(self):
379 start_time = 0.0
380 keysets = ratelimiter._CreateApiCacheKeys(
381 self.client_id, self.client_email, start_time)
382
383 now = 0.1
384 self.ratelimiter.CheckEnd(
385 self.client_id, self.client_email, now, start_time)
386 counters = memcache.get_multi(keysets[0])
387 count = sum(counters.values())
388 # No extra cost charged
389 self.assertEqual(0, count)
390
391 elapsed_ms = settings.ratelimiting_ms_per_count * 2
392 now = start_time + elapsed_ms / 1000
393 self.ratelimiter.CheckEnd(
394 self.client_id, self.client_email, now, start_time)
395 counters = memcache.get_multi(keysets[0])
396 count = sum(counters.values())
397 # Extra cost charged
398 self.assertEqual(1, count)