| # Copyright 2016 The Chromium Authors |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| |
| """Unit tests for the sql module.""" |
| from __future__ import print_function |
| from __future__ import division |
| from __future__ import absolute_import |
| |
| import logging |
| import mock |
| import time |
| import unittest |
| |
| import settings |
| from framework import exceptions |
| from framework import sql |
| |
| |
| class MockSQLCnxn(object): |
| """This class mocks the connection and cursor classes.""" |
| |
| def __init__(self, instance, database): |
| self.instance = instance |
| self.database = database |
| self.last_executed = None |
| self.last_executed_args = None |
| self.result_rows = None |
| self.rowcount = 0 |
| self.lastrowid = None |
| self.pool_key = instance + '/' + database |
| self.is_bad = False |
| self.has_uncommitted = False |
| |
| def execute(self, stmt_str, args=None): |
| self.last_executed = stmt_str % tuple(args or []) |
| if not stmt_str.startswith(('SET', 'SELECT')): |
| self.has_uncommitted = True |
| |
| def executemany(self, stmt_str, args): |
| # We cannot format the string because args has many values for each %s. |
| self.last_executed = stmt_str |
| self.last_executed_args = tuple(args) |
| |
| # sql.py only calls executemany() for INSERT. |
| assert stmt_str.startswith('INSERT') |
| self.lastrowid = 123 |
| |
| def fetchall(self): |
| return self.result_rows |
| |
| def cursor(self): |
| return self |
| |
| def commit(self): |
| self.has_uncommitted = False |
| |
| def close(self): |
| assert not self.has_uncommitted |
| |
| def rollback(self): |
| self.has_uncommitted = False |
| |
| def ping(self): |
| if self.is_bad: |
| raise BaseException('connection error!') |
| |
| |
| sql.cnxn_ctor = MockSQLCnxn |
| |
| |
| class ConnectionPoolingTest(unittest.TestCase): |
| |
| def testGet(self): |
| pool_size = 2 |
| num_dbs = 2 |
| p = sql.ConnectionPool(pool_size) |
| |
| for i in range(num_dbs): |
| for _ in range(pool_size): |
| c = p.get('test', 'db%d' % i) |
| self.assertIsNotNone(c) |
| p.release(c) |
| |
| cnxn1 = p.get('test', 'db0') |
| q = p.queues[cnxn1.pool_key] |
| self.assertIs(q.qsize(), 0) |
| |
| p.release(cnxn1) |
| self.assertIs(q.qsize(), pool_size - 1) |
| self.assertIs(q.full(), False) |
| self.assertIs(q.empty(), False) |
| |
| cnxn2 = p.get('test', 'db0') |
| q = p.queues[cnxn2.pool_key] |
| self.assertIs(q.qsize(), 0) |
| self.assertIs(q.full(), False) |
| self.assertIs(q.empty(), True) |
| |
| def testGetAndReturnPooledCnxn(self): |
| p = sql.ConnectionPool(2) |
| |
| cnxn1 = p.get('test', 'db1') |
| self.assertIs(len(p.queues), 1) |
| |
| cnxn2 = p.get('test', 'db2') |
| self.assertIs(len(p.queues), 2) |
| |
| # Should use the existing pool. |
| cnxn3 = p.get('test', 'db1') |
| self.assertIs(len(p.queues), 2) |
| |
| p.release(cnxn3) |
| p.release(cnxn2) |
| |
| cnxn1.is_bad = True |
| p.release(cnxn1) |
| # cnxn1 should not be returned from the pool if we |
| # ask for a connection to its database. |
| |
| cnxn4 = p.get('test', 'db1') |
| |
| self.assertIsNot(cnxn1, cnxn4) |
| self.assertIs(len(p.queues), 2) |
| self.assertIs(cnxn4.is_bad, False) |
| |
| def testGetAndReturnPooledCnxn_badCnxn(self): |
| p = sql.ConnectionPool(2) |
| |
| cnxn1 = p.get('test', 'db1') |
| cnxn2 = p.get('test', 'db2') |
| cnxn3 = p.get('test', 'db1') |
| |
| cnxn3.is_bad = True |
| |
| p.release(cnxn3) |
| q = p.queues[cnxn3.pool_key] |
| self.assertIs(q.qsize(), 1) |
| |
| with self.assertRaises(BaseException): |
| cnxn3 = p.get('test', 'db1') |
| |
| q = p.queues[cnxn2.pool_key] |
| self.assertIs(q.qsize(), 0) |
| p.release(cnxn2) |
| self.assertIs(q.qsize(), 1) |
| |
| p.release(cnxn1) |
| q = p.queues[cnxn1.pool_key] |
| self.assertIs(q.qsize(), 1) |
| |
| |
| class MonorailConnectionTest(unittest.TestCase): |
| |
| def setUp(self): |
| self.cnxn = sql.MonorailConnection() |
| self.orig_local_mode = settings.local_mode |
| self.orig_num_logical_shards = settings.num_logical_shards |
| settings.local_mode = False |
| |
| def tearDown(self): |
| settings.local_mode = self.orig_local_mode |
| settings.num_logical_shards = self.orig_num_logical_shards |
| |
| def testGetPrimaryConnection(self): |
| sql_cnxn = self.cnxn.GetPrimaryConnection() |
| self.assertEqual(settings.db_instance, sql_cnxn.instance) |
| self.assertEqual(settings.db_database_name, sql_cnxn.database) |
| |
| sql_cnxn2 = self.cnxn.GetPrimaryConnection() |
| self.assertIs(sql_cnxn2, sql_cnxn) |
| |
| def testGetConnectionForShard(self): |
| sql_cnxn = self.cnxn.GetConnectionForShard(1) |
| replica_name = settings.db_replica_names[ |
| 1 % len(settings.db_replica_names)] |
| self.assertEqual(settings.physical_db_name_format % replica_name, |
| sql_cnxn.instance) |
| self.assertEqual(settings.db_database_name, sql_cnxn.database) |
| |
| sql_cnxn2 = self.cnxn.GetConnectionForShard(1) |
| self.assertIs(sql_cnxn2, sql_cnxn) |
| |
| def testClose(self): |
| sql_cnxn = self.cnxn.GetPrimaryConnection() |
| self.cnxn.Close() |
| self.assertFalse(sql_cnxn.has_uncommitted) |
| |
| def testExecute_Primary(self): |
| """Execute() with no shard passes the statement to the primary sql cnxn.""" |
| sql_cnxn = self.cnxn.GetPrimaryConnection() |
| with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc: |
| ewsc.return_value = 'db result' |
| actual_result = self.cnxn.Execute('statement', []) |
| self.assertEqual('db result', actual_result) |
| ewsc.assert_called_once_with(sql_cnxn, 'statement', [], commit=True) |
| |
| def testExecute_Shard(self): |
| """Execute() with a shard passes the statement to the shard sql cnxn.""" |
| shard_id = 1 |
| sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id) |
| with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc: |
| ewsc.return_value = 'db result' |
| actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id) |
| self.assertEqual('db result', actual_result) |
| ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True) |
| |
| def testExecute_Shard_Unavailable(self): |
| """If a shard is unavailable, we try the next one.""" |
| shard_id = 1 |
| sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id) |
| sql_cnxn_2 = self.cnxn.GetConnectionForShard(shard_id + 1) |
| |
| # Simulate a recent failure on shard 1. |
| self.cnxn.unavailable_shards[1] = int(time.time()) - 3 |
| |
| with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc: |
| ewsc.return_value = 'db result' |
| actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id) |
| self.assertEqual('db result', actual_result) |
| ewsc.assert_called_once_with(sql_cnxn_2, 'statement', [], commit=True) |
| |
| # Even a new MonorailConnection instance shares the same state. |
| other_cnxn = sql.MonorailConnection() |
| other_sql_cnxn_2 = other_cnxn.GetConnectionForShard(shard_id + 1) |
| |
| with mock.patch.object(other_cnxn, '_ExecuteWithSQLConnection') as ewsc: |
| ewsc.return_value = 'db result' |
| actual_result = other_cnxn.Execute('statement', [], shard_id=shard_id) |
| self.assertEqual('db result', actual_result) |
| ewsc.assert_called_once_with( |
| other_sql_cnxn_2, 'statement', [], commit=True) |
| |
| # Simulate an old failure on shard 1, allowing us to try using it again. |
| self.cnxn.unavailable_shards[1] = ( |
| int(time.time()) - sql.BAD_SHARD_AVOIDANCE_SEC - 2) |
| |
| with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc: |
| ewsc.return_value = 'db result' |
| actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id) |
| self.assertEqual('db result', actual_result) |
| ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True) |
| |
| |
| class TableManagerTest(unittest.TestCase): |
| |
| def setUp(self): |
| self.emp_tbl = sql.SQLTableManager('Employee') |
| self.cnxn = sql.MonorailConnection() |
| self.primary_cnxn = self.cnxn.GetPrimaryConnection() |
| |
| def testSelect_Trivial(self): |
| self.primary_cnxn.result_rows = [(111, True), (222, False)] |
| rows = self.emp_tbl.Select(self.cnxn) |
| self.assertEqual('SELECT * FROM Employee', self.primary_cnxn.last_executed) |
| self.assertEqual([(111, True), (222, False)], rows) |
| |
| def testSelect_Conditions(self): |
| self.primary_cnxn.result_rows = [(111,)] |
| rows = self.emp_tbl.Select( |
| self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20]) |
| self.assertEqual( |
| 'SELECT emp_id FROM Employee' |
| '\nWHERE dept_id IN (10,20)' |
| '\n AND fulltime = 1', self.primary_cnxn.last_executed) |
| self.assertEqual([(111,)], rows) |
| |
| def testSelectRow(self): |
| self.primary_cnxn.result_rows = [(111,)] |
| row = self.emp_tbl.SelectRow( |
| self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20]) |
| self.assertEqual( |
| 'SELECT DISTINCT emp_id FROM Employee' |
| '\nWHERE dept_id IN (10,20)' |
| '\n AND fulltime = 1', self.primary_cnxn.last_executed) |
| self.assertEqual((111,), row) |
| |
| def testSelectRow_NoMatches(self): |
| self.primary_cnxn.result_rows = [] |
| row = self.emp_tbl.SelectRow( |
| self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99]) |
| self.assertEqual( |
| 'SELECT DISTINCT emp_id FROM Employee' |
| '\nWHERE dept_id IN (99)' |
| '\n AND fulltime = 1', self.primary_cnxn.last_executed) |
| self.assertEqual(None, row) |
| |
| row = self.emp_tbl.SelectRow( |
| self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99], |
| default=(-1,)) |
| self.assertEqual((-1,), row) |
| |
| def testSelectValue(self): |
| self.primary_cnxn.result_rows = [(111,)] |
| val = self.emp_tbl.SelectValue( |
| self.cnxn, 'emp_id', fulltime=True, dept_id=[10, 20]) |
| self.assertEqual( |
| 'SELECT DISTINCT emp_id FROM Employee' |
| '\nWHERE dept_id IN (10,20)' |
| '\n AND fulltime = 1', self.primary_cnxn.last_executed) |
| self.assertEqual(111, val) |
| |
| def testSelectValue_NoMatches(self): |
| self.primary_cnxn.result_rows = [] |
| val = self.emp_tbl.SelectValue( |
| self.cnxn, 'emp_id', fulltime=True, dept_id=[99]) |
| self.assertEqual( |
| 'SELECT DISTINCT emp_id FROM Employee' |
| '\nWHERE dept_id IN (99)' |
| '\n AND fulltime = 1', self.primary_cnxn.last_executed) |
| self.assertEqual(None, val) |
| |
| val = self.emp_tbl.SelectValue( |
| self.cnxn, 'emp_id', fulltime=True, dept_id=[99], |
| default=-1) |
| self.assertEqual(-1, val) |
| |
| def testInsertRow(self): |
| self.primary_cnxn.rowcount = 1 |
| generated_id = self.emp_tbl.InsertRow(self.cnxn, emp_id=111, fulltime=True) |
| self.assertEqual( |
| 'INSERT INTO Employee (emp_id, fulltime)' |
| '\nVALUES (%s,%s)', self.primary_cnxn.last_executed) |
| self.assertEqual(([111, 1],), self.primary_cnxn.last_executed_args) |
| self.assertEqual(123, generated_id) |
| |
| def testInsertRows_Empty(self): |
| generated_id = self.emp_tbl.InsertRows( |
| self.cnxn, ['emp_id', 'fulltime'], []) |
| self.assertIsNone(self.primary_cnxn.last_executed) |
| self.assertIsNone(self.primary_cnxn.last_executed_args) |
| self.assertEqual(None, generated_id) |
| |
| def testInsertRows(self): |
| self.primary_cnxn.rowcount = 2 |
| generated_ids = self.emp_tbl.InsertRows( |
| self.cnxn, ['emp_id', 'fulltime'], [(111, True), (222, False)]) |
| self.assertEqual( |
| 'INSERT INTO Employee (emp_id, fulltime)' |
| '\nVALUES (%s,%s)', self.primary_cnxn.last_executed) |
| self.assertEqual(([111, 1], [222, 0]), self.primary_cnxn.last_executed_args) |
| self.assertEqual([], generated_ids) |
| |
| def testUpdate(self): |
| self.primary_cnxn.rowcount = 2 |
| rowcount = self.emp_tbl.Update( |
| self.cnxn, {'fulltime': True}, emp_id=[111, 222]) |
| self.assertEqual( |
| 'UPDATE Employee SET fulltime=1' |
| '\nWHERE emp_id IN (111,222)', self.primary_cnxn.last_executed) |
| self.assertEqual(2, rowcount) |
| |
| def testUpdate_Limit(self): |
| self.emp_tbl.Update( |
| self.cnxn, {'fulltime': True}, limit=8, emp_id=[111, 222]) |
| self.assertEqual( |
| 'UPDATE Employee SET fulltime=1' |
| '\nWHERE emp_id IN (111,222)' |
| '\nLIMIT 8', self.primary_cnxn.last_executed) |
| |
| def testIncrementCounterValue(self): |
| self.primary_cnxn.rowcount = 1 |
| self.primary_cnxn.lastrowid = 9 |
| new_counter_val = self.emp_tbl.IncrementCounterValue( |
| self.cnxn, 'years_worked', emp_id=111) |
| self.assertEqual( |
| 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + 1)' |
| '\nWHERE emp_id = 111', self.primary_cnxn.last_executed) |
| self.assertEqual(9, new_counter_val) |
| |
| def testDelete(self): |
| self.primary_cnxn.rowcount = 1 |
| rowcount = self.emp_tbl.Delete(self.cnxn, fulltime=True) |
| self.assertEqual( |
| 'DELETE FROM Employee' |
| '\nWHERE fulltime = 1', self.primary_cnxn.last_executed) |
| self.assertEqual(1, rowcount) |
| |
| def testDelete_Limit(self): |
| self.emp_tbl.Delete(self.cnxn, fulltime=True, limit=3) |
| self.assertEqual( |
| 'DELETE FROM Employee' |
| '\nWHERE fulltime = 1' |
| '\nLIMIT 3', self.primary_cnxn.last_executed) |
| |
| |
| class StatementTest(unittest.TestCase): |
| |
| def testMakeSelect(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| stmt = sql.Statement.MakeSelect( |
| 'Employee', ['emp_id', 'fulltime'], distinct=True) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT DISTINCT emp_id, fulltime FROM Employee', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testMakeInsert(self): |
| stmt = sql.Statement.MakeInsert( |
| 'Employee', ['emp_id', 'fulltime'], [(111, True), (222, False)]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'INSERT INTO Employee (emp_id, fulltime)' |
| '\nVALUES (%s,%s)', |
| stmt_str) |
| self.assertEqual([[111, 1], [222, 0]], args) |
| |
| stmt = sql.Statement.MakeInsert( |
| 'Employee', ['emp_id', 'fulltime'], [(111, False)], replace=True) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'INSERT INTO Employee (emp_id, fulltime)' |
| '\nVALUES (%s,%s)' |
| '\nON DUPLICATE KEY UPDATE ' |
| 'emp_id=VALUES(emp_id), fulltime=VALUES(fulltime)', |
| stmt_str) |
| self.assertEqual([[111, 0]], args) |
| |
| stmt = sql.Statement.MakeInsert( |
| 'Employee', ['emp_id', 'fulltime'], [(111, False)], ignore=True) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'INSERT IGNORE INTO Employee (emp_id, fulltime)' |
| '\nVALUES (%s,%s)', |
| stmt_str) |
| self.assertEqual([[111, 0]], args) |
| |
| def testMakeInsert_InvalidString(self): |
| with self.assertRaises(exceptions.InputException): |
| sql.Statement.MakeInsert( |
| 'Employee', ['emp_id', 'name'], [(111, 'First \x00 Last')]) |
| |
| def testMakeUpdate(self): |
| stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True}) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'UPDATE Employee SET fulltime=%s', |
| stmt_str) |
| self.assertEqual([1], args) |
| |
| def testMakeUpdate_InvalidString(self): |
| with self.assertRaises(exceptions.InputException): |
| sql.Statement.MakeUpdate('Employee', {'name': 'First \x00 Last'}) |
| |
| def testMakeIncrement(self): |
| stmt = sql.Statement.MakeIncrement('Employee', 'years_worked') |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)', |
| stmt_str) |
| self.assertEqual([1], args) |
| |
| stmt = sql.Statement.MakeIncrement('Employee', 'years_worked', step=5) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)', |
| stmt_str) |
| self.assertEqual([5], args) |
| |
| def testMakeDelete(self): |
| stmt = sql.Statement.MakeDelete('Employee') |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'DELETE FROM Employee', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testAddUseClause(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddUseClause('USE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)') |
| stmt.AddOrderByTerms([('emp_id', [])]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee' |
| '\nUSE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)' |
| '\nORDER BY emp_id', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testAddJoinClause_Empty(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddJoinClauses([]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testAddJoinClause(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddJoinClauses([('CorporateHoliday', [])]) |
| stmt.AddJoinClauses( |
| [('Product ON Project.inventor_id = emp_id', [])], left=True) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee' |
| '\n JOIN CorporateHoliday' |
| '\n LEFT JOIN Product ON Project.inventor_id = emp_id', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testAddGroupByTerms_Empty(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddGroupByTerms([]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testAddGroupByTerms(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddGroupByTerms(['dept_id', 'location_id']) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee' |
| '\nGROUP BY dept_id, location_id', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testAddOrderByTerms_Empty(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddOrderByTerms([]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testAddOrderByTerms(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddOrderByTerms([('dept_id', []), ('emp_id DESC', [])]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee' |
| '\nORDER BY dept_id, emp_id DESC', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testSetLimitAndOffset(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.SetLimitAndOffset(100, 0) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee' |
| '\nLIMIT 100', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| stmt.SetLimitAndOffset(100, 500) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee' |
| '\nLIMIT 100 OFFSET 500', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testAddWhereTerms_Select(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddWhereTerms([], emp_id=[111, 222]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee' |
| '\nWHERE emp_id IN (%s,%s)', |
| stmt_str) |
| self.assertEqual([111, 222], args) |
| |
| def testAddWhereTerms_Update(self): |
| stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True}) |
| stmt.AddWhereTerms([], emp_id=[111, 222]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'UPDATE Employee SET fulltime=%s' |
| '\nWHERE emp_id IN (%s,%s)', |
| stmt_str) |
| self.assertEqual([1, 111, 222], args) |
| |
| def testAddWhereTerms_Delete(self): |
| stmt = sql.Statement.MakeDelete('Employee') |
| stmt.AddWhereTerms([], emp_id=[111, 222]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'DELETE FROM Employee' |
| '\nWHERE emp_id IN (%s,%s)', |
| stmt_str) |
| self.assertEqual([111, 222], args) |
| |
| def testAddWhereTerms_Empty(self): |
| """Add empty terms should have no effect.""" |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddWhereTerms([]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee', |
| stmt_str) |
| self.assertEqual([], args) |
| |
| def testAddWhereTerms_UpdateEmptyArray(self): |
| """Add empty array should throw an exception.""" |
| stmt = sql.Statement.MakeUpdate('SpamVerdict', {'user_id': 1}) |
| # See https://crbug.com/monorail/6735. |
| with self.assertRaises(exceptions.InputException): |
| stmt.AddWhereTerms([], user_id=[]) |
| mock_log.assert_called_once_with('Invalid update DB value %r', 'user_id') |
| |
| def testAddWhereTerms_MulitpleTerms(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddWhereTerms( |
| [('emp_id %% %s = %s', [2, 0])], fulltime=True, emp_id_not=222) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee' |
| '\nWHERE emp_id %% %s = %s' |
| '\n AND emp_id != %s' |
| '\n AND fulltime = %s', |
| stmt_str) |
| self.assertEqual([2, 0, 222, 1], args) |
| |
| def testAddHavingTerms_NoGroupBy(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddHavingTerms([('COUNT(*) > %s', [10])]) |
| self.assertRaises(AssertionError, stmt.Generate) |
| |
| def testAddHavingTerms_WithGroupBy(self): |
| stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime']) |
| stmt.AddGroupByTerms(['dept_id', 'location_id']) |
| stmt.AddHavingTerms([('COUNT(*) > %s', [10])]) |
| stmt_str, args = stmt.Generate() |
| self.assertEqual( |
| 'SELECT emp_id, fulltime FROM Employee' |
| '\nGROUP BY dept_id, location_id' |
| '\nHAVING COUNT(*) > %s', |
| stmt_str) |
| self.assertEqual([10], args) |
| |
| |
| class FunctionsTest(unittest.TestCase): |
| |
| def testIsValidDBValue_NonString(self): |
| self.assertTrue(sql._IsValidDBValue(12)) |
| self.assertTrue(sql._IsValidDBValue(True)) |
| self.assertTrue(sql._IsValidDBValue(False)) |
| self.assertTrue(sql._IsValidDBValue(None)) |
| |
| def testIsValidDBValue_String(self): |
| self.assertTrue(sql._IsValidDBValue('')) |
| self.assertTrue(sql._IsValidDBValue('hello')) |
| self.assertTrue(sql._IsValidDBValue(u'hello')) |
| self.assertFalse(sql._IsValidDBValue('null \x00 byte')) |
| |
| def testBoolsToInts_NoChanges(self): |
| self.assertEqual(['hello'], sql._BoolsToInts(['hello'])) |
| self.assertEqual([['hello']], sql._BoolsToInts([['hello']])) |
| self.assertEqual([['hello']], sql._BoolsToInts([('hello',)])) |
| self.assertEqual([12], sql._BoolsToInts([12])) |
| self.assertEqual([[12]], sql._BoolsToInts([[12]])) |
| self.assertEqual([[12]], sql._BoolsToInts([(12,)])) |
| self.assertEqual( |
| [12, 13, 'hi', [99, 'yo']], |
| sql._BoolsToInts([12, 13, 'hi', [99, 'yo']])) |
| |
| def testBoolsToInts_WithChanges(self): |
| self.assertEqual([1, 0], sql._BoolsToInts([True, False])) |
| self.assertEqual([[1, 0]], sql._BoolsToInts([[True, False]])) |
| self.assertEqual([[1, 0]], sql._BoolsToInts([(True, False)])) |
| self.assertEqual( |
| [12, 1, 'hi', [0, 'yo']], |
| sql._BoolsToInts([12, True, 'hi', [False, 'yo']])) |
| |
| def testRandomShardID(self): |
| """A random shard ID must always be a valid shard ID.""" |
| shard_id = sql.RandomShardID() |
| self.assertTrue(0 <= shard_id < settings.num_logical_shards) |