blob: e8408cc85b449c4dbb58183fab4a3830dbcefb08 [file] [log] [blame]
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +01001# Copyright 2016 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
Copybara854996b2021-09-07 19:36:02 +00004
5"""Unit tests for the sql module."""
6from __future__ import print_function
7from __future__ import division
8from __future__ import absolute_import
9
10import logging
11import mock
12import time
13import unittest
14
15import settings
16from framework import exceptions
17from framework import sql
18
19
20class MockSQLCnxn(object):
21 """This class mocks the connection and cursor classes."""
22
23 def __init__(self, instance, database):
24 self.instance = instance
25 self.database = database
26 self.last_executed = None
27 self.last_executed_args = None
28 self.result_rows = None
29 self.rowcount = 0
30 self.lastrowid = None
31 self.pool_key = instance + '/' + database
32 self.is_bad = False
33 self.has_uncommitted = False
34
35 def execute(self, stmt_str, args=None):
36 self.last_executed = stmt_str % tuple(args or [])
37 if not stmt_str.startswith(('SET', 'SELECT')):
38 self.has_uncommitted = True
39
40 def executemany(self, stmt_str, args):
41 # We cannot format the string because args has many values for each %s.
42 self.last_executed = stmt_str
43 self.last_executed_args = tuple(args)
44
45 # sql.py only calls executemany() for INSERT.
46 assert stmt_str.startswith('INSERT')
47 self.lastrowid = 123
48
49 def fetchall(self):
50 return self.result_rows
51
52 def cursor(self):
53 return self
54
55 def commit(self):
56 self.has_uncommitted = False
57
58 def close(self):
59 assert not self.has_uncommitted
60
61 def rollback(self):
62 self.has_uncommitted = False
63
64 def ping(self):
65 if self.is_bad:
66 raise BaseException('connection error!')
67
68
69sql.cnxn_ctor = MockSQLCnxn
70
71
72class ConnectionPoolingTest(unittest.TestCase):
73
74 def testGet(self):
75 pool_size = 2
76 num_dbs = 2
77 p = sql.ConnectionPool(pool_size)
78
79 for i in range(num_dbs):
80 for _ in range(pool_size):
81 c = p.get('test', 'db%d' % i)
82 self.assertIsNotNone(c)
83 p.release(c)
84
85 cnxn1 = p.get('test', 'db0')
86 q = p.queues[cnxn1.pool_key]
87 self.assertIs(q.qsize(), 0)
88
89 p.release(cnxn1)
90 self.assertIs(q.qsize(), pool_size - 1)
91 self.assertIs(q.full(), False)
92 self.assertIs(q.empty(), False)
93
94 cnxn2 = p.get('test', 'db0')
95 q = p.queues[cnxn2.pool_key]
96 self.assertIs(q.qsize(), 0)
97 self.assertIs(q.full(), False)
98 self.assertIs(q.empty(), True)
99
100 def testGetAndReturnPooledCnxn(self):
101 p = sql.ConnectionPool(2)
102
103 cnxn1 = p.get('test', 'db1')
104 self.assertIs(len(p.queues), 1)
105
106 cnxn2 = p.get('test', 'db2')
107 self.assertIs(len(p.queues), 2)
108
109 # Should use the existing pool.
110 cnxn3 = p.get('test', 'db1')
111 self.assertIs(len(p.queues), 2)
112
113 p.release(cnxn3)
114 p.release(cnxn2)
115
116 cnxn1.is_bad = True
117 p.release(cnxn1)
118 # cnxn1 should not be returned from the pool if we
119 # ask for a connection to its database.
120
121 cnxn4 = p.get('test', 'db1')
122
123 self.assertIsNot(cnxn1, cnxn4)
124 self.assertIs(len(p.queues), 2)
125 self.assertIs(cnxn4.is_bad, False)
126
127 def testGetAndReturnPooledCnxn_badCnxn(self):
128 p = sql.ConnectionPool(2)
129
130 cnxn1 = p.get('test', 'db1')
131 cnxn2 = p.get('test', 'db2')
132 cnxn3 = p.get('test', 'db1')
133
134 cnxn3.is_bad = True
135
136 p.release(cnxn3)
137 q = p.queues[cnxn3.pool_key]
138 self.assertIs(q.qsize(), 1)
139
140 with self.assertRaises(BaseException):
141 cnxn3 = p.get('test', 'db1')
142
143 q = p.queues[cnxn2.pool_key]
144 self.assertIs(q.qsize(), 0)
145 p.release(cnxn2)
146 self.assertIs(q.qsize(), 1)
147
148 p.release(cnxn1)
149 q = p.queues[cnxn1.pool_key]
150 self.assertIs(q.qsize(), 1)
151
152
153class MonorailConnectionTest(unittest.TestCase):
154
155 def setUp(self):
156 self.cnxn = sql.MonorailConnection()
157 self.orig_local_mode = settings.local_mode
158 self.orig_num_logical_shards = settings.num_logical_shards
159 settings.local_mode = False
160
161 def tearDown(self):
162 settings.local_mode = self.orig_local_mode
163 settings.num_logical_shards = self.orig_num_logical_shards
164
165 def testGetPrimaryConnection(self):
166 sql_cnxn = self.cnxn.GetPrimaryConnection()
167 self.assertEqual(settings.db_instance, sql_cnxn.instance)
168 self.assertEqual(settings.db_database_name, sql_cnxn.database)
169
170 sql_cnxn2 = self.cnxn.GetPrimaryConnection()
171 self.assertIs(sql_cnxn2, sql_cnxn)
172
173 def testGetConnectionForShard(self):
174 sql_cnxn = self.cnxn.GetConnectionForShard(1)
175 replica_name = settings.db_replica_names[
176 1 % len(settings.db_replica_names)]
177 self.assertEqual(settings.physical_db_name_format % replica_name,
178 sql_cnxn.instance)
179 self.assertEqual(settings.db_database_name, sql_cnxn.database)
180
181 sql_cnxn2 = self.cnxn.GetConnectionForShard(1)
182 self.assertIs(sql_cnxn2, sql_cnxn)
183
184 def testClose(self):
185 sql_cnxn = self.cnxn.GetPrimaryConnection()
186 self.cnxn.Close()
187 self.assertFalse(sql_cnxn.has_uncommitted)
188
189 def testExecute_Primary(self):
190 """Execute() with no shard passes the statement to the primary sql cnxn."""
191 sql_cnxn = self.cnxn.GetPrimaryConnection()
192 with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
193 ewsc.return_value = 'db result'
194 actual_result = self.cnxn.Execute('statement', [])
195 self.assertEqual('db result', actual_result)
196 ewsc.assert_called_once_with(sql_cnxn, 'statement', [], commit=True)
197
198 def testExecute_Shard(self):
199 """Execute() with a shard passes the statement to the shard sql cnxn."""
200 shard_id = 1
201 sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
202 with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
203 ewsc.return_value = 'db result'
204 actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
205 self.assertEqual('db result', actual_result)
206 ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
207
208 def testExecute_Shard_Unavailable(self):
209 """If a shard is unavailable, we try the next one."""
210 shard_id = 1
211 sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
212 sql_cnxn_2 = self.cnxn.GetConnectionForShard(shard_id + 1)
213
214 # Simulate a recent failure on shard 1.
215 self.cnxn.unavailable_shards[1] = int(time.time()) - 3
216
217 with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
218 ewsc.return_value = 'db result'
219 actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
220 self.assertEqual('db result', actual_result)
221 ewsc.assert_called_once_with(sql_cnxn_2, 'statement', [], commit=True)
222
223 # Even a new MonorailConnection instance shares the same state.
224 other_cnxn = sql.MonorailConnection()
225 other_sql_cnxn_2 = other_cnxn.GetConnectionForShard(shard_id + 1)
226
227 with mock.patch.object(other_cnxn, '_ExecuteWithSQLConnection') as ewsc:
228 ewsc.return_value = 'db result'
229 actual_result = other_cnxn.Execute('statement', [], shard_id=shard_id)
230 self.assertEqual('db result', actual_result)
231 ewsc.assert_called_once_with(
232 other_sql_cnxn_2, 'statement', [], commit=True)
233
234 # Simulate an old failure on shard 1, allowing us to try using it again.
235 self.cnxn.unavailable_shards[1] = (
236 int(time.time()) - sql.BAD_SHARD_AVOIDANCE_SEC - 2)
237
238 with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
239 ewsc.return_value = 'db result'
240 actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
241 self.assertEqual('db result', actual_result)
242 ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
243
244
245class TableManagerTest(unittest.TestCase):
246
247 def setUp(self):
248 self.emp_tbl = sql.SQLTableManager('Employee')
249 self.cnxn = sql.MonorailConnection()
250 self.primary_cnxn = self.cnxn.GetPrimaryConnection()
251
252 def testSelect_Trivial(self):
253 self.primary_cnxn.result_rows = [(111, True), (222, False)]
254 rows = self.emp_tbl.Select(self.cnxn)
255 self.assertEqual('SELECT * FROM Employee', self.primary_cnxn.last_executed)
256 self.assertEqual([(111, True), (222, False)], rows)
257
258 def testSelect_Conditions(self):
259 self.primary_cnxn.result_rows = [(111,)]
260 rows = self.emp_tbl.Select(
261 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
262 self.assertEqual(
263 'SELECT emp_id FROM Employee'
264 '\nWHERE dept_id IN (10,20)'
265 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
266 self.assertEqual([(111,)], rows)
267
268 def testSelectRow(self):
269 self.primary_cnxn.result_rows = [(111,)]
270 row = self.emp_tbl.SelectRow(
271 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
272 self.assertEqual(
273 'SELECT DISTINCT emp_id FROM Employee'
274 '\nWHERE dept_id IN (10,20)'
275 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
276 self.assertEqual((111,), row)
277
278 def testSelectRow_NoMatches(self):
279 self.primary_cnxn.result_rows = []
280 row = self.emp_tbl.SelectRow(
281 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99])
282 self.assertEqual(
283 'SELECT DISTINCT emp_id FROM Employee'
284 '\nWHERE dept_id IN (99)'
285 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
286 self.assertEqual(None, row)
287
288 row = self.emp_tbl.SelectRow(
289 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99],
290 default=(-1,))
291 self.assertEqual((-1,), row)
292
293 def testSelectValue(self):
294 self.primary_cnxn.result_rows = [(111,)]
295 val = self.emp_tbl.SelectValue(
296 self.cnxn, 'emp_id', fulltime=True, dept_id=[10, 20])
297 self.assertEqual(
298 'SELECT DISTINCT emp_id FROM Employee'
299 '\nWHERE dept_id IN (10,20)'
300 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
301 self.assertEqual(111, val)
302
303 def testSelectValue_NoMatches(self):
304 self.primary_cnxn.result_rows = []
305 val = self.emp_tbl.SelectValue(
306 self.cnxn, 'emp_id', fulltime=True, dept_id=[99])
307 self.assertEqual(
308 'SELECT DISTINCT emp_id FROM Employee'
309 '\nWHERE dept_id IN (99)'
310 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
311 self.assertEqual(None, val)
312
313 val = self.emp_tbl.SelectValue(
314 self.cnxn, 'emp_id', fulltime=True, dept_id=[99],
315 default=-1)
316 self.assertEqual(-1, val)
317
318 def testInsertRow(self):
319 self.primary_cnxn.rowcount = 1
320 generated_id = self.emp_tbl.InsertRow(self.cnxn, emp_id=111, fulltime=True)
321 self.assertEqual(
322 'INSERT INTO Employee (emp_id, fulltime)'
323 '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
324 self.assertEqual(([111, 1],), self.primary_cnxn.last_executed_args)
325 self.assertEqual(123, generated_id)
326
327 def testInsertRows_Empty(self):
328 generated_id = self.emp_tbl.InsertRows(
329 self.cnxn, ['emp_id', 'fulltime'], [])
330 self.assertIsNone(self.primary_cnxn.last_executed)
331 self.assertIsNone(self.primary_cnxn.last_executed_args)
332 self.assertEqual(None, generated_id)
333
334 def testInsertRows(self):
335 self.primary_cnxn.rowcount = 2
336 generated_ids = self.emp_tbl.InsertRows(
337 self.cnxn, ['emp_id', 'fulltime'], [(111, True), (222, False)])
338 self.assertEqual(
339 'INSERT INTO Employee (emp_id, fulltime)'
340 '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
341 self.assertEqual(([111, 1], [222, 0]), self.primary_cnxn.last_executed_args)
342 self.assertEqual([], generated_ids)
343
344 def testUpdate(self):
345 self.primary_cnxn.rowcount = 2
346 rowcount = self.emp_tbl.Update(
347 self.cnxn, {'fulltime': True}, emp_id=[111, 222])
348 self.assertEqual(
349 'UPDATE Employee SET fulltime=1'
350 '\nWHERE emp_id IN (111,222)', self.primary_cnxn.last_executed)
351 self.assertEqual(2, rowcount)
352
353 def testUpdate_Limit(self):
354 self.emp_tbl.Update(
355 self.cnxn, {'fulltime': True}, limit=8, emp_id=[111, 222])
356 self.assertEqual(
357 'UPDATE Employee SET fulltime=1'
358 '\nWHERE emp_id IN (111,222)'
359 '\nLIMIT 8', self.primary_cnxn.last_executed)
360
361 def testIncrementCounterValue(self):
362 self.primary_cnxn.rowcount = 1
363 self.primary_cnxn.lastrowid = 9
364 new_counter_val = self.emp_tbl.IncrementCounterValue(
365 self.cnxn, 'years_worked', emp_id=111)
366 self.assertEqual(
367 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + 1)'
368 '\nWHERE emp_id = 111', self.primary_cnxn.last_executed)
369 self.assertEqual(9, new_counter_val)
370
371 def testDelete(self):
372 self.primary_cnxn.rowcount = 1
373 rowcount = self.emp_tbl.Delete(self.cnxn, fulltime=True)
374 self.assertEqual(
375 'DELETE FROM Employee'
376 '\nWHERE fulltime = 1', self.primary_cnxn.last_executed)
377 self.assertEqual(1, rowcount)
378
379 def testDelete_Limit(self):
380 self.emp_tbl.Delete(self.cnxn, fulltime=True, limit=3)
381 self.assertEqual(
382 'DELETE FROM Employee'
383 '\nWHERE fulltime = 1'
384 '\nLIMIT 3', self.primary_cnxn.last_executed)
385
386
387class StatementTest(unittest.TestCase):
388
389 def testMakeSelect(self):
390 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
391 stmt_str, args = stmt.Generate()
392 self.assertEqual(
393 'SELECT emp_id, fulltime FROM Employee',
394 stmt_str)
395 self.assertEqual([], args)
396
397 stmt = sql.Statement.MakeSelect(
398 'Employee', ['emp_id', 'fulltime'], distinct=True)
399 stmt_str, args = stmt.Generate()
400 self.assertEqual(
401 'SELECT DISTINCT emp_id, fulltime FROM Employee',
402 stmt_str)
403 self.assertEqual([], args)
404
405 def testMakeInsert(self):
406 stmt = sql.Statement.MakeInsert(
407 'Employee', ['emp_id', 'fulltime'], [(111, True), (222, False)])
408 stmt_str, args = stmt.Generate()
409 self.assertEqual(
410 'INSERT INTO Employee (emp_id, fulltime)'
411 '\nVALUES (%s,%s)',
412 stmt_str)
413 self.assertEqual([[111, 1], [222, 0]], args)
414
415 stmt = sql.Statement.MakeInsert(
416 'Employee', ['emp_id', 'fulltime'], [(111, False)], replace=True)
417 stmt_str, args = stmt.Generate()
418 self.assertEqual(
419 'INSERT INTO Employee (emp_id, fulltime)'
420 '\nVALUES (%s,%s)'
421 '\nON DUPLICATE KEY UPDATE '
422 'emp_id=VALUES(emp_id), fulltime=VALUES(fulltime)',
423 stmt_str)
424 self.assertEqual([[111, 0]], args)
425
426 stmt = sql.Statement.MakeInsert(
427 'Employee', ['emp_id', 'fulltime'], [(111, False)], ignore=True)
428 stmt_str, args = stmt.Generate()
429 self.assertEqual(
430 'INSERT IGNORE INTO Employee (emp_id, fulltime)'
431 '\nVALUES (%s,%s)',
432 stmt_str)
433 self.assertEqual([[111, 0]], args)
434
435 def testMakeInsert_InvalidString(self):
436 with self.assertRaises(exceptions.InputException):
437 sql.Statement.MakeInsert(
438 'Employee', ['emp_id', 'name'], [(111, 'First \x00 Last')])
439
440 def testMakeUpdate(self):
441 stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
442 stmt_str, args = stmt.Generate()
443 self.assertEqual(
444 'UPDATE Employee SET fulltime=%s',
445 stmt_str)
446 self.assertEqual([1], args)
447
448 def testMakeUpdate_InvalidString(self):
449 with self.assertRaises(exceptions.InputException):
450 sql.Statement.MakeUpdate('Employee', {'name': 'First \x00 Last'})
451
452 def testMakeIncrement(self):
453 stmt = sql.Statement.MakeIncrement('Employee', 'years_worked')
454 stmt_str, args = stmt.Generate()
455 self.assertEqual(
456 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
457 stmt_str)
458 self.assertEqual([1], args)
459
460 stmt = sql.Statement.MakeIncrement('Employee', 'years_worked', step=5)
461 stmt_str, args = stmt.Generate()
462 self.assertEqual(
463 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
464 stmt_str)
465 self.assertEqual([5], args)
466
467 def testMakeDelete(self):
468 stmt = sql.Statement.MakeDelete('Employee')
469 stmt_str, args = stmt.Generate()
470 self.assertEqual(
471 'DELETE FROM Employee',
472 stmt_str)
473 self.assertEqual([], args)
474
475 def testAddUseClause(self):
476 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
477 stmt.AddUseClause('USE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)')
478 stmt.AddOrderByTerms([('emp_id', [])])
479 stmt_str, args = stmt.Generate()
480 self.assertEqual(
481 'SELECT emp_id, fulltime FROM Employee'
482 '\nUSE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)'
483 '\nORDER BY emp_id',
484 stmt_str)
485 self.assertEqual([], args)
486
487 def testAddJoinClause_Empty(self):
488 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
489 stmt.AddJoinClauses([])
490 stmt_str, args = stmt.Generate()
491 self.assertEqual(
492 'SELECT emp_id, fulltime FROM Employee',
493 stmt_str)
494 self.assertEqual([], args)
495
496 def testAddJoinClause(self):
497 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
498 stmt.AddJoinClauses([('CorporateHoliday', [])])
499 stmt.AddJoinClauses(
500 [('Product ON Project.inventor_id = emp_id', [])], left=True)
501 stmt_str, args = stmt.Generate()
502 self.assertEqual(
503 'SELECT emp_id, fulltime FROM Employee'
504 '\n JOIN CorporateHoliday'
505 '\n LEFT JOIN Product ON Project.inventor_id = emp_id',
506 stmt_str)
507 self.assertEqual([], args)
508
509 def testAddGroupByTerms_Empty(self):
510 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
511 stmt.AddGroupByTerms([])
512 stmt_str, args = stmt.Generate()
513 self.assertEqual(
514 'SELECT emp_id, fulltime FROM Employee',
515 stmt_str)
516 self.assertEqual([], args)
517
518 def testAddGroupByTerms(self):
519 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
520 stmt.AddGroupByTerms(['dept_id', 'location_id'])
521 stmt_str, args = stmt.Generate()
522 self.assertEqual(
523 'SELECT emp_id, fulltime FROM Employee'
524 '\nGROUP BY dept_id, location_id',
525 stmt_str)
526 self.assertEqual([], args)
527
528 def testAddOrderByTerms_Empty(self):
529 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
530 stmt.AddOrderByTerms([])
531 stmt_str, args = stmt.Generate()
532 self.assertEqual(
533 'SELECT emp_id, fulltime FROM Employee',
534 stmt_str)
535 self.assertEqual([], args)
536
537 def testAddOrderByTerms(self):
538 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
539 stmt.AddOrderByTerms([('dept_id', []), ('emp_id DESC', [])])
540 stmt_str, args = stmt.Generate()
541 self.assertEqual(
542 'SELECT emp_id, fulltime FROM Employee'
543 '\nORDER BY dept_id, emp_id DESC',
544 stmt_str)
545 self.assertEqual([], args)
546
547 def testSetLimitAndOffset(self):
548 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
549 stmt.SetLimitAndOffset(100, 0)
550 stmt_str, args = stmt.Generate()
551 self.assertEqual(
552 'SELECT emp_id, fulltime FROM Employee'
553 '\nLIMIT 100',
554 stmt_str)
555 self.assertEqual([], args)
556
557 stmt.SetLimitAndOffset(100, 500)
558 stmt_str, args = stmt.Generate()
559 self.assertEqual(
560 'SELECT emp_id, fulltime FROM Employee'
561 '\nLIMIT 100 OFFSET 500',
562 stmt_str)
563 self.assertEqual([], args)
564
565 def testAddWhereTerms_Select(self):
566 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
567 stmt.AddWhereTerms([], emp_id=[111, 222])
568 stmt_str, args = stmt.Generate()
569 self.assertEqual(
570 'SELECT emp_id, fulltime FROM Employee'
571 '\nWHERE emp_id IN (%s,%s)',
572 stmt_str)
573 self.assertEqual([111, 222], args)
574
575 def testAddWhereTerms_Update(self):
576 stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
577 stmt.AddWhereTerms([], emp_id=[111, 222])
578 stmt_str, args = stmt.Generate()
579 self.assertEqual(
580 'UPDATE Employee SET fulltime=%s'
581 '\nWHERE emp_id IN (%s,%s)',
582 stmt_str)
583 self.assertEqual([1, 111, 222], args)
584
585 def testAddWhereTerms_Delete(self):
586 stmt = sql.Statement.MakeDelete('Employee')
587 stmt.AddWhereTerms([], emp_id=[111, 222])
588 stmt_str, args = stmt.Generate()
589 self.assertEqual(
590 'DELETE FROM Employee'
591 '\nWHERE emp_id IN (%s,%s)',
592 stmt_str)
593 self.assertEqual([111, 222], args)
594
595 def testAddWhereTerms_Empty(self):
596 """Add empty terms should have no effect."""
597 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
598 stmt.AddWhereTerms([])
599 stmt_str, args = stmt.Generate()
600 self.assertEqual(
601 'SELECT emp_id, fulltime FROM Employee',
602 stmt_str)
603 self.assertEqual([], args)
604
605 def testAddWhereTerms_UpdateEmptyArray(self):
606 """Add empty array should throw an exception."""
607 stmt = sql.Statement.MakeUpdate('SpamVerdict', {'user_id': 1})
608 # See https://crbug.com/monorail/6735.
609 with self.assertRaises(exceptions.InputException):
610 stmt.AddWhereTerms([], user_id=[])
611 mock_log.assert_called_once_with('Invalid update DB value %r', 'user_id')
612
613 def testAddWhereTerms_MulitpleTerms(self):
614 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
615 stmt.AddWhereTerms(
616 [('emp_id %% %s = %s', [2, 0])], fulltime=True, emp_id_not=222)
617 stmt_str, args = stmt.Generate()
618 self.assertEqual(
619 'SELECT emp_id, fulltime FROM Employee'
620 '\nWHERE emp_id %% %s = %s'
621 '\n AND emp_id != %s'
622 '\n AND fulltime = %s',
623 stmt_str)
624 self.assertEqual([2, 0, 222, 1], args)
625
626 def testAddHavingTerms_NoGroupBy(self):
627 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
628 stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
629 self.assertRaises(AssertionError, stmt.Generate)
630
631 def testAddHavingTerms_WithGroupBy(self):
632 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
633 stmt.AddGroupByTerms(['dept_id', 'location_id'])
634 stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
635 stmt_str, args = stmt.Generate()
636 self.assertEqual(
637 'SELECT emp_id, fulltime FROM Employee'
638 '\nGROUP BY dept_id, location_id'
639 '\nHAVING COUNT(*) > %s',
640 stmt_str)
641 self.assertEqual([10], args)
642
643
644class FunctionsTest(unittest.TestCase):
645
646 def testIsValidDBValue_NonString(self):
647 self.assertTrue(sql._IsValidDBValue(12))
648 self.assertTrue(sql._IsValidDBValue(True))
649 self.assertTrue(sql._IsValidDBValue(False))
650 self.assertTrue(sql._IsValidDBValue(None))
651
652 def testIsValidDBValue_String(self):
653 self.assertTrue(sql._IsValidDBValue(''))
654 self.assertTrue(sql._IsValidDBValue('hello'))
655 self.assertTrue(sql._IsValidDBValue(u'hello'))
656 self.assertFalse(sql._IsValidDBValue('null \x00 byte'))
657
658 def testBoolsToInts_NoChanges(self):
659 self.assertEqual(['hello'], sql._BoolsToInts(['hello']))
660 self.assertEqual([['hello']], sql._BoolsToInts([['hello']]))
661 self.assertEqual([['hello']], sql._BoolsToInts([('hello',)]))
662 self.assertEqual([12], sql._BoolsToInts([12]))
663 self.assertEqual([[12]], sql._BoolsToInts([[12]]))
664 self.assertEqual([[12]], sql._BoolsToInts([(12,)]))
665 self.assertEqual(
666 [12, 13, 'hi', [99, 'yo']],
667 sql._BoolsToInts([12, 13, 'hi', [99, 'yo']]))
668
669 def testBoolsToInts_WithChanges(self):
670 self.assertEqual([1, 0], sql._BoolsToInts([True, False]))
671 self.assertEqual([[1, 0]], sql._BoolsToInts([[True, False]]))
672 self.assertEqual([[1, 0]], sql._BoolsToInts([(True, False)]))
673 self.assertEqual(
674 [12, 1, 'hi', [0, 'yo']],
675 sql._BoolsToInts([12, True, 'hi', [False, 'yo']]))
676
677 def testRandomShardID(self):
678 """A random shard ID must always be a valid shard ID."""
679 shard_id = sql.RandomShardID()
680 self.assertTrue(0 <= shard_id < settings.num_logical_shards)