Project import generated by Copybara.

GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/framework/test/sql_test.py b/framework/test/sql_test.py
new file mode 100644
index 0000000..f073e24
--- /dev/null
+++ b/framework/test/sql_test.py
@@ -0,0 +1,681 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the sql module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import mock
+import time
+import unittest
+
+import settings
+from framework import exceptions
+from framework import sql
+
+
+class MockSQLCnxn(object):
+  """This class mocks the connection and cursor classes."""
+
+  def __init__(self, instance, database):
+    self.instance = instance
+    self.database = database
+    self.last_executed = None
+    self.last_executed_args = None
+    self.result_rows = None
+    self.rowcount = 0
+    self.lastrowid = None
+    self.pool_key = instance + '/' + database
+    self.is_bad = False
+    self.has_uncommitted = False
+
+  def execute(self, stmt_str, args=None):
+    self.last_executed = stmt_str % tuple(args or [])
+    if not stmt_str.startswith(('SET', 'SELECT')):
+      self.has_uncommitted = True
+
+  def executemany(self, stmt_str, args):
+    # We cannot format the string because args has many values for each %s.
+    self.last_executed = stmt_str
+    self.last_executed_args = tuple(args)
+
+    # sql.py only calls executemany() for INSERT.
+    assert stmt_str.startswith('INSERT')
+    self.lastrowid = 123
+
+  def fetchall(self):
+    return self.result_rows
+
+  def cursor(self):
+    return self
+
+  def commit(self):
+    self.has_uncommitted = False
+
+  def close(self):
+    assert not self.has_uncommitted
+
+  def rollback(self):
+    self.has_uncommitted = False
+
+  def ping(self):
+    if self.is_bad:
+      raise BaseException('connection error!')
+
+
+sql.cnxn_ctor = MockSQLCnxn
+
+
+class ConnectionPoolingTest(unittest.TestCase):
+
+  def testGet(self):
+    pool_size = 2
+    num_dbs = 2
+    p = sql.ConnectionPool(pool_size)
+
+    for i in range(num_dbs):
+      for _ in range(pool_size):
+        c = p.get('test', 'db%d' % i)
+        self.assertIsNotNone(c)
+        p.release(c)
+
+    cnxn1 = p.get('test', 'db0')
+    q = p.queues[cnxn1.pool_key]
+    self.assertIs(q.qsize(), 0)
+
+    p.release(cnxn1)
+    self.assertIs(q.qsize(), pool_size - 1)
+    self.assertIs(q.full(), False)
+    self.assertIs(q.empty(), False)
+
+    cnxn2 = p.get('test', 'db0')
+    q = p.queues[cnxn2.pool_key]
+    self.assertIs(q.qsize(), 0)
+    self.assertIs(q.full(), False)
+    self.assertIs(q.empty(), True)
+
+  def testGetAndReturnPooledCnxn(self):
+    p = sql.ConnectionPool(2)
+
+    cnxn1 = p.get('test', 'db1')
+    self.assertIs(len(p.queues), 1)
+
+    cnxn2 = p.get('test', 'db2')
+    self.assertIs(len(p.queues), 2)
+
+    # Should use the existing pool.
+    cnxn3 = p.get('test', 'db1')
+    self.assertIs(len(p.queues), 2)
+
+    p.release(cnxn3)
+    p.release(cnxn2)
+
+    cnxn1.is_bad = True
+    p.release(cnxn1)
+    # cnxn1 should not be returned from the pool if we
+    # ask for a connection to its database.
+
+    cnxn4 = p.get('test', 'db1')
+
+    self.assertIsNot(cnxn1, cnxn4)
+    self.assertIs(len(p.queues), 2)
+    self.assertIs(cnxn4.is_bad, False)
+
+  def testGetAndReturnPooledCnxn_badCnxn(self):
+    p = sql.ConnectionPool(2)
+
+    cnxn1 = p.get('test', 'db1')
+    cnxn2 = p.get('test', 'db2')
+    cnxn3 = p.get('test', 'db1')
+
+    cnxn3.is_bad = True
+
+    p.release(cnxn3)
+    q = p.queues[cnxn3.pool_key]
+    self.assertIs(q.qsize(), 1)
+
+    with self.assertRaises(BaseException):
+      cnxn3 = p.get('test', 'db1')
+
+    q = p.queues[cnxn2.pool_key]
+    self.assertIs(q.qsize(), 0)
+    p.release(cnxn2)
+    self.assertIs(q.qsize(), 1)
+
+    p.release(cnxn1)
+    q = p.queues[cnxn1.pool_key]
+    self.assertIs(q.qsize(), 1)
+
+
+class MonorailConnectionTest(unittest.TestCase):
+
+  def setUp(self):
+    self.cnxn = sql.MonorailConnection()
+    self.orig_local_mode = settings.local_mode
+    self.orig_num_logical_shards = settings.num_logical_shards
+    settings.local_mode = False
+
+  def tearDown(self):
+    settings.local_mode = self.orig_local_mode
+    settings.num_logical_shards = self.orig_num_logical_shards
+
+  def testGetPrimaryConnection(self):
+    sql_cnxn = self.cnxn.GetPrimaryConnection()
+    self.assertEqual(settings.db_instance, sql_cnxn.instance)
+    self.assertEqual(settings.db_database_name, sql_cnxn.database)
+
+    sql_cnxn2 = self.cnxn.GetPrimaryConnection()
+    self.assertIs(sql_cnxn2, sql_cnxn)
+
+  def testGetConnectionForShard(self):
+    sql_cnxn = self.cnxn.GetConnectionForShard(1)
+    replica_name = settings.db_replica_names[
+      1 % len(settings.db_replica_names)]
+    self.assertEqual(settings.physical_db_name_format % replica_name,
+                      sql_cnxn.instance)
+    self.assertEqual(settings.db_database_name, sql_cnxn.database)
+
+    sql_cnxn2 = self.cnxn.GetConnectionForShard(1)
+    self.assertIs(sql_cnxn2, sql_cnxn)
+
+  def testClose(self):
+    sql_cnxn = self.cnxn.GetPrimaryConnection()
+    self.cnxn.Close()
+    self.assertFalse(sql_cnxn.has_uncommitted)
+
+  def testExecute_Primary(self):
+    """Execute() with no shard passes the statement to the primary sql cnxn."""
+    sql_cnxn = self.cnxn.GetPrimaryConnection()
+    with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+      ewsc.return_value = 'db result'
+      actual_result = self.cnxn.Execute('statement', [])
+      self.assertEqual('db result', actual_result)
+      ewsc.assert_called_once_with(sql_cnxn, 'statement', [], commit=True)
+
+  def testExecute_Shard(self):
+    """Execute() with a shard passes the statement to the shard sql cnxn."""
+    shard_id = 1
+    sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
+    with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+      ewsc.return_value = 'db result'
+      actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
+      self.assertEqual('db result', actual_result)
+      ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
+
+  def testExecute_Shard_Unavailable(self):
+    """If a shard is unavailable, we try the next one."""
+    shard_id = 1
+    sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
+    sql_cnxn_2 = self.cnxn.GetConnectionForShard(shard_id + 1)
+
+    # Simulate a recent failure on shard 1.
+    self.cnxn.unavailable_shards[1] = int(time.time()) - 3
+
+    with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+      ewsc.return_value = 'db result'
+      actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
+      self.assertEqual('db result', actual_result)
+      ewsc.assert_called_once_with(sql_cnxn_2, 'statement', [], commit=True)
+
+    # Even a new MonorailConnection instance shares the same state.
+    other_cnxn = sql.MonorailConnection()
+    other_sql_cnxn_2 = other_cnxn.GetConnectionForShard(shard_id + 1)
+
+    with mock.patch.object(other_cnxn, '_ExecuteWithSQLConnection') as ewsc:
+      ewsc.return_value = 'db result'
+      actual_result = other_cnxn.Execute('statement', [], shard_id=shard_id)
+      self.assertEqual('db result', actual_result)
+      ewsc.assert_called_once_with(
+          other_sql_cnxn_2, 'statement', [], commit=True)
+
+    # Simulate an old failure on shard 1, allowing us to try using it again.
+    self.cnxn.unavailable_shards[1] = (
+        int(time.time()) - sql.BAD_SHARD_AVOIDANCE_SEC - 2)
+
+    with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+      ewsc.return_value = 'db result'
+      actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
+      self.assertEqual('db result', actual_result)
+      ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
+
+
+class TableManagerTest(unittest.TestCase):
+
+  def setUp(self):
+    self.emp_tbl = sql.SQLTableManager('Employee')
+    self.cnxn = sql.MonorailConnection()
+    self.primary_cnxn = self.cnxn.GetPrimaryConnection()
+
+  def testSelect_Trivial(self):
+    self.primary_cnxn.result_rows = [(111, True), (222, False)]
+    rows = self.emp_tbl.Select(self.cnxn)
+    self.assertEqual('SELECT * FROM Employee', self.primary_cnxn.last_executed)
+    self.assertEqual([(111, True), (222, False)], rows)
+
+  def testSelect_Conditions(self):
+    self.primary_cnxn.result_rows = [(111,)]
+    rows = self.emp_tbl.Select(
+        self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
+    self.assertEqual(
+        'SELECT emp_id FROM Employee'
+        '\nWHERE dept_id IN (10,20)'
+        '\n  AND fulltime = 1', self.primary_cnxn.last_executed)
+    self.assertEqual([(111,)], rows)
+
+  def testSelectRow(self):
+    self.primary_cnxn.result_rows = [(111,)]
+    row = self.emp_tbl.SelectRow(
+        self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
+    self.assertEqual(
+        'SELECT DISTINCT emp_id FROM Employee'
+        '\nWHERE dept_id IN (10,20)'
+        '\n  AND fulltime = 1', self.primary_cnxn.last_executed)
+    self.assertEqual((111,), row)
+
+  def testSelectRow_NoMatches(self):
+    self.primary_cnxn.result_rows = []
+    row = self.emp_tbl.SelectRow(
+        self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99])
+    self.assertEqual(
+        'SELECT DISTINCT emp_id FROM Employee'
+        '\nWHERE dept_id IN (99)'
+        '\n  AND fulltime = 1', self.primary_cnxn.last_executed)
+    self.assertEqual(None, row)
+
+    row = self.emp_tbl.SelectRow(
+        self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99],
+        default=(-1,))
+    self.assertEqual((-1,), row)
+
+  def testSelectValue(self):
+    self.primary_cnxn.result_rows = [(111,)]
+    val = self.emp_tbl.SelectValue(
+        self.cnxn, 'emp_id', fulltime=True, dept_id=[10, 20])
+    self.assertEqual(
+        'SELECT DISTINCT emp_id FROM Employee'
+        '\nWHERE dept_id IN (10,20)'
+        '\n  AND fulltime = 1', self.primary_cnxn.last_executed)
+    self.assertEqual(111, val)
+
+  def testSelectValue_NoMatches(self):
+    self.primary_cnxn.result_rows = []
+    val = self.emp_tbl.SelectValue(
+        self.cnxn, 'emp_id', fulltime=True, dept_id=[99])
+    self.assertEqual(
+        'SELECT DISTINCT emp_id FROM Employee'
+        '\nWHERE dept_id IN (99)'
+        '\n  AND fulltime = 1', self.primary_cnxn.last_executed)
+    self.assertEqual(None, val)
+
+    val = self.emp_tbl.SelectValue(
+        self.cnxn, 'emp_id', fulltime=True, dept_id=[99],
+        default=-1)
+    self.assertEqual(-1, val)
+
+  def testInsertRow(self):
+    self.primary_cnxn.rowcount = 1
+    generated_id = self.emp_tbl.InsertRow(self.cnxn, emp_id=111, fulltime=True)
+    self.assertEqual(
+        'INSERT INTO Employee (emp_id, fulltime)'
+        '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
+    self.assertEqual(([111, 1],), self.primary_cnxn.last_executed_args)
+    self.assertEqual(123, generated_id)
+
+  def testInsertRows_Empty(self):
+    generated_id = self.emp_tbl.InsertRows(
+        self.cnxn, ['emp_id', 'fulltime'], [])
+    self.assertIsNone(self.primary_cnxn.last_executed)
+    self.assertIsNone(self.primary_cnxn.last_executed_args)
+    self.assertEqual(None, generated_id)
+
+  def testInsertRows(self):
+    self.primary_cnxn.rowcount = 2
+    generated_ids = self.emp_tbl.InsertRows(
+        self.cnxn, ['emp_id', 'fulltime'], [(111, True), (222, False)])
+    self.assertEqual(
+        'INSERT INTO Employee (emp_id, fulltime)'
+        '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
+    self.assertEqual(([111, 1], [222, 0]), self.primary_cnxn.last_executed_args)
+    self.assertEqual([], generated_ids)
+
+  def testUpdate(self):
+    self.primary_cnxn.rowcount = 2
+    rowcount = self.emp_tbl.Update(
+        self.cnxn, {'fulltime': True}, emp_id=[111, 222])
+    self.assertEqual(
+        'UPDATE Employee SET fulltime=1'
+        '\nWHERE emp_id IN (111,222)', self.primary_cnxn.last_executed)
+    self.assertEqual(2, rowcount)
+
+  def testUpdate_Limit(self):
+    self.emp_tbl.Update(
+        self.cnxn, {'fulltime': True}, limit=8, emp_id=[111, 222])
+    self.assertEqual(
+        'UPDATE Employee SET fulltime=1'
+        '\nWHERE emp_id IN (111,222)'
+        '\nLIMIT 8', self.primary_cnxn.last_executed)
+
+  def testIncrementCounterValue(self):
+    self.primary_cnxn.rowcount = 1
+    self.primary_cnxn.lastrowid = 9
+    new_counter_val = self.emp_tbl.IncrementCounterValue(
+        self.cnxn, 'years_worked', emp_id=111)
+    self.assertEqual(
+        'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + 1)'
+        '\nWHERE emp_id = 111', self.primary_cnxn.last_executed)
+    self.assertEqual(9, new_counter_val)
+
+  def testDelete(self):
+    self.primary_cnxn.rowcount = 1
+    rowcount = self.emp_tbl.Delete(self.cnxn, fulltime=True)
+    self.assertEqual(
+        'DELETE FROM Employee'
+        '\nWHERE fulltime = 1', self.primary_cnxn.last_executed)
+    self.assertEqual(1, rowcount)
+
+  def testDelete_Limit(self):
+    self.emp_tbl.Delete(self.cnxn, fulltime=True, limit=3)
+    self.assertEqual(
+        'DELETE FROM Employee'
+        '\nWHERE fulltime = 1'
+        '\nLIMIT 3', self.primary_cnxn.last_executed)
+
+
+class StatementTest(unittest.TestCase):
+
+  def testMakeSelect(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee',
+        stmt_str)
+    self.assertEqual([], args)
+
+    stmt = sql.Statement.MakeSelect(
+        'Employee', ['emp_id', 'fulltime'], distinct=True)
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT DISTINCT emp_id, fulltime FROM Employee',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testMakeInsert(self):
+    stmt = sql.Statement.MakeInsert(
+        'Employee', ['emp_id', 'fulltime'], [(111, True), (222, False)])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'INSERT INTO Employee (emp_id, fulltime)'
+        '\nVALUES (%s,%s)',
+        stmt_str)
+    self.assertEqual([[111, 1], [222, 0]], args)
+
+    stmt = sql.Statement.MakeInsert(
+        'Employee', ['emp_id', 'fulltime'], [(111, False)], replace=True)
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'INSERT INTO Employee (emp_id, fulltime)'
+        '\nVALUES (%s,%s)'
+        '\nON DUPLICATE KEY UPDATE '
+        'emp_id=VALUES(emp_id), fulltime=VALUES(fulltime)',
+        stmt_str)
+    self.assertEqual([[111, 0]], args)
+
+    stmt = sql.Statement.MakeInsert(
+        'Employee', ['emp_id', 'fulltime'], [(111, False)], ignore=True)
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'INSERT IGNORE INTO Employee (emp_id, fulltime)'
+        '\nVALUES (%s,%s)',
+        stmt_str)
+    self.assertEqual([[111, 0]], args)
+
+  def testMakeInsert_InvalidString(self):
+    with self.assertRaises(exceptions.InputException):
+      sql.Statement.MakeInsert(
+          'Employee', ['emp_id', 'name'], [(111, 'First \x00 Last')])
+
+  def testMakeUpdate(self):
+    stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'UPDATE Employee SET fulltime=%s',
+        stmt_str)
+    self.assertEqual([1], args)
+
+  def testMakeUpdate_InvalidString(self):
+    with self.assertRaises(exceptions.InputException):
+      sql.Statement.MakeUpdate('Employee', {'name': 'First \x00 Last'})
+
+  def testMakeIncrement(self):
+    stmt = sql.Statement.MakeIncrement('Employee', 'years_worked')
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
+        stmt_str)
+    self.assertEqual([1], args)
+
+    stmt = sql.Statement.MakeIncrement('Employee', 'years_worked', step=5)
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
+        stmt_str)
+    self.assertEqual([5], args)
+
+  def testMakeDelete(self):
+    stmt = sql.Statement.MakeDelete('Employee')
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'DELETE FROM Employee',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testAddUseClause(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddUseClause('USE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)')
+    stmt.AddOrderByTerms([('emp_id', [])])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee'
+        '\nUSE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)'
+        '\nORDER BY emp_id',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testAddJoinClause_Empty(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddJoinClauses([])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testAddJoinClause(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddJoinClauses([('CorporateHoliday', [])])
+    stmt.AddJoinClauses(
+        [('Product ON Project.inventor_id = emp_id', [])], left=True)
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee'
+        '\n  JOIN CorporateHoliday'
+        '\n  LEFT JOIN Product ON Project.inventor_id = emp_id',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testAddGroupByTerms_Empty(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddGroupByTerms([])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testAddGroupByTerms(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddGroupByTerms(['dept_id', 'location_id'])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee'
+        '\nGROUP BY dept_id, location_id',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testAddOrderByTerms_Empty(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddOrderByTerms([])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testAddOrderByTerms(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddOrderByTerms([('dept_id', []), ('emp_id DESC', [])])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee'
+        '\nORDER BY dept_id, emp_id DESC',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testSetLimitAndOffset(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.SetLimitAndOffset(100, 0)
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee'
+        '\nLIMIT 100',
+        stmt_str)
+    self.assertEqual([], args)
+
+    stmt.SetLimitAndOffset(100, 500)
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee'
+        '\nLIMIT 100 OFFSET 500',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testAddWhereTerms_Select(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddWhereTerms([], emp_id=[111, 222])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee'
+        '\nWHERE emp_id IN (%s,%s)',
+        stmt_str)
+    self.assertEqual([111, 222], args)
+
+  def testAddWhereTerms_Update(self):
+    stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
+    stmt.AddWhereTerms([], emp_id=[111, 222])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'UPDATE Employee SET fulltime=%s'
+        '\nWHERE emp_id IN (%s,%s)',
+        stmt_str)
+    self.assertEqual([1, 111, 222], args)
+
+  def testAddWhereTerms_Delete(self):
+    stmt = sql.Statement.MakeDelete('Employee')
+    stmt.AddWhereTerms([], emp_id=[111, 222])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'DELETE FROM Employee'
+        '\nWHERE emp_id IN (%s,%s)',
+        stmt_str)
+    self.assertEqual([111, 222], args)
+
+  def testAddWhereTerms_Empty(self):
+    """Add empty terms should have no effect."""
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddWhereTerms([])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee',
+        stmt_str)
+    self.assertEqual([], args)
+
+  def testAddWhereTerms_UpdateEmptyArray(self):
+    """Add empty array should throw an exception."""
+    stmt = sql.Statement.MakeUpdate('SpamVerdict', {'user_id': 1})
+    # See https://crbug.com/monorail/6735.
+    with self.assertRaises(exceptions.InputException):
+      stmt.AddWhereTerms([], user_id=[])
+      mock_log.assert_called_once_with('Invalid update DB value %r', 'user_id')
+
+  def testAddWhereTerms_MulitpleTerms(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddWhereTerms(
+        [('emp_id %% %s = %s', [2, 0])], fulltime=True, emp_id_not=222)
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee'
+        '\nWHERE emp_id %% %s = %s'
+        '\n  AND emp_id != %s'
+        '\n  AND fulltime = %s',
+        stmt_str)
+    self.assertEqual([2, 0, 222, 1], args)
+
+  def testAddHavingTerms_NoGroupBy(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
+    self.assertRaises(AssertionError, stmt.Generate)
+
+  def testAddHavingTerms_WithGroupBy(self):
+    stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+    stmt.AddGroupByTerms(['dept_id', 'location_id'])
+    stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
+    stmt_str, args = stmt.Generate()
+    self.assertEqual(
+        'SELECT emp_id, fulltime FROM Employee'
+        '\nGROUP BY dept_id, location_id'
+        '\nHAVING COUNT(*) > %s',
+        stmt_str)
+    self.assertEqual([10], args)
+
+
+class FunctionsTest(unittest.TestCase):
+
+  def testIsValidDBValue_NonString(self):
+    self.assertTrue(sql._IsValidDBValue(12))
+    self.assertTrue(sql._IsValidDBValue(True))
+    self.assertTrue(sql._IsValidDBValue(False))
+    self.assertTrue(sql._IsValidDBValue(None))
+
+  def testIsValidDBValue_String(self):
+    self.assertTrue(sql._IsValidDBValue(''))
+    self.assertTrue(sql._IsValidDBValue('hello'))
+    self.assertTrue(sql._IsValidDBValue(u'hello'))
+    self.assertFalse(sql._IsValidDBValue('null \x00 byte'))
+
+  def testBoolsToInts_NoChanges(self):
+    self.assertEqual(['hello'], sql._BoolsToInts(['hello']))
+    self.assertEqual([['hello']], sql._BoolsToInts([['hello']]))
+    self.assertEqual([['hello']], sql._BoolsToInts([('hello',)]))
+    self.assertEqual([12], sql._BoolsToInts([12]))
+    self.assertEqual([[12]], sql._BoolsToInts([[12]]))
+    self.assertEqual([[12]], sql._BoolsToInts([(12,)]))
+    self.assertEqual(
+        [12, 13, 'hi', [99, 'yo']],
+        sql._BoolsToInts([12, 13, 'hi', [99, 'yo']]))
+
+  def testBoolsToInts_WithChanges(self):
+    self.assertEqual([1, 0], sql._BoolsToInts([True, False]))
+    self.assertEqual([[1, 0]], sql._BoolsToInts([[True, False]]))
+    self.assertEqual([[1, 0]], sql._BoolsToInts([(True, False)]))
+    self.assertEqual(
+        [12, 1, 'hi', [0, 'yo']],
+        sql._BoolsToInts([12, True, 'hi', [False, 'yo']]))
+
+  def testRandomShardID(self):
+    """A random shard ID must always be a valid shard ID."""
+    shard_id = sql.RandomShardID()
+    self.assertTrue(0 <= shard_id < settings.num_logical_shards)