blob: 4e1a0b862a080005096d6c5a63145897e1638557 [file] [log] [blame]
# Copyright 2016 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Request rate limiting implementation.
This is intented to be used for automatic DDoS protection.
"""
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import datetime
import logging
import os
import settings
import time
from infra_libs import ts_mon
from google.appengine.api import memcache
from google.appengine.api.modules import modules
from google.appengine.api import users
from services import client_config_svc
N_MINUTES = 5
EXPIRE_AFTER_SECS = 60 * 60
DEFAULT_LIMIT = 60 * N_MINUTES # 300 page requests in 5 minutes is 1 QPS.
DEFAULT_API_QPM = 1000 # For example, chromiumdash uses ~64 per page, 8s each.
ANON_USER = 'anon'
COUNTRY_HEADER = 'X-AppEngine-Country'
COUNTRY_LIMITS = {
# Two-letter country code: max requests per N_MINUTES
# This limit will apply to all requests coming
# from this country.
# To add a country code, see GAE logs and use the
# appropriate code from https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2
# E.g., 'cn': 300, # Limit to 1 QPS.
}
# Modules not in this list will not have rate limiting applied by this
# class.
MODULE_ALLOWLIST = ['default', 'api']
def _CacheKeys(request, now_sec):
""" Returns an array of arrays. Each array contains strings with
the same prefix and a timestamp suffix, starting with the most
recent and decrementing by 1 minute each time.
"""
now = datetime.datetime.fromtimestamp(now_sec)
country = request.headers.get(COUNTRY_HEADER, 'ZZ')
ip = request.remote_addr
minute_buckets = [now - datetime.timedelta(minutes=m) for m in
range(N_MINUTES)]
user = users.get_current_user()
user_email = user.email() if user else ANON_USER
# <IP, country, user_email> to be rendered into each key prefix.
prefixes = []
# All logged-in users get a per-user rate limit, regardless of IP and country.
if user:
prefixes.append(['ALL', 'ALL', user.email()])
else:
# All anon requests get a per-IP ratelimit.
prefixes.append([ip, 'ALL', 'ALL'])
# All requests from a problematic country get a per-country rate limit,
# regardless of the user (even a non-logged-in one) or IP.
if country in COUNTRY_LIMITS:
prefixes.append(['ALL', country, 'ALL'])
keysets = []
for prefix in prefixes:
keysets.append(['ratelimit-%s-%s' % ('-'.join(prefix),
str(minute_bucket.replace(second=0, microsecond=0)))
for minute_bucket in minute_buckets])
return keysets, country, ip, user_email
def _CreateApiCacheKeys(client_id, client_email, now_sec):
country = os.environ.get('HTTP_X_APPENGINE_COUNTRY')
ip = os.environ.get('REMOTE_ADDR')
now = datetime.datetime.fromtimestamp(now_sec)
minute_buckets = [now - datetime.timedelta(minutes=m) for m in
range(N_MINUTES)]
minute_strs = [str(minute_bucket.replace(second=0, microsecond=0))
for minute_bucket in minute_buckets]
keys = []
if client_id and client_id != 'anonymous':
keys.append(['apiratelimit-%s-%s' % (client_id, minute_str)
for minute_str in minute_strs])
if client_email:
keys.append(['apiratelimit-%s-%s' % (client_email, minute_str)
for minute_str in minute_strs])
else:
keys.append(['apiratelimit-%s-%s' % (ip, minute_str)
for minute_str in minute_strs])
if country in COUNTRY_LIMITS:
keys.append(['apiratelimit-%s-%s' % (country, minute_str)
for minute_str in minute_strs])
return keys
class RateLimiter(object):
blocked_requests = ts_mon.CounterMetric(
'monorail/ratelimiter/blocked_request',
'Count of requests that exceeded the rate limit and were blocked.',
None)
limit_exceeded = ts_mon.CounterMetric(
'monorail/ratelimiter/rate_exceeded',
'Count of requests that exceeded the rate limit.',
None)
cost_thresh_exceeded = ts_mon.CounterMetric(
'monorail/ratelimiter/cost_thresh_exceeded',
'Count of requests that were expensive to process',
None)
checks = ts_mon.CounterMetric(
'monorail/ratelimiter/check',
'Count of checks done, by fail/success type.',
[ts_mon.StringField('type')])
def __init__(self, _cache=memcache, fail_open=True, **_kwargs):
self.fail_open = fail_open
def CheckStart(self, request, now=None):
if (modules.get_current_module_name() not in MODULE_ALLOWLIST or
users.is_current_user_admin()):
return
logging.info('X-AppEngine-Country: %s' %
request.headers.get(COUNTRY_HEADER, 'ZZ'))
if now is None:
now = time.time()
keysets, country, ip, user_email = _CacheKeys(request, now)
# There are either two or three sets of keys in keysets.
# Three if the user's country is in COUNTRY_LIMITS, otherwise two.
self._AuxCheckStart(
keysets, COUNTRY_LIMITS.get(country, DEFAULT_LIMIT),
settings.ratelimiting_enabled,
RateLimitExceeded(country=country, ip=ip, user_email=user_email))
def _AuxCheckStart(self, keysets, limit, ratelimiting_enabled,
exception_obj):
for keys in keysets:
count = 0
try:
counters = memcache.get_multi(keys)
count = sum(counters.values())
self.checks.increment({'type': 'success'})
except Exception as e:
logging.error(e)
if not self.fail_open:
self.checks.increment({'type': 'fail_closed'})
raise exception_obj
self.checks.increment({'type': 'fail_open'})
if count > limit:
# Since webapp2 won't let us return a 429 error code
# <http://tools.ietf.org/html/rfc6585#section-4>, we can't
# monitor rate limit exceeded events with our standard tools.
# We return a 400 with a custom error message to the client,
# and this logging is so we can monitor it internally.
logging.info('%s, %d' % (str(exception_obj), count))
self.limit_exceeded.increment()
if ratelimiting_enabled:
self.blocked_requests.increment()
raise exception_obj
k = keys[0]
# Only update the latest *time* bucket for each prefix (reverse chron).
memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
memcache.incr(k, initial_value=0)
def CheckEnd(self, request, now, start_time):
"""If a request was expensive to process, charge some extra points
against this set of buckets.
We pass in both now and start_time so we can update the buckets
based on keys created from start_time instead of now.
now and start_time are float seconds.
"""
if (modules.get_current_module_name() not in MODULE_ALLOWLIST):
return
elapsed_ms = int((now - start_time) * 1000)
# Would it kill the python lib maintainers to have timedelta.total_ms()?
penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1
if penalty >= 1:
# TODO: Look into caching the keys instead of generating them twice
# for every request. Say, return them from CheckStart so they can
# be passed back in here later.
keysets, country, ip, user_email = _CacheKeys(request, start_time)
self._AuxCheckEnd(
keysets,
'Rate Limit Cost Threshold Exceeded: %s, %s, %s' % (
country, ip, user_email),
penalty)
def _AuxCheckEnd(self, keysets, log_str, penalty):
self.cost_thresh_exceeded.increment()
for keys in keysets:
logging.info(log_str)
# Only update the latest *time* bucket for each prefix (reverse chron).
k = keys[0]
memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
memcache.incr(k, delta=penalty, initial_value=0)
class ApiRateLimiter(RateLimiter):
blocked_requests = ts_mon.CounterMetric(
'monorail/apiratelimiter/blocked_request',
'Count of requests that exceeded the rate limit and were blocked.',
None)
limit_exceeded = ts_mon.CounterMetric(
'monorail/apiratelimiter/rate_exceeded',
'Count of requests that exceeded the rate limit.',
None)
cost_thresh_exceeded = ts_mon.CounterMetric(
'monorail/apiratelimiter/cost_thresh_exceeded',
'Count of requests that were expensive to process',
None)
checks = ts_mon.CounterMetric(
'monorail/apiratelimiter/check',
'Count of checks done, by fail/success type.',
[ts_mon.StringField('type')])
#pylint: disable=arguments-differ
def CheckStart(self, client_id, client_email, now=None):
if now is None:
now = time.time()
keysets = _CreateApiCacheKeys(client_id, client_email, now)
qpm_limit = client_config_svc.GetQPMDict().get(
client_email, DEFAULT_API_QPM)
if qpm_limit < DEFAULT_API_QPM:
qpm_limit = DEFAULT_API_QPM
window_limit = qpm_limit * N_MINUTES
self._AuxCheckStart(
keysets, window_limit,
settings.api_ratelimiting_enabled,
ApiRateLimitExceeded(client_id, client_email))
#pylint: disable=arguments-differ
def CheckEnd(self, client_id, client_email, now, start_time):
elapsed_ms = int((now - start_time) * 1000)
penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1
if penalty >= 1:
keysets = _CreateApiCacheKeys(client_id, client_email, start_time)
self._AuxCheckEnd(
keysets,
'API Rate Limit Cost Threshold Exceeded: %s, %s' % (
client_id, client_email),
penalty)
class RateLimitExceeded(Exception):
def __init__(self, country=None, ip=None, user_email=None, **_kwargs):
self.country = country
self.ip = ip
self.user_email = user_email
message = 'RateLimitExceeded: %s, %s, %s' % (
self.country, self.ip, self.user_email)
super(RateLimitExceeded, self).__init__(message)
class ApiRateLimitExceeded(Exception):
def __init__(self, client_id, client_email):
self.client_id = client_id
self.client_email = client_email
message = 'RateLimitExceeded: %s, %s' % (
self.client_id, self.client_email)
super(ApiRateLimitExceeded, self).__init__(message)