blob: 0fb8043042dc447adc379951e709a6852450e0cd [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2016 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""A set of classes for interacting with tables in SQL."""
7from __future__ import print_function
8from __future__ import division
9from __future__ import absolute_import
10
11import logging
12import random
13import re
14import sys
15import time
16
17from six import string_types
18
19import settings
20
21if not settings.unit_test_mode:
22 import MySQLdb
23
24from framework import exceptions
25from framework import framework_helpers
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +020026from framework import logger
Copybara854996b2021-09-07 19:36:02 +000027
28from infra_libs import ts_mon
29
30from Queue import Queue
31
32
33class ConnectionPool(object):
34 """Manage a set of database connections such that they may be re-used.
35 """
36
37 def __init__(self, poolsize=1):
38 self.poolsize = poolsize
39 self.queues = {}
40
41 def get(self, instance, database):
42 """Retun a database connection, or throw an exception if none can
43 be made.
44 """
45 key = instance + '/' + database
46
47 if not key in self.queues:
48 queue = Queue(self.poolsize)
49 self.queues[key] = queue
50
51 queue = self.queues[key]
52
53 if queue.empty():
54 cnxn = cnxn_ctor(instance, database)
55 else:
56 cnxn = queue.get()
57 # Make sure the connection is still good.
58 cnxn.ping()
59 cnxn.commit()
60
61 return cnxn
62
63 def release(self, cnxn):
64 if not cnxn.pool_key in self.queues:
65 raise BaseException('unknown pool key: %s' % cnxn.pool_key)
66
67 q = self.queues[cnxn.pool_key]
68 if q.full():
69 cnxn.close()
70 else:
71 q.put(cnxn)
72
73
74@framework_helpers.retry(1, delay=1, backoff=2)
75def cnxn_ctor(instance, database):
76 logging.info('About to connect to SQL instance %r db %r', instance, database)
77 if settings.unit_test_mode:
78 raise ValueError('unit tests should not need real database connections')
79 try:
80 if settings.local_mode:
81 start_time = time.time()
82 cnxn = MySQLdb.connect(
83 host='127.0.0.1', port=3306, db=database, user='root', charset='utf8')
84 else:
85 start_time = time.time()
86 cnxn = MySQLdb.connect(
Adrià Vilanova Martínezf5e10392021-12-07 22:55:40 +010087 unix_socket='/cloudsql/' + instance,
88 db=database,
89 user='monorail-mysql',
90 charset='utf8')
Copybara854996b2021-09-07 19:36:02 +000091 duration = int((time.time() - start_time) * 1000)
92 DB_CNXN_LATENCY.add(duration)
93 CONNECTION_COUNT.increment({'success': True})
94 except MySQLdb.OperationalError:
95 CONNECTION_COUNT.increment({'success': False})
96 raise
97 cnxn.pool_key = instance + '/' + database
98 cnxn.is_bad = False
99 return cnxn
100
101
102# One connection pool per database instance (primary, replicas are each an
103# instance). We'll have four connections per instance because we fetch
104# issue comments, stars, spam verdicts and spam verdict history in parallel
105# with promises.
106cnxn_pool = ConnectionPool(settings.db_cnxn_pool_size)
107
108# MonorailConnection maintains a dictionary of connections to SQL databases.
109# Each is identified by an int shard ID.
110# And there is one connection to the primary DB identified by key PRIMARY_CNXN.
111PRIMARY_CNXN = 'primary_cnxn'
112
113# When one replica is temporarily unresponseive, we can use a different one.
114BAD_SHARD_AVOIDANCE_SEC = 45
115
116
117CONNECTION_COUNT = ts_mon.CounterMetric(
118 'monorail/sql/connection_count',
119 'Count of connections made to the SQL database.',
120 [ts_mon.BooleanField('success')])
121
122DB_CNXN_LATENCY = ts_mon.CumulativeDistributionMetric(
123 'monorail/sql/db_cnxn_latency',
124 'Time needed to establish a DB connection.',
125 None)
126
127DB_QUERY_LATENCY = ts_mon.CumulativeDistributionMetric(
128 'monorail/sql/db_query_latency',
129 'Time needed to make a DB query.',
130 [ts_mon.StringField('type')])
131
132DB_COMMIT_LATENCY = ts_mon.CumulativeDistributionMetric(
133 'monorail/sql/db_commit_latency',
134 'Time needed to make a DB commit.',
135 None)
136
137DB_ROLLBACK_LATENCY = ts_mon.CumulativeDistributionMetric(
138 'monorail/sql/db_rollback_latency',
139 'Time needed to make a DB rollback.',
140 None)
141
142DB_RETRY_COUNT = ts_mon.CounterMetric(
143 'monorail/sql/db_retry_count',
144 'Count of queries retried.',
145 None)
146
147DB_QUERY_COUNT = ts_mon.CounterMetric(
148 'monorail/sql/db_query_count',
149 'Count of queries sent to the DB.',
150 [ts_mon.StringField('type')])
151
152DB_COMMIT_COUNT = ts_mon.CounterMetric(
153 'monorail/sql/db_commit_count',
154 'Count of commits sent to the DB.',
155 None)
156
157DB_ROLLBACK_COUNT = ts_mon.CounterMetric(
158 'monorail/sql/db_rollback_count',
159 'Count of rollbacks sent to the DB.',
160 None)
161
162DB_RESULT_ROWS = ts_mon.CumulativeDistributionMetric(
163 'monorail/sql/db_result_rows',
164 'Number of results returned by a DB query.',
165 None)
166
167
168def RandomShardID():
169 """Return a random shard ID to load balance across replicas."""
170 return random.randint(0, settings.num_logical_shards - 1)
171
172
173class MonorailConnection(object):
174 """Create and manage connections to the SQL servers.
175
176 We only store connections in the context of a single user request, not
177 across user requests. The main purpose of this class is to make using
178 sharded tables easier.
179 """
180 unavailable_shards = {} # {shard_id: timestamp of failed attempt}
181
182 def __init__(self):
183 self.sql_cnxns = {} # {PRIMARY_CNXN: cnxn, shard_id: cnxn, ...}
184
185 @framework_helpers.retry(1, delay=0.1, backoff=2)
186 def GetPrimaryConnection(self):
187 """Return a connection to the primary SQL DB."""
188 if PRIMARY_CNXN not in self.sql_cnxns:
189 self.sql_cnxns[PRIMARY_CNXN] = cnxn_pool.get(
190 settings.db_instance, settings.db_database_name)
191 logging.info(
192 'created a primary connection %r', self.sql_cnxns[PRIMARY_CNXN])
193
194 return self.sql_cnxns[PRIMARY_CNXN]
195
196 @framework_helpers.retry(1, delay=0.1, backoff=2)
197 def GetConnectionForShard(self, shard_id):
198 """Return a connection to the DB replica that will be used for shard_id."""
199 if shard_id not in self.sql_cnxns:
200 physical_shard_id = shard_id % settings.num_logical_shards
201
202 replica_name = settings.db_replica_names[
203 physical_shard_id % len(settings.db_replica_names)]
204 shard_instance_name = (
205 settings.physical_db_name_format % replica_name)
206 self.unavailable_shards[shard_id] = int(time.time())
207 self.sql_cnxns[shard_id] = cnxn_pool.get(
208 shard_instance_name, settings.db_database_name)
209 del self.unavailable_shards[shard_id]
210 logging.info('created a replica connection for shard %d', shard_id)
211
212 return self.sql_cnxns[shard_id]
213
214 def Execute(self, stmt_str, stmt_args, shard_id=None, commit=True, retries=2):
215 """Execute the given SQL statement on one of the relevant databases."""
216 if shard_id is None:
217 # No shard was specified, so hit the primary.
218 sql_cnxn = self.GetPrimaryConnection()
219 else:
220 if shard_id in self.unavailable_shards:
221 bad_age_sec = int(time.time()) - self.unavailable_shards[shard_id]
222 if bad_age_sec < BAD_SHARD_AVOIDANCE_SEC:
223 logging.info('Avoiding bad replica %r, age %r', shard_id, bad_age_sec)
224 shard_id = (shard_id + 1) % settings.num_logical_shards
225 sql_cnxn = self.GetConnectionForShard(shard_id)
226
227 try:
228 return self._ExecuteWithSQLConnection(
229 sql_cnxn, stmt_str, stmt_args, commit=commit)
230 except MySQLdb.OperationalError as e:
231 logging.exception(e)
232 logging.info('retries: %r', retries)
233 if retries > 0:
234 DB_RETRY_COUNT.increment()
235 self.sql_cnxns = {} # Drop all old mysql connections and make new.
236 return self.Execute(
237 stmt_str, stmt_args, shard_id=shard_id, commit=commit,
238 retries=retries - 1)
239 else:
240 raise e
241
242 def _ExecuteWithSQLConnection(
243 self, sql_cnxn, stmt_str, stmt_args, commit=True):
244 """Execute a statement on the given database and return a cursor."""
245
246 start_time = time.time()
247 cursor = sql_cnxn.cursor()
248 cursor.execute('SET NAMES utf8mb4')
249 if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
250 cursor.executemany(stmt_str, stmt_args)
251 duration = (time.time() - start_time) * 1000
252 DB_QUERY_LATENCY.add(duration, {'type': 'write'})
253 DB_QUERY_COUNT.increment({'type': 'write'})
254 else:
255 cursor.execute(stmt_str, args=stmt_args)
256 duration = (time.time() - start_time) * 1000
257 DB_QUERY_LATENCY.add(duration, {'type': 'read'})
258 DB_QUERY_COUNT.increment({'type': 'read'})
259 DB_RESULT_ROWS.add(cursor.rowcount)
260
261 if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +0200262 formatted_statement = ('%s %s' % (stmt_str, stmt_args)).replace('\n', ' ')
Copybara854996b2021-09-07 19:36:02 +0000263 else:
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +0200264 formatted_statement = (stmt_str % tuple(stmt_args)).replace('\n', ' ')
Copybara854996b2021-09-07 19:36:02 +0000265 logging.info(
266 '%d rows in %d ms: %s', cursor.rowcount, int(duration),
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +0200267 formatted_statement)
268 if duration >= 2000:
269 logger.log({
270 'log_type': 'database/query',
271 'statement': formatted_statement,
272 'type': formatted_statement.split(' ')[0],
273 'duration': duration / 1000,
274 'row_count': cursor.rowcount,
275 })
Copybara854996b2021-09-07 19:36:02 +0000276
277 if commit and not stmt_str.startswith('SELECT'):
278 try:
279 sql_cnxn.commit()
280 duration = (time.time() - start_time) * 1000
281 DB_COMMIT_LATENCY.add(duration)
282 DB_COMMIT_COUNT.increment()
283 except MySQLdb.DatabaseError:
284 sql_cnxn.rollback()
285 duration = (time.time() - start_time) * 1000
286 DB_ROLLBACK_LATENCY.add(duration)
287 DB_ROLLBACK_COUNT.increment()
288
289 return cursor
290
291 def Commit(self):
292 """Explicitly commit any pending txns. Normally done automatically."""
293 sql_cnxn = self.GetPrimaryConnection()
294 try:
295 sql_cnxn.commit()
296 except MySQLdb.DatabaseError:
297 logging.exception('Commit failed for cnxn, rolling back')
298 sql_cnxn.rollback()
299
300 def Close(self):
301 """Safely close any connections that are still open."""
302 for sql_cnxn in self.sql_cnxns.values():
303 try:
304 sql_cnxn.rollback() # Abandon any uncommitted changes.
305 cnxn_pool.release(sql_cnxn)
306 except MySQLdb.DatabaseError:
307 # This might happen if the cnxn is somehow already closed.
308 logging.exception('ProgrammingError when trying to close cnxn')
309
310
311class SQLTableManager(object):
312 """Helper class to make it easier to deal with an SQL table."""
313
314 def __init__(self, table_name):
315 self.table_name = table_name
316
317 def Select(
318 self, cnxn, distinct=False, cols=None, left_joins=None,
319 joins=None, where=None, or_where_conds=False, group_by=None,
320 order_by=None, limit=None, offset=None, shard_id=None, use_clause=None,
321 having=None, **kwargs):
322 """Compose and execute an SQL SELECT statement on this table.
323
324 Args:
325 cnxn: MonorailConnection to the databases.
326 distinct: If True, add DISTINCT keyword.
327 cols: List of columns to retrieve, defaults to '*'.
328 left_joins: List of LEFT JOIN (str, args) pairs.
329 joins: List of regular JOIN (str, args) pairs.
330 where: List of (str, args) for WHERE clause.
331 or_where_conds: Set to True to use OR in the WHERE conds.
332 group_by: List of strings for GROUP BY clause.
333 order_by: List of (str, args) for ORDER BY clause.
334 limit: Optional LIMIT on the number of rows returned.
335 offset: Optional OFFSET when using LIMIT.
336 shard_id: Int ID of the shard to query.
337 use_clause: Optional string USE clause to tell the DB which index to use.
338 having: List of (str, args) for Optional HAVING clause
339 **kwargs: WHERE-clause equality and set-membership conditions.
340
341 Keyword args are used to build up more WHERE conditions that compare
342 column values to constants. Key word Argument foo='bar' translates to 'foo
343 = "bar"', and foo=[3, 4, 5] translates to 'foo IN (3, 4, 5)'.
344
345 Returns:
346 A list of rows, each row is a tuple of values for the requested cols.
347 """
348 cols = cols or ['*'] # If columns not specified, retrieve all columns.
349 stmt = Statement.MakeSelect(
350 self.table_name, cols, distinct=distinct,
351 or_where_conds=or_where_conds)
352 if use_clause:
353 stmt.AddUseClause(use_clause)
354 if having:
355 stmt.AddHavingTerms(having)
356 stmt.AddJoinClauses(left_joins or [], left=True)
357 stmt.AddJoinClauses(joins or [])
358 stmt.AddWhereTerms(where or [], **kwargs)
359 stmt.AddGroupByTerms(group_by or [])
360 stmt.AddOrderByTerms(order_by or [])
361 stmt.SetLimitAndOffset(limit, offset)
362 stmt_str, stmt_args = stmt.Generate()
363
364 cursor = cnxn.Execute(stmt_str, stmt_args, shard_id=shard_id)
365 rows = cursor.fetchall()
366 cursor.close()
367 return rows
368
369 def SelectRow(
370 self, cnxn, cols=None, default=None, where=None, **kwargs):
371 """Run a query that is expected to return just one row."""
372 rows = self.Select(cnxn, distinct=True, cols=cols, where=where, **kwargs)
373 if len(rows) == 1:
374 return rows[0]
375 elif not rows:
376 logging.info('SelectRow got 0 results, so using default %r', default)
377 return default
378 else:
379 raise ValueError('SelectRow got %d results, expected only 1', len(rows))
380
381 def SelectValue(self, cnxn, col, default=None, where=None, **kwargs):
382 """Run a query that is expected to return just one row w/ one value."""
383 row = self.SelectRow(
384 cnxn, cols=[col], default=[default], where=where, **kwargs)
385 return row[0]
386
387 def InsertRows(
388 self, cnxn, cols, row_values, replace=False, ignore=False,
389 commit=True, return_generated_ids=False):
390 """Insert all the given rows.
391
392 Args:
393 cnxn: MonorailConnection object.
394 cols: List of column names to set.
395 row_values: List of lists with values to store. The length of each
396 nested list should be equal to len(cols).
397 replace: Set to True if inserted values should replace existing DB rows
398 that have the same DB keys.
399 ignore: Set to True to ignore rows that would duplicate existing DB keys.
400 commit: Set to False if this operation is part of a series of operations
401 that should not be committed until the final one is done.
402 return_generated_ids: Set to True to return a list of generated
403 autoincrement IDs for inserted rows. This requires us to insert rows
404 one at a time.
405
406 Returns:
407 If return_generated_ids is set to True, this method returns a list of the
408 auto-increment IDs generated by the DB. Otherwise, [] is returned.
409 """
410 if not row_values:
411 return None # Nothing to insert
412
413 generated_ids = []
414 if return_generated_ids:
415 # We must insert the rows one-at-a-time to know the generated IDs.
416 for row_value in row_values:
417 stmt = Statement.MakeInsert(
418 self.table_name, cols, [row_value], replace=replace, ignore=ignore)
419 stmt_str, stmt_args = stmt.Generate()
420 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
421 if cursor.lastrowid:
422 generated_ids.append(cursor.lastrowid)
423 cursor.close()
424 return generated_ids
425
426 stmt = Statement.MakeInsert(
427 self.table_name, cols, row_values, replace=replace, ignore=ignore)
428 stmt_str, stmt_args = stmt.Generate()
429 cnxn.Execute(stmt_str, stmt_args, commit=commit)
430 return []
431
432
433 def InsertRow(
434 self, cnxn, replace=False, ignore=False, commit=True, **kwargs):
435 """Insert a single row into the table.
436
437 Args:
438 cnxn: MonorailConnection object.
439 replace: Set to True if inserted values should replace existing DB rows
440 that have the same DB keys.
441 ignore: Set to True to ignore rows that would duplicate existing DB keys.
442 commit: Set to False if this operation is part of a series of operations
443 that should not be committed until the final one is done.
444 **kwargs: column=value assignments to specify what to store in the DB.
445
446 Returns:
447 The generated autoincrement ID of the key column if one was generated.
448 Otherwise, return None.
449 """
450 cols = sorted(kwargs.keys())
451 row = tuple(kwargs[col] for col in cols)
452 generated_ids = self.InsertRows(
453 cnxn, cols, [row], replace=replace, ignore=ignore,
454 commit=commit, return_generated_ids=True)
455 if generated_ids:
456 return generated_ids[0]
457 else:
458 return None
459
460 def Update(self, cnxn, delta, where=None, commit=True, limit=None, **kwargs):
461 """Update one or more rows.
462
463 Args:
464 cnxn: MonorailConnection object.
465 delta: Dictionary of {column: new_value} assignments.
466 where: Optional list of WHERE conditions saying which rows to update.
467 commit: Set to False if this operation is part of a series of operations
468 that should not be committed until the final one is done.
469 limit: Optional LIMIT on the number of rows updated.
470 **kwargs: WHERE-clause equality and set-membership conditions.
471
472 Returns:
473 Int number of rows updated.
474 """
475 if not delta:
476 return 0 # Nothing is being changed
477
478 stmt = Statement.MakeUpdate(self.table_name, delta)
479 stmt.AddWhereTerms(where, **kwargs)
480 stmt.SetLimitAndOffset(limit, None)
481 stmt_str, stmt_args = stmt.Generate()
482
483 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
484 result = cursor.rowcount
485 cursor.close()
486 return result
487
488 def IncrementCounterValue(self, cnxn, col_name, where=None, **kwargs):
489 """Atomically increment a counter stored in MySQL, return new value.
490
491 Args:
492 cnxn: MonorailConnection object.
493 col_name: int column to increment.
494 where: Optional list of WHERE conditions saying which rows to update.
495 **kwargs: WHERE-clause equality and set-membership conditions. The
496 where and kwargs together should narrow the update down to exactly
497 one row.
498
499 Returns:
500 The new, post-increment value of the counter.
501 """
502 stmt = Statement.MakeIncrement(self.table_name, col_name)
503 stmt.AddWhereTerms(where, **kwargs)
504 stmt_str, stmt_args = stmt.Generate()
505
506 cursor = cnxn.Execute(stmt_str, stmt_args)
507 assert cursor.rowcount == 1, (
508 'missing or ambiguous counter: %r' % cursor.rowcount)
509 result = cursor.lastrowid
510 cursor.close()
511 return result
512
513 def Delete(self, cnxn, where=None, or_where_conds=False, commit=True,
514 limit=None, **kwargs):
515 """Delete the specified table rows.
516
517 Args:
518 cnxn: MonorailConnection object.
519 where: Optional list of WHERE conditions saying which rows to update.
520 or_where_conds: Set to True to use OR in the WHERE conds.
521 commit: Set to False if this operation is part of a series of operations
522 that should not be committed until the final one is done.
523 limit: Optional LIMIT on the number of rows deleted.
524 **kwargs: WHERE-clause equality and set-membership conditions.
525
526 Returns:
527 Int number of rows updated.
528 """
529 # Deleting the whole table is never intended in Monorail.
530 assert where or kwargs
531
532 stmt = Statement.MakeDelete(self.table_name, or_where_conds=or_where_conds)
533 stmt.AddWhereTerms(where, **kwargs)
534 stmt.SetLimitAndOffset(limit, None)
535 stmt_str, stmt_args = stmt.Generate()
536
537 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
538 result = cursor.rowcount
539 cursor.close()
540 return result
541
542
543class Statement(object):
544 """A class to help build complex SQL statements w/ full escaping.
545
546 Start with a Make*() method, then fill in additional clauses as needed,
547 then call Generate() to return the SQL string and argument list. We pass
548 the string and args to MySQLdb separately so that it can do escaping on
549 the arg values as appropriate to prevent SQL-injection attacks.
550
551 The only values that are not escaped by MySQLdb are the table names
552 and column names, and bits of SQL syntax, all of which is hard-coded
553 in our application.
554 """
555
556 @classmethod
557 def MakeSelect(cls, table_name, cols, distinct=False, or_where_conds=False):
558 """Construct a SELECT statement."""
559 assert _IsValidTableName(table_name)
560 assert all(_IsValidColumnName(col) for col in cols)
561 main_clause = 'SELECT%s %s FROM %s' % (
562 (' DISTINCT' if distinct else ''), ', '.join(cols), table_name)
563 return cls(main_clause, or_where_conds=or_where_conds)
564
565 @classmethod
566 def MakeInsert(
567 cls, table_name, cols, new_values, replace=False, ignore=False):
568 """Construct an INSERT statement."""
569 if replace == True:
570 return cls.MakeReplace(table_name, cols, new_values, ignore)
571 assert _IsValidTableName(table_name)
572 assert all(_IsValidColumnName(col) for col in cols)
573 ignore_word = ' IGNORE' if ignore else ''
574 main_clause = 'INSERT%s INTO %s (%s)' % (
575 ignore_word, table_name, ', '.join(cols))
576 return cls(main_clause, insert_args=new_values)
577
578 @classmethod
579 def MakeReplace(
580 cls, table_name, cols, new_values, ignore=False):
581 """Construct an INSERT...ON DUPLICATE KEY UPDATE... statement.
582
583 Uses the INSERT/UPDATE syntax because REPLACE is literally a DELETE
584 followed by an INSERT, which doesn't play well with foreign keys.
585 INSERT/UPDATE is an atomic check of whether the primary key exists,
586 followed by an INSERT if it doesn't or an UPDATE if it does.
587 """
588 assert _IsValidTableName(table_name)
589 assert all(_IsValidColumnName(col) for col in cols)
590 ignore_word = ' IGNORE' if ignore else ''
591 main_clause = 'INSERT%s INTO %s (%s)' % (
592 ignore_word, table_name, ', '.join(cols))
593 return cls(main_clause, insert_args=new_values, duplicate_update_cols=cols)
594
595 @classmethod
596 def MakeUpdate(cls, table_name, delta):
597 """Construct an UPDATE statement."""
598 assert _IsValidTableName(table_name)
599 assert all(_IsValidColumnName(col) for col in delta.keys())
600 update_strs = []
601 update_args = []
602 for col, val in delta.items():
603 update_strs.append(col + '=%s')
604 update_args.append(val)
605
606 main_clause = 'UPDATE %s SET %s' % (
607 table_name, ', '.join(update_strs))
608 return cls(main_clause, update_args=update_args)
609
610 @classmethod
611 def MakeIncrement(cls, table_name, col_name, step=1):
612 """Construct an UPDATE statement that increments and returns a counter."""
613 assert _IsValidTableName(table_name)
614 assert _IsValidColumnName(col_name)
615
616 main_clause = (
617 'UPDATE %s SET %s = LAST_INSERT_ID(%s + %%s)' % (
618 table_name, col_name, col_name))
619 update_args = [step]
620 return cls(main_clause, update_args=update_args)
621
622 @classmethod
623 def MakeDelete(cls, table_name, or_where_conds=False):
624 """Construct a DELETE statement."""
625 assert _IsValidTableName(table_name)
626 main_clause = 'DELETE FROM %s' % table_name
627 return cls(main_clause, or_where_conds=or_where_conds)
628
629 def __init__(
630 self, main_clause, insert_args=None, update_args=None,
631 duplicate_update_cols=None, or_where_conds=False):
632 self.main_clause = main_clause # E.g., SELECT or DELETE
633 self.or_where_conds = or_where_conds
634 self.insert_args = insert_args or [] # For INSERT statements
635 for row_value in self.insert_args:
636 if not all(_IsValidDBValue(val) for val in row_value):
637 raise exceptions.InputException('Invalid DB value %r' % (row_value,))
638 self.update_args = update_args or [] # For UPDATEs
639 for val in self.update_args:
640 if not _IsValidDBValue(val):
641 raise exceptions.InputException('Invalid DB value %r' % val)
642 self.duplicate_update_cols = duplicate_update_cols or [] # For REPLACE-ish
643
644 self.use_clauses = []
645 self.join_clauses, self.join_args = [], []
646 self.where_conds, self.where_args = [], []
647 self.having_conds, self.having_args = [], []
648 self.group_by_terms, self.group_by_args = [], []
649 self.order_by_terms, self.order_by_args = [], []
650 self.limit, self.offset = None, None
651
652 def Generate(self):
653 """Return an SQL string having %s placeholders and args to fill them in."""
654 clauses = [self.main_clause] + self.use_clauses + self.join_clauses
655 if self.where_conds:
656 if self.or_where_conds:
657 clauses.append('WHERE ' + '\n OR '.join(self.where_conds))
658 else:
659 clauses.append('WHERE ' + '\n AND '.join(self.where_conds))
660 if self.group_by_terms:
661 clauses.append('GROUP BY ' + ', '.join(self.group_by_terms))
662 if self.having_conds:
663 assert self.group_by_terms
664 clauses.append('HAVING %s' % ','.join(self.having_conds))
665 if self.order_by_terms:
666 clauses.append('ORDER BY ' + ', '.join(self.order_by_terms))
667
668 if self.limit and self.offset:
669 clauses.append('LIMIT %d OFFSET %d' % (self.limit, self.offset))
670 elif self.limit:
671 clauses.append('LIMIT %d' % self.limit)
672 elif self.offset:
673 clauses.append('LIMIT %d OFFSET %d' % (sys.maxint, self.offset))
674
675 if self.insert_args:
676 clauses.append('VALUES (' + PlaceHolders(self.insert_args[0]) + ')')
677 args = self.insert_args
678 if self.duplicate_update_cols:
679 clauses.append('ON DUPLICATE KEY UPDATE %s' % (
680 ', '.join(['%s=VALUES(%s)' % (col, col)
681 for col in self.duplicate_update_cols])))
682 assert not (self.join_args + self.update_args + self.where_args +
683 self.group_by_args + self.order_by_args + self.having_args)
684 else:
685 args = (self.join_args + self.update_args + self.where_args +
686 self.group_by_args + self.having_args + self.order_by_args)
687 assert not (self.insert_args + self.duplicate_update_cols)
688
689 args = _BoolsToInts(args)
690 stmt_str = '\n'.join(clause for clause in clauses if clause)
691
692 assert _IsValidStatement(stmt_str), stmt_str
693 return stmt_str, args
694
695 def AddUseClause(self, use_clause):
696 """Add a USE clause (giving the DB a hint about which indexes to use)."""
697 assert _IsValidUseClause(use_clause), use_clause
698 self.use_clauses.append(use_clause)
699
700 def AddJoinClauses(self, join_pairs, left=False):
701 """Save JOIN clauses based on the given list of join conditions."""
702 for join, args in join_pairs:
703 assert _IsValidJoin(join), join
704 assert join.count('%s') == len(args), join
705 self.join_clauses.append(
706 ' %sJOIN %s' % (('LEFT ' if left else ''), join))
707 self.join_args.extend(args)
708
709 def AddGroupByTerms(self, group_by_term_list):
710 """Save info needed to generate the GROUP BY clause."""
711 assert all(_IsValidGroupByTerm(term) for term in group_by_term_list)
712 self.group_by_terms.extend(group_by_term_list)
713
714 def AddOrderByTerms(self, order_by_pairs):
715 """Save info needed to generate the ORDER BY clause."""
716 for term, args in order_by_pairs:
717 assert _IsValidOrderByTerm(term), term
718 assert term.count('%s') == len(args), term
719 self.order_by_terms.append(term)
720 self.order_by_args.extend(args)
721
722 def SetLimitAndOffset(self, limit, offset):
723 """Save info needed to generate the LIMIT OFFSET clause."""
724 self.limit = limit
725 self.offset = offset
726
727 def AddWhereTerms(self, where_cond_pairs, **kwargs):
728 """Generate a WHERE clause."""
729 where_cond_pairs = where_cond_pairs or []
730
731 for cond, args in where_cond_pairs:
732 assert _IsValidWhereCond(cond), cond
733 assert cond.count('%s') == len(args), cond
734 self.where_conds.append(cond)
735 self.where_args.extend(args)
736
737 for col, val in sorted(kwargs.items()):
738 assert _IsValidColumnName(col), col
739 eq = True
740 if col.endswith('_not'):
741 col = col[:-4]
742 eq = False
743
744 if isinstance(val, set):
745 val = list(val) # MySQL inteface cannot handle sets.
746
747 if val is None or val == []:
748 if val == [] and self.main_clause and self.main_clause.startswith(
749 'UPDATE'):
750 # https://crbug.com/monorail/6735: Avoid empty arrays for UPDATE.
751 raise exceptions.InputException('Invalid update DB value %r' % col)
752 op = 'IS' if eq else 'IS NOT'
753 self.where_conds.append(col + ' ' + op + ' NULL')
754 elif isinstance(val, list):
755 op = 'IN' if eq else 'NOT IN'
756 # Sadly, MySQLdb cannot escape lists, so we flatten to multiple "%s"s
757 self.where_conds.append(
758 col + ' ' + op + ' (' + PlaceHolders(val) + ')')
759 self.where_args.extend(val)
760 else:
761 op = '=' if eq else '!='
762 self.where_conds.append(col + ' ' + op + ' %s')
763 self.where_args.append(val)
764
765 def AddHavingTerms(self, having_cond_pairs):
766 """Generate a HAVING clause."""
767 for cond, args in having_cond_pairs:
768 assert _IsValidHavingCond(cond), cond
769 assert cond.count('%s') == len(args), cond
770 self.having_conds.append(cond)
771 self.having_args.extend(args)
772
773
774def PlaceHolders(sql_args):
775 """Return a comma-separated list of %s placeholders for the given args."""
776 return ','.join('%s' for _ in sql_args)
777
778
779TABLE_PAT = '[A-Z][_a-zA-Z0-9]+'
780COLUMN_PAT = '[a-z][_a-z]+'
781COMPARE_OP_PAT = '(<|>|=|!=|>=|<=|LIKE|NOT LIKE)'
782SHORTHAND = {
783 'table': TABLE_PAT,
784 'column': COLUMN_PAT,
785 'tab_col': r'(%s\.)?%s' % (TABLE_PAT, COLUMN_PAT),
786 'placeholder': '%s', # That's a literal %s that gets passed to MySQLdb
787 'multi_placeholder': '%s(, ?%s)*',
788 'compare_op': COMPARE_OP_PAT,
789 'opt_asc_desc': '( ASC| DESC)?',
790 'opt_alias': '( AS %s)?' % TABLE_PAT,
791 'email_cond': (r'\(?'
792 r'('
793 r'(LOWER\(Spare\d+\.email\) IS NULL OR )?'
794 r'LOWER\(Spare\d+\.email\) '
795 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
796 r'( (AND|OR) )?'
797 r')+'
798 r'\)?' % COMPARE_OP_PAT),
799 'hotlist_cond': (r'\(?'
800 r'('
801 r'(LOWER\(Cond\d+\.name\) IS NULL OR )?'
802 r'LOWER\(Cond\d+\.name\) '
803 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
804 r'( (AND|OR) )?'
805 r')+'
806 r'\)?' % COMPARE_OP_PAT),
807 'phase_cond': (r'\(?'
808 r'('
809 r'(LOWER\(Phase\d+\.name\) IS NULL OR )?'
810 r'LOWER\(Phase\d+\.name\) '
811 r'(%s %%s|IN \(%%s(, ?%%s)*\))?'
812 r'( (AND|OR) )?'
813 r')+'
814 r'\)?' % COMPARE_OP_PAT),
815 'approval_cond': (r'\(?'
816 r'('
817 r'(LOWER\(Cond\d+\.status\) IS NULL OR )?'
818 r'LOWER\(Cond\d+\.status\) '
819 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
820 r'( (AND|OR) )?'
821 r')+'
822 r'\)?' % COMPARE_OP_PAT),
823 }
824
825
826def _MakeRE(regex_str):
827 """Return a regular expression object, expanding our shorthand as needed."""
828 return re.compile(regex_str.format(**SHORTHAND))
829
830
831TABLE_RE = _MakeRE('^{table}$')
832TAB_COL_RE = _MakeRE('^{tab_col}$')
833USE_CLAUSE_RE = _MakeRE(
834 r'^USE INDEX \({column}\) USE INDEX FOR ORDER BY \({column}\)$')
835HAVING_RE_LIST = [
836 _MakeRE(r'^COUNT\(\*\) {compare_op} {placeholder}$')]
837COLUMN_RE_LIST = [
838 TAB_COL_RE,
839 _MakeRE(r'\*'),
840 _MakeRE(r'COUNT\(\*\)'),
841 _MakeRE(r'COUNT\({tab_col}\)'),
842 _MakeRE(r'COUNT\(DISTINCT\({tab_col}\)\)'),
843 _MakeRE(r'MAX\({tab_col}\)'),
844 _MakeRE(r'MIN\({tab_col}\)'),
845 _MakeRE(r'GROUP_CONCAT\((DISTINCT )?{tab_col}( ORDER BY {tab_col})?' \
846 r'( SEPARATOR \'.*\')?\)'),
847 ]
848JOIN_RE_LIST = [
849 TABLE_RE,
850 _MakeRE(
851 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
852 r'( AND {tab_col} = {tab_col})?'
853 r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
854 _MakeRE(
855 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
856 r'( AND {tab_col} = {tab_col})?'
857 r'( AND {tab_col} = {placeholder})?'
858 r'( AND {tab_col} IN \({multi_placeholder}\))?'
859 r'( AND {tab_col} = {tab_col})?$'),
860 _MakeRE(
861 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
862 r'( AND {tab_col} = {tab_col})?'
863 r'( AND {tab_col} = {placeholder})?'
864 r'( AND {tab_col} IN \({multi_placeholder}\))?'
865 r'( AND {tab_col} IS NULL)?'
866 r'( AND \({tab_col} IS NULL'
867 r' OR {tab_col} NOT IN \({multi_placeholder}\)\))?$'),
868 _MakeRE(
869 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
870 r'( AND {tab_col} = {tab_col})?'
871 r'( AND {tab_col} = {placeholder})?'
872 r' AND \(?{tab_col} {compare_op} {placeholder}\)?'
873 r'( AND {tab_col} = {tab_col})?$'),
874 _MakeRE(
875 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
876 r'( AND {tab_col} = {tab_col})?'
877 r'( AND {tab_col} = {placeholder})?'
878 r' AND {tab_col} = {tab_col}$'),
879 _MakeRE(
880 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
881 r'( AND {tab_col} = {tab_col})?'
882 r'( AND {tab_col} = {placeholder})?'
883 r' AND \({tab_col} IS NULL OR'
884 r' {tab_col} {compare_op} {placeholder}\)$'),
885 _MakeRE(
886 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
887 r' AND \({tab_col} IS NOT NULL AND {tab_col} != {placeholder}\)'),
888 _MakeRE(
889 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
890 r' AND LOWER\({tab_col}\) = LOWER\({placeholder}\)'),
891 _MakeRE(
892 r'^{table}{opt_alias} ON {tab_col} = {tab_col} AND {email_cond}$'),
893 _MakeRE(
894 r'^{table}{opt_alias} ON {email_cond}$'),
895 _MakeRE(
896 r'^{table}{opt_alias} ON '
897 r'\({tab_col} = {tab_col} OR {tab_col} = {tab_col}\)$'),
898 _MakeRE(
899 r'^\({table} AS {table} JOIN User AS {table} '
900 r'ON {tab_col} = {tab_col} AND {email_cond}\) '
901 r'ON Issue(Snapshot)?.id = {tab_col}'
902 r'( AND {tab_col} IS NULL)?'),
903 _MakeRE(
904 r'^\({table} JOIN Hotlist AS {table} '
905 r'ON {tab_col} = {tab_col} AND {hotlist_cond}\) '
906 r'ON Issue.id = {tab_col}?'),
907 _MakeRE(
908 r'^\({table} AS {table} JOIN IssuePhaseDef AS {table} '
909 r'ON {tab_col} = {tab_col} AND {phase_cond}\) '
910 r'ON Issue.id = {tab_col}?'),
911 _MakeRE(
912 r'^IssuePhaseDef AS {table} ON {phase_cond}'),
913 _MakeRE(
914 r'^Issue2ApprovalValue AS {table} ON {tab_col} = {tab_col} '
915 r'AND {tab_col} = {placeholder} AND {approval_cond}'),
916 _MakeRE(
917 r'^{table} AS {table} ON {tab_col} = {tab_col} '
918 r'LEFT JOIN {table} AS {table} ON {tab_col} = {tab_col}'),
919 ]
920ORDER_BY_RE_LIST = [
921 _MakeRE(r'^{tab_col}{opt_asc_desc}$'),
922 _MakeRE(r'^LOWER\({tab_col}\){opt_asc_desc}$'),
923 _MakeRE(r'^ISNULL\({tab_col}\){opt_asc_desc}$'),
924 _MakeRE(r'^\(ISNULL\({tab_col}\) AND ISNULL\({tab_col}\)\){opt_asc_desc}$'),
925 _MakeRE(r'^FIELD\({tab_col}, {multi_placeholder}\){opt_asc_desc}$'),
926 _MakeRE(r'^FIELD\(IF\(ISNULL\({tab_col}\), {tab_col}, {tab_col}\), '
927 r'{multi_placeholder}\){opt_asc_desc}$'),
928 _MakeRE(r'^CONCAT\({tab_col}, {tab_col}\){opt_asc_desc}$'),
929 ]
930GROUP_BY_RE_LIST = [
931 TAB_COL_RE,
932 ]
933WHERE_COND_RE_LIST = [
934 _MakeRE(r'^TRUE$'),
935 _MakeRE(r'^FALSE$'),
936 _MakeRE(r'^{tab_col} IS NULL$'),
937 _MakeRE(r'^{tab_col} IS NOT NULL$'),
938 _MakeRE(r'^{tab_col} {compare_op} {tab_col}$'),
939 _MakeRE(r'^{tab_col} {compare_op} {placeholder}$'),
940 _MakeRE(r'^{tab_col} %% {placeholder} = {placeholder}$'),
941 _MakeRE(r'^{tab_col} IN \({multi_placeholder}\)$'),
942 _MakeRE(r'^{tab_col} NOT IN \({multi_placeholder}\)$'),
943 _MakeRE(r'^LOWER\({tab_col}\) IS NULL$'),
944 _MakeRE(r'^LOWER\({tab_col}\) IS NOT NULL$'),
945 _MakeRE(r'^LOWER\({tab_col}\) {compare_op} {placeholder}$'),
946 _MakeRE(r'^LOWER\({tab_col}\) IN \({multi_placeholder}\)$'),
947 _MakeRE(r'^LOWER\({tab_col}\) NOT IN \({multi_placeholder}\)$'),
948 _MakeRE(r'^LOWER\({tab_col}\) LIKE {placeholder}$'),
949 _MakeRE(r'^LOWER\({tab_col}\) NOT LIKE {placeholder}$'),
950 _MakeRE(r'^timestep < \(SELECT MAX\(j.timestep\) FROM Invalidate AS j '
951 r'WHERE j.kind = %s '
952 r'AND j.cache_key = Invalidate.cache_key\)$'),
953 _MakeRE(r'^\({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}\) '
954 'AND \({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}'
955 '\)$'),
956 _MakeRE(r'^\({tab_col} IS NOT NULL AND {tab_col} {compare_op} '
957 '{placeholder}\) OR \({tab_col} IS NOT NULL AND {tab_col} '
958 '{compare_op} {placeholder}\)$'),
959 ]
960
961# Note: We never use ';' for multiple statements, '@' for SQL variables, or
962# any quoted strings in stmt_str (quotes are put in my MySQLdb for args).
963STMT_STR_RE = re.compile(
964 r'\A(SELECT|UPDATE|DELETE|INSERT|REPLACE) [\'-+=!<>%*.,()\w\s]+\Z',
965 re.MULTILINE)
966
967
968def _IsValidDBValue(val):
969 if isinstance(val, string_types):
970 return '\x00' not in val
971 return True
972
973
974def _IsValidTableName(table_name):
975 return TABLE_RE.match(table_name)
976
977
978def _IsValidColumnName(column_expr):
979 return any(regex.match(column_expr) for regex in COLUMN_RE_LIST)
980
981
982def _IsValidUseClause(use_clause):
983 return USE_CLAUSE_RE.match(use_clause)
984
985def _IsValidHavingCond(cond):
986 if cond.startswith('(') and cond.endswith(')'):
987 cond = cond[1:-1]
988
989 if ' OR ' in cond:
990 return all(_IsValidHavingCond(c) for c in cond.split(' OR '))
991
992 if ' AND ' in cond:
993 return all(_IsValidHavingCond(c) for c in cond.split(' AND '))
994
995 return any(regex.match(cond) for regex in HAVING_RE_LIST)
996
997
998def _IsValidJoin(join):
999 return any(regex.match(join) for regex in JOIN_RE_LIST)
1000
1001
1002def _IsValidOrderByTerm(term):
1003 return any(regex.match(term) for regex in ORDER_BY_RE_LIST)
1004
1005
1006def _IsValidGroupByTerm(term):
1007 return any(regex.match(term) for regex in GROUP_BY_RE_LIST)
1008
1009
1010def _IsValidWhereCond(cond):
1011 if cond.startswith('NOT '):
1012 cond = cond[4:]
1013 if cond.startswith('(') and cond.endswith(')'):
1014 cond = cond[1:-1]
1015
1016 if any(regex.match(cond) for regex in WHERE_COND_RE_LIST):
1017 return True
1018
1019 if ' OR ' in cond:
1020 return all(_IsValidWhereCond(c) for c in cond.split(' OR '))
1021
1022 if ' AND ' in cond:
1023 return all(_IsValidWhereCond(c) for c in cond.split(' AND '))
1024
1025 return False
1026
1027
1028def _IsValidStatement(stmt_str):
1029 """Final check to make sure there is no funny junk sneaking in somehow."""
1030 return (STMT_STR_RE.match(stmt_str) and
1031 '--' not in stmt_str)
1032
1033
1034def _BoolsToInts(arg_list):
1035 """Convert any True values to 1s and Falses to 0s.
1036
1037 Google's copy of MySQLdb has bool-to-int conversion disabled,
1038 and yet it seems to be needed otherwise they are converted
1039 to strings and always interpreted as 0 (which is FALSE).
1040
1041 Args:
1042 arg_list: (nested) list of SQL statment argument values, which may
1043 include some boolean values.
1044
1045 Returns:
1046 The same list, but with True replaced by 1 and False replaced by 0.
1047 """
1048 result = []
1049 for arg in arg_list:
1050 if isinstance(arg, (list, tuple)):
1051 result.append(_BoolsToInts(arg))
1052 elif arg is True:
1053 result.append(1)
1054 elif arg is False:
1055 result.append(0)
1056 else:
1057 result.append(arg)
1058
1059 return result