Project import generated by Copybara.
GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/framework/sql.py b/framework/sql.py
new file mode 100644
index 0000000..d99b045
--- /dev/null
+++ b/framework/sql.py
@@ -0,0 +1,1048 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of classes for interacting with tables in SQL."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import random
+import re
+import sys
+import time
+
+from six import string_types
+
+import settings
+
+if not settings.unit_test_mode:
+ import MySQLdb
+
+from framework import exceptions
+from framework import framework_helpers
+
+from infra_libs import ts_mon
+
+from Queue import Queue
+
+
+class ConnectionPool(object):
+ """Manage a set of database connections such that they may be re-used.
+ """
+
+ def __init__(self, poolsize=1):
+ self.poolsize = poolsize
+ self.queues = {}
+
+ def get(self, instance, database):
+ """Retun a database connection, or throw an exception if none can
+ be made.
+ """
+ key = instance + '/' + database
+
+ if not key in self.queues:
+ queue = Queue(self.poolsize)
+ self.queues[key] = queue
+
+ queue = self.queues[key]
+
+ if queue.empty():
+ cnxn = cnxn_ctor(instance, database)
+ else:
+ cnxn = queue.get()
+ # Make sure the connection is still good.
+ cnxn.ping()
+ cnxn.commit()
+
+ return cnxn
+
+ def release(self, cnxn):
+ if not cnxn.pool_key in self.queues:
+ raise BaseException('unknown pool key: %s' % cnxn.pool_key)
+
+ q = self.queues[cnxn.pool_key]
+ if q.full():
+ cnxn.close()
+ else:
+ q.put(cnxn)
+
+
+@framework_helpers.retry(1, delay=1, backoff=2)
+def cnxn_ctor(instance, database):
+ logging.info('About to connect to SQL instance %r db %r', instance, database)
+ if settings.unit_test_mode:
+ raise ValueError('unit tests should not need real database connections')
+ try:
+ if settings.local_mode:
+ start_time = time.time()
+ cnxn = MySQLdb.connect(
+ host='127.0.0.1', port=3306, db=database, user='root', charset='utf8')
+ else:
+ start_time = time.time()
+ cnxn = MySQLdb.connect(
+ unix_socket='/cloudsql/' + instance, db=database, user='root',
+ charset='utf8')
+ duration = int((time.time() - start_time) * 1000)
+ DB_CNXN_LATENCY.add(duration)
+ CONNECTION_COUNT.increment({'success': True})
+ except MySQLdb.OperationalError:
+ CONNECTION_COUNT.increment({'success': False})
+ raise
+ cnxn.pool_key = instance + '/' + database
+ cnxn.is_bad = False
+ return cnxn
+
+
+# One connection pool per database instance (primary, replicas are each an
+# instance). We'll have four connections per instance because we fetch
+# issue comments, stars, spam verdicts and spam verdict history in parallel
+# with promises.
+cnxn_pool = ConnectionPool(settings.db_cnxn_pool_size)
+
+# MonorailConnection maintains a dictionary of connections to SQL databases.
+# Each is identified by an int shard ID.
+# And there is one connection to the primary DB identified by key PRIMARY_CNXN.
+PRIMARY_CNXN = 'primary_cnxn'
+
+# When one replica is temporarily unresponseive, we can use a different one.
+BAD_SHARD_AVOIDANCE_SEC = 45
+
+
+CONNECTION_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/connection_count',
+ 'Count of connections made to the SQL database.',
+ [ts_mon.BooleanField('success')])
+
+DB_CNXN_LATENCY = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_cnxn_latency',
+ 'Time needed to establish a DB connection.',
+ None)
+
+DB_QUERY_LATENCY = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_query_latency',
+ 'Time needed to make a DB query.',
+ [ts_mon.StringField('type')])
+
+DB_COMMIT_LATENCY = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_commit_latency',
+ 'Time needed to make a DB commit.',
+ None)
+
+DB_ROLLBACK_LATENCY = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_rollback_latency',
+ 'Time needed to make a DB rollback.',
+ None)
+
+DB_RETRY_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/db_retry_count',
+ 'Count of queries retried.',
+ None)
+
+DB_QUERY_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/db_query_count',
+ 'Count of queries sent to the DB.',
+ [ts_mon.StringField('type')])
+
+DB_COMMIT_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/db_commit_count',
+ 'Count of commits sent to the DB.',
+ None)
+
+DB_ROLLBACK_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/db_rollback_count',
+ 'Count of rollbacks sent to the DB.',
+ None)
+
+DB_RESULT_ROWS = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_result_rows',
+ 'Number of results returned by a DB query.',
+ None)
+
+
+def RandomShardID():
+ """Return a random shard ID to load balance across replicas."""
+ return random.randint(0, settings.num_logical_shards - 1)
+
+
+class MonorailConnection(object):
+ """Create and manage connections to the SQL servers.
+
+ We only store connections in the context of a single user request, not
+ across user requests. The main purpose of this class is to make using
+ sharded tables easier.
+ """
+ unavailable_shards = {} # {shard_id: timestamp of failed attempt}
+
+ def __init__(self):
+ self.sql_cnxns = {} # {PRIMARY_CNXN: cnxn, shard_id: cnxn, ...}
+
+ @framework_helpers.retry(1, delay=0.1, backoff=2)
+ def GetPrimaryConnection(self):
+ """Return a connection to the primary SQL DB."""
+ if PRIMARY_CNXN not in self.sql_cnxns:
+ self.sql_cnxns[PRIMARY_CNXN] = cnxn_pool.get(
+ settings.db_instance, settings.db_database_name)
+ logging.info(
+ 'created a primary connection %r', self.sql_cnxns[PRIMARY_CNXN])
+
+ return self.sql_cnxns[PRIMARY_CNXN]
+
+ @framework_helpers.retry(1, delay=0.1, backoff=2)
+ def GetConnectionForShard(self, shard_id):
+ """Return a connection to the DB replica that will be used for shard_id."""
+ if shard_id not in self.sql_cnxns:
+ physical_shard_id = shard_id % settings.num_logical_shards
+
+ replica_name = settings.db_replica_names[
+ physical_shard_id % len(settings.db_replica_names)]
+ shard_instance_name = (
+ settings.physical_db_name_format % replica_name)
+ self.unavailable_shards[shard_id] = int(time.time())
+ self.sql_cnxns[shard_id] = cnxn_pool.get(
+ shard_instance_name, settings.db_database_name)
+ del self.unavailable_shards[shard_id]
+ logging.info('created a replica connection for shard %d', shard_id)
+
+ return self.sql_cnxns[shard_id]
+
+ def Execute(self, stmt_str, stmt_args, shard_id=None, commit=True, retries=2):
+ """Execute the given SQL statement on one of the relevant databases."""
+ if shard_id is None:
+ # No shard was specified, so hit the primary.
+ sql_cnxn = self.GetPrimaryConnection()
+ else:
+ if shard_id in self.unavailable_shards:
+ bad_age_sec = int(time.time()) - self.unavailable_shards[shard_id]
+ if bad_age_sec < BAD_SHARD_AVOIDANCE_SEC:
+ logging.info('Avoiding bad replica %r, age %r', shard_id, bad_age_sec)
+ shard_id = (shard_id + 1) % settings.num_logical_shards
+ sql_cnxn = self.GetConnectionForShard(shard_id)
+
+ try:
+ return self._ExecuteWithSQLConnection(
+ sql_cnxn, stmt_str, stmt_args, commit=commit)
+ except MySQLdb.OperationalError as e:
+ logging.exception(e)
+ logging.info('retries: %r', retries)
+ if retries > 0:
+ DB_RETRY_COUNT.increment()
+ self.sql_cnxns = {} # Drop all old mysql connections and make new.
+ return self.Execute(
+ stmt_str, stmt_args, shard_id=shard_id, commit=commit,
+ retries=retries - 1)
+ else:
+ raise e
+
+ def _ExecuteWithSQLConnection(
+ self, sql_cnxn, stmt_str, stmt_args, commit=True):
+ """Execute a statement on the given database and return a cursor."""
+
+ start_time = time.time()
+ cursor = sql_cnxn.cursor()
+ cursor.execute('SET NAMES utf8mb4')
+ if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
+ cursor.executemany(stmt_str, stmt_args)
+ duration = (time.time() - start_time) * 1000
+ DB_QUERY_LATENCY.add(duration, {'type': 'write'})
+ DB_QUERY_COUNT.increment({'type': 'write'})
+ else:
+ cursor.execute(stmt_str, args=stmt_args)
+ duration = (time.time() - start_time) * 1000
+ DB_QUERY_LATENCY.add(duration, {'type': 'read'})
+ DB_QUERY_COUNT.increment({'type': 'read'})
+ DB_RESULT_ROWS.add(cursor.rowcount)
+
+ if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
+ formatted_statement = '%s %s' % (stmt_str, stmt_args)
+ else:
+ formatted_statement = stmt_str % tuple(stmt_args)
+ logging.info(
+ '%d rows in %d ms: %s', cursor.rowcount, int(duration),
+ formatted_statement.replace('\n', ' '))
+
+ if commit and not stmt_str.startswith('SELECT'):
+ try:
+ sql_cnxn.commit()
+ duration = (time.time() - start_time) * 1000
+ DB_COMMIT_LATENCY.add(duration)
+ DB_COMMIT_COUNT.increment()
+ except MySQLdb.DatabaseError:
+ sql_cnxn.rollback()
+ duration = (time.time() - start_time) * 1000
+ DB_ROLLBACK_LATENCY.add(duration)
+ DB_ROLLBACK_COUNT.increment()
+
+ return cursor
+
+ def Commit(self):
+ """Explicitly commit any pending txns. Normally done automatically."""
+ sql_cnxn = self.GetPrimaryConnection()
+ try:
+ sql_cnxn.commit()
+ except MySQLdb.DatabaseError:
+ logging.exception('Commit failed for cnxn, rolling back')
+ sql_cnxn.rollback()
+
+ def Close(self):
+ """Safely close any connections that are still open."""
+ for sql_cnxn in self.sql_cnxns.values():
+ try:
+ sql_cnxn.rollback() # Abandon any uncommitted changes.
+ cnxn_pool.release(sql_cnxn)
+ except MySQLdb.DatabaseError:
+ # This might happen if the cnxn is somehow already closed.
+ logging.exception('ProgrammingError when trying to close cnxn')
+
+
+class SQLTableManager(object):
+ """Helper class to make it easier to deal with an SQL table."""
+
+ def __init__(self, table_name):
+ self.table_name = table_name
+
+ def Select(
+ self, cnxn, distinct=False, cols=None, left_joins=None,
+ joins=None, where=None, or_where_conds=False, group_by=None,
+ order_by=None, limit=None, offset=None, shard_id=None, use_clause=None,
+ having=None, **kwargs):
+ """Compose and execute an SQL SELECT statement on this table.
+
+ Args:
+ cnxn: MonorailConnection to the databases.
+ distinct: If True, add DISTINCT keyword.
+ cols: List of columns to retrieve, defaults to '*'.
+ left_joins: List of LEFT JOIN (str, args) pairs.
+ joins: List of regular JOIN (str, args) pairs.
+ where: List of (str, args) for WHERE clause.
+ or_where_conds: Set to True to use OR in the WHERE conds.
+ group_by: List of strings for GROUP BY clause.
+ order_by: List of (str, args) for ORDER BY clause.
+ limit: Optional LIMIT on the number of rows returned.
+ offset: Optional OFFSET when using LIMIT.
+ shard_id: Int ID of the shard to query.
+ use_clause: Optional string USE clause to tell the DB which index to use.
+ having: List of (str, args) for Optional HAVING clause
+ **kwargs: WHERE-clause equality and set-membership conditions.
+
+ Keyword args are used to build up more WHERE conditions that compare
+ column values to constants. Key word Argument foo='bar' translates to 'foo
+ = "bar"', and foo=[3, 4, 5] translates to 'foo IN (3, 4, 5)'.
+
+ Returns:
+ A list of rows, each row is a tuple of values for the requested cols.
+ """
+ cols = cols or ['*'] # If columns not specified, retrieve all columns.
+ stmt = Statement.MakeSelect(
+ self.table_name, cols, distinct=distinct,
+ or_where_conds=or_where_conds)
+ if use_clause:
+ stmt.AddUseClause(use_clause)
+ if having:
+ stmt.AddHavingTerms(having)
+ stmt.AddJoinClauses(left_joins or [], left=True)
+ stmt.AddJoinClauses(joins or [])
+ stmt.AddWhereTerms(where or [], **kwargs)
+ stmt.AddGroupByTerms(group_by or [])
+ stmt.AddOrderByTerms(order_by or [])
+ stmt.SetLimitAndOffset(limit, offset)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args, shard_id=shard_id)
+ rows = cursor.fetchall()
+ cursor.close()
+ return rows
+
+ def SelectRow(
+ self, cnxn, cols=None, default=None, where=None, **kwargs):
+ """Run a query that is expected to return just one row."""
+ rows = self.Select(cnxn, distinct=True, cols=cols, where=where, **kwargs)
+ if len(rows) == 1:
+ return rows[0]
+ elif not rows:
+ logging.info('SelectRow got 0 results, so using default %r', default)
+ return default
+ else:
+ raise ValueError('SelectRow got %d results, expected only 1', len(rows))
+
+ def SelectValue(self, cnxn, col, default=None, where=None, **kwargs):
+ """Run a query that is expected to return just one row w/ one value."""
+ row = self.SelectRow(
+ cnxn, cols=[col], default=[default], where=where, **kwargs)
+ return row[0]
+
+ def InsertRows(
+ self, cnxn, cols, row_values, replace=False, ignore=False,
+ commit=True, return_generated_ids=False):
+ """Insert all the given rows.
+
+ Args:
+ cnxn: MonorailConnection object.
+ cols: List of column names to set.
+ row_values: List of lists with values to store. The length of each
+ nested list should be equal to len(cols).
+ replace: Set to True if inserted values should replace existing DB rows
+ that have the same DB keys.
+ ignore: Set to True to ignore rows that would duplicate existing DB keys.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ return_generated_ids: Set to True to return a list of generated
+ autoincrement IDs for inserted rows. This requires us to insert rows
+ one at a time.
+
+ Returns:
+ If return_generated_ids is set to True, this method returns a list of the
+ auto-increment IDs generated by the DB. Otherwise, [] is returned.
+ """
+ if not row_values:
+ return None # Nothing to insert
+
+ generated_ids = []
+ if return_generated_ids:
+ # We must insert the rows one-at-a-time to know the generated IDs.
+ for row_value in row_values:
+ stmt = Statement.MakeInsert(
+ self.table_name, cols, [row_value], replace=replace, ignore=ignore)
+ stmt_str, stmt_args = stmt.Generate()
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ if cursor.lastrowid:
+ generated_ids.append(cursor.lastrowid)
+ cursor.close()
+ return generated_ids
+
+ stmt = Statement.MakeInsert(
+ self.table_name, cols, row_values, replace=replace, ignore=ignore)
+ stmt_str, stmt_args = stmt.Generate()
+ cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ return []
+
+
+ def InsertRow(
+ self, cnxn, replace=False, ignore=False, commit=True, **kwargs):
+ """Insert a single row into the table.
+
+ Args:
+ cnxn: MonorailConnection object.
+ replace: Set to True if inserted values should replace existing DB rows
+ that have the same DB keys.
+ ignore: Set to True to ignore rows that would duplicate existing DB keys.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ **kwargs: column=value assignments to specify what to store in the DB.
+
+ Returns:
+ The generated autoincrement ID of the key column if one was generated.
+ Otherwise, return None.
+ """
+ cols = sorted(kwargs.keys())
+ row = tuple(kwargs[col] for col in cols)
+ generated_ids = self.InsertRows(
+ cnxn, cols, [row], replace=replace, ignore=ignore,
+ commit=commit, return_generated_ids=True)
+ if generated_ids:
+ return generated_ids[0]
+ else:
+ return None
+
+ def Update(self, cnxn, delta, where=None, commit=True, limit=None, **kwargs):
+ """Update one or more rows.
+
+ Args:
+ cnxn: MonorailConnection object.
+ delta: Dictionary of {column: new_value} assignments.
+ where: Optional list of WHERE conditions saying which rows to update.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ limit: Optional LIMIT on the number of rows updated.
+ **kwargs: WHERE-clause equality and set-membership conditions.
+
+ Returns:
+ Int number of rows updated.
+ """
+ if not delta:
+ return 0 # Nothing is being changed
+
+ stmt = Statement.MakeUpdate(self.table_name, delta)
+ stmt.AddWhereTerms(where, **kwargs)
+ stmt.SetLimitAndOffset(limit, None)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ result = cursor.rowcount
+ cursor.close()
+ return result
+
+ def IncrementCounterValue(self, cnxn, col_name, where=None, **kwargs):
+ """Atomically increment a counter stored in MySQL, return new value.
+
+ Args:
+ cnxn: MonorailConnection object.
+ col_name: int column to increment.
+ where: Optional list of WHERE conditions saying which rows to update.
+ **kwargs: WHERE-clause equality and set-membership conditions. The
+ where and kwargs together should narrow the update down to exactly
+ one row.
+
+ Returns:
+ The new, post-increment value of the counter.
+ """
+ stmt = Statement.MakeIncrement(self.table_name, col_name)
+ stmt.AddWhereTerms(where, **kwargs)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args)
+ assert cursor.rowcount == 1, (
+ 'missing or ambiguous counter: %r' % cursor.rowcount)
+ result = cursor.lastrowid
+ cursor.close()
+ return result
+
+ def Delete(self, cnxn, where=None, or_where_conds=False, commit=True,
+ limit=None, **kwargs):
+ """Delete the specified table rows.
+
+ Args:
+ cnxn: MonorailConnection object.
+ where: Optional list of WHERE conditions saying which rows to update.
+ or_where_conds: Set to True to use OR in the WHERE conds.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ limit: Optional LIMIT on the number of rows deleted.
+ **kwargs: WHERE-clause equality and set-membership conditions.
+
+ Returns:
+ Int number of rows updated.
+ """
+ # Deleting the whole table is never intended in Monorail.
+ assert where or kwargs
+
+ stmt = Statement.MakeDelete(self.table_name, or_where_conds=or_where_conds)
+ stmt.AddWhereTerms(where, **kwargs)
+ stmt.SetLimitAndOffset(limit, None)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ result = cursor.rowcount
+ cursor.close()
+ return result
+
+
+class Statement(object):
+ """A class to help build complex SQL statements w/ full escaping.
+
+ Start with a Make*() method, then fill in additional clauses as needed,
+ then call Generate() to return the SQL string and argument list. We pass
+ the string and args to MySQLdb separately so that it can do escaping on
+ the arg values as appropriate to prevent SQL-injection attacks.
+
+ The only values that are not escaped by MySQLdb are the table names
+ and column names, and bits of SQL syntax, all of which is hard-coded
+ in our application.
+ """
+
+ @classmethod
+ def MakeSelect(cls, table_name, cols, distinct=False, or_where_conds=False):
+ """Construct a SELECT statement."""
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in cols)
+ main_clause = 'SELECT%s %s FROM %s' % (
+ (' DISTINCT' if distinct else ''), ', '.join(cols), table_name)
+ return cls(main_clause, or_where_conds=or_where_conds)
+
+ @classmethod
+ def MakeInsert(
+ cls, table_name, cols, new_values, replace=False, ignore=False):
+ """Construct an INSERT statement."""
+ if replace == True:
+ return cls.MakeReplace(table_name, cols, new_values, ignore)
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in cols)
+ ignore_word = ' IGNORE' if ignore else ''
+ main_clause = 'INSERT%s INTO %s (%s)' % (
+ ignore_word, table_name, ', '.join(cols))
+ return cls(main_clause, insert_args=new_values)
+
+ @classmethod
+ def MakeReplace(
+ cls, table_name, cols, new_values, ignore=False):
+ """Construct an INSERT...ON DUPLICATE KEY UPDATE... statement.
+
+ Uses the INSERT/UPDATE syntax because REPLACE is literally a DELETE
+ followed by an INSERT, which doesn't play well with foreign keys.
+ INSERT/UPDATE is an atomic check of whether the primary key exists,
+ followed by an INSERT if it doesn't or an UPDATE if it does.
+ """
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in cols)
+ ignore_word = ' IGNORE' if ignore else ''
+ main_clause = 'INSERT%s INTO %s (%s)' % (
+ ignore_word, table_name, ', '.join(cols))
+ return cls(main_clause, insert_args=new_values, duplicate_update_cols=cols)
+
+ @classmethod
+ def MakeUpdate(cls, table_name, delta):
+ """Construct an UPDATE statement."""
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in delta.keys())
+ update_strs = []
+ update_args = []
+ for col, val in delta.items():
+ update_strs.append(col + '=%s')
+ update_args.append(val)
+
+ main_clause = 'UPDATE %s SET %s' % (
+ table_name, ', '.join(update_strs))
+ return cls(main_clause, update_args=update_args)
+
+ @classmethod
+ def MakeIncrement(cls, table_name, col_name, step=1):
+ """Construct an UPDATE statement that increments and returns a counter."""
+ assert _IsValidTableName(table_name)
+ assert _IsValidColumnName(col_name)
+
+ main_clause = (
+ 'UPDATE %s SET %s = LAST_INSERT_ID(%s + %%s)' % (
+ table_name, col_name, col_name))
+ update_args = [step]
+ return cls(main_clause, update_args=update_args)
+
+ @classmethod
+ def MakeDelete(cls, table_name, or_where_conds=False):
+ """Construct a DELETE statement."""
+ assert _IsValidTableName(table_name)
+ main_clause = 'DELETE FROM %s' % table_name
+ return cls(main_clause, or_where_conds=or_where_conds)
+
+ def __init__(
+ self, main_clause, insert_args=None, update_args=None,
+ duplicate_update_cols=None, or_where_conds=False):
+ self.main_clause = main_clause # E.g., SELECT or DELETE
+ self.or_where_conds = or_where_conds
+ self.insert_args = insert_args or [] # For INSERT statements
+ for row_value in self.insert_args:
+ if not all(_IsValidDBValue(val) for val in row_value):
+ raise exceptions.InputException('Invalid DB value %r' % (row_value,))
+ self.update_args = update_args or [] # For UPDATEs
+ for val in self.update_args:
+ if not _IsValidDBValue(val):
+ raise exceptions.InputException('Invalid DB value %r' % val)
+ self.duplicate_update_cols = duplicate_update_cols or [] # For REPLACE-ish
+
+ self.use_clauses = []
+ self.join_clauses, self.join_args = [], []
+ self.where_conds, self.where_args = [], []
+ self.having_conds, self.having_args = [], []
+ self.group_by_terms, self.group_by_args = [], []
+ self.order_by_terms, self.order_by_args = [], []
+ self.limit, self.offset = None, None
+
+ def Generate(self):
+ """Return an SQL string having %s placeholders and args to fill them in."""
+ clauses = [self.main_clause] + self.use_clauses + self.join_clauses
+ if self.where_conds:
+ if self.or_where_conds:
+ clauses.append('WHERE ' + '\n OR '.join(self.where_conds))
+ else:
+ clauses.append('WHERE ' + '\n AND '.join(self.where_conds))
+ if self.group_by_terms:
+ clauses.append('GROUP BY ' + ', '.join(self.group_by_terms))
+ if self.having_conds:
+ assert self.group_by_terms
+ clauses.append('HAVING %s' % ','.join(self.having_conds))
+ if self.order_by_terms:
+ clauses.append('ORDER BY ' + ', '.join(self.order_by_terms))
+
+ if self.limit and self.offset:
+ clauses.append('LIMIT %d OFFSET %d' % (self.limit, self.offset))
+ elif self.limit:
+ clauses.append('LIMIT %d' % self.limit)
+ elif self.offset:
+ clauses.append('LIMIT %d OFFSET %d' % (sys.maxint, self.offset))
+
+ if self.insert_args:
+ clauses.append('VALUES (' + PlaceHolders(self.insert_args[0]) + ')')
+ args = self.insert_args
+ if self.duplicate_update_cols:
+ clauses.append('ON DUPLICATE KEY UPDATE %s' % (
+ ', '.join(['%s=VALUES(%s)' % (col, col)
+ for col in self.duplicate_update_cols])))
+ assert not (self.join_args + self.update_args + self.where_args +
+ self.group_by_args + self.order_by_args + self.having_args)
+ else:
+ args = (self.join_args + self.update_args + self.where_args +
+ self.group_by_args + self.having_args + self.order_by_args)
+ assert not (self.insert_args + self.duplicate_update_cols)
+
+ args = _BoolsToInts(args)
+ stmt_str = '\n'.join(clause for clause in clauses if clause)
+
+ assert _IsValidStatement(stmt_str), stmt_str
+ return stmt_str, args
+
+ def AddUseClause(self, use_clause):
+ """Add a USE clause (giving the DB a hint about which indexes to use)."""
+ assert _IsValidUseClause(use_clause), use_clause
+ self.use_clauses.append(use_clause)
+
+ def AddJoinClauses(self, join_pairs, left=False):
+ """Save JOIN clauses based on the given list of join conditions."""
+ for join, args in join_pairs:
+ assert _IsValidJoin(join), join
+ assert join.count('%s') == len(args), join
+ self.join_clauses.append(
+ ' %sJOIN %s' % (('LEFT ' if left else ''), join))
+ self.join_args.extend(args)
+
+ def AddGroupByTerms(self, group_by_term_list):
+ """Save info needed to generate the GROUP BY clause."""
+ assert all(_IsValidGroupByTerm(term) for term in group_by_term_list)
+ self.group_by_terms.extend(group_by_term_list)
+
+ def AddOrderByTerms(self, order_by_pairs):
+ """Save info needed to generate the ORDER BY clause."""
+ for term, args in order_by_pairs:
+ assert _IsValidOrderByTerm(term), term
+ assert term.count('%s') == len(args), term
+ self.order_by_terms.append(term)
+ self.order_by_args.extend(args)
+
+ def SetLimitAndOffset(self, limit, offset):
+ """Save info needed to generate the LIMIT OFFSET clause."""
+ self.limit = limit
+ self.offset = offset
+
+ def AddWhereTerms(self, where_cond_pairs, **kwargs):
+ """Generate a WHERE clause."""
+ where_cond_pairs = where_cond_pairs or []
+
+ for cond, args in where_cond_pairs:
+ assert _IsValidWhereCond(cond), cond
+ assert cond.count('%s') == len(args), cond
+ self.where_conds.append(cond)
+ self.where_args.extend(args)
+
+ for col, val in sorted(kwargs.items()):
+ assert _IsValidColumnName(col), col
+ eq = True
+ if col.endswith('_not'):
+ col = col[:-4]
+ eq = False
+
+ if isinstance(val, set):
+ val = list(val) # MySQL inteface cannot handle sets.
+
+ if val is None or val == []:
+ if val == [] and self.main_clause and self.main_clause.startswith(
+ 'UPDATE'):
+ # https://crbug.com/monorail/6735: Avoid empty arrays for UPDATE.
+ raise exceptions.InputException('Invalid update DB value %r' % col)
+ op = 'IS' if eq else 'IS NOT'
+ self.where_conds.append(col + ' ' + op + ' NULL')
+ elif isinstance(val, list):
+ op = 'IN' if eq else 'NOT IN'
+ # Sadly, MySQLdb cannot escape lists, so we flatten to multiple "%s"s
+ self.where_conds.append(
+ col + ' ' + op + ' (' + PlaceHolders(val) + ')')
+ self.where_args.extend(val)
+ else:
+ op = '=' if eq else '!='
+ self.where_conds.append(col + ' ' + op + ' %s')
+ self.where_args.append(val)
+
+ def AddHavingTerms(self, having_cond_pairs):
+ """Generate a HAVING clause."""
+ for cond, args in having_cond_pairs:
+ assert _IsValidHavingCond(cond), cond
+ assert cond.count('%s') == len(args), cond
+ self.having_conds.append(cond)
+ self.having_args.extend(args)
+
+
+def PlaceHolders(sql_args):
+ """Return a comma-separated list of %s placeholders for the given args."""
+ return ','.join('%s' for _ in sql_args)
+
+
+TABLE_PAT = '[A-Z][_a-zA-Z0-9]+'
+COLUMN_PAT = '[a-z][_a-z]+'
+COMPARE_OP_PAT = '(<|>|=|!=|>=|<=|LIKE|NOT LIKE)'
+SHORTHAND = {
+ 'table': TABLE_PAT,
+ 'column': COLUMN_PAT,
+ 'tab_col': r'(%s\.)?%s' % (TABLE_PAT, COLUMN_PAT),
+ 'placeholder': '%s', # That's a literal %s that gets passed to MySQLdb
+ 'multi_placeholder': '%s(, ?%s)*',
+ 'compare_op': COMPARE_OP_PAT,
+ 'opt_asc_desc': '( ASC| DESC)?',
+ 'opt_alias': '( AS %s)?' % TABLE_PAT,
+ 'email_cond': (r'\(?'
+ r'('
+ r'(LOWER\(Spare\d+\.email\) IS NULL OR )?'
+ r'LOWER\(Spare\d+\.email\) '
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))'
+ r'( (AND|OR) )?'
+ r')+'
+ r'\)?' % COMPARE_OP_PAT),
+ 'hotlist_cond': (r'\(?'
+ r'('
+ r'(LOWER\(Cond\d+\.name\) IS NULL OR )?'
+ r'LOWER\(Cond\d+\.name\) '
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))'
+ r'( (AND|OR) )?'
+ r')+'
+ r'\)?' % COMPARE_OP_PAT),
+ 'phase_cond': (r'\(?'
+ r'('
+ r'(LOWER\(Phase\d+\.name\) IS NULL OR )?'
+ r'LOWER\(Phase\d+\.name\) '
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))?'
+ r'( (AND|OR) )?'
+ r')+'
+ r'\)?' % COMPARE_OP_PAT),
+ 'approval_cond': (r'\(?'
+ r'('
+ r'(LOWER\(Cond\d+\.status\) IS NULL OR )?'
+ r'LOWER\(Cond\d+\.status\) '
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))'
+ r'( (AND|OR) )?'
+ r')+'
+ r'\)?' % COMPARE_OP_PAT),
+ }
+
+
+def _MakeRE(regex_str):
+ """Return a regular expression object, expanding our shorthand as needed."""
+ return re.compile(regex_str.format(**SHORTHAND))
+
+
+TABLE_RE = _MakeRE('^{table}$')
+TAB_COL_RE = _MakeRE('^{tab_col}$')
+USE_CLAUSE_RE = _MakeRE(
+ r'^USE INDEX \({column}\) USE INDEX FOR ORDER BY \({column}\)$')
+HAVING_RE_LIST = [
+ _MakeRE(r'^COUNT\(\*\) {compare_op} {placeholder}$')]
+COLUMN_RE_LIST = [
+ TAB_COL_RE,
+ _MakeRE(r'\*'),
+ _MakeRE(r'COUNT\(\*\)'),
+ _MakeRE(r'COUNT\({tab_col}\)'),
+ _MakeRE(r'COUNT\(DISTINCT\({tab_col}\)\)'),
+ _MakeRE(r'MAX\({tab_col}\)'),
+ _MakeRE(r'MIN\({tab_col}\)'),
+ _MakeRE(r'GROUP_CONCAT\((DISTINCT )?{tab_col}( ORDER BY {tab_col})?' \
+ r'( SEPARATOR \'.*\')?\)'),
+ ]
+JOIN_RE_LIST = [
+ TABLE_RE,
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r'( AND {tab_col} IN \({multi_placeholder}\))?'
+ r'( AND {tab_col} = {tab_col})?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r'( AND {tab_col} IN \({multi_placeholder}\))?'
+ r'( AND {tab_col} IS NULL)?'
+ r'( AND \({tab_col} IS NULL'
+ r' OR {tab_col} NOT IN \({multi_placeholder}\)\))?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r' AND \(?{tab_col} {compare_op} {placeholder}\)?'
+ r'( AND {tab_col} = {tab_col})?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r' AND {tab_col} = {tab_col}$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r' AND \({tab_col} IS NULL OR'
+ r' {tab_col} {compare_op} {placeholder}\)$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r' AND \({tab_col} IS NOT NULL AND {tab_col} != {placeholder}\)'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r' AND LOWER\({tab_col}\) = LOWER\({placeholder}\)'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col} AND {email_cond}$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {email_cond}$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON '
+ r'\({tab_col} = {tab_col} OR {tab_col} = {tab_col}\)$'),
+ _MakeRE(
+ r'^\({table} AS {table} JOIN User AS {table} '
+ r'ON {tab_col} = {tab_col} AND {email_cond}\) '
+ r'ON Issue(Snapshot)?.id = {tab_col}'
+ r'( AND {tab_col} IS NULL)?'),
+ _MakeRE(
+ r'^\({table} JOIN Hotlist AS {table} '
+ r'ON {tab_col} = {tab_col} AND {hotlist_cond}\) '
+ r'ON Issue.id = {tab_col}?'),
+ _MakeRE(
+ r'^\({table} AS {table} JOIN IssuePhaseDef AS {table} '
+ r'ON {tab_col} = {tab_col} AND {phase_cond}\) '
+ r'ON Issue.id = {tab_col}?'),
+ _MakeRE(
+ r'^IssuePhaseDef AS {table} ON {phase_cond}'),
+ _MakeRE(
+ r'^Issue2ApprovalValue AS {table} ON {tab_col} = {tab_col} '
+ r'AND {tab_col} = {placeholder} AND {approval_cond}'),
+ _MakeRE(
+ r'^{table} AS {table} ON {tab_col} = {tab_col} '
+ r'LEFT JOIN {table} AS {table} ON {tab_col} = {tab_col}'),
+ ]
+ORDER_BY_RE_LIST = [
+ _MakeRE(r'^{tab_col}{opt_asc_desc}$'),
+ _MakeRE(r'^LOWER\({tab_col}\){opt_asc_desc}$'),
+ _MakeRE(r'^ISNULL\({tab_col}\){opt_asc_desc}$'),
+ _MakeRE(r'^\(ISNULL\({tab_col}\) AND ISNULL\({tab_col}\)\){opt_asc_desc}$'),
+ _MakeRE(r'^FIELD\({tab_col}, {multi_placeholder}\){opt_asc_desc}$'),
+ _MakeRE(r'^FIELD\(IF\(ISNULL\({tab_col}\), {tab_col}, {tab_col}\), '
+ r'{multi_placeholder}\){opt_asc_desc}$'),
+ _MakeRE(r'^CONCAT\({tab_col}, {tab_col}\){opt_asc_desc}$'),
+ ]
+GROUP_BY_RE_LIST = [
+ TAB_COL_RE,
+ ]
+WHERE_COND_RE_LIST = [
+ _MakeRE(r'^TRUE$'),
+ _MakeRE(r'^FALSE$'),
+ _MakeRE(r'^{tab_col} IS NULL$'),
+ _MakeRE(r'^{tab_col} IS NOT NULL$'),
+ _MakeRE(r'^{tab_col} {compare_op} {tab_col}$'),
+ _MakeRE(r'^{tab_col} {compare_op} {placeholder}$'),
+ _MakeRE(r'^{tab_col} %% {placeholder} = {placeholder}$'),
+ _MakeRE(r'^{tab_col} IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^{tab_col} NOT IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^LOWER\({tab_col}\) IS NULL$'),
+ _MakeRE(r'^LOWER\({tab_col}\) IS NOT NULL$'),
+ _MakeRE(r'^LOWER\({tab_col}\) {compare_op} {placeholder}$'),
+ _MakeRE(r'^LOWER\({tab_col}\) IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^LOWER\({tab_col}\) NOT IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^LOWER\({tab_col}\) LIKE {placeholder}$'),
+ _MakeRE(r'^LOWER\({tab_col}\) NOT LIKE {placeholder}$'),
+ _MakeRE(r'^timestep < \(SELECT MAX\(j.timestep\) FROM Invalidate AS j '
+ r'WHERE j.kind = %s '
+ r'AND j.cache_key = Invalidate.cache_key\)$'),
+ _MakeRE(r'^\({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}\) '
+ 'AND \({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}'
+ '\)$'),
+ _MakeRE(r'^\({tab_col} IS NOT NULL AND {tab_col} {compare_op} '
+ '{placeholder}\) OR \({tab_col} IS NOT NULL AND {tab_col} '
+ '{compare_op} {placeholder}\)$'),
+ ]
+
+# Note: We never use ';' for multiple statements, '@' for SQL variables, or
+# any quoted strings in stmt_str (quotes are put in my MySQLdb for args).
+STMT_STR_RE = re.compile(
+ r'\A(SELECT|UPDATE|DELETE|INSERT|REPLACE) [\'-+=!<>%*.,()\w\s]+\Z',
+ re.MULTILINE)
+
+
+def _IsValidDBValue(val):
+ if isinstance(val, string_types):
+ return '\x00' not in val
+ return True
+
+
+def _IsValidTableName(table_name):
+ return TABLE_RE.match(table_name)
+
+
+def _IsValidColumnName(column_expr):
+ return any(regex.match(column_expr) for regex in COLUMN_RE_LIST)
+
+
+def _IsValidUseClause(use_clause):
+ return USE_CLAUSE_RE.match(use_clause)
+
+def _IsValidHavingCond(cond):
+ if cond.startswith('(') and cond.endswith(')'):
+ cond = cond[1:-1]
+
+ if ' OR ' in cond:
+ return all(_IsValidHavingCond(c) for c in cond.split(' OR '))
+
+ if ' AND ' in cond:
+ return all(_IsValidHavingCond(c) for c in cond.split(' AND '))
+
+ return any(regex.match(cond) for regex in HAVING_RE_LIST)
+
+
+def _IsValidJoin(join):
+ return any(regex.match(join) for regex in JOIN_RE_LIST)
+
+
+def _IsValidOrderByTerm(term):
+ return any(regex.match(term) for regex in ORDER_BY_RE_LIST)
+
+
+def _IsValidGroupByTerm(term):
+ return any(regex.match(term) for regex in GROUP_BY_RE_LIST)
+
+
+def _IsValidWhereCond(cond):
+ if cond.startswith('NOT '):
+ cond = cond[4:]
+ if cond.startswith('(') and cond.endswith(')'):
+ cond = cond[1:-1]
+
+ if any(regex.match(cond) for regex in WHERE_COND_RE_LIST):
+ return True
+
+ if ' OR ' in cond:
+ return all(_IsValidWhereCond(c) for c in cond.split(' OR '))
+
+ if ' AND ' in cond:
+ return all(_IsValidWhereCond(c) for c in cond.split(' AND '))
+
+ return False
+
+
+def _IsValidStatement(stmt_str):
+ """Final check to make sure there is no funny junk sneaking in somehow."""
+ return (STMT_STR_RE.match(stmt_str) and
+ '--' not in stmt_str)
+
+
+def _BoolsToInts(arg_list):
+ """Convert any True values to 1s and Falses to 0s.
+
+ Google's copy of MySQLdb has bool-to-int conversion disabled,
+ and yet it seems to be needed otherwise they are converted
+ to strings and always interpreted as 0 (which is FALSE).
+
+ Args:
+ arg_list: (nested) list of SQL statment argument values, which may
+ include some boolean values.
+
+ Returns:
+ The same list, but with True replaced by 1 and False replaced by 0.
+ """
+ result = []
+ for arg in arg_list:
+ if isinstance(arg, (list, tuple)):
+ result.append(_BoolsToInts(arg))
+ elif arg is True:
+ result.append(1)
+ elif arg is False:
+ result.append(0)
+ else:
+ result.append(arg)
+
+ return result