blob: d99b0450e36b60ccbec12b4e2e26046dc298e463 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2016 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""A set of classes for interacting with tables in SQL."""
7from __future__ import print_function
8from __future__ import division
9from __future__ import absolute_import
10
11import logging
12import random
13import re
14import sys
15import time
16
17from six import string_types
18
19import settings
20
21if not settings.unit_test_mode:
22 import MySQLdb
23
24from framework import exceptions
25from framework import framework_helpers
26
27from infra_libs import ts_mon
28
29from Queue import Queue
30
31
32class ConnectionPool(object):
33 """Manage a set of database connections such that they may be re-used.
34 """
35
36 def __init__(self, poolsize=1):
37 self.poolsize = poolsize
38 self.queues = {}
39
40 def get(self, instance, database):
41 """Retun a database connection, or throw an exception if none can
42 be made.
43 """
44 key = instance + '/' + database
45
46 if not key in self.queues:
47 queue = Queue(self.poolsize)
48 self.queues[key] = queue
49
50 queue = self.queues[key]
51
52 if queue.empty():
53 cnxn = cnxn_ctor(instance, database)
54 else:
55 cnxn = queue.get()
56 # Make sure the connection is still good.
57 cnxn.ping()
58 cnxn.commit()
59
60 return cnxn
61
62 def release(self, cnxn):
63 if not cnxn.pool_key in self.queues:
64 raise BaseException('unknown pool key: %s' % cnxn.pool_key)
65
66 q = self.queues[cnxn.pool_key]
67 if q.full():
68 cnxn.close()
69 else:
70 q.put(cnxn)
71
72
73@framework_helpers.retry(1, delay=1, backoff=2)
74def cnxn_ctor(instance, database):
75 logging.info('About to connect to SQL instance %r db %r', instance, database)
76 if settings.unit_test_mode:
77 raise ValueError('unit tests should not need real database connections')
78 try:
79 if settings.local_mode:
80 start_time = time.time()
81 cnxn = MySQLdb.connect(
82 host='127.0.0.1', port=3306, db=database, user='root', charset='utf8')
83 else:
84 start_time = time.time()
85 cnxn = MySQLdb.connect(
86 unix_socket='/cloudsql/' + instance, db=database, user='root',
87 charset='utf8')
88 duration = int((time.time() - start_time) * 1000)
89 DB_CNXN_LATENCY.add(duration)
90 CONNECTION_COUNT.increment({'success': True})
91 except MySQLdb.OperationalError:
92 CONNECTION_COUNT.increment({'success': False})
93 raise
94 cnxn.pool_key = instance + '/' + database
95 cnxn.is_bad = False
96 return cnxn
97
98
99# One connection pool per database instance (primary, replicas are each an
100# instance). We'll have four connections per instance because we fetch
101# issue comments, stars, spam verdicts and spam verdict history in parallel
102# with promises.
103cnxn_pool = ConnectionPool(settings.db_cnxn_pool_size)
104
105# MonorailConnection maintains a dictionary of connections to SQL databases.
106# Each is identified by an int shard ID.
107# And there is one connection to the primary DB identified by key PRIMARY_CNXN.
108PRIMARY_CNXN = 'primary_cnxn'
109
110# When one replica is temporarily unresponseive, we can use a different one.
111BAD_SHARD_AVOIDANCE_SEC = 45
112
113
114CONNECTION_COUNT = ts_mon.CounterMetric(
115 'monorail/sql/connection_count',
116 'Count of connections made to the SQL database.',
117 [ts_mon.BooleanField('success')])
118
119DB_CNXN_LATENCY = ts_mon.CumulativeDistributionMetric(
120 'monorail/sql/db_cnxn_latency',
121 'Time needed to establish a DB connection.',
122 None)
123
124DB_QUERY_LATENCY = ts_mon.CumulativeDistributionMetric(
125 'monorail/sql/db_query_latency',
126 'Time needed to make a DB query.',
127 [ts_mon.StringField('type')])
128
129DB_COMMIT_LATENCY = ts_mon.CumulativeDistributionMetric(
130 'monorail/sql/db_commit_latency',
131 'Time needed to make a DB commit.',
132 None)
133
134DB_ROLLBACK_LATENCY = ts_mon.CumulativeDistributionMetric(
135 'monorail/sql/db_rollback_latency',
136 'Time needed to make a DB rollback.',
137 None)
138
139DB_RETRY_COUNT = ts_mon.CounterMetric(
140 'monorail/sql/db_retry_count',
141 'Count of queries retried.',
142 None)
143
144DB_QUERY_COUNT = ts_mon.CounterMetric(
145 'monorail/sql/db_query_count',
146 'Count of queries sent to the DB.',
147 [ts_mon.StringField('type')])
148
149DB_COMMIT_COUNT = ts_mon.CounterMetric(
150 'monorail/sql/db_commit_count',
151 'Count of commits sent to the DB.',
152 None)
153
154DB_ROLLBACK_COUNT = ts_mon.CounterMetric(
155 'monorail/sql/db_rollback_count',
156 'Count of rollbacks sent to the DB.',
157 None)
158
159DB_RESULT_ROWS = ts_mon.CumulativeDistributionMetric(
160 'monorail/sql/db_result_rows',
161 'Number of results returned by a DB query.',
162 None)
163
164
165def RandomShardID():
166 """Return a random shard ID to load balance across replicas."""
167 return random.randint(0, settings.num_logical_shards - 1)
168
169
170class MonorailConnection(object):
171 """Create and manage connections to the SQL servers.
172
173 We only store connections in the context of a single user request, not
174 across user requests. The main purpose of this class is to make using
175 sharded tables easier.
176 """
177 unavailable_shards = {} # {shard_id: timestamp of failed attempt}
178
179 def __init__(self):
180 self.sql_cnxns = {} # {PRIMARY_CNXN: cnxn, shard_id: cnxn, ...}
181
182 @framework_helpers.retry(1, delay=0.1, backoff=2)
183 def GetPrimaryConnection(self):
184 """Return a connection to the primary SQL DB."""
185 if PRIMARY_CNXN not in self.sql_cnxns:
186 self.sql_cnxns[PRIMARY_CNXN] = cnxn_pool.get(
187 settings.db_instance, settings.db_database_name)
188 logging.info(
189 'created a primary connection %r', self.sql_cnxns[PRIMARY_CNXN])
190
191 return self.sql_cnxns[PRIMARY_CNXN]
192
193 @framework_helpers.retry(1, delay=0.1, backoff=2)
194 def GetConnectionForShard(self, shard_id):
195 """Return a connection to the DB replica that will be used for shard_id."""
196 if shard_id not in self.sql_cnxns:
197 physical_shard_id = shard_id % settings.num_logical_shards
198
199 replica_name = settings.db_replica_names[
200 physical_shard_id % len(settings.db_replica_names)]
201 shard_instance_name = (
202 settings.physical_db_name_format % replica_name)
203 self.unavailable_shards[shard_id] = int(time.time())
204 self.sql_cnxns[shard_id] = cnxn_pool.get(
205 shard_instance_name, settings.db_database_name)
206 del self.unavailable_shards[shard_id]
207 logging.info('created a replica connection for shard %d', shard_id)
208
209 return self.sql_cnxns[shard_id]
210
211 def Execute(self, stmt_str, stmt_args, shard_id=None, commit=True, retries=2):
212 """Execute the given SQL statement on one of the relevant databases."""
213 if shard_id is None:
214 # No shard was specified, so hit the primary.
215 sql_cnxn = self.GetPrimaryConnection()
216 else:
217 if shard_id in self.unavailable_shards:
218 bad_age_sec = int(time.time()) - self.unavailable_shards[shard_id]
219 if bad_age_sec < BAD_SHARD_AVOIDANCE_SEC:
220 logging.info('Avoiding bad replica %r, age %r', shard_id, bad_age_sec)
221 shard_id = (shard_id + 1) % settings.num_logical_shards
222 sql_cnxn = self.GetConnectionForShard(shard_id)
223
224 try:
225 return self._ExecuteWithSQLConnection(
226 sql_cnxn, stmt_str, stmt_args, commit=commit)
227 except MySQLdb.OperationalError as e:
228 logging.exception(e)
229 logging.info('retries: %r', retries)
230 if retries > 0:
231 DB_RETRY_COUNT.increment()
232 self.sql_cnxns = {} # Drop all old mysql connections and make new.
233 return self.Execute(
234 stmt_str, stmt_args, shard_id=shard_id, commit=commit,
235 retries=retries - 1)
236 else:
237 raise e
238
239 def _ExecuteWithSQLConnection(
240 self, sql_cnxn, stmt_str, stmt_args, commit=True):
241 """Execute a statement on the given database and return a cursor."""
242
243 start_time = time.time()
244 cursor = sql_cnxn.cursor()
245 cursor.execute('SET NAMES utf8mb4')
246 if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
247 cursor.executemany(stmt_str, stmt_args)
248 duration = (time.time() - start_time) * 1000
249 DB_QUERY_LATENCY.add(duration, {'type': 'write'})
250 DB_QUERY_COUNT.increment({'type': 'write'})
251 else:
252 cursor.execute(stmt_str, args=stmt_args)
253 duration = (time.time() - start_time) * 1000
254 DB_QUERY_LATENCY.add(duration, {'type': 'read'})
255 DB_QUERY_COUNT.increment({'type': 'read'})
256 DB_RESULT_ROWS.add(cursor.rowcount)
257
258 if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
259 formatted_statement = '%s %s' % (stmt_str, stmt_args)
260 else:
261 formatted_statement = stmt_str % tuple(stmt_args)
262 logging.info(
263 '%d rows in %d ms: %s', cursor.rowcount, int(duration),
264 formatted_statement.replace('\n', ' '))
265
266 if commit and not stmt_str.startswith('SELECT'):
267 try:
268 sql_cnxn.commit()
269 duration = (time.time() - start_time) * 1000
270 DB_COMMIT_LATENCY.add(duration)
271 DB_COMMIT_COUNT.increment()
272 except MySQLdb.DatabaseError:
273 sql_cnxn.rollback()
274 duration = (time.time() - start_time) * 1000
275 DB_ROLLBACK_LATENCY.add(duration)
276 DB_ROLLBACK_COUNT.increment()
277
278 return cursor
279
280 def Commit(self):
281 """Explicitly commit any pending txns. Normally done automatically."""
282 sql_cnxn = self.GetPrimaryConnection()
283 try:
284 sql_cnxn.commit()
285 except MySQLdb.DatabaseError:
286 logging.exception('Commit failed for cnxn, rolling back')
287 sql_cnxn.rollback()
288
289 def Close(self):
290 """Safely close any connections that are still open."""
291 for sql_cnxn in self.sql_cnxns.values():
292 try:
293 sql_cnxn.rollback() # Abandon any uncommitted changes.
294 cnxn_pool.release(sql_cnxn)
295 except MySQLdb.DatabaseError:
296 # This might happen if the cnxn is somehow already closed.
297 logging.exception('ProgrammingError when trying to close cnxn')
298
299
300class SQLTableManager(object):
301 """Helper class to make it easier to deal with an SQL table."""
302
303 def __init__(self, table_name):
304 self.table_name = table_name
305
306 def Select(
307 self, cnxn, distinct=False, cols=None, left_joins=None,
308 joins=None, where=None, or_where_conds=False, group_by=None,
309 order_by=None, limit=None, offset=None, shard_id=None, use_clause=None,
310 having=None, **kwargs):
311 """Compose and execute an SQL SELECT statement on this table.
312
313 Args:
314 cnxn: MonorailConnection to the databases.
315 distinct: If True, add DISTINCT keyword.
316 cols: List of columns to retrieve, defaults to '*'.
317 left_joins: List of LEFT JOIN (str, args) pairs.
318 joins: List of regular JOIN (str, args) pairs.
319 where: List of (str, args) for WHERE clause.
320 or_where_conds: Set to True to use OR in the WHERE conds.
321 group_by: List of strings for GROUP BY clause.
322 order_by: List of (str, args) for ORDER BY clause.
323 limit: Optional LIMIT on the number of rows returned.
324 offset: Optional OFFSET when using LIMIT.
325 shard_id: Int ID of the shard to query.
326 use_clause: Optional string USE clause to tell the DB which index to use.
327 having: List of (str, args) for Optional HAVING clause
328 **kwargs: WHERE-clause equality and set-membership conditions.
329
330 Keyword args are used to build up more WHERE conditions that compare
331 column values to constants. Key word Argument foo='bar' translates to 'foo
332 = "bar"', and foo=[3, 4, 5] translates to 'foo IN (3, 4, 5)'.
333
334 Returns:
335 A list of rows, each row is a tuple of values for the requested cols.
336 """
337 cols = cols or ['*'] # If columns not specified, retrieve all columns.
338 stmt = Statement.MakeSelect(
339 self.table_name, cols, distinct=distinct,
340 or_where_conds=or_where_conds)
341 if use_clause:
342 stmt.AddUseClause(use_clause)
343 if having:
344 stmt.AddHavingTerms(having)
345 stmt.AddJoinClauses(left_joins or [], left=True)
346 stmt.AddJoinClauses(joins or [])
347 stmt.AddWhereTerms(where or [], **kwargs)
348 stmt.AddGroupByTerms(group_by or [])
349 stmt.AddOrderByTerms(order_by or [])
350 stmt.SetLimitAndOffset(limit, offset)
351 stmt_str, stmt_args = stmt.Generate()
352
353 cursor = cnxn.Execute(stmt_str, stmt_args, shard_id=shard_id)
354 rows = cursor.fetchall()
355 cursor.close()
356 return rows
357
358 def SelectRow(
359 self, cnxn, cols=None, default=None, where=None, **kwargs):
360 """Run a query that is expected to return just one row."""
361 rows = self.Select(cnxn, distinct=True, cols=cols, where=where, **kwargs)
362 if len(rows) == 1:
363 return rows[0]
364 elif not rows:
365 logging.info('SelectRow got 0 results, so using default %r', default)
366 return default
367 else:
368 raise ValueError('SelectRow got %d results, expected only 1', len(rows))
369
370 def SelectValue(self, cnxn, col, default=None, where=None, **kwargs):
371 """Run a query that is expected to return just one row w/ one value."""
372 row = self.SelectRow(
373 cnxn, cols=[col], default=[default], where=where, **kwargs)
374 return row[0]
375
376 def InsertRows(
377 self, cnxn, cols, row_values, replace=False, ignore=False,
378 commit=True, return_generated_ids=False):
379 """Insert all the given rows.
380
381 Args:
382 cnxn: MonorailConnection object.
383 cols: List of column names to set.
384 row_values: List of lists with values to store. The length of each
385 nested list should be equal to len(cols).
386 replace: Set to True if inserted values should replace existing DB rows
387 that have the same DB keys.
388 ignore: Set to True to ignore rows that would duplicate existing DB keys.
389 commit: Set to False if this operation is part of a series of operations
390 that should not be committed until the final one is done.
391 return_generated_ids: Set to True to return a list of generated
392 autoincrement IDs for inserted rows. This requires us to insert rows
393 one at a time.
394
395 Returns:
396 If return_generated_ids is set to True, this method returns a list of the
397 auto-increment IDs generated by the DB. Otherwise, [] is returned.
398 """
399 if not row_values:
400 return None # Nothing to insert
401
402 generated_ids = []
403 if return_generated_ids:
404 # We must insert the rows one-at-a-time to know the generated IDs.
405 for row_value in row_values:
406 stmt = Statement.MakeInsert(
407 self.table_name, cols, [row_value], replace=replace, ignore=ignore)
408 stmt_str, stmt_args = stmt.Generate()
409 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
410 if cursor.lastrowid:
411 generated_ids.append(cursor.lastrowid)
412 cursor.close()
413 return generated_ids
414
415 stmt = Statement.MakeInsert(
416 self.table_name, cols, row_values, replace=replace, ignore=ignore)
417 stmt_str, stmt_args = stmt.Generate()
418 cnxn.Execute(stmt_str, stmt_args, commit=commit)
419 return []
420
421
422 def InsertRow(
423 self, cnxn, replace=False, ignore=False, commit=True, **kwargs):
424 """Insert a single row into the table.
425
426 Args:
427 cnxn: MonorailConnection object.
428 replace: Set to True if inserted values should replace existing DB rows
429 that have the same DB keys.
430 ignore: Set to True to ignore rows that would duplicate existing DB keys.
431 commit: Set to False if this operation is part of a series of operations
432 that should not be committed until the final one is done.
433 **kwargs: column=value assignments to specify what to store in the DB.
434
435 Returns:
436 The generated autoincrement ID of the key column if one was generated.
437 Otherwise, return None.
438 """
439 cols = sorted(kwargs.keys())
440 row = tuple(kwargs[col] for col in cols)
441 generated_ids = self.InsertRows(
442 cnxn, cols, [row], replace=replace, ignore=ignore,
443 commit=commit, return_generated_ids=True)
444 if generated_ids:
445 return generated_ids[0]
446 else:
447 return None
448
449 def Update(self, cnxn, delta, where=None, commit=True, limit=None, **kwargs):
450 """Update one or more rows.
451
452 Args:
453 cnxn: MonorailConnection object.
454 delta: Dictionary of {column: new_value} assignments.
455 where: Optional list of WHERE conditions saying which rows to update.
456 commit: Set to False if this operation is part of a series of operations
457 that should not be committed until the final one is done.
458 limit: Optional LIMIT on the number of rows updated.
459 **kwargs: WHERE-clause equality and set-membership conditions.
460
461 Returns:
462 Int number of rows updated.
463 """
464 if not delta:
465 return 0 # Nothing is being changed
466
467 stmt = Statement.MakeUpdate(self.table_name, delta)
468 stmt.AddWhereTerms(where, **kwargs)
469 stmt.SetLimitAndOffset(limit, None)
470 stmt_str, stmt_args = stmt.Generate()
471
472 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
473 result = cursor.rowcount
474 cursor.close()
475 return result
476
477 def IncrementCounterValue(self, cnxn, col_name, where=None, **kwargs):
478 """Atomically increment a counter stored in MySQL, return new value.
479
480 Args:
481 cnxn: MonorailConnection object.
482 col_name: int column to increment.
483 where: Optional list of WHERE conditions saying which rows to update.
484 **kwargs: WHERE-clause equality and set-membership conditions. The
485 where and kwargs together should narrow the update down to exactly
486 one row.
487
488 Returns:
489 The new, post-increment value of the counter.
490 """
491 stmt = Statement.MakeIncrement(self.table_name, col_name)
492 stmt.AddWhereTerms(where, **kwargs)
493 stmt_str, stmt_args = stmt.Generate()
494
495 cursor = cnxn.Execute(stmt_str, stmt_args)
496 assert cursor.rowcount == 1, (
497 'missing or ambiguous counter: %r' % cursor.rowcount)
498 result = cursor.lastrowid
499 cursor.close()
500 return result
501
502 def Delete(self, cnxn, where=None, or_where_conds=False, commit=True,
503 limit=None, **kwargs):
504 """Delete the specified table rows.
505
506 Args:
507 cnxn: MonorailConnection object.
508 where: Optional list of WHERE conditions saying which rows to update.
509 or_where_conds: Set to True to use OR in the WHERE conds.
510 commit: Set to False if this operation is part of a series of operations
511 that should not be committed until the final one is done.
512 limit: Optional LIMIT on the number of rows deleted.
513 **kwargs: WHERE-clause equality and set-membership conditions.
514
515 Returns:
516 Int number of rows updated.
517 """
518 # Deleting the whole table is never intended in Monorail.
519 assert where or kwargs
520
521 stmt = Statement.MakeDelete(self.table_name, or_where_conds=or_where_conds)
522 stmt.AddWhereTerms(where, **kwargs)
523 stmt.SetLimitAndOffset(limit, None)
524 stmt_str, stmt_args = stmt.Generate()
525
526 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
527 result = cursor.rowcount
528 cursor.close()
529 return result
530
531
532class Statement(object):
533 """A class to help build complex SQL statements w/ full escaping.
534
535 Start with a Make*() method, then fill in additional clauses as needed,
536 then call Generate() to return the SQL string and argument list. We pass
537 the string and args to MySQLdb separately so that it can do escaping on
538 the arg values as appropriate to prevent SQL-injection attacks.
539
540 The only values that are not escaped by MySQLdb are the table names
541 and column names, and bits of SQL syntax, all of which is hard-coded
542 in our application.
543 """
544
545 @classmethod
546 def MakeSelect(cls, table_name, cols, distinct=False, or_where_conds=False):
547 """Construct a SELECT statement."""
548 assert _IsValidTableName(table_name)
549 assert all(_IsValidColumnName(col) for col in cols)
550 main_clause = 'SELECT%s %s FROM %s' % (
551 (' DISTINCT' if distinct else ''), ', '.join(cols), table_name)
552 return cls(main_clause, or_where_conds=or_where_conds)
553
554 @classmethod
555 def MakeInsert(
556 cls, table_name, cols, new_values, replace=False, ignore=False):
557 """Construct an INSERT statement."""
558 if replace == True:
559 return cls.MakeReplace(table_name, cols, new_values, ignore)
560 assert _IsValidTableName(table_name)
561 assert all(_IsValidColumnName(col) for col in cols)
562 ignore_word = ' IGNORE' if ignore else ''
563 main_clause = 'INSERT%s INTO %s (%s)' % (
564 ignore_word, table_name, ', '.join(cols))
565 return cls(main_clause, insert_args=new_values)
566
567 @classmethod
568 def MakeReplace(
569 cls, table_name, cols, new_values, ignore=False):
570 """Construct an INSERT...ON DUPLICATE KEY UPDATE... statement.
571
572 Uses the INSERT/UPDATE syntax because REPLACE is literally a DELETE
573 followed by an INSERT, which doesn't play well with foreign keys.
574 INSERT/UPDATE is an atomic check of whether the primary key exists,
575 followed by an INSERT if it doesn't or an UPDATE if it does.
576 """
577 assert _IsValidTableName(table_name)
578 assert all(_IsValidColumnName(col) for col in cols)
579 ignore_word = ' IGNORE' if ignore else ''
580 main_clause = 'INSERT%s INTO %s (%s)' % (
581 ignore_word, table_name, ', '.join(cols))
582 return cls(main_clause, insert_args=new_values, duplicate_update_cols=cols)
583
584 @classmethod
585 def MakeUpdate(cls, table_name, delta):
586 """Construct an UPDATE statement."""
587 assert _IsValidTableName(table_name)
588 assert all(_IsValidColumnName(col) for col in delta.keys())
589 update_strs = []
590 update_args = []
591 for col, val in delta.items():
592 update_strs.append(col + '=%s')
593 update_args.append(val)
594
595 main_clause = 'UPDATE %s SET %s' % (
596 table_name, ', '.join(update_strs))
597 return cls(main_clause, update_args=update_args)
598
599 @classmethod
600 def MakeIncrement(cls, table_name, col_name, step=1):
601 """Construct an UPDATE statement that increments and returns a counter."""
602 assert _IsValidTableName(table_name)
603 assert _IsValidColumnName(col_name)
604
605 main_clause = (
606 'UPDATE %s SET %s = LAST_INSERT_ID(%s + %%s)' % (
607 table_name, col_name, col_name))
608 update_args = [step]
609 return cls(main_clause, update_args=update_args)
610
611 @classmethod
612 def MakeDelete(cls, table_name, or_where_conds=False):
613 """Construct a DELETE statement."""
614 assert _IsValidTableName(table_name)
615 main_clause = 'DELETE FROM %s' % table_name
616 return cls(main_clause, or_where_conds=or_where_conds)
617
618 def __init__(
619 self, main_clause, insert_args=None, update_args=None,
620 duplicate_update_cols=None, or_where_conds=False):
621 self.main_clause = main_clause # E.g., SELECT or DELETE
622 self.or_where_conds = or_where_conds
623 self.insert_args = insert_args or [] # For INSERT statements
624 for row_value in self.insert_args:
625 if not all(_IsValidDBValue(val) for val in row_value):
626 raise exceptions.InputException('Invalid DB value %r' % (row_value,))
627 self.update_args = update_args or [] # For UPDATEs
628 for val in self.update_args:
629 if not _IsValidDBValue(val):
630 raise exceptions.InputException('Invalid DB value %r' % val)
631 self.duplicate_update_cols = duplicate_update_cols or [] # For REPLACE-ish
632
633 self.use_clauses = []
634 self.join_clauses, self.join_args = [], []
635 self.where_conds, self.where_args = [], []
636 self.having_conds, self.having_args = [], []
637 self.group_by_terms, self.group_by_args = [], []
638 self.order_by_terms, self.order_by_args = [], []
639 self.limit, self.offset = None, None
640
641 def Generate(self):
642 """Return an SQL string having %s placeholders and args to fill them in."""
643 clauses = [self.main_clause] + self.use_clauses + self.join_clauses
644 if self.where_conds:
645 if self.or_where_conds:
646 clauses.append('WHERE ' + '\n OR '.join(self.where_conds))
647 else:
648 clauses.append('WHERE ' + '\n AND '.join(self.where_conds))
649 if self.group_by_terms:
650 clauses.append('GROUP BY ' + ', '.join(self.group_by_terms))
651 if self.having_conds:
652 assert self.group_by_terms
653 clauses.append('HAVING %s' % ','.join(self.having_conds))
654 if self.order_by_terms:
655 clauses.append('ORDER BY ' + ', '.join(self.order_by_terms))
656
657 if self.limit and self.offset:
658 clauses.append('LIMIT %d OFFSET %d' % (self.limit, self.offset))
659 elif self.limit:
660 clauses.append('LIMIT %d' % self.limit)
661 elif self.offset:
662 clauses.append('LIMIT %d OFFSET %d' % (sys.maxint, self.offset))
663
664 if self.insert_args:
665 clauses.append('VALUES (' + PlaceHolders(self.insert_args[0]) + ')')
666 args = self.insert_args
667 if self.duplicate_update_cols:
668 clauses.append('ON DUPLICATE KEY UPDATE %s' % (
669 ', '.join(['%s=VALUES(%s)' % (col, col)
670 for col in self.duplicate_update_cols])))
671 assert not (self.join_args + self.update_args + self.where_args +
672 self.group_by_args + self.order_by_args + self.having_args)
673 else:
674 args = (self.join_args + self.update_args + self.where_args +
675 self.group_by_args + self.having_args + self.order_by_args)
676 assert not (self.insert_args + self.duplicate_update_cols)
677
678 args = _BoolsToInts(args)
679 stmt_str = '\n'.join(clause for clause in clauses if clause)
680
681 assert _IsValidStatement(stmt_str), stmt_str
682 return stmt_str, args
683
684 def AddUseClause(self, use_clause):
685 """Add a USE clause (giving the DB a hint about which indexes to use)."""
686 assert _IsValidUseClause(use_clause), use_clause
687 self.use_clauses.append(use_clause)
688
689 def AddJoinClauses(self, join_pairs, left=False):
690 """Save JOIN clauses based on the given list of join conditions."""
691 for join, args in join_pairs:
692 assert _IsValidJoin(join), join
693 assert join.count('%s') == len(args), join
694 self.join_clauses.append(
695 ' %sJOIN %s' % (('LEFT ' if left else ''), join))
696 self.join_args.extend(args)
697
698 def AddGroupByTerms(self, group_by_term_list):
699 """Save info needed to generate the GROUP BY clause."""
700 assert all(_IsValidGroupByTerm(term) for term in group_by_term_list)
701 self.group_by_terms.extend(group_by_term_list)
702
703 def AddOrderByTerms(self, order_by_pairs):
704 """Save info needed to generate the ORDER BY clause."""
705 for term, args in order_by_pairs:
706 assert _IsValidOrderByTerm(term), term
707 assert term.count('%s') == len(args), term
708 self.order_by_terms.append(term)
709 self.order_by_args.extend(args)
710
711 def SetLimitAndOffset(self, limit, offset):
712 """Save info needed to generate the LIMIT OFFSET clause."""
713 self.limit = limit
714 self.offset = offset
715
716 def AddWhereTerms(self, where_cond_pairs, **kwargs):
717 """Generate a WHERE clause."""
718 where_cond_pairs = where_cond_pairs or []
719
720 for cond, args in where_cond_pairs:
721 assert _IsValidWhereCond(cond), cond
722 assert cond.count('%s') == len(args), cond
723 self.where_conds.append(cond)
724 self.where_args.extend(args)
725
726 for col, val in sorted(kwargs.items()):
727 assert _IsValidColumnName(col), col
728 eq = True
729 if col.endswith('_not'):
730 col = col[:-4]
731 eq = False
732
733 if isinstance(val, set):
734 val = list(val) # MySQL inteface cannot handle sets.
735
736 if val is None or val == []:
737 if val == [] and self.main_clause and self.main_clause.startswith(
738 'UPDATE'):
739 # https://crbug.com/monorail/6735: Avoid empty arrays for UPDATE.
740 raise exceptions.InputException('Invalid update DB value %r' % col)
741 op = 'IS' if eq else 'IS NOT'
742 self.where_conds.append(col + ' ' + op + ' NULL')
743 elif isinstance(val, list):
744 op = 'IN' if eq else 'NOT IN'
745 # Sadly, MySQLdb cannot escape lists, so we flatten to multiple "%s"s
746 self.where_conds.append(
747 col + ' ' + op + ' (' + PlaceHolders(val) + ')')
748 self.where_args.extend(val)
749 else:
750 op = '=' if eq else '!='
751 self.where_conds.append(col + ' ' + op + ' %s')
752 self.where_args.append(val)
753
754 def AddHavingTerms(self, having_cond_pairs):
755 """Generate a HAVING clause."""
756 for cond, args in having_cond_pairs:
757 assert _IsValidHavingCond(cond), cond
758 assert cond.count('%s') == len(args), cond
759 self.having_conds.append(cond)
760 self.having_args.extend(args)
761
762
763def PlaceHolders(sql_args):
764 """Return a comma-separated list of %s placeholders for the given args."""
765 return ','.join('%s' for _ in sql_args)
766
767
768TABLE_PAT = '[A-Z][_a-zA-Z0-9]+'
769COLUMN_PAT = '[a-z][_a-z]+'
770COMPARE_OP_PAT = '(<|>|=|!=|>=|<=|LIKE|NOT LIKE)'
771SHORTHAND = {
772 'table': TABLE_PAT,
773 'column': COLUMN_PAT,
774 'tab_col': r'(%s\.)?%s' % (TABLE_PAT, COLUMN_PAT),
775 'placeholder': '%s', # That's a literal %s that gets passed to MySQLdb
776 'multi_placeholder': '%s(, ?%s)*',
777 'compare_op': COMPARE_OP_PAT,
778 'opt_asc_desc': '( ASC| DESC)?',
779 'opt_alias': '( AS %s)?' % TABLE_PAT,
780 'email_cond': (r'\(?'
781 r'('
782 r'(LOWER\(Spare\d+\.email\) IS NULL OR )?'
783 r'LOWER\(Spare\d+\.email\) '
784 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
785 r'( (AND|OR) )?'
786 r')+'
787 r'\)?' % COMPARE_OP_PAT),
788 'hotlist_cond': (r'\(?'
789 r'('
790 r'(LOWER\(Cond\d+\.name\) IS NULL OR )?'
791 r'LOWER\(Cond\d+\.name\) '
792 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
793 r'( (AND|OR) )?'
794 r')+'
795 r'\)?' % COMPARE_OP_PAT),
796 'phase_cond': (r'\(?'
797 r'('
798 r'(LOWER\(Phase\d+\.name\) IS NULL OR )?'
799 r'LOWER\(Phase\d+\.name\) '
800 r'(%s %%s|IN \(%%s(, ?%%s)*\))?'
801 r'( (AND|OR) )?'
802 r')+'
803 r'\)?' % COMPARE_OP_PAT),
804 'approval_cond': (r'\(?'
805 r'('
806 r'(LOWER\(Cond\d+\.status\) IS NULL OR )?'
807 r'LOWER\(Cond\d+\.status\) '
808 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
809 r'( (AND|OR) )?'
810 r')+'
811 r'\)?' % COMPARE_OP_PAT),
812 }
813
814
815def _MakeRE(regex_str):
816 """Return a regular expression object, expanding our shorthand as needed."""
817 return re.compile(regex_str.format(**SHORTHAND))
818
819
820TABLE_RE = _MakeRE('^{table}$')
821TAB_COL_RE = _MakeRE('^{tab_col}$')
822USE_CLAUSE_RE = _MakeRE(
823 r'^USE INDEX \({column}\) USE INDEX FOR ORDER BY \({column}\)$')
824HAVING_RE_LIST = [
825 _MakeRE(r'^COUNT\(\*\) {compare_op} {placeholder}$')]
826COLUMN_RE_LIST = [
827 TAB_COL_RE,
828 _MakeRE(r'\*'),
829 _MakeRE(r'COUNT\(\*\)'),
830 _MakeRE(r'COUNT\({tab_col}\)'),
831 _MakeRE(r'COUNT\(DISTINCT\({tab_col}\)\)'),
832 _MakeRE(r'MAX\({tab_col}\)'),
833 _MakeRE(r'MIN\({tab_col}\)'),
834 _MakeRE(r'GROUP_CONCAT\((DISTINCT )?{tab_col}( ORDER BY {tab_col})?' \
835 r'( SEPARATOR \'.*\')?\)'),
836 ]
837JOIN_RE_LIST = [
838 TABLE_RE,
839 _MakeRE(
840 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
841 r'( AND {tab_col} = {tab_col})?'
842 r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
843 _MakeRE(
844 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
845 r'( AND {tab_col} = {tab_col})?'
846 r'( AND {tab_col} = {placeholder})?'
847 r'( AND {tab_col} IN \({multi_placeholder}\))?'
848 r'( AND {tab_col} = {tab_col})?$'),
849 _MakeRE(
850 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
851 r'( AND {tab_col} = {tab_col})?'
852 r'( AND {tab_col} = {placeholder})?'
853 r'( AND {tab_col} IN \({multi_placeholder}\))?'
854 r'( AND {tab_col} IS NULL)?'
855 r'( AND \({tab_col} IS NULL'
856 r' OR {tab_col} NOT IN \({multi_placeholder}\)\))?$'),
857 _MakeRE(
858 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
859 r'( AND {tab_col} = {tab_col})?'
860 r'( AND {tab_col} = {placeholder})?'
861 r' AND \(?{tab_col} {compare_op} {placeholder}\)?'
862 r'( AND {tab_col} = {tab_col})?$'),
863 _MakeRE(
864 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
865 r'( AND {tab_col} = {tab_col})?'
866 r'( AND {tab_col} = {placeholder})?'
867 r' AND {tab_col} = {tab_col}$'),
868 _MakeRE(
869 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
870 r'( AND {tab_col} = {tab_col})?'
871 r'( AND {tab_col} = {placeholder})?'
872 r' AND \({tab_col} IS NULL OR'
873 r' {tab_col} {compare_op} {placeholder}\)$'),
874 _MakeRE(
875 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
876 r' AND \({tab_col} IS NOT NULL AND {tab_col} != {placeholder}\)'),
877 _MakeRE(
878 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
879 r' AND LOWER\({tab_col}\) = LOWER\({placeholder}\)'),
880 _MakeRE(
881 r'^{table}{opt_alias} ON {tab_col} = {tab_col} AND {email_cond}$'),
882 _MakeRE(
883 r'^{table}{opt_alias} ON {email_cond}$'),
884 _MakeRE(
885 r'^{table}{opt_alias} ON '
886 r'\({tab_col} = {tab_col} OR {tab_col} = {tab_col}\)$'),
887 _MakeRE(
888 r'^\({table} AS {table} JOIN User AS {table} '
889 r'ON {tab_col} = {tab_col} AND {email_cond}\) '
890 r'ON Issue(Snapshot)?.id = {tab_col}'
891 r'( AND {tab_col} IS NULL)?'),
892 _MakeRE(
893 r'^\({table} JOIN Hotlist AS {table} '
894 r'ON {tab_col} = {tab_col} AND {hotlist_cond}\) '
895 r'ON Issue.id = {tab_col}?'),
896 _MakeRE(
897 r'^\({table} AS {table} JOIN IssuePhaseDef AS {table} '
898 r'ON {tab_col} = {tab_col} AND {phase_cond}\) '
899 r'ON Issue.id = {tab_col}?'),
900 _MakeRE(
901 r'^IssuePhaseDef AS {table} ON {phase_cond}'),
902 _MakeRE(
903 r'^Issue2ApprovalValue AS {table} ON {tab_col} = {tab_col} '
904 r'AND {tab_col} = {placeholder} AND {approval_cond}'),
905 _MakeRE(
906 r'^{table} AS {table} ON {tab_col} = {tab_col} '
907 r'LEFT JOIN {table} AS {table} ON {tab_col} = {tab_col}'),
908 ]
909ORDER_BY_RE_LIST = [
910 _MakeRE(r'^{tab_col}{opt_asc_desc}$'),
911 _MakeRE(r'^LOWER\({tab_col}\){opt_asc_desc}$'),
912 _MakeRE(r'^ISNULL\({tab_col}\){opt_asc_desc}$'),
913 _MakeRE(r'^\(ISNULL\({tab_col}\) AND ISNULL\({tab_col}\)\){opt_asc_desc}$'),
914 _MakeRE(r'^FIELD\({tab_col}, {multi_placeholder}\){opt_asc_desc}$'),
915 _MakeRE(r'^FIELD\(IF\(ISNULL\({tab_col}\), {tab_col}, {tab_col}\), '
916 r'{multi_placeholder}\){opt_asc_desc}$'),
917 _MakeRE(r'^CONCAT\({tab_col}, {tab_col}\){opt_asc_desc}$'),
918 ]
919GROUP_BY_RE_LIST = [
920 TAB_COL_RE,
921 ]
922WHERE_COND_RE_LIST = [
923 _MakeRE(r'^TRUE$'),
924 _MakeRE(r'^FALSE$'),
925 _MakeRE(r'^{tab_col} IS NULL$'),
926 _MakeRE(r'^{tab_col} IS NOT NULL$'),
927 _MakeRE(r'^{tab_col} {compare_op} {tab_col}$'),
928 _MakeRE(r'^{tab_col} {compare_op} {placeholder}$'),
929 _MakeRE(r'^{tab_col} %% {placeholder} = {placeholder}$'),
930 _MakeRE(r'^{tab_col} IN \({multi_placeholder}\)$'),
931 _MakeRE(r'^{tab_col} NOT IN \({multi_placeholder}\)$'),
932 _MakeRE(r'^LOWER\({tab_col}\) IS NULL$'),
933 _MakeRE(r'^LOWER\({tab_col}\) IS NOT NULL$'),
934 _MakeRE(r'^LOWER\({tab_col}\) {compare_op} {placeholder}$'),
935 _MakeRE(r'^LOWER\({tab_col}\) IN \({multi_placeholder}\)$'),
936 _MakeRE(r'^LOWER\({tab_col}\) NOT IN \({multi_placeholder}\)$'),
937 _MakeRE(r'^LOWER\({tab_col}\) LIKE {placeholder}$'),
938 _MakeRE(r'^LOWER\({tab_col}\) NOT LIKE {placeholder}$'),
939 _MakeRE(r'^timestep < \(SELECT MAX\(j.timestep\) FROM Invalidate AS j '
940 r'WHERE j.kind = %s '
941 r'AND j.cache_key = Invalidate.cache_key\)$'),
942 _MakeRE(r'^\({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}\) '
943 'AND \({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}'
944 '\)$'),
945 _MakeRE(r'^\({tab_col} IS NOT NULL AND {tab_col} {compare_op} '
946 '{placeholder}\) OR \({tab_col} IS NOT NULL AND {tab_col} '
947 '{compare_op} {placeholder}\)$'),
948 ]
949
950# Note: We never use ';' for multiple statements, '@' for SQL variables, or
951# any quoted strings in stmt_str (quotes are put in my MySQLdb for args).
952STMT_STR_RE = re.compile(
953 r'\A(SELECT|UPDATE|DELETE|INSERT|REPLACE) [\'-+=!<>%*.,()\w\s]+\Z',
954 re.MULTILINE)
955
956
957def _IsValidDBValue(val):
958 if isinstance(val, string_types):
959 return '\x00' not in val
960 return True
961
962
963def _IsValidTableName(table_name):
964 return TABLE_RE.match(table_name)
965
966
967def _IsValidColumnName(column_expr):
968 return any(regex.match(column_expr) for regex in COLUMN_RE_LIST)
969
970
971def _IsValidUseClause(use_clause):
972 return USE_CLAUSE_RE.match(use_clause)
973
974def _IsValidHavingCond(cond):
975 if cond.startswith('(') and cond.endswith(')'):
976 cond = cond[1:-1]
977
978 if ' OR ' in cond:
979 return all(_IsValidHavingCond(c) for c in cond.split(' OR '))
980
981 if ' AND ' in cond:
982 return all(_IsValidHavingCond(c) for c in cond.split(' AND '))
983
984 return any(regex.match(cond) for regex in HAVING_RE_LIST)
985
986
987def _IsValidJoin(join):
988 return any(regex.match(join) for regex in JOIN_RE_LIST)
989
990
991def _IsValidOrderByTerm(term):
992 return any(regex.match(term) for regex in ORDER_BY_RE_LIST)
993
994
995def _IsValidGroupByTerm(term):
996 return any(regex.match(term) for regex in GROUP_BY_RE_LIST)
997
998
999def _IsValidWhereCond(cond):
1000 if cond.startswith('NOT '):
1001 cond = cond[4:]
1002 if cond.startswith('(') and cond.endswith(')'):
1003 cond = cond[1:-1]
1004
1005 if any(regex.match(cond) for regex in WHERE_COND_RE_LIST):
1006 return True
1007
1008 if ' OR ' in cond:
1009 return all(_IsValidWhereCond(c) for c in cond.split(' OR '))
1010
1011 if ' AND ' in cond:
1012 return all(_IsValidWhereCond(c) for c in cond.split(' AND '))
1013
1014 return False
1015
1016
1017def _IsValidStatement(stmt_str):
1018 """Final check to make sure there is no funny junk sneaking in somehow."""
1019 return (STMT_STR_RE.match(stmt_str) and
1020 '--' not in stmt_str)
1021
1022
1023def _BoolsToInts(arg_list):
1024 """Convert any True values to 1s and Falses to 0s.
1025
1026 Google's copy of MySQLdb has bool-to-int conversion disabled,
1027 and yet it seems to be needed otherwise they are converted
1028 to strings and always interpreted as 0 (which is FALSE).
1029
1030 Args:
1031 arg_list: (nested) list of SQL statment argument values, which may
1032 include some boolean values.
1033
1034 Returns:
1035 The same list, but with True replaced by 1 and False replaced by 0.
1036 """
1037 result = []
1038 for arg in arg_list:
1039 if isinstance(arg, (list, tuple)):
1040 result.append(_BoolsToInts(arg))
1041 elif arg is True:
1042 result.append(1)
1043 elif arg is False:
1044 result.append(0)
1045 else:
1046 result.append(arg)
1047
1048 return result