blob: f073e24ab97e358f662de7837ef9017c15ca4eb7 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2016 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""Unit tests for the sql module."""
7from __future__ import print_function
8from __future__ import division
9from __future__ import absolute_import
10
11import logging
12import mock
13import time
14import unittest
15
16import settings
17from framework import exceptions
18from framework import sql
19
20
21class MockSQLCnxn(object):
22 """This class mocks the connection and cursor classes."""
23
24 def __init__(self, instance, database):
25 self.instance = instance
26 self.database = database
27 self.last_executed = None
28 self.last_executed_args = None
29 self.result_rows = None
30 self.rowcount = 0
31 self.lastrowid = None
32 self.pool_key = instance + '/' + database
33 self.is_bad = False
34 self.has_uncommitted = False
35
36 def execute(self, stmt_str, args=None):
37 self.last_executed = stmt_str % tuple(args or [])
38 if not stmt_str.startswith(('SET', 'SELECT')):
39 self.has_uncommitted = True
40
41 def executemany(self, stmt_str, args):
42 # We cannot format the string because args has many values for each %s.
43 self.last_executed = stmt_str
44 self.last_executed_args = tuple(args)
45
46 # sql.py only calls executemany() for INSERT.
47 assert stmt_str.startswith('INSERT')
48 self.lastrowid = 123
49
50 def fetchall(self):
51 return self.result_rows
52
53 def cursor(self):
54 return self
55
56 def commit(self):
57 self.has_uncommitted = False
58
59 def close(self):
60 assert not self.has_uncommitted
61
62 def rollback(self):
63 self.has_uncommitted = False
64
65 def ping(self):
66 if self.is_bad:
67 raise BaseException('connection error!')
68
69
70sql.cnxn_ctor = MockSQLCnxn
71
72
73class ConnectionPoolingTest(unittest.TestCase):
74
75 def testGet(self):
76 pool_size = 2
77 num_dbs = 2
78 p = sql.ConnectionPool(pool_size)
79
80 for i in range(num_dbs):
81 for _ in range(pool_size):
82 c = p.get('test', 'db%d' % i)
83 self.assertIsNotNone(c)
84 p.release(c)
85
86 cnxn1 = p.get('test', 'db0')
87 q = p.queues[cnxn1.pool_key]
88 self.assertIs(q.qsize(), 0)
89
90 p.release(cnxn1)
91 self.assertIs(q.qsize(), pool_size - 1)
92 self.assertIs(q.full(), False)
93 self.assertIs(q.empty(), False)
94
95 cnxn2 = p.get('test', 'db0')
96 q = p.queues[cnxn2.pool_key]
97 self.assertIs(q.qsize(), 0)
98 self.assertIs(q.full(), False)
99 self.assertIs(q.empty(), True)
100
101 def testGetAndReturnPooledCnxn(self):
102 p = sql.ConnectionPool(2)
103
104 cnxn1 = p.get('test', 'db1')
105 self.assertIs(len(p.queues), 1)
106
107 cnxn2 = p.get('test', 'db2')
108 self.assertIs(len(p.queues), 2)
109
110 # Should use the existing pool.
111 cnxn3 = p.get('test', 'db1')
112 self.assertIs(len(p.queues), 2)
113
114 p.release(cnxn3)
115 p.release(cnxn2)
116
117 cnxn1.is_bad = True
118 p.release(cnxn1)
119 # cnxn1 should not be returned from the pool if we
120 # ask for a connection to its database.
121
122 cnxn4 = p.get('test', 'db1')
123
124 self.assertIsNot(cnxn1, cnxn4)
125 self.assertIs(len(p.queues), 2)
126 self.assertIs(cnxn4.is_bad, False)
127
128 def testGetAndReturnPooledCnxn_badCnxn(self):
129 p = sql.ConnectionPool(2)
130
131 cnxn1 = p.get('test', 'db1')
132 cnxn2 = p.get('test', 'db2')
133 cnxn3 = p.get('test', 'db1')
134
135 cnxn3.is_bad = True
136
137 p.release(cnxn3)
138 q = p.queues[cnxn3.pool_key]
139 self.assertIs(q.qsize(), 1)
140
141 with self.assertRaises(BaseException):
142 cnxn3 = p.get('test', 'db1')
143
144 q = p.queues[cnxn2.pool_key]
145 self.assertIs(q.qsize(), 0)
146 p.release(cnxn2)
147 self.assertIs(q.qsize(), 1)
148
149 p.release(cnxn1)
150 q = p.queues[cnxn1.pool_key]
151 self.assertIs(q.qsize(), 1)
152
153
154class MonorailConnectionTest(unittest.TestCase):
155
156 def setUp(self):
157 self.cnxn = sql.MonorailConnection()
158 self.orig_local_mode = settings.local_mode
159 self.orig_num_logical_shards = settings.num_logical_shards
160 settings.local_mode = False
161
162 def tearDown(self):
163 settings.local_mode = self.orig_local_mode
164 settings.num_logical_shards = self.orig_num_logical_shards
165
166 def testGetPrimaryConnection(self):
167 sql_cnxn = self.cnxn.GetPrimaryConnection()
168 self.assertEqual(settings.db_instance, sql_cnxn.instance)
169 self.assertEqual(settings.db_database_name, sql_cnxn.database)
170
171 sql_cnxn2 = self.cnxn.GetPrimaryConnection()
172 self.assertIs(sql_cnxn2, sql_cnxn)
173
174 def testGetConnectionForShard(self):
175 sql_cnxn = self.cnxn.GetConnectionForShard(1)
176 replica_name = settings.db_replica_names[
177 1 % len(settings.db_replica_names)]
178 self.assertEqual(settings.physical_db_name_format % replica_name,
179 sql_cnxn.instance)
180 self.assertEqual(settings.db_database_name, sql_cnxn.database)
181
182 sql_cnxn2 = self.cnxn.GetConnectionForShard(1)
183 self.assertIs(sql_cnxn2, sql_cnxn)
184
185 def testClose(self):
186 sql_cnxn = self.cnxn.GetPrimaryConnection()
187 self.cnxn.Close()
188 self.assertFalse(sql_cnxn.has_uncommitted)
189
190 def testExecute_Primary(self):
191 """Execute() with no shard passes the statement to the primary sql cnxn."""
192 sql_cnxn = self.cnxn.GetPrimaryConnection()
193 with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
194 ewsc.return_value = 'db result'
195 actual_result = self.cnxn.Execute('statement', [])
196 self.assertEqual('db result', actual_result)
197 ewsc.assert_called_once_with(sql_cnxn, 'statement', [], commit=True)
198
199 def testExecute_Shard(self):
200 """Execute() with a shard passes the statement to the shard sql cnxn."""
201 shard_id = 1
202 sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
203 with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
204 ewsc.return_value = 'db result'
205 actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
206 self.assertEqual('db result', actual_result)
207 ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
208
209 def testExecute_Shard_Unavailable(self):
210 """If a shard is unavailable, we try the next one."""
211 shard_id = 1
212 sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
213 sql_cnxn_2 = self.cnxn.GetConnectionForShard(shard_id + 1)
214
215 # Simulate a recent failure on shard 1.
216 self.cnxn.unavailable_shards[1] = int(time.time()) - 3
217
218 with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
219 ewsc.return_value = 'db result'
220 actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
221 self.assertEqual('db result', actual_result)
222 ewsc.assert_called_once_with(sql_cnxn_2, 'statement', [], commit=True)
223
224 # Even a new MonorailConnection instance shares the same state.
225 other_cnxn = sql.MonorailConnection()
226 other_sql_cnxn_2 = other_cnxn.GetConnectionForShard(shard_id + 1)
227
228 with mock.patch.object(other_cnxn, '_ExecuteWithSQLConnection') as ewsc:
229 ewsc.return_value = 'db result'
230 actual_result = other_cnxn.Execute('statement', [], shard_id=shard_id)
231 self.assertEqual('db result', actual_result)
232 ewsc.assert_called_once_with(
233 other_sql_cnxn_2, 'statement', [], commit=True)
234
235 # Simulate an old failure on shard 1, allowing us to try using it again.
236 self.cnxn.unavailable_shards[1] = (
237 int(time.time()) - sql.BAD_SHARD_AVOIDANCE_SEC - 2)
238
239 with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
240 ewsc.return_value = 'db result'
241 actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
242 self.assertEqual('db result', actual_result)
243 ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
244
245
246class TableManagerTest(unittest.TestCase):
247
248 def setUp(self):
249 self.emp_tbl = sql.SQLTableManager('Employee')
250 self.cnxn = sql.MonorailConnection()
251 self.primary_cnxn = self.cnxn.GetPrimaryConnection()
252
253 def testSelect_Trivial(self):
254 self.primary_cnxn.result_rows = [(111, True), (222, False)]
255 rows = self.emp_tbl.Select(self.cnxn)
256 self.assertEqual('SELECT * FROM Employee', self.primary_cnxn.last_executed)
257 self.assertEqual([(111, True), (222, False)], rows)
258
259 def testSelect_Conditions(self):
260 self.primary_cnxn.result_rows = [(111,)]
261 rows = self.emp_tbl.Select(
262 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
263 self.assertEqual(
264 'SELECT emp_id FROM Employee'
265 '\nWHERE dept_id IN (10,20)'
266 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
267 self.assertEqual([(111,)], rows)
268
269 def testSelectRow(self):
270 self.primary_cnxn.result_rows = [(111,)]
271 row = self.emp_tbl.SelectRow(
272 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
273 self.assertEqual(
274 'SELECT DISTINCT emp_id FROM Employee'
275 '\nWHERE dept_id IN (10,20)'
276 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
277 self.assertEqual((111,), row)
278
279 def testSelectRow_NoMatches(self):
280 self.primary_cnxn.result_rows = []
281 row = self.emp_tbl.SelectRow(
282 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99])
283 self.assertEqual(
284 'SELECT DISTINCT emp_id FROM Employee'
285 '\nWHERE dept_id IN (99)'
286 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
287 self.assertEqual(None, row)
288
289 row = self.emp_tbl.SelectRow(
290 self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99],
291 default=(-1,))
292 self.assertEqual((-1,), row)
293
294 def testSelectValue(self):
295 self.primary_cnxn.result_rows = [(111,)]
296 val = self.emp_tbl.SelectValue(
297 self.cnxn, 'emp_id', fulltime=True, dept_id=[10, 20])
298 self.assertEqual(
299 'SELECT DISTINCT emp_id FROM Employee'
300 '\nWHERE dept_id IN (10,20)'
301 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
302 self.assertEqual(111, val)
303
304 def testSelectValue_NoMatches(self):
305 self.primary_cnxn.result_rows = []
306 val = self.emp_tbl.SelectValue(
307 self.cnxn, 'emp_id', fulltime=True, dept_id=[99])
308 self.assertEqual(
309 'SELECT DISTINCT emp_id FROM Employee'
310 '\nWHERE dept_id IN (99)'
311 '\n AND fulltime = 1', self.primary_cnxn.last_executed)
312 self.assertEqual(None, val)
313
314 val = self.emp_tbl.SelectValue(
315 self.cnxn, 'emp_id', fulltime=True, dept_id=[99],
316 default=-1)
317 self.assertEqual(-1, val)
318
319 def testInsertRow(self):
320 self.primary_cnxn.rowcount = 1
321 generated_id = self.emp_tbl.InsertRow(self.cnxn, emp_id=111, fulltime=True)
322 self.assertEqual(
323 'INSERT INTO Employee (emp_id, fulltime)'
324 '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
325 self.assertEqual(([111, 1],), self.primary_cnxn.last_executed_args)
326 self.assertEqual(123, generated_id)
327
328 def testInsertRows_Empty(self):
329 generated_id = self.emp_tbl.InsertRows(
330 self.cnxn, ['emp_id', 'fulltime'], [])
331 self.assertIsNone(self.primary_cnxn.last_executed)
332 self.assertIsNone(self.primary_cnxn.last_executed_args)
333 self.assertEqual(None, generated_id)
334
335 def testInsertRows(self):
336 self.primary_cnxn.rowcount = 2
337 generated_ids = self.emp_tbl.InsertRows(
338 self.cnxn, ['emp_id', 'fulltime'], [(111, True), (222, False)])
339 self.assertEqual(
340 'INSERT INTO Employee (emp_id, fulltime)'
341 '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
342 self.assertEqual(([111, 1], [222, 0]), self.primary_cnxn.last_executed_args)
343 self.assertEqual([], generated_ids)
344
345 def testUpdate(self):
346 self.primary_cnxn.rowcount = 2
347 rowcount = self.emp_tbl.Update(
348 self.cnxn, {'fulltime': True}, emp_id=[111, 222])
349 self.assertEqual(
350 'UPDATE Employee SET fulltime=1'
351 '\nWHERE emp_id IN (111,222)', self.primary_cnxn.last_executed)
352 self.assertEqual(2, rowcount)
353
354 def testUpdate_Limit(self):
355 self.emp_tbl.Update(
356 self.cnxn, {'fulltime': True}, limit=8, emp_id=[111, 222])
357 self.assertEqual(
358 'UPDATE Employee SET fulltime=1'
359 '\nWHERE emp_id IN (111,222)'
360 '\nLIMIT 8', self.primary_cnxn.last_executed)
361
362 def testIncrementCounterValue(self):
363 self.primary_cnxn.rowcount = 1
364 self.primary_cnxn.lastrowid = 9
365 new_counter_val = self.emp_tbl.IncrementCounterValue(
366 self.cnxn, 'years_worked', emp_id=111)
367 self.assertEqual(
368 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + 1)'
369 '\nWHERE emp_id = 111', self.primary_cnxn.last_executed)
370 self.assertEqual(9, new_counter_val)
371
372 def testDelete(self):
373 self.primary_cnxn.rowcount = 1
374 rowcount = self.emp_tbl.Delete(self.cnxn, fulltime=True)
375 self.assertEqual(
376 'DELETE FROM Employee'
377 '\nWHERE fulltime = 1', self.primary_cnxn.last_executed)
378 self.assertEqual(1, rowcount)
379
380 def testDelete_Limit(self):
381 self.emp_tbl.Delete(self.cnxn, fulltime=True, limit=3)
382 self.assertEqual(
383 'DELETE FROM Employee'
384 '\nWHERE fulltime = 1'
385 '\nLIMIT 3', self.primary_cnxn.last_executed)
386
387
388class StatementTest(unittest.TestCase):
389
390 def testMakeSelect(self):
391 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
392 stmt_str, args = stmt.Generate()
393 self.assertEqual(
394 'SELECT emp_id, fulltime FROM Employee',
395 stmt_str)
396 self.assertEqual([], args)
397
398 stmt = sql.Statement.MakeSelect(
399 'Employee', ['emp_id', 'fulltime'], distinct=True)
400 stmt_str, args = stmt.Generate()
401 self.assertEqual(
402 'SELECT DISTINCT emp_id, fulltime FROM Employee',
403 stmt_str)
404 self.assertEqual([], args)
405
406 def testMakeInsert(self):
407 stmt = sql.Statement.MakeInsert(
408 'Employee', ['emp_id', 'fulltime'], [(111, True), (222, False)])
409 stmt_str, args = stmt.Generate()
410 self.assertEqual(
411 'INSERT INTO Employee (emp_id, fulltime)'
412 '\nVALUES (%s,%s)',
413 stmt_str)
414 self.assertEqual([[111, 1], [222, 0]], args)
415
416 stmt = sql.Statement.MakeInsert(
417 'Employee', ['emp_id', 'fulltime'], [(111, False)], replace=True)
418 stmt_str, args = stmt.Generate()
419 self.assertEqual(
420 'INSERT INTO Employee (emp_id, fulltime)'
421 '\nVALUES (%s,%s)'
422 '\nON DUPLICATE KEY UPDATE '
423 'emp_id=VALUES(emp_id), fulltime=VALUES(fulltime)',
424 stmt_str)
425 self.assertEqual([[111, 0]], args)
426
427 stmt = sql.Statement.MakeInsert(
428 'Employee', ['emp_id', 'fulltime'], [(111, False)], ignore=True)
429 stmt_str, args = stmt.Generate()
430 self.assertEqual(
431 'INSERT IGNORE INTO Employee (emp_id, fulltime)'
432 '\nVALUES (%s,%s)',
433 stmt_str)
434 self.assertEqual([[111, 0]], args)
435
436 def testMakeInsert_InvalidString(self):
437 with self.assertRaises(exceptions.InputException):
438 sql.Statement.MakeInsert(
439 'Employee', ['emp_id', 'name'], [(111, 'First \x00 Last')])
440
441 def testMakeUpdate(self):
442 stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
443 stmt_str, args = stmt.Generate()
444 self.assertEqual(
445 'UPDATE Employee SET fulltime=%s',
446 stmt_str)
447 self.assertEqual([1], args)
448
449 def testMakeUpdate_InvalidString(self):
450 with self.assertRaises(exceptions.InputException):
451 sql.Statement.MakeUpdate('Employee', {'name': 'First \x00 Last'})
452
453 def testMakeIncrement(self):
454 stmt = sql.Statement.MakeIncrement('Employee', 'years_worked')
455 stmt_str, args = stmt.Generate()
456 self.assertEqual(
457 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
458 stmt_str)
459 self.assertEqual([1], args)
460
461 stmt = sql.Statement.MakeIncrement('Employee', 'years_worked', step=5)
462 stmt_str, args = stmt.Generate()
463 self.assertEqual(
464 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
465 stmt_str)
466 self.assertEqual([5], args)
467
468 def testMakeDelete(self):
469 stmt = sql.Statement.MakeDelete('Employee')
470 stmt_str, args = stmt.Generate()
471 self.assertEqual(
472 'DELETE FROM Employee',
473 stmt_str)
474 self.assertEqual([], args)
475
476 def testAddUseClause(self):
477 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
478 stmt.AddUseClause('USE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)')
479 stmt.AddOrderByTerms([('emp_id', [])])
480 stmt_str, args = stmt.Generate()
481 self.assertEqual(
482 'SELECT emp_id, fulltime FROM Employee'
483 '\nUSE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)'
484 '\nORDER BY emp_id',
485 stmt_str)
486 self.assertEqual([], args)
487
488 def testAddJoinClause_Empty(self):
489 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
490 stmt.AddJoinClauses([])
491 stmt_str, args = stmt.Generate()
492 self.assertEqual(
493 'SELECT emp_id, fulltime FROM Employee',
494 stmt_str)
495 self.assertEqual([], args)
496
497 def testAddJoinClause(self):
498 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
499 stmt.AddJoinClauses([('CorporateHoliday', [])])
500 stmt.AddJoinClauses(
501 [('Product ON Project.inventor_id = emp_id', [])], left=True)
502 stmt_str, args = stmt.Generate()
503 self.assertEqual(
504 'SELECT emp_id, fulltime FROM Employee'
505 '\n JOIN CorporateHoliday'
506 '\n LEFT JOIN Product ON Project.inventor_id = emp_id',
507 stmt_str)
508 self.assertEqual([], args)
509
510 def testAddGroupByTerms_Empty(self):
511 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
512 stmt.AddGroupByTerms([])
513 stmt_str, args = stmt.Generate()
514 self.assertEqual(
515 'SELECT emp_id, fulltime FROM Employee',
516 stmt_str)
517 self.assertEqual([], args)
518
519 def testAddGroupByTerms(self):
520 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
521 stmt.AddGroupByTerms(['dept_id', 'location_id'])
522 stmt_str, args = stmt.Generate()
523 self.assertEqual(
524 'SELECT emp_id, fulltime FROM Employee'
525 '\nGROUP BY dept_id, location_id',
526 stmt_str)
527 self.assertEqual([], args)
528
529 def testAddOrderByTerms_Empty(self):
530 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
531 stmt.AddOrderByTerms([])
532 stmt_str, args = stmt.Generate()
533 self.assertEqual(
534 'SELECT emp_id, fulltime FROM Employee',
535 stmt_str)
536 self.assertEqual([], args)
537
538 def testAddOrderByTerms(self):
539 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
540 stmt.AddOrderByTerms([('dept_id', []), ('emp_id DESC', [])])
541 stmt_str, args = stmt.Generate()
542 self.assertEqual(
543 'SELECT emp_id, fulltime FROM Employee'
544 '\nORDER BY dept_id, emp_id DESC',
545 stmt_str)
546 self.assertEqual([], args)
547
548 def testSetLimitAndOffset(self):
549 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
550 stmt.SetLimitAndOffset(100, 0)
551 stmt_str, args = stmt.Generate()
552 self.assertEqual(
553 'SELECT emp_id, fulltime FROM Employee'
554 '\nLIMIT 100',
555 stmt_str)
556 self.assertEqual([], args)
557
558 stmt.SetLimitAndOffset(100, 500)
559 stmt_str, args = stmt.Generate()
560 self.assertEqual(
561 'SELECT emp_id, fulltime FROM Employee'
562 '\nLIMIT 100 OFFSET 500',
563 stmt_str)
564 self.assertEqual([], args)
565
566 def testAddWhereTerms_Select(self):
567 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
568 stmt.AddWhereTerms([], emp_id=[111, 222])
569 stmt_str, args = stmt.Generate()
570 self.assertEqual(
571 'SELECT emp_id, fulltime FROM Employee'
572 '\nWHERE emp_id IN (%s,%s)',
573 stmt_str)
574 self.assertEqual([111, 222], args)
575
576 def testAddWhereTerms_Update(self):
577 stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
578 stmt.AddWhereTerms([], emp_id=[111, 222])
579 stmt_str, args = stmt.Generate()
580 self.assertEqual(
581 'UPDATE Employee SET fulltime=%s'
582 '\nWHERE emp_id IN (%s,%s)',
583 stmt_str)
584 self.assertEqual([1, 111, 222], args)
585
586 def testAddWhereTerms_Delete(self):
587 stmt = sql.Statement.MakeDelete('Employee')
588 stmt.AddWhereTerms([], emp_id=[111, 222])
589 stmt_str, args = stmt.Generate()
590 self.assertEqual(
591 'DELETE FROM Employee'
592 '\nWHERE emp_id IN (%s,%s)',
593 stmt_str)
594 self.assertEqual([111, 222], args)
595
596 def testAddWhereTerms_Empty(self):
597 """Add empty terms should have no effect."""
598 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
599 stmt.AddWhereTerms([])
600 stmt_str, args = stmt.Generate()
601 self.assertEqual(
602 'SELECT emp_id, fulltime FROM Employee',
603 stmt_str)
604 self.assertEqual([], args)
605
606 def testAddWhereTerms_UpdateEmptyArray(self):
607 """Add empty array should throw an exception."""
608 stmt = sql.Statement.MakeUpdate('SpamVerdict', {'user_id': 1})
609 # See https://crbug.com/monorail/6735.
610 with self.assertRaises(exceptions.InputException):
611 stmt.AddWhereTerms([], user_id=[])
612 mock_log.assert_called_once_with('Invalid update DB value %r', 'user_id')
613
614 def testAddWhereTerms_MulitpleTerms(self):
615 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
616 stmt.AddWhereTerms(
617 [('emp_id %% %s = %s', [2, 0])], fulltime=True, emp_id_not=222)
618 stmt_str, args = stmt.Generate()
619 self.assertEqual(
620 'SELECT emp_id, fulltime FROM Employee'
621 '\nWHERE emp_id %% %s = %s'
622 '\n AND emp_id != %s'
623 '\n AND fulltime = %s',
624 stmt_str)
625 self.assertEqual([2, 0, 222, 1], args)
626
627 def testAddHavingTerms_NoGroupBy(self):
628 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
629 stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
630 self.assertRaises(AssertionError, stmt.Generate)
631
632 def testAddHavingTerms_WithGroupBy(self):
633 stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
634 stmt.AddGroupByTerms(['dept_id', 'location_id'])
635 stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
636 stmt_str, args = stmt.Generate()
637 self.assertEqual(
638 'SELECT emp_id, fulltime FROM Employee'
639 '\nGROUP BY dept_id, location_id'
640 '\nHAVING COUNT(*) > %s',
641 stmt_str)
642 self.assertEqual([10], args)
643
644
645class FunctionsTest(unittest.TestCase):
646
647 def testIsValidDBValue_NonString(self):
648 self.assertTrue(sql._IsValidDBValue(12))
649 self.assertTrue(sql._IsValidDBValue(True))
650 self.assertTrue(sql._IsValidDBValue(False))
651 self.assertTrue(sql._IsValidDBValue(None))
652
653 def testIsValidDBValue_String(self):
654 self.assertTrue(sql._IsValidDBValue(''))
655 self.assertTrue(sql._IsValidDBValue('hello'))
656 self.assertTrue(sql._IsValidDBValue(u'hello'))
657 self.assertFalse(sql._IsValidDBValue('null \x00 byte'))
658
659 def testBoolsToInts_NoChanges(self):
660 self.assertEqual(['hello'], sql._BoolsToInts(['hello']))
661 self.assertEqual([['hello']], sql._BoolsToInts([['hello']]))
662 self.assertEqual([['hello']], sql._BoolsToInts([('hello',)]))
663 self.assertEqual([12], sql._BoolsToInts([12]))
664 self.assertEqual([[12]], sql._BoolsToInts([[12]]))
665 self.assertEqual([[12]], sql._BoolsToInts([(12,)]))
666 self.assertEqual(
667 [12, 13, 'hi', [99, 'yo']],
668 sql._BoolsToInts([12, 13, 'hi', [99, 'yo']]))
669
670 def testBoolsToInts_WithChanges(self):
671 self.assertEqual([1, 0], sql._BoolsToInts([True, False]))
672 self.assertEqual([[1, 0]], sql._BoolsToInts([[True, False]]))
673 self.assertEqual([[1, 0]], sql._BoolsToInts([(True, False)]))
674 self.assertEqual(
675 [12, 1, 'hi', [0, 'yo']],
676 sql._BoolsToInts([12, True, 'hi', [False, 'yo']]))
677
678 def testRandomShardID(self):
679 """A random shard ID must always be a valid shard ID."""
680 shard_id = sql.RandomShardID()
681 self.assertTrue(0 <= shard_id < settings.num_logical_shards)