blob: a0a86af251ec3ff82e862522f263120bbf73205a [file] [log] [blame]
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +01001# Copyright 2016 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
Copybara854996b2021-09-07 19:36:02 +00004
5"""A set of classes for interacting with tables in SQL."""
6from __future__ import print_function
7from __future__ import division
8from __future__ import absolute_import
9
10import logging
11import random
12import re
13import sys
14import time
15
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +010016import six
17from six.moves import queue
Copybara854996b2021-09-07 19:36:02 +000018
19import settings
20
21if not settings.unit_test_mode:
22 import MySQLdb
23
24from framework import exceptions
25from framework import framework_helpers
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +020026from framework import logger
Copybara854996b2021-09-07 19:36:02 +000027
28from infra_libs import ts_mon
29
Copybara854996b2021-09-07 19:36:02 +000030
31class ConnectionPool(object):
32 """Manage a set of database connections such that they may be re-used.
33 """
34
35 def __init__(self, poolsize=1):
36 self.poolsize = poolsize
37 self.queues = {}
38
39 def get(self, instance, database):
40 """Retun a database connection, or throw an exception if none can
41 be made.
42 """
43 key = instance + '/' + database
44
45 if not key in self.queues:
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +010046 q = queue.Queue(self.poolsize)
47 self.queues[key] = q
Copybara854996b2021-09-07 19:36:02 +000048
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +010049 q = self.queues[key]
Copybara854996b2021-09-07 19:36:02 +000050
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +010051 if q.empty():
Copybara854996b2021-09-07 19:36:02 +000052 cnxn = cnxn_ctor(instance, database)
53 else:
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +010054 cnxn = q.get()
Copybara854996b2021-09-07 19:36:02 +000055 # Make sure the connection is still good.
56 cnxn.ping()
57 cnxn.commit()
58
59 return cnxn
60
61 def release(self, cnxn):
62 if not cnxn.pool_key in self.queues:
63 raise BaseException('unknown pool key: %s' % cnxn.pool_key)
64
65 q = self.queues[cnxn.pool_key]
66 if q.full():
67 cnxn.close()
68 else:
69 q.put(cnxn)
70
71
72@framework_helpers.retry(1, delay=1, backoff=2)
73def cnxn_ctor(instance, database):
74 logging.info('About to connect to SQL instance %r db %r', instance, database)
75 if settings.unit_test_mode:
76 raise ValueError('unit tests should not need real database connections')
77 try:
78 if settings.local_mode:
79 start_time = time.time()
80 cnxn = MySQLdb.connect(
Adrià Vilanova Martínez5ccb1a02024-01-26 21:19:41 +010081 host='127.0.0.1', port=3306, db=database, user='root', password='root', charset='utf8')
Copybara854996b2021-09-07 19:36:02 +000082 else:
83 start_time = time.time()
84 cnxn = MySQLdb.connect(
Adrià Vilanova Martínezf5e10392021-12-07 22:55:40 +010085 unix_socket='/cloudsql/' + instance,
86 db=database,
87 user='monorail-mysql',
88 charset='utf8')
Copybara854996b2021-09-07 19:36:02 +000089 duration = int((time.time() - start_time) * 1000)
90 DB_CNXN_LATENCY.add(duration)
91 CONNECTION_COUNT.increment({'success': True})
92 except MySQLdb.OperationalError:
93 CONNECTION_COUNT.increment({'success': False})
94 raise
95 cnxn.pool_key = instance + '/' + database
96 cnxn.is_bad = False
97 return cnxn
98
99
100# One connection pool per database instance (primary, replicas are each an
101# instance). We'll have four connections per instance because we fetch
102# issue comments, stars, spam verdicts and spam verdict history in parallel
103# with promises.
104cnxn_pool = ConnectionPool(settings.db_cnxn_pool_size)
105
106# MonorailConnection maintains a dictionary of connections to SQL databases.
107# Each is identified by an int shard ID.
108# And there is one connection to the primary DB identified by key PRIMARY_CNXN.
109PRIMARY_CNXN = 'primary_cnxn'
110
111# When one replica is temporarily unresponseive, we can use a different one.
112BAD_SHARD_AVOIDANCE_SEC = 45
113
114
115CONNECTION_COUNT = ts_mon.CounterMetric(
116 'monorail/sql/connection_count',
117 'Count of connections made to the SQL database.',
118 [ts_mon.BooleanField('success')])
119
120DB_CNXN_LATENCY = ts_mon.CumulativeDistributionMetric(
121 'monorail/sql/db_cnxn_latency',
122 'Time needed to establish a DB connection.',
123 None)
124
125DB_QUERY_LATENCY = ts_mon.CumulativeDistributionMetric(
126 'monorail/sql/db_query_latency',
127 'Time needed to make a DB query.',
128 [ts_mon.StringField('type')])
129
130DB_COMMIT_LATENCY = ts_mon.CumulativeDistributionMetric(
131 'monorail/sql/db_commit_latency',
132 'Time needed to make a DB commit.',
133 None)
134
135DB_ROLLBACK_LATENCY = ts_mon.CumulativeDistributionMetric(
136 'monorail/sql/db_rollback_latency',
137 'Time needed to make a DB rollback.',
138 None)
139
140DB_RETRY_COUNT = ts_mon.CounterMetric(
141 'monorail/sql/db_retry_count',
142 'Count of queries retried.',
143 None)
144
145DB_QUERY_COUNT = ts_mon.CounterMetric(
146 'monorail/sql/db_query_count',
147 'Count of queries sent to the DB.',
148 [ts_mon.StringField('type')])
149
150DB_COMMIT_COUNT = ts_mon.CounterMetric(
151 'monorail/sql/db_commit_count',
152 'Count of commits sent to the DB.',
153 None)
154
155DB_ROLLBACK_COUNT = ts_mon.CounterMetric(
156 'monorail/sql/db_rollback_count',
157 'Count of rollbacks sent to the DB.',
158 None)
159
160DB_RESULT_ROWS = ts_mon.CumulativeDistributionMetric(
161 'monorail/sql/db_result_rows',
162 'Number of results returned by a DB query.',
163 None)
164
165
166def RandomShardID():
167 """Return a random shard ID to load balance across replicas."""
168 return random.randint(0, settings.num_logical_shards - 1)
169
170
171class MonorailConnection(object):
172 """Create and manage connections to the SQL servers.
173
174 We only store connections in the context of a single user request, not
175 across user requests. The main purpose of this class is to make using
176 sharded tables easier.
177 """
178 unavailable_shards = {} # {shard_id: timestamp of failed attempt}
179
180 def __init__(self):
181 self.sql_cnxns = {} # {PRIMARY_CNXN: cnxn, shard_id: cnxn, ...}
182
183 @framework_helpers.retry(1, delay=0.1, backoff=2)
184 def GetPrimaryConnection(self):
185 """Return a connection to the primary SQL DB."""
186 if PRIMARY_CNXN not in self.sql_cnxns:
187 self.sql_cnxns[PRIMARY_CNXN] = cnxn_pool.get(
188 settings.db_instance, settings.db_database_name)
189 logging.info(
190 'created a primary connection %r', self.sql_cnxns[PRIMARY_CNXN])
191
192 return self.sql_cnxns[PRIMARY_CNXN]
193
194 @framework_helpers.retry(1, delay=0.1, backoff=2)
195 def GetConnectionForShard(self, shard_id):
196 """Return a connection to the DB replica that will be used for shard_id."""
197 if shard_id not in self.sql_cnxns:
198 physical_shard_id = shard_id % settings.num_logical_shards
199
200 replica_name = settings.db_replica_names[
201 physical_shard_id % len(settings.db_replica_names)]
202 shard_instance_name = (
203 settings.physical_db_name_format % replica_name)
204 self.unavailable_shards[shard_id] = int(time.time())
205 self.sql_cnxns[shard_id] = cnxn_pool.get(
206 shard_instance_name, settings.db_database_name)
207 del self.unavailable_shards[shard_id]
208 logging.info('created a replica connection for shard %d', shard_id)
209
210 return self.sql_cnxns[shard_id]
211
212 def Execute(self, stmt_str, stmt_args, shard_id=None, commit=True, retries=2):
213 """Execute the given SQL statement on one of the relevant databases."""
214 if shard_id is None:
215 # No shard was specified, so hit the primary.
216 sql_cnxn = self.GetPrimaryConnection()
217 else:
218 if shard_id in self.unavailable_shards:
219 bad_age_sec = int(time.time()) - self.unavailable_shards[shard_id]
220 if bad_age_sec < BAD_SHARD_AVOIDANCE_SEC:
221 logging.info('Avoiding bad replica %r, age %r', shard_id, bad_age_sec)
222 shard_id = (shard_id + 1) % settings.num_logical_shards
223 sql_cnxn = self.GetConnectionForShard(shard_id)
224
225 try:
226 return self._ExecuteWithSQLConnection(
227 sql_cnxn, stmt_str, stmt_args, commit=commit)
228 except MySQLdb.OperationalError as e:
229 logging.exception(e)
230 logging.info('retries: %r', retries)
231 if retries > 0:
232 DB_RETRY_COUNT.increment()
233 self.sql_cnxns = {} # Drop all old mysql connections and make new.
234 return self.Execute(
235 stmt_str, stmt_args, shard_id=shard_id, commit=commit,
236 retries=retries - 1)
237 else:
238 raise e
239
240 def _ExecuteWithSQLConnection(
241 self, sql_cnxn, stmt_str, stmt_args, commit=True):
242 """Execute a statement on the given database and return a cursor."""
243
244 start_time = time.time()
245 cursor = sql_cnxn.cursor()
246 cursor.execute('SET NAMES utf8mb4')
247 if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
248 cursor.executemany(stmt_str, stmt_args)
249 duration = (time.time() - start_time) * 1000
250 DB_QUERY_LATENCY.add(duration, {'type': 'write'})
251 DB_QUERY_COUNT.increment({'type': 'write'})
252 else:
253 cursor.execute(stmt_str, args=stmt_args)
254 duration = (time.time() - start_time) * 1000
255 DB_QUERY_LATENCY.add(duration, {'type': 'read'})
256 DB_QUERY_COUNT.increment({'type': 'read'})
257 DB_RESULT_ROWS.add(cursor.rowcount)
258
259 if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +0200260 formatted_statement = ('%s %s' % (stmt_str, stmt_args)).replace('\n', ' ')
Copybara854996b2021-09-07 19:36:02 +0000261 else:
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +0200262 formatted_statement = (stmt_str % tuple(stmt_args)).replace('\n', ' ')
Copybara854996b2021-09-07 19:36:02 +0000263 logging.info(
264 '%d rows in %d ms: %s', cursor.rowcount, int(duration),
Adrià Vilanova Martínez9f9ade52022-10-10 23:20:11 +0200265 formatted_statement)
266 if duration >= 2000:
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +0100267 logger.log(
268 {
269 'log_type': 'database/query',
270 'statement': formatted_statement[:100000],
271 'type': formatted_statement.split(' ')[0],
272 'duration': duration / 1000,
273 'row_count': cursor.rowcount,
274 })
Copybara854996b2021-09-07 19:36:02 +0000275
276 if commit and not stmt_str.startswith('SELECT'):
277 try:
278 sql_cnxn.commit()
279 duration = (time.time() - start_time) * 1000
280 DB_COMMIT_LATENCY.add(duration)
281 DB_COMMIT_COUNT.increment()
282 except MySQLdb.DatabaseError:
283 sql_cnxn.rollback()
284 duration = (time.time() - start_time) * 1000
285 DB_ROLLBACK_LATENCY.add(duration)
286 DB_ROLLBACK_COUNT.increment()
287
288 return cursor
289
290 def Commit(self):
291 """Explicitly commit any pending txns. Normally done automatically."""
292 sql_cnxn = self.GetPrimaryConnection()
293 try:
294 sql_cnxn.commit()
295 except MySQLdb.DatabaseError:
296 logging.exception('Commit failed for cnxn, rolling back')
297 sql_cnxn.rollback()
298
299 def Close(self):
300 """Safely close any connections that are still open."""
301 for sql_cnxn in self.sql_cnxns.values():
302 try:
303 sql_cnxn.rollback() # Abandon any uncommitted changes.
304 cnxn_pool.release(sql_cnxn)
305 except MySQLdb.DatabaseError:
306 # This might happen if the cnxn is somehow already closed.
307 logging.exception('ProgrammingError when trying to close cnxn')
308
309
310class SQLTableManager(object):
311 """Helper class to make it easier to deal with an SQL table."""
312
313 def __init__(self, table_name):
314 self.table_name = table_name
315
316 def Select(
317 self, cnxn, distinct=False, cols=None, left_joins=None,
318 joins=None, where=None, or_where_conds=False, group_by=None,
319 order_by=None, limit=None, offset=None, shard_id=None, use_clause=None,
320 having=None, **kwargs):
321 """Compose and execute an SQL SELECT statement on this table.
322
323 Args:
324 cnxn: MonorailConnection to the databases.
325 distinct: If True, add DISTINCT keyword.
326 cols: List of columns to retrieve, defaults to '*'.
327 left_joins: List of LEFT JOIN (str, args) pairs.
328 joins: List of regular JOIN (str, args) pairs.
329 where: List of (str, args) for WHERE clause.
330 or_where_conds: Set to True to use OR in the WHERE conds.
331 group_by: List of strings for GROUP BY clause.
332 order_by: List of (str, args) for ORDER BY clause.
333 limit: Optional LIMIT on the number of rows returned.
334 offset: Optional OFFSET when using LIMIT.
335 shard_id: Int ID of the shard to query.
336 use_clause: Optional string USE clause to tell the DB which index to use.
337 having: List of (str, args) for Optional HAVING clause
338 **kwargs: WHERE-clause equality and set-membership conditions.
339
340 Keyword args are used to build up more WHERE conditions that compare
341 column values to constants. Key word Argument foo='bar' translates to 'foo
342 = "bar"', and foo=[3, 4, 5] translates to 'foo IN (3, 4, 5)'.
343
344 Returns:
345 A list of rows, each row is a tuple of values for the requested cols.
346 """
347 cols = cols or ['*'] # If columns not specified, retrieve all columns.
348 stmt = Statement.MakeSelect(
349 self.table_name, cols, distinct=distinct,
350 or_where_conds=or_where_conds)
351 if use_clause:
352 stmt.AddUseClause(use_clause)
353 if having:
354 stmt.AddHavingTerms(having)
355 stmt.AddJoinClauses(left_joins or [], left=True)
356 stmt.AddJoinClauses(joins or [])
357 stmt.AddWhereTerms(where or [], **kwargs)
358 stmt.AddGroupByTerms(group_by or [])
359 stmt.AddOrderByTerms(order_by or [])
360 stmt.SetLimitAndOffset(limit, offset)
361 stmt_str, stmt_args = stmt.Generate()
362
363 cursor = cnxn.Execute(stmt_str, stmt_args, shard_id=shard_id)
364 rows = cursor.fetchall()
365 cursor.close()
366 return rows
367
368 def SelectRow(
369 self, cnxn, cols=None, default=None, where=None, **kwargs):
370 """Run a query that is expected to return just one row."""
371 rows = self.Select(cnxn, distinct=True, cols=cols, where=where, **kwargs)
372 if len(rows) == 1:
373 return rows[0]
374 elif not rows:
375 logging.info('SelectRow got 0 results, so using default %r', default)
376 return default
377 else:
378 raise ValueError('SelectRow got %d results, expected only 1', len(rows))
379
380 def SelectValue(self, cnxn, col, default=None, where=None, **kwargs):
381 """Run a query that is expected to return just one row w/ one value."""
382 row = self.SelectRow(
383 cnxn, cols=[col], default=[default], where=where, **kwargs)
384 return row[0]
385
386 def InsertRows(
387 self, cnxn, cols, row_values, replace=False, ignore=False,
388 commit=True, return_generated_ids=False):
389 """Insert all the given rows.
390
391 Args:
392 cnxn: MonorailConnection object.
393 cols: List of column names to set.
394 row_values: List of lists with values to store. The length of each
395 nested list should be equal to len(cols).
396 replace: Set to True if inserted values should replace existing DB rows
397 that have the same DB keys.
398 ignore: Set to True to ignore rows that would duplicate existing DB keys.
399 commit: Set to False if this operation is part of a series of operations
400 that should not be committed until the final one is done.
401 return_generated_ids: Set to True to return a list of generated
402 autoincrement IDs for inserted rows. This requires us to insert rows
403 one at a time.
404
405 Returns:
406 If return_generated_ids is set to True, this method returns a list of the
407 auto-increment IDs generated by the DB. Otherwise, [] is returned.
408 """
409 if not row_values:
410 return None # Nothing to insert
411
412 generated_ids = []
413 if return_generated_ids:
414 # We must insert the rows one-at-a-time to know the generated IDs.
415 for row_value in row_values:
416 stmt = Statement.MakeInsert(
417 self.table_name, cols, [row_value], replace=replace, ignore=ignore)
418 stmt_str, stmt_args = stmt.Generate()
419 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
420 if cursor.lastrowid:
421 generated_ids.append(cursor.lastrowid)
422 cursor.close()
423 return generated_ids
424
425 stmt = Statement.MakeInsert(
426 self.table_name, cols, row_values, replace=replace, ignore=ignore)
427 stmt_str, stmt_args = stmt.Generate()
428 cnxn.Execute(stmt_str, stmt_args, commit=commit)
429 return []
430
431
432 def InsertRow(
433 self, cnxn, replace=False, ignore=False, commit=True, **kwargs):
434 """Insert a single row into the table.
435
436 Args:
437 cnxn: MonorailConnection object.
438 replace: Set to True if inserted values should replace existing DB rows
439 that have the same DB keys.
440 ignore: Set to True to ignore rows that would duplicate existing DB keys.
441 commit: Set to False if this operation is part of a series of operations
442 that should not be committed until the final one is done.
443 **kwargs: column=value assignments to specify what to store in the DB.
444
445 Returns:
446 The generated autoincrement ID of the key column if one was generated.
447 Otherwise, return None.
448 """
449 cols = sorted(kwargs.keys())
450 row = tuple(kwargs[col] for col in cols)
451 generated_ids = self.InsertRows(
452 cnxn, cols, [row], replace=replace, ignore=ignore,
453 commit=commit, return_generated_ids=True)
454 if generated_ids:
455 return generated_ids[0]
456 else:
457 return None
458
459 def Update(self, cnxn, delta, where=None, commit=True, limit=None, **kwargs):
460 """Update one or more rows.
461
462 Args:
463 cnxn: MonorailConnection object.
464 delta: Dictionary of {column: new_value} assignments.
465 where: Optional list of WHERE conditions saying which rows to update.
466 commit: Set to False if this operation is part of a series of operations
467 that should not be committed until the final one is done.
468 limit: Optional LIMIT on the number of rows updated.
469 **kwargs: WHERE-clause equality and set-membership conditions.
470
471 Returns:
472 Int number of rows updated.
473 """
474 if not delta:
475 return 0 # Nothing is being changed
476
477 stmt = Statement.MakeUpdate(self.table_name, delta)
478 stmt.AddWhereTerms(where, **kwargs)
479 stmt.SetLimitAndOffset(limit, None)
480 stmt_str, stmt_args = stmt.Generate()
481
482 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
483 result = cursor.rowcount
484 cursor.close()
485 return result
486
487 def IncrementCounterValue(self, cnxn, col_name, where=None, **kwargs):
488 """Atomically increment a counter stored in MySQL, return new value.
489
490 Args:
491 cnxn: MonorailConnection object.
492 col_name: int column to increment.
493 where: Optional list of WHERE conditions saying which rows to update.
494 **kwargs: WHERE-clause equality and set-membership conditions. The
495 where and kwargs together should narrow the update down to exactly
496 one row.
497
498 Returns:
499 The new, post-increment value of the counter.
500 """
501 stmt = Statement.MakeIncrement(self.table_name, col_name)
502 stmt.AddWhereTerms(where, **kwargs)
503 stmt_str, stmt_args = stmt.Generate()
504
505 cursor = cnxn.Execute(stmt_str, stmt_args)
506 assert cursor.rowcount == 1, (
507 'missing or ambiguous counter: %r' % cursor.rowcount)
508 result = cursor.lastrowid
509 cursor.close()
510 return result
511
512 def Delete(self, cnxn, where=None, or_where_conds=False, commit=True,
513 limit=None, **kwargs):
514 """Delete the specified table rows.
515
516 Args:
517 cnxn: MonorailConnection object.
518 where: Optional list of WHERE conditions saying which rows to update.
519 or_where_conds: Set to True to use OR in the WHERE conds.
520 commit: Set to False if this operation is part of a series of operations
521 that should not be committed until the final one is done.
522 limit: Optional LIMIT on the number of rows deleted.
523 **kwargs: WHERE-clause equality and set-membership conditions.
524
525 Returns:
526 Int number of rows updated.
527 """
528 # Deleting the whole table is never intended in Monorail.
529 assert where or kwargs
530
531 stmt = Statement.MakeDelete(self.table_name, or_where_conds=or_where_conds)
532 stmt.AddWhereTerms(where, **kwargs)
533 stmt.SetLimitAndOffset(limit, None)
534 stmt_str, stmt_args = stmt.Generate()
535
536 cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
537 result = cursor.rowcount
538 cursor.close()
539 return result
540
541
542class Statement(object):
543 """A class to help build complex SQL statements w/ full escaping.
544
545 Start with a Make*() method, then fill in additional clauses as needed,
546 then call Generate() to return the SQL string and argument list. We pass
547 the string and args to MySQLdb separately so that it can do escaping on
548 the arg values as appropriate to prevent SQL-injection attacks.
549
550 The only values that are not escaped by MySQLdb are the table names
551 and column names, and bits of SQL syntax, all of which is hard-coded
552 in our application.
553 """
554
555 @classmethod
556 def MakeSelect(cls, table_name, cols, distinct=False, or_where_conds=False):
557 """Construct a SELECT statement."""
558 assert _IsValidTableName(table_name)
559 assert all(_IsValidColumnName(col) for col in cols)
560 main_clause = 'SELECT%s %s FROM %s' % (
561 (' DISTINCT' if distinct else ''), ', '.join(cols), table_name)
562 return cls(main_clause, or_where_conds=or_where_conds)
563
564 @classmethod
565 def MakeInsert(
566 cls, table_name, cols, new_values, replace=False, ignore=False):
567 """Construct an INSERT statement."""
568 if replace == True:
569 return cls.MakeReplace(table_name, cols, new_values, ignore)
570 assert _IsValidTableName(table_name)
571 assert all(_IsValidColumnName(col) for col in cols)
572 ignore_word = ' IGNORE' if ignore else ''
573 main_clause = 'INSERT%s INTO %s (%s)' % (
574 ignore_word, table_name, ', '.join(cols))
575 return cls(main_clause, insert_args=new_values)
576
577 @classmethod
578 def MakeReplace(
579 cls, table_name, cols, new_values, ignore=False):
580 """Construct an INSERT...ON DUPLICATE KEY UPDATE... statement.
581
582 Uses the INSERT/UPDATE syntax because REPLACE is literally a DELETE
583 followed by an INSERT, which doesn't play well with foreign keys.
584 INSERT/UPDATE is an atomic check of whether the primary key exists,
585 followed by an INSERT if it doesn't or an UPDATE if it does.
586 """
587 assert _IsValidTableName(table_name)
588 assert all(_IsValidColumnName(col) for col in cols)
589 ignore_word = ' IGNORE' if ignore else ''
590 main_clause = 'INSERT%s INTO %s (%s)' % (
591 ignore_word, table_name, ', '.join(cols))
592 return cls(main_clause, insert_args=new_values, duplicate_update_cols=cols)
593
594 @classmethod
595 def MakeUpdate(cls, table_name, delta):
596 """Construct an UPDATE statement."""
597 assert _IsValidTableName(table_name)
598 assert all(_IsValidColumnName(col) for col in delta.keys())
599 update_strs = []
600 update_args = []
601 for col, val in delta.items():
602 update_strs.append(col + '=%s')
603 update_args.append(val)
604
605 main_clause = 'UPDATE %s SET %s' % (
606 table_name, ', '.join(update_strs))
607 return cls(main_clause, update_args=update_args)
608
609 @classmethod
610 def MakeIncrement(cls, table_name, col_name, step=1):
611 """Construct an UPDATE statement that increments and returns a counter."""
612 assert _IsValidTableName(table_name)
613 assert _IsValidColumnName(col_name)
614
615 main_clause = (
616 'UPDATE %s SET %s = LAST_INSERT_ID(%s + %%s)' % (
617 table_name, col_name, col_name))
618 update_args = [step]
619 return cls(main_clause, update_args=update_args)
620
621 @classmethod
622 def MakeDelete(cls, table_name, or_where_conds=False):
623 """Construct a DELETE statement."""
624 assert _IsValidTableName(table_name)
625 main_clause = 'DELETE FROM %s' % table_name
626 return cls(main_clause, or_where_conds=or_where_conds)
627
628 def __init__(
629 self, main_clause, insert_args=None, update_args=None,
630 duplicate_update_cols=None, or_where_conds=False):
631 self.main_clause = main_clause # E.g., SELECT or DELETE
632 self.or_where_conds = or_where_conds
633 self.insert_args = insert_args or [] # For INSERT statements
634 for row_value in self.insert_args:
635 if not all(_IsValidDBValue(val) for val in row_value):
636 raise exceptions.InputException('Invalid DB value %r' % (row_value,))
637 self.update_args = update_args or [] # For UPDATEs
638 for val in self.update_args:
639 if not _IsValidDBValue(val):
640 raise exceptions.InputException('Invalid DB value %r' % val)
641 self.duplicate_update_cols = duplicate_update_cols or [] # For REPLACE-ish
642
643 self.use_clauses = []
644 self.join_clauses, self.join_args = [], []
645 self.where_conds, self.where_args = [], []
646 self.having_conds, self.having_args = [], []
647 self.group_by_terms, self.group_by_args = [], []
648 self.order_by_terms, self.order_by_args = [], []
649 self.limit, self.offset = None, None
650
651 def Generate(self):
652 """Return an SQL string having %s placeholders and args to fill them in."""
653 clauses = [self.main_clause] + self.use_clauses + self.join_clauses
654 if self.where_conds:
655 if self.or_where_conds:
656 clauses.append('WHERE ' + '\n OR '.join(self.where_conds))
657 else:
658 clauses.append('WHERE ' + '\n AND '.join(self.where_conds))
659 if self.group_by_terms:
660 clauses.append('GROUP BY ' + ', '.join(self.group_by_terms))
661 if self.having_conds:
662 assert self.group_by_terms
663 clauses.append('HAVING %s' % ','.join(self.having_conds))
664 if self.order_by_terms:
665 clauses.append('ORDER BY ' + ', '.join(self.order_by_terms))
666
667 if self.limit and self.offset:
668 clauses.append('LIMIT %d OFFSET %d' % (self.limit, self.offset))
669 elif self.limit:
670 clauses.append('LIMIT %d' % self.limit)
671 elif self.offset:
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +0100672 clauses.append('LIMIT %d OFFSET %d' % (sys.maxsize, self.offset))
Copybara854996b2021-09-07 19:36:02 +0000673
674 if self.insert_args:
675 clauses.append('VALUES (' + PlaceHolders(self.insert_args[0]) + ')')
676 args = self.insert_args
677 if self.duplicate_update_cols:
678 clauses.append('ON DUPLICATE KEY UPDATE %s' % (
679 ', '.join(['%s=VALUES(%s)' % (col, col)
680 for col in self.duplicate_update_cols])))
681 assert not (self.join_args + self.update_args + self.where_args +
682 self.group_by_args + self.order_by_args + self.having_args)
683 else:
684 args = (self.join_args + self.update_args + self.where_args +
685 self.group_by_args + self.having_args + self.order_by_args)
686 assert not (self.insert_args + self.duplicate_update_cols)
687
688 args = _BoolsToInts(args)
689 stmt_str = '\n'.join(clause for clause in clauses if clause)
690
691 assert _IsValidStatement(stmt_str), stmt_str
692 return stmt_str, args
693
694 def AddUseClause(self, use_clause):
695 """Add a USE clause (giving the DB a hint about which indexes to use)."""
696 assert _IsValidUseClause(use_clause), use_clause
697 self.use_clauses.append(use_clause)
698
699 def AddJoinClauses(self, join_pairs, left=False):
700 """Save JOIN clauses based on the given list of join conditions."""
701 for join, args in join_pairs:
702 assert _IsValidJoin(join), join
703 assert join.count('%s') == len(args), join
704 self.join_clauses.append(
705 ' %sJOIN %s' % (('LEFT ' if left else ''), join))
706 self.join_args.extend(args)
707
708 def AddGroupByTerms(self, group_by_term_list):
709 """Save info needed to generate the GROUP BY clause."""
710 assert all(_IsValidGroupByTerm(term) for term in group_by_term_list)
711 self.group_by_terms.extend(group_by_term_list)
712
713 def AddOrderByTerms(self, order_by_pairs):
714 """Save info needed to generate the ORDER BY clause."""
715 for term, args in order_by_pairs:
716 assert _IsValidOrderByTerm(term), term
717 assert term.count('%s') == len(args), term
718 self.order_by_terms.append(term)
719 self.order_by_args.extend(args)
720
721 def SetLimitAndOffset(self, limit, offset):
722 """Save info needed to generate the LIMIT OFFSET clause."""
723 self.limit = limit
724 self.offset = offset
725
726 def AddWhereTerms(self, where_cond_pairs, **kwargs):
727 """Generate a WHERE clause."""
728 where_cond_pairs = where_cond_pairs or []
729
730 for cond, args in where_cond_pairs:
731 assert _IsValidWhereCond(cond), cond
732 assert cond.count('%s') == len(args), cond
733 self.where_conds.append(cond)
734 self.where_args.extend(args)
735
736 for col, val in sorted(kwargs.items()):
737 assert _IsValidColumnName(col), col
738 eq = True
739 if col.endswith('_not'):
740 col = col[:-4]
741 eq = False
742
743 if isinstance(val, set):
744 val = list(val) # MySQL inteface cannot handle sets.
745
746 if val is None or val == []:
747 if val == [] and self.main_clause and self.main_clause.startswith(
748 'UPDATE'):
749 # https://crbug.com/monorail/6735: Avoid empty arrays for UPDATE.
750 raise exceptions.InputException('Invalid update DB value %r' % col)
751 op = 'IS' if eq else 'IS NOT'
752 self.where_conds.append(col + ' ' + op + ' NULL')
753 elif isinstance(val, list):
754 op = 'IN' if eq else 'NOT IN'
755 # Sadly, MySQLdb cannot escape lists, so we flatten to multiple "%s"s
756 self.where_conds.append(
757 col + ' ' + op + ' (' + PlaceHolders(val) + ')')
758 self.where_args.extend(val)
759 else:
760 op = '=' if eq else '!='
761 self.where_conds.append(col + ' ' + op + ' %s')
762 self.where_args.append(val)
763
764 def AddHavingTerms(self, having_cond_pairs):
765 """Generate a HAVING clause."""
766 for cond, args in having_cond_pairs:
767 assert _IsValidHavingCond(cond), cond
768 assert cond.count('%s') == len(args), cond
769 self.having_conds.append(cond)
770 self.having_args.extend(args)
771
772
773def PlaceHolders(sql_args):
774 """Return a comma-separated list of %s placeholders for the given args."""
775 return ','.join('%s' for _ in sql_args)
776
777
778TABLE_PAT = '[A-Z][_a-zA-Z0-9]+'
779COLUMN_PAT = '[a-z][_a-z]+'
780COMPARE_OP_PAT = '(<|>|=|!=|>=|<=|LIKE|NOT LIKE)'
781SHORTHAND = {
782 'table': TABLE_PAT,
783 'column': COLUMN_PAT,
784 'tab_col': r'(%s\.)?%s' % (TABLE_PAT, COLUMN_PAT),
785 'placeholder': '%s', # That's a literal %s that gets passed to MySQLdb
786 'multi_placeholder': '%s(, ?%s)*',
787 'compare_op': COMPARE_OP_PAT,
788 'opt_asc_desc': '( ASC| DESC)?',
789 'opt_alias': '( AS %s)?' % TABLE_PAT,
790 'email_cond': (r'\(?'
791 r'('
792 r'(LOWER\(Spare\d+\.email\) IS NULL OR )?'
793 r'LOWER\(Spare\d+\.email\) '
794 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
795 r'( (AND|OR) )?'
796 r')+'
797 r'\)?' % COMPARE_OP_PAT),
798 'hotlist_cond': (r'\(?'
799 r'('
800 r'(LOWER\(Cond\d+\.name\) IS NULL OR )?'
801 r'LOWER\(Cond\d+\.name\) '
802 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
803 r'( (AND|OR) )?'
804 r')+'
805 r'\)?' % COMPARE_OP_PAT),
806 'phase_cond': (r'\(?'
807 r'('
808 r'(LOWER\(Phase\d+\.name\) IS NULL OR )?'
809 r'LOWER\(Phase\d+\.name\) '
810 r'(%s %%s|IN \(%%s(, ?%%s)*\))?'
811 r'( (AND|OR) )?'
812 r')+'
813 r'\)?' % COMPARE_OP_PAT),
814 'approval_cond': (r'\(?'
815 r'('
816 r'(LOWER\(Cond\d+\.status\) IS NULL OR )?'
817 r'LOWER\(Cond\d+\.status\) '
818 r'(%s %%s|IN \(%%s(, ?%%s)*\))'
819 r'( (AND|OR) )?'
820 r')+'
821 r'\)?' % COMPARE_OP_PAT),
822 }
823
824
825def _MakeRE(regex_str):
826 """Return a regular expression object, expanding our shorthand as needed."""
827 return re.compile(regex_str.format(**SHORTHAND))
828
829
830TABLE_RE = _MakeRE('^{table}$')
831TAB_COL_RE = _MakeRE('^{tab_col}$')
832USE_CLAUSE_RE = _MakeRE(
833 r'^USE INDEX \({column}\) USE INDEX FOR ORDER BY \({column}\)$')
834HAVING_RE_LIST = [
835 _MakeRE(r'^COUNT\(\*\) {compare_op} {placeholder}$')]
836COLUMN_RE_LIST = [
837 TAB_COL_RE,
838 _MakeRE(r'\*'),
839 _MakeRE(r'COUNT\(\*\)'),
840 _MakeRE(r'COUNT\({tab_col}\)'),
841 _MakeRE(r'COUNT\(DISTINCT\({tab_col}\)\)'),
842 _MakeRE(r'MAX\({tab_col}\)'),
843 _MakeRE(r'MIN\({tab_col}\)'),
844 _MakeRE(r'GROUP_CONCAT\((DISTINCT )?{tab_col}( ORDER BY {tab_col})?' \
845 r'( SEPARATOR \'.*\')?\)'),
846 ]
847JOIN_RE_LIST = [
848 TABLE_RE,
849 _MakeRE(
850 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
851 r'( AND {tab_col} = {tab_col})?'
852 r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
853 _MakeRE(
854 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
855 r'( AND {tab_col} = {tab_col})?'
856 r'( AND {tab_col} = {placeholder})?'
857 r'( AND {tab_col} IN \({multi_placeholder}\))?'
858 r'( AND {tab_col} = {tab_col})?$'),
859 _MakeRE(
860 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
861 r'( AND {tab_col} = {tab_col})?'
862 r'( AND {tab_col} = {placeholder})?'
863 r'( AND {tab_col} IN \({multi_placeholder}\))?'
864 r'( AND {tab_col} IS NULL)?'
865 r'( AND \({tab_col} IS NULL'
866 r' OR {tab_col} NOT IN \({multi_placeholder}\)\))?$'),
867 _MakeRE(
868 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
869 r'( AND {tab_col} = {tab_col})?'
870 r'( AND {tab_col} = {placeholder})?'
871 r' AND \(?{tab_col} {compare_op} {placeholder}\)?'
872 r'( AND {tab_col} = {tab_col})?$'),
873 _MakeRE(
874 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
875 r'( AND {tab_col} = {tab_col})?'
876 r'( AND {tab_col} = {placeholder})?'
877 r' AND {tab_col} = {tab_col}$'),
878 _MakeRE(
879 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
880 r'( AND {tab_col} = {tab_col})?'
881 r'( AND {tab_col} = {placeholder})?'
882 r' AND \({tab_col} IS NULL OR'
883 r' {tab_col} {compare_op} {placeholder}\)$'),
884 _MakeRE(
885 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
886 r' AND \({tab_col} IS NOT NULL AND {tab_col} != {placeholder}\)'),
887 _MakeRE(
888 r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
889 r' AND LOWER\({tab_col}\) = LOWER\({placeholder}\)'),
890 _MakeRE(
891 r'^{table}{opt_alias} ON {tab_col} = {tab_col} AND {email_cond}$'),
892 _MakeRE(
893 r'^{table}{opt_alias} ON {email_cond}$'),
894 _MakeRE(
895 r'^{table}{opt_alias} ON '
896 r'\({tab_col} = {tab_col} OR {tab_col} = {tab_col}\)$'),
897 _MakeRE(
898 r'^\({table} AS {table} JOIN User AS {table} '
899 r'ON {tab_col} = {tab_col} AND {email_cond}\) '
900 r'ON Issue(Snapshot)?.id = {tab_col}'
901 r'( AND {tab_col} IS NULL)?'),
902 _MakeRE(
903 r'^\({table} JOIN Hotlist AS {table} '
904 r'ON {tab_col} = {tab_col} AND {hotlist_cond}\) '
905 r'ON Issue.id = {tab_col}?'),
906 _MakeRE(
907 r'^\({table} AS {table} JOIN IssuePhaseDef AS {table} '
908 r'ON {tab_col} = {tab_col} AND {phase_cond}\) '
909 r'ON Issue.id = {tab_col}?'),
910 _MakeRE(
911 r'^IssuePhaseDef AS {table} ON {phase_cond}'),
912 _MakeRE(
913 r'^Issue2ApprovalValue AS {table} ON {tab_col} = {tab_col} '
914 r'AND {tab_col} = {placeholder} AND {approval_cond}'),
915 _MakeRE(
916 r'^{table} AS {table} ON {tab_col} = {tab_col} '
917 r'LEFT JOIN {table} AS {table} ON {tab_col} = {tab_col}'),
918 ]
919ORDER_BY_RE_LIST = [
920 _MakeRE(r'^{tab_col}{opt_asc_desc}$'),
921 _MakeRE(r'^LOWER\({tab_col}\){opt_asc_desc}$'),
922 _MakeRE(r'^ISNULL\({tab_col}\){opt_asc_desc}$'),
923 _MakeRE(r'^\(ISNULL\({tab_col}\) AND ISNULL\({tab_col}\)\){opt_asc_desc}$'),
924 _MakeRE(r'^FIELD\({tab_col}, {multi_placeholder}\){opt_asc_desc}$'),
925 _MakeRE(r'^FIELD\(IF\(ISNULL\({tab_col}\), {tab_col}, {tab_col}\), '
926 r'{multi_placeholder}\){opt_asc_desc}$'),
927 _MakeRE(r'^CONCAT\({tab_col}, {tab_col}\){opt_asc_desc}$'),
928 ]
929GROUP_BY_RE_LIST = [
930 TAB_COL_RE,
931 ]
932WHERE_COND_RE_LIST = [
933 _MakeRE(r'^TRUE$'),
934 _MakeRE(r'^FALSE$'),
935 _MakeRE(r'^{tab_col} IS NULL$'),
936 _MakeRE(r'^{tab_col} IS NOT NULL$'),
937 _MakeRE(r'^{tab_col} {compare_op} {tab_col}$'),
938 _MakeRE(r'^{tab_col} {compare_op} {placeholder}$'),
939 _MakeRE(r'^{tab_col} %% {placeholder} = {placeholder}$'),
940 _MakeRE(r'^{tab_col} IN \({multi_placeholder}\)$'),
941 _MakeRE(r'^{tab_col} NOT IN \({multi_placeholder}\)$'),
942 _MakeRE(r'^LOWER\({tab_col}\) IS NULL$'),
943 _MakeRE(r'^LOWER\({tab_col}\) IS NOT NULL$'),
944 _MakeRE(r'^LOWER\({tab_col}\) {compare_op} {placeholder}$'),
945 _MakeRE(r'^LOWER\({tab_col}\) IN \({multi_placeholder}\)$'),
946 _MakeRE(r'^LOWER\({tab_col}\) NOT IN \({multi_placeholder}\)$'),
947 _MakeRE(r'^LOWER\({tab_col}\) LIKE {placeholder}$'),
948 _MakeRE(r'^LOWER\({tab_col}\) NOT LIKE {placeholder}$'),
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +0100949 _MakeRE(
950 r'^timestep < \(SELECT MAX\(j.timestep\) FROM Invalidate AS j '
951 r'WHERE j.kind = %s '
952 r'AND j.cache_key = Invalidate.cache_key\)$'),
953 _MakeRE(
954 r'^\({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}\) '
955 r'AND \({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}'
956 r'\)$'),
957 _MakeRE(
958 r'^\({tab_col} IS NOT NULL AND {tab_col} {compare_op} '
959 r'{placeholder}\) OR \({tab_col} IS NOT NULL AND {tab_col} '
960 r'{compare_op} {placeholder}\)$'),
961]
Copybara854996b2021-09-07 19:36:02 +0000962
963# Note: We never use ';' for multiple statements, '@' for SQL variables, or
964# any quoted strings in stmt_str (quotes are put in my MySQLdb for args).
965STMT_STR_RE = re.compile(
966 r'\A(SELECT|UPDATE|DELETE|INSERT|REPLACE) [\'-+=!<>%*.,()\w\s]+\Z',
967 re.MULTILINE)
968
969
970def _IsValidDBValue(val):
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +0100971 if isinstance(val, six.string_types):
Copybara854996b2021-09-07 19:36:02 +0000972 return '\x00' not in val
973 return True
974
975
976def _IsValidTableName(table_name):
977 return TABLE_RE.match(table_name)
978
979
980def _IsValidColumnName(column_expr):
981 return any(regex.match(column_expr) for regex in COLUMN_RE_LIST)
982
983
984def _IsValidUseClause(use_clause):
985 return USE_CLAUSE_RE.match(use_clause)
986
987def _IsValidHavingCond(cond):
988 if cond.startswith('(') and cond.endswith(')'):
989 cond = cond[1:-1]
990
991 if ' OR ' in cond:
992 return all(_IsValidHavingCond(c) for c in cond.split(' OR '))
993
994 if ' AND ' in cond:
995 return all(_IsValidHavingCond(c) for c in cond.split(' AND '))
996
997 return any(regex.match(cond) for regex in HAVING_RE_LIST)
998
999
1000def _IsValidJoin(join):
1001 return any(regex.match(join) for regex in JOIN_RE_LIST)
1002
1003
1004def _IsValidOrderByTerm(term):
1005 return any(regex.match(term) for regex in ORDER_BY_RE_LIST)
1006
1007
1008def _IsValidGroupByTerm(term):
1009 return any(regex.match(term) for regex in GROUP_BY_RE_LIST)
1010
1011
1012def _IsValidWhereCond(cond):
1013 if cond.startswith('NOT '):
1014 cond = cond[4:]
1015 if cond.startswith('(') and cond.endswith(')'):
1016 cond = cond[1:-1]
1017
1018 if any(regex.match(cond) for regex in WHERE_COND_RE_LIST):
1019 return True
1020
1021 if ' OR ' in cond:
1022 return all(_IsValidWhereCond(c) for c in cond.split(' OR '))
1023
1024 if ' AND ' in cond:
1025 return all(_IsValidWhereCond(c) for c in cond.split(' AND '))
1026
1027 return False
1028
1029
1030def _IsValidStatement(stmt_str):
1031 """Final check to make sure there is no funny junk sneaking in somehow."""
1032 return (STMT_STR_RE.match(stmt_str) and
1033 '--' not in stmt_str)
1034
1035
1036def _BoolsToInts(arg_list):
1037 """Convert any True values to 1s and Falses to 0s.
1038
1039 Google's copy of MySQLdb has bool-to-int conversion disabled,
1040 and yet it seems to be needed otherwise they are converted
1041 to strings and always interpreted as 0 (which is FALSE).
1042
1043 Args:
1044 arg_list: (nested) list of SQL statment argument values, which may
1045 include some boolean values.
1046
1047 Returns:
1048 The same list, but with True replaced by 1 and False replaced by 0.
1049 """
1050 result = []
1051 for arg in arg_list:
1052 if isinstance(arg, (list, tuple)):
1053 result.append(_BoolsToInts(arg))
1054 elif arg is True:
1055 result.append(1)
1056 elif arg is False:
1057 result.append(0)
1058 else:
1059 result.append(arg)
1060
1061 return result