blob: 1d7573d56f28cfcfdd1b9090c0f1591c50865fc1 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2016 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""A set of classes for interacting with tables in SQL."""
7from __future__ import print_function
8from __future__ import division
9from __future__ import absolute_import
10
11import logging
12import random
13import re
14import sys
15import time
16
17from six import string_types
18
19import settings
20
21if not settings.unit_test_mode:
22 import MySQLdb
23
24from framework import exceptions
25from framework import framework_helpers
26
27from infra_libs import ts_mon
28
29from Queue import Queue
30
31
32class ConnectionPool(object):
33 """Manage a set of database connections such that they may be re-used.
34 """
35
36 def __init__(self, poolsize=1):
37 self.poolsize = poolsize
38 self.queues = {}
39
40 def get(self, instance, database):
41 """Retun a database connection, or throw an exception if none can
42 be made.
43 """
44 key = instance + '/' + database
45
46 if not key in self.queues:
47 queue = Queue(self.poolsize)
48 self.queues[key] = queue
49
50 queue = self.queues[key]
51
52 if queue.empty():
53 cnxn = cnxn_ctor(instance, database)
54 else:
55 cnxn = queue.get()
56 # Make sure the connection is still good.
57 cnxn.ping()
58 cnxn.commit()
59
60 return cnxn
61
62 def release(self, cnxn):
63 if not cnxn.pool_key in self.queues:
64 raise BaseException('unknown pool key: %s' % cnxn.pool_key)
65
66 q = self.queues[cnxn.pool_key]
67 if q.full():
68 cnxn.close()
69 else:
70 q.put(cnxn)
71
72
73@framework_helpers.retry(1, delay=1, backoff=2)
74def cnxn_ctor(instance, database):
75 logging.info('About to connect to SQL instance %r db %r', instance, database)
76 if settings.unit_test_mode:
77 raise ValueError('unit tests should not need real database connections')
78 try:
79 if settings.local_mode:
80 start_time = time.time()
81 cnxn = MySQLdb.connect(
82 host='127.0.0.1', port=3306, db=database, user='root', charset='utf8')
83 else:
84 start_time = time.time()
85 cnxn = MySQLdb.connect(
Adrià Vilanova Martínezf5e10392021-12-07 22:55:40 +010086 unix_socket='/cloudsql/' + instance,
87 db=database,
88 user='monorail-mysql',
89 charset='utf8')
Copybara854996b2021-09-07 19:36:02 +000090 duration = int((time.time() - start_time) * 1000)
91 DB_CNXN_LATENCY.add(duration)
92 CONNECTION_COUNT.increment({'success': True})
93 except MySQLdb.OperationalError:
94 CONNECTION_COUNT.increment({'success': False})
95 raise
96 cnxn.pool_key = instance + '/' + database
97 cnxn.is_bad = False
98 return cnxn
99
100
101# One connection pool per database instance (primary, replicas are each an
102# instance). We'll have four connections per instance because we fetch
103# issue comments, stars, spam verdicts and spam verdict history in parallel
104# with promises.
105cnxn_pool = ConnectionPool(settings.db_cnxn_pool_size)
106
107# MonorailConnection maintains a dictionary of connections to SQL databases.
108# Each is identified by an int shard ID.
109# And there is one connection to the primary DB identified by key PRIMARY_CNXN.
110PRIMARY_CNXN = 'primary_cnxn'
111
112# When one replica is temporarily unresponseive, we can use a different one.
113BAD_SHARD_AVOIDANCE_SEC = 45
114
115
116CONNECTION_COUNT = ts_mon.CounterMetric(
117 'monorail/sql/connection_count',
118 'Count of connections made to the SQL database.',
119 [ts_mon.BooleanField('success')])
120
121DB_CNXN_LATENCY = ts_mon.CumulativeDistributionMetric(
122 'monorail/sql/db_cnxn_latency',
123 'Time needed to establish a DB connection.',
124 None)
125
126DB_QUERY_LATENCY = ts_mon.CumulativeDistributionMetric(
127 'monorail/sql/db_query_latency',
128 'Time needed to make a DB query.',
129 [ts_mon.StringField('type')])
130
131DB_COMMIT_LATENCY = ts_mon.CumulativeDistributionMetric(
132 'monorail/sql/db_commit_latency',
133 'Time needed to make a DB commit.',
134 None)
135
136DB_ROLLBACK_LATENCY = ts_mon.CumulativeDistributionMetric(
137 'monorail/sql/db_rollback_latency',
138 'Time needed to make a DB rollback.',
139 None)
140
141DB_RETRY_COUNT = ts_mon.CounterMetric(
142 'monorail/sql/db_retry_count',
143 'Count of queries retried.',
144 None)
145
146DB_QUERY_COUNT = ts_mon.CounterMetric(
147 'monorail/sql/db_query_count',
148 'Count of queries sent to the DB.',
149 [ts_mon.StringField('type')])
150
151DB_COMMIT_COUNT = ts_mon.CounterMetric(
152 'monorail/sql/db_commit_count',
153 'Count of commits sent to the DB.',
154 None)
155
156DB_ROLLBACK_COUNT = ts_mon.CounterMetric(
157 'monorail/sql/db_rollback_count',
158 'Count of rollbacks sent to the DB.',
159 None)
160
161DB_RESULT_ROWS = ts_mon.CumulativeDistributionMetric(
162 'monorail/sql/db_result_rows',
163 'Number of results returned by a DB query.',
164 None)
165
166
167def RandomShardID():
168 """Return a random shard ID to load balance across replicas."""
169 return random.randint(0, settings.num_logical_shards - 1)
170
171
172class MonorailConnection(object):
173 """Create and manage connections to the SQL servers.
174
175 We only store connections in the context of a single user request, not
176 across user requests. The main purpose of this class is to make using
177 sharded tables easier.
178 """
179 unavailable_shards = {} # {shard_id: timestamp of failed attempt}
180
181 def __init__(self):
182 self.sql_cnxns = {} # {PRIMARY_CNXN: cnxn, shard_id: cnxn, ...}
183
184 @framework_helpers.retry(1, delay=0.1, backoff=2)
185 def GetPrimaryConnection(self):
186 """Return a connection to the primary SQL DB."""
187 if PRIMARY_CNXN not in self.sql_cnxns:
188 self.sql_cnxns[PRIMARY_CNXN] = cnxn_pool.get(
189 settings.db_instance, settings.db_database_name)
190 logging.info(
191 'created a primary connection %r', self.sql_cnxns[PRIMARY_CNXN])
192
193 return self.sql_cnxns[PRIMARY_CNXN]
194
195 @framework_helpers.retry(1, delay=0.1, backoff=2)
196 def GetConnectionForShard(self, shard_id):
197 """Return a connection to the DB replica that will be used for shard_id."""
198 if shard_id not in self.sql_cnxns:
199 physical_shard_id = shard_id % settings.num_logical_shards
200
201 replica_name = settings.db_replica_names[
202 physical_shard_id % len(settings.db_replica_names)]
203 shard_instance_name = (
204 settings.physical_db_name_format % replica_name)
205 self.unavailable_shards[shard_id] = int(time.time())
206 self.sql_cnxns[shard_id] = cnxn_pool.get(
207 shard_instance_name, settings.db_database_name)
208 del self.unavailable_shards[shard_id]
209 logging.info('created a replica connection for shard %d', shard_id)
210
211 return self.sql_cnxns[shard_id]
212
213 def Execute(self, stmt_str, stmt_args, shard_id=None, commit=True, retries=2):
214 """Execute the given SQL statement on one of the relevant databases."""
215 if shard_id is None:
216 # No shard was specified, so hit the primary.
217 sql_cnxn = self.GetPrimaryConnection()
218 else:
219 if shard_id in self.unavailable_shards:
220 bad_age_sec = int(time.time()) - self.unavailable_shards[shard_id]
221 if bad_age_sec < BAD_SHARD_AVOIDANCE_SEC:
222 logging.info('Avoiding bad replica %r, age %r', shard_id, bad_age_sec)
223 shard_id = (shard_id + 1) % settings.num_logical_shards
224 sql_cnxn = self.GetConnectionForShard(shard_id)
225
226 try:
227 return self._ExecuteWithSQLConnection(
228 sql_cnxn, stmt_str, stmt_args, commit=commit)
229 except MySQLdb.OperationalError as e:
230 logging.exception(e)
231 logging.info('retries: %r', retries)
232 if retries > 0:
233 DB_RETRY_COUNT.increment()
234 self.sql_cnxns = {} # Drop all old mysql connections and make new.
235 return self.Execute(
236 stmt_str, stmt_args, shard_id=shard_id, commit=commit,
237 retries=retries - 1)
238 else:
239 raise e
240
241 def _ExecuteWithSQLConnection(
242 self, sql_cnxn, stmt_str, stmt_args, commit=True):
243 """Execute a statement on the given database and return a cursor."""
244
245 start_time = time.time()
246 cursor = sql_cnxn.cursor()
247 cursor.execute('SET NAMES utf8mb4')
248 if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
249 cursor.executemany(stmt_str, stmt_args)
250 duration = (time.time() - start_time) * 1000
251 DB_QUERY_LATENCY.add(duration, {'type': 'write'})
252 DB_QUERY_COUNT.increment({'type': 'write'})
253 else:
254 cursor.execute(stmt_str, args=stmt_args)
255 duration = (time.time() - start_time) * 1000
256 DB_QUERY_LATENCY.add(duration, {'type': 'read'})
257 DB_QUERY_COUNT.increment({'type': 'read'})
258 DB_RESULT_ROWS.add(cursor.rowcount)
259
260 if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
261 formatted_statement = '%s %s' % (stmt_str, stmt_args)
262 else:
263 formatted_statement = stmt_str % tuple(stmt_args)
264 logging.info(
265 '%d rows in %d ms: %s', cursor.rowcount, int(duration),
266 formatted_statement.replace('\n', ' '))
267
268 if commit and not stmt_str.startswith('SELECT'):
269 try:
270 sql_cnxn.commit()
271 duration = (time.time() - start_time) * 1000
272 DB_COMMIT_LATENCY.add(duration)
273 DB_COMMIT_COUNT.increment()
274 except MySQLdb.DatabaseError:
275 sql_cnxn.rollback()
276 duration = (time.time() - start_time) * 1000
277 DB_ROLLBACK_LATENCY.add(duration)
278 DB_ROLLBACK_COUNT.increment()
279
280 return cursor
281
282 def Commit(self):
283 """Explicitly commit any pending txns. Normally done automatically."""
284 sql_cnxn = self.GetPrimaryConnection()
285 try:
286 sql_cnxn.commit()
287 except MySQLdb.DatabaseError:
288 logging.exception('Commit failed for cnxn, rolling back')
289 sql_cnxn.rollback()
290
291 def Close(self):
292 """Safely close any connections that are still open."""
293 for sql_cnxn in self.sql_cnxns.values():
294 try:
295 sql_cnxn.rollback() # Abandon any uncommitted changes.
296 cnxn_pool.release(sql_cnxn)
297 except MySQLdb.DatabaseError:
298 # This might happen if the cnxn is somehow already closed.
299 logging.exception('ProgrammingError when trying to close cnxn')
300
301
302class SQLTableManager(object):
303 """Helper class to make it easier to deal with an SQL table."""
304
305 def __init__(self, table_name):
306 self.table_name = table_name
307
308 def Select(
309 self, cnxn, distinct=False, cols=None, left_joins=None,
310 joins=None, where=None, or_where_conds=False, group_by=None,
311 order_by=None, limit=None, offset=None, shard_id=None, use_clause=None,
312 having=None, **kwargs):
313 """Compose and execute an SQL SELECT statement on this table.
314
315 Args:
316 cnxn: MonorailConnection to the databases.
317 distinct: If True, add DISTINCT keyword.
318 cols: List of columns to retrieve, defaults to '*'.
319 left_joins: List of LEFT JOIN (str, args) pairs.
320 joins: List of regular JOIN (str, args) pairs.
321 where: List of (str, args) for WHERE clause.
322 or_where_conds: Set to True to use OR in the WHERE conds.
323 group_by: List of strings for GROUP BY clause.
324 order_by: List of (str, args) for ORDER BY clause.
325 limit: Optional LIMIT on the number of rows returned.
326 offset: Optional OFFSET when using LIMIT.
327 shard_id: Int ID of the shard to query.
328 use_clause: Optional string USE clause to tell the DB which index to use.
329 having: List of (str, args) for Optional HAVING clause
330 **kwargs: WHERE-clause equality and set-membership conditions.
331
332 Keyword args are used to build up more WHERE conditions that compare
333 column values to constants. Key word Argument foo='bar' translates to 'foo
334 = "bar"', and foo=[3, 4, 5] translates to 'foo IN (3, 4, 5)'.
335
336 Returns:
337 A list of rows, each row is a tuple of values for the requested cols.
338 """
339 cols = cols or ['*'] # If columns not specified, retrieve all columns.
340 stmt = Statement.MakeSelect(
341 self.table_name, cols, distinct=distinct,
342 or_where_conds=or_where_conds)
343 if use_clause:
344 stmt.AddUseClause(use_clause)
345 if having:
346 stmt.AddHavingTerms(having)
347 stmt.AddJoinClauses(left_joins or [], left=True)
348 stmt.AddJoinClauses(joins or [])
349 stmt.AddWhereTerms(where or [], **kwargs)
350 stmt.AddGroupByTerms(group_by or [])
351 stmt.AddOrderByTerms(order_by or [])
352 stmt.SetLimitAndOffset(limit, offset)
353 stmt_str, stmt_args = stmt.Generate()
354
355 cursor = cnxn.Execute(stmt_str, stmt_args, shard_id=shard_id)
356 rows = cursor.fetchall()
357 cursor.close()
358 return rows
359
360 def SelectRow(
361 self, cnxn, cols=None, default=None, where=None, **kwargs):
362 """Run a query that is expected to return just one row."""
363 rows = self.Select(cnxn, distinct=True, cols=cols, where=where, **kwargs)
364 if len(rows) == 1:
365 return rows[0]
366 elif not rows:
367 logging.info('SelectRow got 0 results, so using default %r', default)
368 return default
369 else:
370 raise ValueError('SelectRow got %d results, expected only 1', len(rows))
371
372 def SelectValue(self, cnxn, col, default=None, where=None, **kwargs):
373 """Run a query that is expected to return just one row w/ one value."""
374 row = self.SelectRow(
375 cnxn, cols=[col], default=[default], where=where, **kwargs)
376 return row[0]
377
378 def InsertRows(
379 self, cnxn, cols, row_values, replace=False, ignore=False,
380 commit=True, return_generated_ids=False):
381 """Insert all the given rows.
382
383 Args:
384 cnxn: MonorailConnection object.
385 cols: List of column names to set.
386 row_values: List of lists with values to store. The length of each
387 nested list should be equal to len(cols).
388 replace: Set to True if inserted values should replace existing DB rows
389 that have the same DB keys.
390 ignore: Set to True to ignore rows that would duplicate existing DB keys.
391 commit: Set to False if this operation is part of a series of operations
392 that should not be committed until the final one is done.
393 return_generated_ids: Set to True to return a list of generated
394 autoincrement IDs for inserted rows. This requires us to insert rows
395 one at a time.
396
397 Returns:
398 If return_generated_ids is set to True, this method returns a list of the
399 auto-increment IDs generated by the DB. Otherwise, [] is returned.
400 """
401 if not row_values:
402 return None # Nothing to insert
403
404 generated_ids = []
405 if return_generated_ids:
406 # We must insert the rows one-at-a-time to know the generated IDs.
407 for row_value in row_values:
408 stmt = Statement.MakeInsert(
409 self.table_name, cols, [row_value], replace=replace, ignore=ignore)
410 stmt_str, stmt_args = stmt.Generate()
411 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
412 if cursor.lastrowid:
413 generated_ids.append(cursor.lastrowid)
414 cursor.close()
415 return generated_ids
416
417 stmt = Statement.MakeInsert(
418 self.table_name, cols, row_values, replace=replace, ignore=ignore)
419 stmt_str, stmt_args = stmt.Generate()
420 cnxn.Execute(stmt_str, stmt_args, commit=commit)
421 return []
422
423
424 def InsertRow(
425 self, cnxn, replace=False, ignore=False, commit=True, **kwargs):
426 """Insert a single row into the table.
427
428 Args:
429 cnxn: MonorailConnection object.
430 replace: Set to True if inserted values should replace existing DB rows
431 that have the same DB keys.
432 ignore: Set to True to ignore rows that would duplicate existing DB keys.
433 commit: Set to False if this operation is part of a series of operations
434 that should not be committed until the final one is done.
435 **kwargs: column=value assignments to specify what to store in the DB.
436
437 Returns:
438 The generated autoincrement ID of the key column if one was generated.
439 Otherwise, return None.
440 """
441 cols = sorted(kwargs.keys())
442 row = tuple(kwargs[col] for col in cols)
443 generated_ids = self.InsertRows(
444 cnxn, cols, [row], replace=replace, ignore=ignore,
445 commit=commit, return_generated_ids=True)
446 if generated_ids:
447 return generated_ids[0]
448 else:
449 return None
450
451 def Update(self, cnxn, delta, where=None, commit=True, limit=None, **kwargs):
452 """Update one or more rows.
453
454 Args:
455 cnxn: MonorailConnection object.
456 delta: Dictionary of {column: new_value} assignments.
457 where: Optional list of WHERE conditions saying which rows to update.
458 commit: Set to False if this operation is part of a series of operations
459 that should not be committed until the final one is done.
460 limit: Optional LIMIT on the number of rows updated.
461 **kwargs: WHERE-clause equality and set-membership conditions.
462
463 Returns:
464 Int number of rows updated.
465 """
466 if not delta:
467 return 0 # Nothing is being changed
468
469 stmt = Statement.MakeUpdate(self.table_name, delta)
470 stmt.AddWhereTerms(where, **kwargs)
471 stmt.SetLimitAndOffset(limit, None)
472 stmt_str, stmt_args = stmt.Generate()
473
474 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
475 result = cursor.rowcount
476 cursor.close()
477 return result
478
479 def IncrementCounterValue(self, cnxn, col_name, where=None, **kwargs):
480 """Atomically increment a counter stored in MySQL, return new value.
481
482 Args:
483 cnxn: MonorailConnection object.
484 col_name: int column to increment.
485 where: Optional list of WHERE conditions saying which rows to update.
486 **kwargs: WHERE-clause equality and set-membership conditions. The
487 where and kwargs together should narrow the update down to exactly
488 one row.
489
490 Returns:
491 The new, post-increment value of the counter.
492 """
493 stmt = Statement.MakeIncrement(self.table_name, col_name)
494 stmt.AddWhereTerms(where, **kwargs)
495 stmt_str, stmt_args = stmt.Generate()
496
497 cursor = cnxn.Execute(stmt_str, stmt_args)
498 assert cursor.rowcount == 1, (
499 'missing or ambiguous counter: %r' % cursor.rowcount)
500 result = cursor.lastrowid
501 cursor.close()
502 return result
503
504 def Delete(self, cnxn, where=None, or_where_conds=False, commit=True,
505 limit=None, **kwargs):
506 """Delete the specified table rows.
507
508 Args:
509 cnxn: MonorailConnection object.
510 where: Optional list of WHERE conditions saying which rows to update.
511 or_where_conds: Set to True to use OR in the WHERE conds.
512 commit: Set to False if this operation is part of a series of operations
513 that should not be committed until the final one is done.
514 limit: Optional LIMIT on the number of rows deleted.
515 **kwargs: WHERE-clause equality and set-membership conditions.
516
517 Returns:
518 Int number of rows updated.
519 """
520 # Deleting the whole table is never intended in Monorail.
521 assert where or kwargs
522
523 stmt = Statement.MakeDelete(self.table_name, or_where_conds=or_where_conds)
524 stmt.AddWhereTerms(where, **kwargs)
525 stmt.SetLimitAndOffset(limit, None)
526 stmt_str, stmt_args = stmt.Generate()
527
528 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
529 result = cursor.rowcount
530 cursor.close()
531 return result
532
533
534class Statement(object):
535 """A class to help build complex SQL statements w/ full escaping.
536
537 Start with a Make*() method, then fill in additional clauses as needed,
538 then call Generate() to return the SQL string and argument list. We pass
539 the string and args to MySQLdb separately so that it can do escaping on
540 the arg values as appropriate to prevent SQL-injection attacks.
541
542 The only values that are not escaped by MySQLdb are the table names
543 and column names, and bits of SQL syntax, all of which is hard-coded
544 in our application.
545 """
546
547 @classmethod
548 def MakeSelect(cls, table_name, cols, distinct=False, or_where_conds=False):
549 """Construct a SELECT statement."""
550 assert _IsValidTableName(table_name)
551 assert all(_IsValidColumnName(col) for col in cols)
552 main_clause = 'SELECT%s %s FROM %s' % (
553 (' DISTINCT' if distinct else ''), ', '.join(cols), table_name)
554 return cls(main_clause, or_where_conds=or_where_conds)
555
556 @classmethod
557 def MakeInsert(
558 cls, table_name, cols, new_values, replace=False, ignore=False):
559 """Construct an INSERT statement."""
560 if replace == True:
561 return cls.MakeReplace(table_name, cols, new_values, ignore)
562 assert _IsValidTableName(table_name)
563 assert all(_IsValidColumnName(col) for col in cols)
564 ignore_word = ' IGNORE' if ignore else ''
565 main_clause = 'INSERT%s INTO %s (%s)' % (
566 ignore_word, table_name, ', '.join(cols))
567 return cls(main_clause, insert_args=new_values)
568
569 @classmethod
570 def MakeReplace(
571 cls, table_name, cols, new_values, ignore=False):
572 """Construct an INSERT...ON DUPLICATE KEY UPDATE... statement.
573
574 Uses the INSERT/UPDATE syntax because REPLACE is literally a DELETE
575 followed by an INSERT, which doesn't play well with foreign keys.
576 INSERT/UPDATE is an atomic check of whether the primary key exists,
577 followed by an INSERT if it doesn't or an UPDATE if it does.
578 """
579 assert _IsValidTableName(table_name)
580 assert all(_IsValidColumnName(col) for col in cols)
581 ignore_word = ' IGNORE' if ignore else ''
582 main_clause = 'INSERT%s INTO %s (%s)' % (
583 ignore_word, table_name, ', '.join(cols))
584 return cls(main_clause, insert_args=new_values, duplicate_update_cols=cols)
585
586 @classmethod
587 def MakeUpdate(cls, table_name, delta):
588 """Construct an UPDATE statement."""
589 assert _IsValidTableName(table_name)
590 assert all(_IsValidColumnName(col) for col in delta.keys())
591 update_strs = []
592 update_args = []
593 for col, val in delta.items():
594 update_strs.append(col + '=%s')
595 update_args.append(val)
596
597 main_clause = 'UPDATE %s SET %s' % (
598 table_name, ', '.join(update_strs))
599 return cls(main_clause, update_args=update_args)
600
601 @classmethod
602 def MakeIncrement(cls, table_name, col_name, step=1):
603 """Construct an UPDATE statement that increments and returns a counter."""
604 assert _IsValidTableName(table_name)
605 assert _IsValidColumnName(col_name)
606
607 main_clause = (
608 'UPDATE %s SET %s = LAST_INSERT_ID(%s + %%s)' % (
609 table_name, col_name, col_name))
610 update_args = [step]
611 return cls(main_clause, update_args=update_args)
612
613 @classmethod
614 def MakeDelete(cls, table_name, or_where_conds=False):
615 """Construct a DELETE statement."""
616 assert _IsValidTableName(table_name)
617 main_clause = 'DELETE FROM %s' % table_name
618 return cls(main_clause, or_where_conds=or_where_conds)
619
620 def __init__(
621 self, main_clause, insert_args=None, update_args=None,
622 duplicate_update_cols=None, or_where_conds=False):
623 self.main_clause = main_clause # E.g., SELECT or DELETE
624 self.or_where_conds = or_where_conds
625 self.insert_args = insert_args or [] # For INSERT statements
626 for row_value in self.insert_args:
627 if not all(_IsValidDBValue(val) for val in row_value):
628 raise exceptions.InputException('Invalid DB value %r' % (row_value,))
629 self.update_args = update_args or [] # For UPDATEs
630 for val in self.update_args:
631 if not _IsValidDBValue(val):
632 raise exceptions.InputException('Invalid DB value %r' % val)
633 self.duplicate_update_cols = duplicate_update_cols or [] # For REPLACE-ish
634
635 self.use_clauses = []
636 self.join_clauses, self.join_args = [], []
637 self.where_conds, self.where_args = [], []
638 self.having_conds, self.having_args = [], []
639 self.group_by_terms, self.group_by_args = [], []
640 self.order_by_terms, self.order_by_args = [], []
641 self.limit, self.offset = None, None
642
643 def Generate(self):
644 """Return an SQL string having %s placeholders and args to fill them in."""
645 clauses = [self.main_clause] + self.use_clauses + self.join_clauses
646 if self.where_conds:
647 if self.or_where_conds:
648 clauses.append('WHERE ' + '\n OR '.join(self.where_conds))
649 else:
650 clauses.append('WHERE ' + '\n AND '.join(self.where_conds))
651 if self.group_by_terms:
652 clauses.append('GROUP BY ' + ', '.join(self.group_by_terms))
653 if self.having_conds:
654 assert self.group_by_terms
655 clauses.append('HAVING %s' % ','.join(self.having_conds))
656 if self.order_by_terms:
657 clauses.append('ORDER BY ' + ', '.join(self.order_by_terms))
658
659 if self.limit and self.offset:
660 clauses.append('LIMIT %d OFFSET %d' % (self.limit, self.offset))
661 elif self.limit:
662 clauses.append('LIMIT %d' % self.limit)
663 elif self.offset:
664 clauses.append('LIMIT %d OFFSET %d' % (sys.maxint, self.offset))
665
666 if self.insert_args:
667 clauses.append('VALUES (' + PlaceHolders(self.insert_args[0]) + ')')
668 args = self.insert_args
669 if self.duplicate_update_cols:
670 clauses.append('ON DUPLICATE KEY UPDATE %s' % (
671 ', '.join(['%s=VALUES(%s)' % (col, col)
672 for col in self.duplicate_update_cols])))
673 assert not (self.join_args + self.update_args + self.where_args +
674 self.group_by_args + self.order_by_args + self.having_args)
675 else:
676 args = (self.join_args + self.update_args + self.where_args +
677 self.group_by_args + self.having_args + self.order_by_args)
678 assert not (self.insert_args + self.duplicate_update_cols)
679
680 args = _BoolsToInts(args)
681 stmt_str = '\n'.join(clause for clause in clauses if clause)
682
683 assert _IsValidStatement(stmt_str), stmt_str
684 return stmt_str, args
685
686 def AddUseClause(self, use_clause):
687 """Add a USE clause (giving the DB a hint about which indexes to use)."""
688 assert _IsValidUseClause(use_clause), use_clause
689 self.use_clauses.append(use_clause)
690
691 def AddJoinClauses(self, join_pairs, left=False):
692 """Save JOIN clauses based on the given list of join conditions."""
693 for join, args in join_pairs:
694 assert _IsValidJoin(join), join
695 assert join.count('%s') == len(args), join
696 self.join_clauses.append(
697 ' %sJOIN %s' % (('LEFT ' if left else ''), join))
698 self.join_args.extend(args)
699
700 def AddGroupByTerms(self, group_by_term_list):
701 """Save info needed to generate the GROUP BY clause."""
702 assert all(_IsValidGroupByTerm(term) for term in group_by_term_list)
703 self.group_by_terms.extend(group_by_term_list)
704
705 def AddOrderByTerms(self, order_by_pairs):
706 """Save info needed to generate the ORDER BY clause."""
707 for term, args in order_by_pairs:
708 assert _IsValidOrderByTerm(term), term
709 assert term.count('%s') == len(args), term
710 self.order_by_terms.append(term)
711 self.order_by_args.extend(args)
712
713 def SetLimitAndOffset(self, limit, offset):
714 """Save info needed to generate the LIMIT OFFSET clause."""
715 self.limit = limit
716 self.offset = offset
717
718 def AddWhereTerms(self, where_cond_pairs, **kwargs):
719 """Generate a WHERE clause."""
720 where_cond_pairs = where_cond_pairs or []
721
722 for cond, args in where_cond_pairs:
723 assert _IsValidWhereCond(cond), cond
724 assert cond.count('%s') == len(args), cond
725 self.where_conds.append(cond)
726 self.where_args.extend(args)
727
728 for col, val in sorted(kwargs.items()):
729 assert _IsValidColumnName(col), col
730 eq = True
731 if col.endswith('_not'):
732 col = col[:-4]
733 eq = False
734
735 if isinstance(val, set):
736 val = list(val) # MySQL inteface cannot handle sets.
737
738 if val is None or val == []:
739 if val == [] and self.main_clause and self.main_clause.startswith(
740 'UPDATE'):
741 # https://crbug.com/monorail/6735: Avoid empty arrays for UPDATE.
742 raise exceptions.InputException('Invalid update DB value %r' % col)
743 op = 'IS' if eq else 'IS NOT'
744 self.where_conds.append(col + ' ' + op + ' NULL')
745 elif isinstance(val, list):
746 op = 'IN' if eq else 'NOT IN'
747 # Sadly, MySQLdb cannot escape lists, so we flatten to multiple "%s"s
748 self.where_conds.append(
749 col + ' ' + op + ' (' + PlaceHolders(val) + ')')
750 self.where_args.extend(val)
751 else:
752 op = '=' if eq else '!='
753 self.where_conds.append(col + ' ' + op + ' %s')
754 self.where_args.append(val)
755
756 def AddHavingTerms(self, having_cond_pairs):
757 """Generate a HAVING clause."""
758 for cond, args in having_cond_pairs:
759 assert _IsValidHavingCond(cond), cond
760 assert cond.count('%s') == len(args), cond
761 self.having_conds.append(cond)
762 self.having_args.extend(args)
763
764
765def PlaceHolders(sql_args):
766 """Return a comma-separated list of %s placeholders for the given args."""
767 return ','.join('%s' for _ in sql_args)
768
769
770TABLE_PAT = '[A-Z][_a-zA-Z0-9]+'
771COLUMN_PAT = '[a-z][_a-z]+'
772COMPARE_OP_PAT = '(<|>|=|!=|>=|<=|LIKE|NOT LIKE)'
773SHORTHAND = {
774 'table': TABLE_PAT,
775 'column': COLUMN_PAT,
776 'tab_col': r'(%s\.)?%s' % (TABLE_PAT, COLUMN_PAT),
777 'placeholder': '%s', # That's a literal %s that gets passed to MySQLdb
778 'multi_placeholder': '%s(, ?%s)*',
779 'compare_op': COMPARE_OP_PAT,
780 'opt_asc_desc': '( ASC| DESC)?',
781 'opt_alias': '( AS %s)?' % TABLE_PAT,
782 'email_cond': (r'\(?'
783 r'('
784 r'(LOWER\(Spare\d+\.email\) IS NULL OR )?'
785 r'LOWER\(Spare\d+\.email\) '
786 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
787 r'( (AND|OR) )?'
788 r')+'
789 r'\)?' % COMPARE_OP_PAT),
790 'hotlist_cond': (r'\(?'
791 r'('
792 r'(LOWER\(Cond\d+\.name\) IS NULL OR )?'
793 r'LOWER\(Cond\d+\.name\) '
794 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
795 r'( (AND|OR) )?'
796 r')+'
797 r'\)?' % COMPARE_OP_PAT),
798 'phase_cond': (r'\(?'
799 r'('
800 r'(LOWER\(Phase\d+\.name\) IS NULL OR )?'
801 r'LOWER\(Phase\d+\.name\) '
802 r'(%s %%s|IN \(%%s(, ?%%s)*\))?'
803 r'( (AND|OR) )?'
804 r')+'
805 r'\)?' % COMPARE_OP_PAT),
806 'approval_cond': (r'\(?'
807 r'('
808 r'(LOWER\(Cond\d+\.status\) IS NULL OR )?'
809 r'LOWER\(Cond\d+\.status\) '
810 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
811 r'( (AND|OR) )?'
812 r')+'
813 r'\)?' % COMPARE_OP_PAT),
814 }
815
816
817def _MakeRE(regex_str):
818 """Return a regular expression object, expanding our shorthand as needed."""
819 return re.compile(regex_str.format(**SHORTHAND))
820
821
822TABLE_RE = _MakeRE('^{table}$')
823TAB_COL_RE = _MakeRE('^{tab_col}$')
824USE_CLAUSE_RE = _MakeRE(
825 r'^USE INDEX \({column}\) USE INDEX FOR ORDER BY \({column}\)$')
826HAVING_RE_LIST = [
827 _MakeRE(r'^COUNT\(\*\) {compare_op} {placeholder}$')]
828COLUMN_RE_LIST = [
829 TAB_COL_RE,
830 _MakeRE(r'\*'),
831 _MakeRE(r'COUNT\(\*\)'),
832 _MakeRE(r'COUNT\({tab_col}\)'),
833 _MakeRE(r'COUNT\(DISTINCT\({tab_col}\)\)'),
834 _MakeRE(r'MAX\({tab_col}\)'),
835 _MakeRE(r'MIN\({tab_col}\)'),
836 _MakeRE(r'GROUP_CONCAT\((DISTINCT )?{tab_col}( ORDER BY {tab_col})?' \
837 r'( SEPARATOR \'.*\')?\)'),
838 ]
839JOIN_RE_LIST = [
840 TABLE_RE,
841 _MakeRE(
842 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
843 r'( AND {tab_col} = {tab_col})?'
844 r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
845 _MakeRE(
846 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
847 r'( AND {tab_col} = {tab_col})?'
848 r'( AND {tab_col} = {placeholder})?'
849 r'( AND {tab_col} IN \({multi_placeholder}\))?'
850 r'( AND {tab_col} = {tab_col})?$'),
851 _MakeRE(
852 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
853 r'( AND {tab_col} = {tab_col})?'
854 r'( AND {tab_col} = {placeholder})?'
855 r'( AND {tab_col} IN \({multi_placeholder}\))?'
856 r'( AND {tab_col} IS NULL)?'
857 r'( AND \({tab_col} IS NULL'
858 r' OR {tab_col} NOT IN \({multi_placeholder}\)\))?$'),
859 _MakeRE(
860 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
861 r'( AND {tab_col} = {tab_col})?'
862 r'( AND {tab_col} = {placeholder})?'
863 r' AND \(?{tab_col} {compare_op} {placeholder}\)?'
864 r'( AND {tab_col} = {tab_col})?$'),
865 _MakeRE(
866 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
867 r'( AND {tab_col} = {tab_col})?'
868 r'( AND {tab_col} = {placeholder})?'
869 r' AND {tab_col} = {tab_col}$'),
870 _MakeRE(
871 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
872 r'( AND {tab_col} = {tab_col})?'
873 r'( AND {tab_col} = {placeholder})?'
874 r' AND \({tab_col} IS NULL OR'
875 r' {tab_col} {compare_op} {placeholder}\)$'),
876 _MakeRE(
877 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
878 r' AND \({tab_col} IS NOT NULL AND {tab_col} != {placeholder}\)'),
879 _MakeRE(
880 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
881 r' AND LOWER\({tab_col}\) = LOWER\({placeholder}\)'),
882 _MakeRE(
883 r'^{table}{opt_alias} ON {tab_col} = {tab_col} AND {email_cond}$'),
884 _MakeRE(
885 r'^{table}{opt_alias} ON {email_cond}$'),
886 _MakeRE(
887 r'^{table}{opt_alias} ON '
888 r'\({tab_col} = {tab_col} OR {tab_col} = {tab_col}\)$'),
889 _MakeRE(
890 r'^\({table} AS {table} JOIN User AS {table} '
891 r'ON {tab_col} = {tab_col} AND {email_cond}\) '
892 r'ON Issue(Snapshot)?.id = {tab_col}'
893 r'( AND {tab_col} IS NULL)?'),
894 _MakeRE(
895 r'^\({table} JOIN Hotlist AS {table} '
896 r'ON {tab_col} = {tab_col} AND {hotlist_cond}\) '
897 r'ON Issue.id = {tab_col}?'),
898 _MakeRE(
899 r'^\({table} AS {table} JOIN IssuePhaseDef AS {table} '
900 r'ON {tab_col} = {tab_col} AND {phase_cond}\) '
901 r'ON Issue.id = {tab_col}?'),
902 _MakeRE(
903 r'^IssuePhaseDef AS {table} ON {phase_cond}'),
904 _MakeRE(
905 r'^Issue2ApprovalValue AS {table} ON {tab_col} = {tab_col} '
906 r'AND {tab_col} = {placeholder} AND {approval_cond}'),
907 _MakeRE(
908 r'^{table} AS {table} ON {tab_col} = {tab_col} '
909 r'LEFT JOIN {table} AS {table} ON {tab_col} = {tab_col}'),
910 ]
911ORDER_BY_RE_LIST = [
912 _MakeRE(r'^{tab_col}{opt_asc_desc}$'),
913 _MakeRE(r'^LOWER\({tab_col}\){opt_asc_desc}$'),
914 _MakeRE(r'^ISNULL\({tab_col}\){opt_asc_desc}$'),
915 _MakeRE(r'^\(ISNULL\({tab_col}\) AND ISNULL\({tab_col}\)\){opt_asc_desc}$'),
916 _MakeRE(r'^FIELD\({tab_col}, {multi_placeholder}\){opt_asc_desc}$'),
917 _MakeRE(r'^FIELD\(IF\(ISNULL\({tab_col}\), {tab_col}, {tab_col}\), '
918 r'{multi_placeholder}\){opt_asc_desc}$'),
919 _MakeRE(r'^CONCAT\({tab_col}, {tab_col}\){opt_asc_desc}$'),
920 ]
921GROUP_BY_RE_LIST = [
922 TAB_COL_RE,
923 ]
924WHERE_COND_RE_LIST = [
925 _MakeRE(r'^TRUE$'),
926 _MakeRE(r'^FALSE$'),
927 _MakeRE(r'^{tab_col} IS NULL$'),
928 _MakeRE(r'^{tab_col} IS NOT NULL$'),
929 _MakeRE(r'^{tab_col} {compare_op} {tab_col}$'),
930 _MakeRE(r'^{tab_col} {compare_op} {placeholder}$'),
931 _MakeRE(r'^{tab_col} %% {placeholder} = {placeholder}$'),
932 _MakeRE(r'^{tab_col} IN \({multi_placeholder}\)$'),
933 _MakeRE(r'^{tab_col} NOT IN \({multi_placeholder}\)$'),
934 _MakeRE(r'^LOWER\({tab_col}\) IS NULL$'),
935 _MakeRE(r'^LOWER\({tab_col}\) IS NOT NULL$'),
936 _MakeRE(r'^LOWER\({tab_col}\) {compare_op} {placeholder}$'),
937 _MakeRE(r'^LOWER\({tab_col}\) IN \({multi_placeholder}\)$'),
938 _MakeRE(r'^LOWER\({tab_col}\) NOT IN \({multi_placeholder}\)$'),
939 _MakeRE(r'^LOWER\({tab_col}\) LIKE {placeholder}$'),
940 _MakeRE(r'^LOWER\({tab_col}\) NOT LIKE {placeholder}$'),
941 _MakeRE(r'^timestep < \(SELECT MAX\(j.timestep\) FROM Invalidate AS j '
942 r'WHERE j.kind = %s '
943 r'AND j.cache_key = Invalidate.cache_key\)$'),
944 _MakeRE(r'^\({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}\) '
945 'AND \({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}'
946 '\)$'),
947 _MakeRE(r'^\({tab_col} IS NOT NULL AND {tab_col} {compare_op} '
948 '{placeholder}\) OR \({tab_col} IS NOT NULL AND {tab_col} '
949 '{compare_op} {placeholder}\)$'),
950 ]
951
952# Note: We never use ';' for multiple statements, '@' for SQL variables, or
953# any quoted strings in stmt_str (quotes are put in my MySQLdb for args).
954STMT_STR_RE = re.compile(
955 r'\A(SELECT|UPDATE|DELETE|INSERT|REPLACE) [\'-+=!<>%*.,()\w\s]+\Z',
956 re.MULTILINE)
957
958
959def _IsValidDBValue(val):
960 if isinstance(val, string_types):
961 return '\x00' not in val
962 return True
963
964
965def _IsValidTableName(table_name):
966 return TABLE_RE.match(table_name)
967
968
969def _IsValidColumnName(column_expr):
970 return any(regex.match(column_expr) for regex in COLUMN_RE_LIST)
971
972
973def _IsValidUseClause(use_clause):
974 return USE_CLAUSE_RE.match(use_clause)
975
976def _IsValidHavingCond(cond):
977 if cond.startswith('(') and cond.endswith(')'):
978 cond = cond[1:-1]
979
980 if ' OR ' in cond:
981 return all(_IsValidHavingCond(c) for c in cond.split(' OR '))
982
983 if ' AND ' in cond:
984 return all(_IsValidHavingCond(c) for c in cond.split(' AND '))
985
986 return any(regex.match(cond) for regex in HAVING_RE_LIST)
987
988
989def _IsValidJoin(join):
990 return any(regex.match(join) for regex in JOIN_RE_LIST)
991
992
993def _IsValidOrderByTerm(term):
994 return any(regex.match(term) for regex in ORDER_BY_RE_LIST)
995
996
997def _IsValidGroupByTerm(term):
998 return any(regex.match(term) for regex in GROUP_BY_RE_LIST)
999
1000
1001def _IsValidWhereCond(cond):
1002 if cond.startswith('NOT '):
1003 cond = cond[4:]
1004 if cond.startswith('(') and cond.endswith(')'):
1005 cond = cond[1:-1]
1006
1007 if any(regex.match(cond) for regex in WHERE_COND_RE_LIST):
1008 return True
1009
1010 if ' OR ' in cond:
1011 return all(_IsValidWhereCond(c) for c in cond.split(' OR '))
1012
1013 if ' AND ' in cond:
1014 return all(_IsValidWhereCond(c) for c in cond.split(' AND '))
1015
1016 return False
1017
1018
1019def _IsValidStatement(stmt_str):
1020 """Final check to make sure there is no funny junk sneaking in somehow."""
1021 return (STMT_STR_RE.match(stmt_str) and
1022 '--' not in stmt_str)
1023
1024
1025def _BoolsToInts(arg_list):
1026 """Convert any True values to 1s and Falses to 0s.
1027
1028 Google's copy of MySQLdb has bool-to-int conversion disabled,
1029 and yet it seems to be needed otherwise they are converted
1030 to strings and always interpreted as 0 (which is FALSE).
1031
1032 Args:
1033 arg_list: (nested) list of SQL statment argument values, which may
1034 include some boolean values.
1035
1036 Returns:
1037 The same list, but with True replaced by 1 and False replaced by 0.
1038 """
1039 result = []
1040 for arg in arg_list:
1041 if isinstance(arg, (list, tuple)):
1042 result.append(_BoolsToInts(arg))
1043 elif arg is True:
1044 result.append(1)
1045 elif arg is False:
1046 result.append(0)
1047 else:
1048 result.append(arg)
1049
1050 return result