Project import generated by Copybara.
GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/framework/test/__init__.py b/framework/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/framework/test/__init__.py
diff --git a/framework/test/alerts_test.py b/framework/test/alerts_test.py
new file mode 100644
index 0000000..0c398c1
--- /dev/null
+++ b/framework/test/alerts_test.py
@@ -0,0 +1,43 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for alert display helpers."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import unittest
+
+import ezt
+
+from framework import alerts
+from testing import fake
+from testing import testing_helpers
+
+
+class AlertsViewTest(unittest.TestCase):
+
+ def testTimestamp(self):
+ """Tests that alerts are only shown when the timestamp is valid."""
+ project = fake.Project(project_name='testproj')
+
+ now = int(time.time())
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/testproj/?updated=10&ts=%s' % now, project=project)
+ alerts_view = alerts.AlertsView(mr)
+ self.assertEqual(10, alerts_view.updated)
+ self.assertEqual(ezt.boolean(True), alerts_view.show)
+
+ now -= 10
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/testproj/?updated=10&ts=%s' % now, project=project)
+ alerts_view = alerts.AlertsView(mr)
+ self.assertEqual(ezt.boolean(False), alerts_view.show)
+
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/testproj/?updated=10', project=project)
+ alerts_view = alerts.AlertsView(mr)
+ self.assertEqual(ezt.boolean(False), alerts_view.show)
diff --git a/framework/test/authdata_test.py b/framework/test/authdata_test.py
new file mode 100644
index 0000000..a0e7313
--- /dev/null
+++ b/framework/test/authdata_test.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the authdata module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+
+from google.appengine.api import users
+
+from framework import authdata
+from services import service_manager
+from testing import fake
+
+
+class AuthDataTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.user_1 = self.services.user.TestAddUser('test@example.com', 111)
+
+ def testFromRequest(self):
+
+ class FakeUser(object):
+ email = lambda _: self.user_1.email
+
+ with mock.patch.object(users, 'get_current_user',
+ autospec=True) as mock_get_current_user:
+ mock_get_current_user.return_value = FakeUser()
+ auth = authdata.AuthData.FromRequest(self.cnxn, self.services)
+ self.assertEqual(auth.user_id, 111)
+
+ def testFromEmail(self):
+ auth = authdata.AuthData.FromEmail(
+ self.cnxn, self.user_1.email, self.services)
+ self.assertEqual(auth.user_id, 111)
+ self.assertEqual(auth.user_pb.email, self.user_1.email)
+
+ def testFromuserId(self):
+ auth = authdata.AuthData.FromUserID(self.cnxn, 111, self.services)
+ self.assertEqual(auth.user_id, 111)
+ self.assertEqual(auth.user_pb.email, self.user_1.email)
+
+ def testFromUser(self):
+ auth = authdata.AuthData.FromUser(self.cnxn, self.user_1, self.services)
+ self.assertEqual(auth.user_id, 111)
+ self.assertEqual(auth.user_pb.email, self.user_1.email)
diff --git a/framework/test/banned_test.py b/framework/test/banned_test.py
new file mode 100644
index 0000000..73b9f03
--- /dev/null
+++ b/framework/test/banned_test.py
@@ -0,0 +1,58 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unittests for monorail.framework.banned."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import webapp2
+
+from framework import banned
+from framework import monorailrequest
+from services import service_manager
+from testing import testing_helpers
+
+
+class BannedTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services()
+
+ def testAssertBasePermission(self):
+ servlet = banned.Banned('request', 'response', services=self.services)
+
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.auth.user_id = 0 # Anon user cannot see banned page.
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ servlet.AssertBasePermission(mr)
+ self.assertEqual(404, cm.exception.code)
+
+ mr.auth.user_id = 111 # User who is not banned cannot view banned page.
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ servlet.AssertBasePermission(mr)
+ self.assertEqual(404, cm.exception.code)
+
+ # This should not throw exception.
+ mr.auth.user_pb.banned = 'spammer'
+ servlet.AssertBasePermission(mr)
+
+ def testGatherPageData(self):
+ servlet = banned.Banned('request', 'response', services=self.services)
+ self.assertNotEqual(servlet.template, None)
+
+ _request, mr = testing_helpers.GetRequestObjects()
+ page_data = servlet.GatherPageData(mr)
+
+ self.assertFalse(page_data['is_plus_address'])
+ self.assertEqual(None, page_data['currentPageURLEncoded'])
+
+ mr.auth.user_pb.email = 'user+shadystuff@example.com'
+ page_data = servlet.GatherPageData(mr)
+
+ self.assertTrue(page_data['is_plus_address'])
+ self.assertEqual(None, page_data['currentPageURLEncoded'])
diff --git a/framework/test/cloud_tasks_helpers_test.py b/framework/test/cloud_tasks_helpers_test.py
new file mode 100644
index 0000000..09ad2cd
--- /dev/null
+++ b/framework/test/cloud_tasks_helpers_test.py
@@ -0,0 +1,88 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""Tests for the cloud tasks helper module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from google.api_core import exceptions
+
+import mock
+import unittest
+
+from framework import cloud_tasks_helpers
+import settings
+
+
+class CloudTasksHelpersTest(unittest.TestCase):
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def test_create_task(self, get_client_mock):
+
+ queue = 'somequeue'
+ task = {
+ 'app_engine_http_request':
+ {
+ 'http_method': 'GET',
+ 'relative_uri': '/some_url'
+ }
+ }
+ cloud_tasks_helpers.create_task(task, queue=queue)
+
+ get_client_mock().queue_path.assert_called_with(
+ settings.app_id, settings.CLOUD_TASKS_REGION, queue)
+ get_client_mock().create_task.assert_called_once()
+ ((_parent, called_task), _kwargs) = get_client_mock().create_task.call_args
+ self.assertEqual(called_task, task)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def test_create_task_raises(self, get_client_mock):
+ task = {'app_engine_http_request': {}}
+
+ get_client_mock().create_task.side_effect = exceptions.GoogleAPICallError(
+ 'oh no!')
+
+ with self.assertRaises(exceptions.GoogleAPICallError):
+ cloud_tasks_helpers.create_task(task)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def test_create_task_retries(self, get_client_mock):
+ task = {'app_engine_http_request': {}}
+
+ cloud_tasks_helpers.create_task(task)
+
+ (_args, kwargs) = get_client_mock().create_task.call_args
+ self.assertEqual(kwargs.get('retry'), cloud_tasks_helpers._DEFAULT_RETRY)
+
+ def test_generate_simple_task(self):
+ actual = cloud_tasks_helpers.generate_simple_task(
+ '/alphabet/letters', {
+ 'a': 'a',
+ 'b': 'b'
+ })
+ expected = {
+ 'app_engine_http_request':
+ {
+ 'relative_uri': '/alphabet/letters',
+ 'body': 'a=a&b=b',
+ 'headers': {
+ 'Content-type': 'application/x-www-form-urlencoded'
+ }
+ }
+ }
+ self.assertEqual(actual, expected)
+
+ actual = cloud_tasks_helpers.generate_simple_task('/alphabet/letters', {})
+ expected = {
+ 'app_engine_http_request':
+ {
+ 'relative_uri': '/alphabet/letters',
+ 'body': '',
+ 'headers': {
+ 'Content-type': 'application/x-www-form-urlencoded'
+ }
+ }
+ }
+ self.assertEqual(actual, expected)
diff --git a/framework/test/csv_helpers_test.py b/framework/test/csv_helpers_test.py
new file mode 100644
index 0000000..19c89c5
--- /dev/null
+++ b/framework/test/csv_helpers_test.py
@@ -0,0 +1,61 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for csv_helpers functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import csv_helpers
+
+
+class IssueListCSVFunctionsTest(unittest.TestCase):
+
+ def testRewriteColspec(self):
+ self.assertEqual('', csv_helpers.RewriteColspec(''))
+
+ self.assertEqual('a B c', csv_helpers.RewriteColspec('a B c'))
+
+ self.assertEqual('a Summary AllLabels B Opened OpenedTimestamp c',
+ csv_helpers.RewriteColspec('a summary B opened c'))
+
+ self.assertEqual('Closed ClosedTimestamp Modified ModifiedTimestamp',
+ csv_helpers.RewriteColspec('Closed Modified'))
+
+ self.assertEqual('OwnerModified OwnerModifiedTimestamp',
+ csv_helpers.RewriteColspec('OwnerModified'))
+
+ def testReformatRowsForCSV(self):
+ # TODO(jojwang): write this test
+ pass
+
+ def testEscapeCSV(self):
+ self.assertEqual('', csv_helpers.EscapeCSV(None))
+ self.assertEqual(0, csv_helpers.EscapeCSV(0))
+ self.assertEqual('', csv_helpers.EscapeCSV(''))
+ self.assertEqual('hello', csv_helpers.EscapeCSV('hello'))
+ self.assertEqual('hello', csv_helpers.EscapeCSV(' hello '))
+
+ # Double quotes are escaped as two double quotes.
+ self.assertEqual("say 'hello'", csv_helpers.EscapeCSV("say 'hello'"))
+ self.assertEqual('say ""hello""', csv_helpers.EscapeCSV('say "hello"'))
+
+ # Things that look like formulas are prefixed with a single quote because
+ # some formula functions can have side-effects. See:
+ # https://www.contextis.com/resources/blog/comma-separated-vulnerabilities/
+ self.assertEqual("'=2+2", csv_helpers.EscapeCSV('=2+2'))
+ self.assertEqual("'=CMD| del *.*", csv_helpers.EscapeCSV('=CMD| del *.*'))
+
+ # Some spreadsheets apparently allow formula cells that start with
+ # plus, minus, and at-signs.
+ self.assertEqual("'+2+2", csv_helpers.EscapeCSV('+2+2'))
+ self.assertEqual("'-2+2", csv_helpers.EscapeCSV('-2+2'))
+ self.assertEqual("'@2+2", csv_helpers.EscapeCSV('@2+2'))
+
+ self.assertEqual(
+ u'division\xc3\xb7sign',
+ csv_helpers.EscapeCSV(u'division\xc3\xb7sign'))
diff --git a/framework/test/deleteusers_test.py b/framework/test/deleteusers_test.py
new file mode 100644
index 0000000..4cadbbd
--- /dev/null
+++ b/framework/test/deleteusers_test.py
@@ -0,0 +1,214 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for deleteusers classes."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import mock
+import unittest
+import urllib
+
+from framework import cloud_tasks_helpers
+from framework import deleteusers
+from framework import framework_constants
+from framework import urls
+from services import service_manager
+from testing import fake
+from testing import testing_helpers
+
+class TestWipeoutSyncCron(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services(user=fake.UserService())
+ self.task = deleteusers.WipeoutSyncCron(
+ request=None, response=None, services=self.services)
+ self.user_1 = self.services.user.TestAddUser('user1@example.com', 111)
+ self.user_2 = self.services.user.TestAddUser('user2@example.com', 222)
+ self.user_3 = self.services.user.TestAddUser('user3@example.com', 333)
+
+ def generate_simple_task(self, url, body):
+ return {
+ 'app_engine_http_request':
+ {
+ 'relative_uri': url,
+ 'body': body,
+ 'headers': {
+ 'Content-type': 'application/x-www-form-urlencoded'
+ }
+ }
+ }
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest(
+ path='url/url?batchsize=2',
+ services=self.services)
+ self.task.HandleRequest(mr)
+
+ self.assertEqual(get_client_mock().create_task.call_count, 3)
+
+ expected_task = self.generate_simple_task(
+ urls.SEND_WIPEOUT_USER_LISTS_TASK + '.do', 'limit=2&offset=0')
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ expected_task = self.generate_simple_task(
+ urls.SEND_WIPEOUT_USER_LISTS_TASK + '.do', 'limit=2&offset=2')
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ expected_task = self.generate_simple_task(
+ urls.DELETE_WIPEOUT_USERS_TASK + '.do', '')
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest_NoBatchSizeParam(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest(services=self.services)
+ self.task.HandleRequest(mr)
+
+ expected_task = self.generate_simple_task(
+ urls.SEND_WIPEOUT_USER_LISTS_TASK + '.do',
+ 'limit={}&offset=0'.format(deleteusers.MAX_BATCH_SIZE))
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest_NoUsers(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest()
+ self.services.user.users_by_id = {}
+ self.task.HandleRequest(mr)
+
+ calls = get_client_mock().create_task.call_args_list
+ self.assertEqual(len(calls), 0)
+
+
+class SendWipeoutUserListsTaskTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services(user=fake.UserService())
+ self.task = deleteusers.SendWipeoutUserListsTask(
+ request=None, response=None, services=self.services)
+ self.task.sendUserLists = mock.Mock()
+ deleteusers.authorize = mock.Mock(return_value='service')
+ self.user_1 = self.services.user.TestAddUser('user1@example.com', 111)
+ self.user_2 = self.services.user.TestAddUser('user2@example.com', 222)
+ self.user_3 = self.services.user.TestAddUser('user3@example.com', 333)
+
+ def testHandleRequest_NoBatchSizeParam(self):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url?limit=2&offset=1')
+ self.task.HandleRequest(mr)
+ deleteusers.authorize.assert_called_once_with()
+ self.task.sendUserLists.assert_called_once_with(
+ 'service', [
+ {'id': self.user_2.email},
+ {'id': self.user_3.email}])
+
+ def testHandleRequest_NoLimit(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ self.services.user.users_by_id = {}
+ with self.assertRaisesRegexp(AssertionError, 'Missing param limit'):
+ self.task.HandleRequest(mr)
+
+ def testHandleRequest_NoOffset(self):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url?limit=3')
+ self.services.user.users_by_id = {}
+ with self.assertRaisesRegexp(AssertionError, 'Missing param offset'):
+ self.task.HandleRequest(mr)
+
+ def testHandleRequest_ZeroOffset(self):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url?limit=2&offset=0')
+ self.task.HandleRequest(mr)
+ self.task.sendUserLists.assert_called_once_with(
+ 'service', [
+ {'id': self.user_1.email},
+ {'id': self.user_2.email}])
+
+
+class DeleteWipeoutUsersTaskTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services()
+ deleteusers.authorize = mock.Mock(return_value='service')
+ self.task = deleteusers.DeleteWipeoutUsersTask(
+ request=None, response=None, services=self.services)
+ deleted_users = [
+ {'id': 'user1@gmail.com'}, {'id': 'user2@gmail.com'},
+ {'id': 'user3@gmail.com'}, {'id': 'user4@gmail.com'}]
+ self.task.fetchDeletedUsers = mock.Mock(return_value=deleted_users)
+
+ def generate_simple_task(self, url, body):
+ return {
+ 'app_engine_http_request':
+ {
+ 'relative_uri': url,
+ 'body': body,
+ 'headers': {
+ 'Content-type': 'application/x-www-form-urlencoded'
+ }
+ }
+ }
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url?limit=3')
+ self.task.HandleRequest(mr)
+
+ deleteusers.authorize.assert_called_once_with()
+ self.task.fetchDeletedUsers.assert_called_once_with('service')
+ ((_app_id, _region, queue),
+ _kwargs) = get_client_mock().queue_path.call_args
+ self.assertEqual(queue, framework_constants.QUEUE_DELETE_USERS)
+
+ self.assertEqual(get_client_mock().create_task.call_count, 2)
+
+ query = urllib.urlencode(
+ {'emails': 'user1@gmail.com,user2@gmail.com,user3@gmail.com'})
+ expected_task = self.generate_simple_task(
+ urls.DELETE_USERS_TASK + '.do', query)
+
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ query = urllib.urlencode({'emails': 'user4@gmail.com'})
+ expected_task = self.generate_simple_task(
+ urls.DELETE_USERS_TASK + '.do', query)
+
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest_DefaultMax(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url')
+ self.task.HandleRequest(mr)
+
+ deleteusers.authorize.assert_called_once_with()
+ self.task.fetchDeletedUsers.assert_called_once_with('service')
+ self.assertEqual(get_client_mock().create_task.call_count, 1)
+
+ emails = 'user1@gmail.com,user2@gmail.com,user3@gmail.com,user4@gmail.com'
+ query = urllib.urlencode({'emails': emails})
+ expected_task = self.generate_simple_task(
+ urls.DELETE_USERS_TASK + '.do', query)
+
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
diff --git a/framework/test/emailfmt_test.py b/framework/test/emailfmt_test.py
new file mode 100644
index 0000000..dd7cca3
--- /dev/null
+++ b/framework/test/emailfmt_test.py
@@ -0,0 +1,821 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for monorail.framework.emailfmt."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+
+from google.appengine.ext import testbed
+
+import settings
+from framework import emailfmt
+from framework import framework_views
+from proto import project_pb2
+from testing import testing_helpers
+
+from google.appengine.api import apiproxy_stub_map
+
+
+class EmailFmtTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testValidateReferencesHeader(self):
+ project = project_pb2.Project()
+ project.project_name = 'open-open'
+ subject = 'slipped disk'
+ expected = emailfmt.MakeMessageID(
+ 'jrobbins@gmail.com', subject,
+ '%s@%s' % (project.project_name, emailfmt.MailDomain()))
+ self.assertTrue(
+ emailfmt.ValidateReferencesHeader(
+ expected, project, 'jrobbins@gmail.com', subject))
+
+ self.assertFalse(
+ emailfmt.ValidateReferencesHeader(
+ expected, project, 'jrobbins@gmail.com', 'something else'))
+
+ self.assertFalse(
+ emailfmt.ValidateReferencesHeader(
+ expected, project, 'someoneelse@gmail.com', subject))
+
+ project.project_name = 'other-project'
+ self.assertFalse(
+ emailfmt.ValidateReferencesHeader(
+ expected, project, 'jrobbins@gmail.com', subject))
+
+ def testParseEmailMessage(self):
+ msg = testing_helpers.MakeMessage(testing_helpers.HEADER_LINES, 'awesome!')
+
+ (from_addr, to_addrs, cc_addrs, references, incident_id,
+ subject, body) = emailfmt.ParseEmailMessage(msg)
+
+ self.assertEqual('user@example.com', from_addr)
+ self.assertEqual(['proj@monorail.example.com'], to_addrs)
+ self.assertEqual(['ningerso@chromium.org'], cc_addrs)
+ # Expected msg-id was generated from a previous known-good test run.
+ self.assertEqual(['<0=969704940193871313=13442892928193434663='
+ 'proj@monorail.example.com>'],
+ references)
+ self.assertEqual('', incident_id)
+ self.assertEqual('Issue 123 in proj: broken link', subject)
+ self.assertEqual('awesome!', body)
+
+ references_header = ('References', '<1234@foo.com> <5678@bar.com>')
+ msg = testing_helpers.MakeMessage(
+ testing_helpers.HEADER_LINES + [references_header], 'awesome!')
+ (from_addr, to_addrs, cc_addrs, references, incident_id, subject,
+ body) = emailfmt.ParseEmailMessage(msg)
+ self.assertItemsEqual(
+ ['<5678@bar.com>',
+ '<0=969704940193871313=13442892928193434663='
+ 'proj@monorail.example.com>',
+ '<1234@foo.com>'],
+ references)
+
+ def testParseEmailMessage_Bulk(self):
+ for precedence in ['Bulk', 'Junk']:
+ msg = testing_helpers.MakeMessage(
+ testing_helpers.HEADER_LINES + [('Precedence', precedence)],
+ 'I am on vacation!')
+
+ (from_addr, to_addrs, cc_addrs, references, incident_id, subject,
+ body) = emailfmt.ParseEmailMessage(msg)
+
+ self.assertEqual('', from_addr)
+ self.assertEqual([], to_addrs)
+ self.assertEqual([], cc_addrs)
+ self.assertEqual('', references)
+ self.assertEqual('', incident_id)
+ self.assertEqual('', subject)
+ self.assertEqual('', body)
+
+ def testExtractAddrs(self):
+ header_val = ''
+ self.assertEqual(
+ [], emailfmt._ExtractAddrs(header_val))
+
+ header_val = 'J. Robbins <a@b.com>, c@d.com,\n Nick "Name" Dude <e@f.com>'
+ self.assertEqual(
+ ['a@b.com', 'c@d.com', 'e@f.com'],
+ emailfmt._ExtractAddrs(header_val))
+
+ header_val = ('hot: J. O\'Robbins <a@b.com>; '
+ 'cool: "friendly" <e.g-h@i-j.k-L.com>')
+ self.assertEqual(
+ ['a@b.com', 'e.g-h@i-j.k-L.com'],
+ emailfmt._ExtractAddrs(header_val))
+
+ def CheckIdentifiedValues(
+ self, project_addr, subject, expected_project_name, expected_local_id,
+ expected_verb=None, expected_label=None):
+ """Testing helper function to check 3 results against expected values."""
+ project_name, verb, label = emailfmt.IdentifyProjectVerbAndLabel(
+ project_addr)
+ local_id = emailfmt.IdentifyIssue(project_name, subject)
+ self.assertEqual(expected_project_name, project_name)
+ self.assertEqual(expected_local_id, local_id)
+ self.assertEqual(expected_verb, verb)
+ self.assertEqual(expected_label, label)
+
+ def testIdentifyProjectAndIssues_Normal(self):
+ """Parse normal issue notification subject lines."""
+ self.CheckIdentifiedValues(
+ 'proj@monorail.example.com',
+ 'Issue 123 in proj: the dogs wont eat the dogfood',
+ 'proj', 123)
+
+ self.CheckIdentifiedValues(
+ 'Proj@MonoRail.Example.Com',
+ 'Issue 123 in proj: the dogs wont eat the dogfood',
+ 'proj', 123)
+
+ self.CheckIdentifiedValues(
+ 'proj-4-u@test-example3.com',
+ 'Issue 123 in proj-4-u: this one goes to: 11',
+ 'proj-4-u', 123)
+
+ self.CheckIdentifiedValues(
+ 'night@monorail.example.com',
+ 'Issue 451 in day: something is fishy',
+ 'night', None)
+
+ def testIdentifyProjectAndIssues_Compact(self):
+ """Parse compact subject lines."""
+ self.CheckIdentifiedValues(
+ 'proj@monorail.example.com',
+ 'proj:123: the dogs wont eat the dogfood',
+ 'proj', 123)
+
+ self.CheckIdentifiedValues(
+ 'Proj@MonoRail.Example.Com',
+ 'proj:123: the dogs wont eat the dogfood',
+ 'proj', 123)
+
+ self.CheckIdentifiedValues(
+ 'proj-4-u@test-example3.com',
+ 'proj-4-u:123: this one goes to: 11',
+ 'proj-4-u', 123)
+
+ self.CheckIdentifiedValues(
+ 'night@monorail.example.com',
+ 'day:451: something is fishy',
+ 'night', None)
+
+ def testIdentifyProjectAndIssues_NotAMatch(self):
+ """These subject lines do not match the ones we send."""
+ self.CheckIdentifiedValues(
+ 'no_reply@chromium.org',
+ 'Issue 234 in project foo: ignore this one',
+ None, None)
+
+ self.CheckIdentifiedValues(
+ 'no_reply@chromium.org',
+ 'foo-234: ignore this one',
+ None, None)
+
+ def testStripSubjectPrefixes(self):
+ self.assertEqual(
+ '',
+ emailfmt._StripSubjectPrefixes(''))
+
+ self.assertEqual(
+ 'this is it',
+ emailfmt._StripSubjectPrefixes('this is it'))
+
+ self.assertEqual(
+ 'this is it',
+ emailfmt._StripSubjectPrefixes('re: this is it'))
+
+ self.assertEqual(
+ 'this is it',
+ emailfmt._StripSubjectPrefixes('Re: Fwd: aw:this is it'))
+
+ self.assertEqual(
+ 'This - . IS it',
+ emailfmt._StripSubjectPrefixes('This - . IS it'))
+
+
+class MailDomainTest(unittest.TestCase):
+
+ def testTrivialCases(self):
+ self.assertEqual(
+ 'testbed-test.appspotmail.com',
+ emailfmt.MailDomain())
+
+
+class NoReplyAddressTest(unittest.TestCase):
+
+ def testNoCommenter(self):
+ self.assertEqual(
+ 'no_reply@testbed-test.appspotmail.com',
+ emailfmt.NoReplyAddress())
+
+ def testWithCommenter(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ 'user via monorail '
+ '<no_reply+v2.111@testbed-test.appspotmail.com>',
+ emailfmt.NoReplyAddress(
+ commenter_view=commenter_view, reveal_addr=True))
+
+ def testObscuredCommenter(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ u'u\u2026 via monorail '
+ '<no_reply+v2.111@testbed-test.appspotmail.com>',
+ emailfmt.NoReplyAddress(
+ commenter_view=commenter_view, reveal_addr=False))
+
+
+class FormatFromAddrTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project = project_pb2.Project(project_name='monorail')
+ self.old_send_email_as_format = settings.send_email_as_format
+ settings.send_email_as_format = 'monorail@%(domain)s'
+ self.old_send_noreply_email_as_format = (
+ settings.send_noreply_email_as_format)
+ settings.send_noreply_email_as_format = 'monorail+noreply@%(domain)s'
+
+ def tearDown(self):
+ self.old_send_email_as_format = settings.send_email_as_format
+ self.old_send_noreply_email_as_format = (
+ settings.send_noreply_email_as_format)
+
+ def testNoCommenter(self):
+ self.assertEqual('monorail@chromium.org',
+ emailfmt.FormatFromAddr(self.project))
+
+ @mock.patch('settings.branded_domains',
+ {'monorail': 'bugs.branded.com', '*': 'bugs.chromium.org'})
+ def testNoCommenter_Branded(self):
+ self.assertEqual('monorail@branded.com',
+ emailfmt.FormatFromAddr(self.project))
+
+ def testNoCommenterWithNoReply(self):
+ self.assertEqual('monorail+noreply@chromium.org',
+ emailfmt.FormatFromAddr(self.project, can_reply_to=False))
+
+ @mock.patch('settings.branded_domains',
+ {'monorail': 'bugs.branded.com', '*': 'bugs.chromium.org'})
+ def testNoCommenterWithNoReply_Branded(self):
+ self.assertEqual('monorail+noreply@branded.com',
+ emailfmt.FormatFromAddr(self.project, can_reply_to=False))
+
+ def testWithCommenter(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ u'user via monorail <monorail+v2.111@chromium.org>',
+ emailfmt.FormatFromAddr(
+ self.project, commenter_view=commenter_view, reveal_addr=True))
+
+ @mock.patch('settings.branded_domains',
+ {'monorail': 'bugs.branded.com', '*': 'bugs.chromium.org'})
+ def testWithCommenter_Branded(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ u'user via monorail <monorail+v2.111@branded.com>',
+ emailfmt.FormatFromAddr(
+ self.project, commenter_view=commenter_view, reveal_addr=True))
+
+ def testObscuredCommenter(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ u'u\u2026 via monorail <monorail+v2.111@chromium.org>',
+ emailfmt.FormatFromAddr(
+ self.project, commenter_view=commenter_view, reveal_addr=False))
+
+ def testServiceAccountCommenter(self):
+ johndoe_bot = '123456789@developer.gserviceaccount.com'
+ commenter_view = framework_views.StuffUserView(
+ 111, johndoe_bot, True)
+ self.assertEqual(
+ ('johndoe via monorail <monorail+v2.111@chromium.org>'),
+ emailfmt.FormatFromAddr(
+ self.project, commenter_view=commenter_view, reveal_addr=False))
+
+
+class NormalizeHeaderWhitespaceTest(unittest.TestCase):
+
+ def testTrivialCases(self):
+ self.assertEqual(
+ '',
+ emailfmt.NormalizeHeader(''))
+
+ self.assertEqual(
+ '',
+ emailfmt.NormalizeHeader(' \t\n'))
+
+ self.assertEqual(
+ 'a',
+ emailfmt.NormalizeHeader('a'))
+
+ self.assertEqual(
+ 'a b',
+ emailfmt.NormalizeHeader(' a b '))
+
+ def testLongSummary(self):
+ big_string = 'x' * 500
+ self.assertEqual(
+ big_string[:emailfmt.MAX_HEADER_CHARS_CONSIDERED],
+ emailfmt.NormalizeHeader(big_string))
+
+ big_string = 'x y ' * 500
+ self.assertEqual(
+ big_string[:emailfmt.MAX_HEADER_CHARS_CONSIDERED],
+ emailfmt.NormalizeHeader(big_string))
+
+ big_string = 'x ' * 100
+ self.assertEqual(
+ 'x ' * 99 + 'x',
+ emailfmt.NormalizeHeader(big_string))
+
+ def testNormalCase(self):
+ self.assertEqual(
+ '[a] b: c d',
+ emailfmt.NormalizeHeader('[a] b:\tc\n\td'))
+
+
+class MakeMessageIDTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testMakeMessageIDTest(self):
+ message_id = emailfmt.MakeMessageID(
+ 'to@to.com', 'subject', 'from@from.com')
+ self.assertTrue(message_id.startswith('<0='))
+ self.assertEqual('testbed-test.appspotmail.com>',
+ message_id.split('@')[-1])
+
+ settings.mail_domain = None
+ message_id = emailfmt.MakeMessageID(
+ 'to@to.com', 'subject', 'from@from.com')
+ self.assertTrue(message_id.startswith('<0='))
+ self.assertEqual('testbed-test.appspotmail.com>',
+ message_id.split('@')[-1])
+
+ message_id = emailfmt.MakeMessageID(
+ 'to@to.com', 'subject', 'from@from.com')
+ self.assertTrue(message_id.startswith('<0='))
+ self.assertEqual('testbed-test.appspotmail.com>',
+ message_id.split('@')[-1])
+
+ message_id_ws_1 = emailfmt.MakeMessageID(
+ 'to@to.com',
+ 'this is a very long subject that is sure to be wordwrapped by gmail',
+ 'from@from.com')
+ message_id_ws_2 = emailfmt.MakeMessageID(
+ 'to@to.com',
+ 'this is a very long subject that \n\tis sure to be '
+ 'wordwrapped \t\tby gmail',
+ 'from@from.com')
+ self.assertEqual(message_id_ws_1, message_id_ws_2)
+
+
+class GetReferencesTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testNotPartOfThread(self):
+ refs = emailfmt.GetReferences(
+ 'a@a.com', 'hi', None, emailfmt.NoReplyAddress())
+ self.assertEqual(0, len(refs))
+
+ def testAnywhereInThread(self):
+ refs = emailfmt.GetReferences(
+ 'a@a.com', 'hi', 0, emailfmt.NoReplyAddress())
+ self.assertTrue(len(refs))
+ self.assertTrue(refs.startswith('<0='))
+
+
+class StripQuotedTextTest(unittest.TestCase):
+
+ def CheckExpected(self, expected_output, test_input):
+ actual_output = emailfmt.StripQuotedText(test_input)
+ self.assertEqual(expected_output, actual_output)
+
+ def testAllNewText(self):
+ self.CheckExpected('', '')
+ self.CheckExpected('', '\n')
+ self.CheckExpected('', '\n\n')
+ self.CheckExpected('new', 'new')
+ self.CheckExpected('new', '\nnew\n')
+ self.CheckExpected('new\ntext', '\nnew\ntext\n')
+ self.CheckExpected('new\n\ntext', '\nnew\n\ntext\n')
+
+ def testQuotedLines(self):
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ '> something you said\n'
+ '> that took two lines'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ '> something you said\n'
+ '> that took two lines'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('> something you said\n'
+ '> that took two lines\n'
+ 'new\n'
+ 'text\n'
+ '\n'))
+
+ self.CheckExpected(
+ ('newtext'),
+ ('> something you said\n'
+ '> that took two lines\n'
+ 'newtext'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, So-and-so <so@and-so.com> Wrote:\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, So-and-so <so@and-so.com> Wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, user@example.com via Monorail\n'
+ '<monorail@chromium.com> Wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Jan 14, 2016 6:19 AM, "user@example.com via Monorail" <\n'
+ 'monorail@chromium.com> Wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Jan 14, 2016 6:19 AM, "user@example.com via Monorail" <\n'
+ 'monorail@monorail-prod.appspotmail.com> wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, So-and-so so@and-so.com wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Wed, Sep 8, 2010 at 6:56 PM, So =AND= <so@gmail.com>wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, So-and-so <so@and-so.com> Wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'project-name@testbed-test.appspotmail.com wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'project-name@testbed-test.appspotmail.com a \xc3\xa9crit :\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'project.domain.com@testbed-test.appspotmail.com a \xc3\xa9crit :\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ '2023/01/4 <so@and-so.com>\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ '2023/01/4 <so-and@so.com>\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ def testBoundaryLines(self):
+
+ self.CheckExpected(
+ ('new'),
+ ('new\n'
+ '---- forwarded message ======\n'
+ '\n'
+ 'something you said\n'
+ '> in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new'),
+ ('new\n'
+ '-----Original Message-----\n'
+ '\n'
+ 'something you said\n'
+ '> in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new'),
+ ('new\n'
+ '\n'
+ 'Updates:\n'
+ '\tStatus: Fixed\n'
+ '\n'
+ 'notification text\n'))
+
+ self.CheckExpected(
+ ('new'),
+ ('new\n'
+ '\n'
+ 'Comment #1 on issue 9 by username: Is there ...'
+ 'notification text\n'))
+
+ def testSignatures(self):
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '-- \n'
+ 'Name\n'
+ 'phone\n'
+ 'funny quote, or legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '--\n'
+ 'Name\n'
+ 'phone\n'
+ 'funny quote, or legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '--\n'
+ 'Name\n'
+ 'ginormous signature\n'
+ 'phone\n'
+ 'address\n'
+ 'address\n'
+ 'address\n'
+ 'homepage\n'
+ 'social network A\n'
+ 'social network B\n'
+ 'social network C\n'
+ 'funny quote\n'
+ '4 lines about why email should be short\n'
+ 'legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '_______________\n'
+ 'Name\n'
+ 'phone\n'
+ 'funny quote, or legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Thanks,\n'
+ 'Name\n'
+ '\n'
+ '_______________\n'
+ 'Name\n'
+ 'phone\n'
+ 'funny quote, or legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Thanks,\n'
+ 'Name'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Cheers,\n'
+ 'Name'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Regards\n'
+ 'Name'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'best regards'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'THX'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Thank you,\n'
+ 'Name'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Sent from my iPhone'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Sent from my iPod'))
diff --git a/framework/test/exceptions_test.py b/framework/test/exceptions_test.py
new file mode 100644
index 0000000..8fe2295
--- /dev/null
+++ b/framework/test/exceptions_test.py
@@ -0,0 +1,64 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+"""Unittest for the exceptions module."""
+
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import exceptions
+from framework import permissions
+
+
+class ErrorsManagerTest(unittest.TestCase):
+
+ def testRaiseIfErrors_Errors(self):
+ """We raise the given exception if there are errors."""
+ err_aggregator = exceptions.ErrorAggregator(exceptions.InputException)
+
+ err_aggregator.AddErrorMessage('The chickens are missing.')
+ err_aggregator.AddErrorMessage('The foxes are free.')
+ with self.assertRaisesRegexp(
+ exceptions.InputException,
+ 'The chickens are missing.\nThe foxes are free.'):
+ err_aggregator.RaiseIfErrors()
+
+ def testErrorsManager_NoErrors(self):
+ """ We don't raise exceptions if there are not errors. """
+ err_aggregator = exceptions.ErrorAggregator(exceptions.InputException)
+ err_aggregator.RaiseIfErrors()
+
+ def testWithinContext_ExceptionPassedIn(self):
+ """We do not suppress exceptions raised within wrapped code."""
+
+ with self.assertRaisesRegexp(exceptions.InputException,
+ 'We should raise this'):
+ with exceptions.ErrorAggregator(exceptions.InputException) as errors:
+ errors.AddErrorMessage('We should ignore this error.')
+ raise exceptions.InputException('We should raise this')
+
+ def testWithinContext_NoExceptionPassedIn(self):
+ """We raise an exception for any errors if no exceptions are passed in."""
+ with self.assertRaisesRegexp(exceptions.InputException,
+ 'We can raise this now.'):
+ with exceptions.ErrorAggregator(exceptions.InputException) as errors:
+ errors.AddErrorMessage('We can raise this now.')
+ return True
+
+ def testAddErrorMessage(self):
+ """We properly handle string formatting when needed."""
+ err_aggregator = exceptions.ErrorAggregator(exceptions.InputException)
+ err_aggregator.AddErrorMessage('No args')
+ err_aggregator.AddErrorMessage('No args2', 'unused', unused2=1)
+ err_aggregator.AddErrorMessage('{}', 'One arg')
+ err_aggregator.AddErrorMessage('{}, {two}', '1', two='2')
+
+ # Verify exceptions formatting a message don't clear the earlier messages.
+ with self.assertRaises(IndexError):
+ err_aggregator.AddErrorMessage('{}')
+
+ expected = ['No args', 'No args2', 'One arg', '1, 2']
+ self.assertEqual(err_aggregator.error_messages, expected)
diff --git a/framework/test/filecontent_test.py b/framework/test/filecontent_test.py
new file mode 100644
index 0000000..4843b47
--- /dev/null
+++ b/framework/test/filecontent_test.py
@@ -0,0 +1,188 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the filecontent module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import filecontent
+
+
+class MimeTest(unittest.TestCase):
+ """Test methods for the mime module."""
+
+ _TEST_EXTENSIONS_TO_CTYPES = {
+ 'html': 'text/plain',
+ 'htm': 'text/plain',
+ 'jpg': 'image/jpeg',
+ 'jpeg': 'image/jpeg',
+ 'pdf': 'application/pdf',
+ }
+
+ _CODE_EXTENSIONS = [
+ 'py', 'java', 'mf', 'bat', 'sh', 'php', 'vb', 'pl', 'sql',
+ 'patch', 'diff',
+ ]
+
+ def testCommonExtensions(self):
+ """Tests some common extensions for their expected content types."""
+ for ext, ctype in self._TEST_EXTENSIONS_TO_CTYPES.items():
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('file.%s' % ext),
+ ctype)
+
+ def testCaseDoesNotMatter(self):
+ """Ensure that case (upper/lower) of extension does not matter."""
+ for ext, ctype in self._TEST_EXTENSIONS_TO_CTYPES.items():
+ ext = ext.upper()
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('file.%s' % ext),
+ ctype)
+
+ for ext in self._CODE_EXTENSIONS:
+ ext = ext.upper()
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('code.%s' % ext),
+ 'text/plain')
+
+ def testCodeIsText(self):
+ """Ensure that code extensions are text/plain."""
+ for ext in self._CODE_EXTENSIONS:
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('code.%s' % ext),
+ 'text/plain')
+
+ def testNoExtensionIsText(self):
+ """Ensure that no extension indicates text/plain."""
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('noextension'),
+ 'text/plain')
+
+ def testUnknownExtension(self):
+ """Ensure that an obviously unknown extension returns is binary."""
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('f.madeupextension'),
+ 'application/octet-stream')
+
+ def testNoShockwaveFlash(self):
+ """Ensure that Shockwave files will NOT be served w/ that content type."""
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('bad.swf'),
+ 'application/octet-stream')
+
+
+class DecodeFileContentsTest(unittest.TestCase):
+
+ def IsBinary(self, contents):
+ _contents, is_binary, _is_long = (
+ filecontent.DecodeFileContents(contents))
+ return is_binary
+
+ def testFileIsBinaryEmpty(self):
+ self.assertFalse(self.IsBinary(''))
+
+ def testFileIsBinaryShortText(self):
+ self.assertFalse(self.IsBinary('This is some plain text.'))
+
+ def testLineLengthDetection(self):
+ unicode_str = (
+ u'Some non-ascii chars - '
+ u'\xa2\xfa\xb6\xe7\xfc\xea\xd0\xf4\xe6\xf0\xce\xf6\xbe')
+ short_line = unicode_str.encode('iso-8859-1')
+ long_line = (unicode_str * 100)[:filecontent._MAX_SOURCE_LINE_LEN_LOWER+1]
+ long_line = long_line.encode('iso-8859-1')
+
+ lines = [short_line] * 100
+ lines.append(long_line)
+
+ # High lower ratio - text
+ self.assertFalse(self.IsBinary('\n'.join(lines)))
+
+ lines.extend([long_line] * 99)
+
+ # 50/50 lower/upper ratio - binary
+ self.assertTrue(self.IsBinary('\n'.join(lines)))
+
+ # Single line too long - binary
+ lines = [short_line] * 100
+ lines.append(short_line * 100) # Very long line
+ self.assertTrue(self.IsBinary('\n'.join(lines)))
+
+ def testFileIsBinaryLongText(self):
+ self.assertFalse(self.IsBinary('This is plain text. \n' * 100))
+ # long utf-8 lines are OK
+ self.assertFalse(self.IsBinary('This one long line. ' * 100))
+
+ def testFileIsBinaryLongBinary(self):
+ bin_string = ''.join([chr(c) for c in range(122, 252)])
+ self.assertTrue(self.IsBinary(bin_string * 100))
+
+ def testFileIsTextByPath(self):
+ bin_string = ''.join([chr(c) for c in range(122, 252)] * 100)
+ unicode_str = (
+ u'Some non-ascii chars - '
+ u'\xa2\xfa\xb6\xe7\xfc\xea\xd0\xf4\xe6\xf0\xce\xf6\xbe')
+ long_line = (unicode_str * 100)[:filecontent._MAX_SOURCE_LINE_LEN_LOWER+1]
+ long_line = long_line.encode('iso-8859-1')
+
+ for contents in [bin_string, long_line]:
+ self.assertTrue(filecontent.DecodeFileContents(contents, path=None)[1])
+ self.assertTrue(filecontent.DecodeFileContents(contents, path='')[1])
+ self.assertTrue(filecontent.DecodeFileContents(contents, path='foo')[1])
+ self.assertTrue(
+ filecontent.DecodeFileContents(contents, path='foo.bin')[1])
+ self.assertTrue(
+ filecontent.DecodeFileContents(contents, path='foo.zzz')[1])
+ for path in ['a/b/Makefile.in', 'README', 'a/file.js', 'b.txt']:
+ self.assertFalse(
+ filecontent.DecodeFileContents(contents, path=path)[1])
+
+ def testFileIsBinaryByCommonExtensions(self):
+ contents = 'this is not examined'
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='junk.zip')[1])
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='JUNK.ZIP')[1])
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='/build/HelloWorld.o')[1])
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='/build/Hello.class')[1])
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='/trunk/libs.old/swing.jar')[1])
+
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='HelloWorld.cc')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='Hello.java')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='README')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='READ.ME')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='README.txt')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='README.TXT')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='/trunk/src/com/monorail/Hello.java')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='/branches/1.2/resource.el')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='/wiki/PageName.wiki')[1])
+
+ def testUnreasonablyLongFile(self):
+ contents = '\n' * (filecontent.SOURCE_FILE_MAX_LINES + 2)
+ _contents, is_binary, is_long = filecontent.DecodeFileContents(
+ contents)
+ self.assertFalse(is_binary)
+ self.assertTrue(is_long)
+
+ contents = '\n' * 100
+ _contents, is_binary, is_long = filecontent.DecodeFileContents(
+ contents)
+ self.assertFalse(is_binary)
+ self.assertFalse(is_long)
diff --git a/framework/test/framework_bizobj_test.py b/framework/test/framework_bizobj_test.py
new file mode 100644
index 0000000..131ebb5
--- /dev/null
+++ b/framework/test/framework_bizobj_test.py
@@ -0,0 +1,696 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for monorail.framework.framework_bizobj."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+import mock
+
+import settings
+from framework import authdata
+from framework import framework_bizobj
+from framework import framework_constants
+from proto import project_pb2
+from proto import tracker_pb2
+from proto import user_pb2
+from services import service_manager
+from services import client_config_svc
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+
+
+class CreateUserDisplayNamesAndEmailsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+
+ self.user_1 = self.services.user.TestAddUser(
+ 'user_1@test.com', 111, obscure_email=True)
+ self.user_2 = self.services.user.TestAddUser(
+ 'user_2@test.com', 222, obscure_email=False)
+ self.user_3 = self.services.user.TestAddUser(
+ 'user_3@test.com', 333, obscure_email=True)
+ self.user_4 = self.services.user.TestAddUser(
+ 'user_4@test.com', 444, obscure_email=False)
+ self.service_account = self.services.user.TestAddUser(
+ 'service@account.com', 999, obscure_email=True)
+ self.user_deleted = self.services.user.TestAddUser(
+ '', framework_constants.DELETED_USER_ID)
+ self.requester = self.services.user.TestAddUser('user_5@test.com', 555)
+ self.user_auth = authdata.AuthData(
+ user_id=self.requester.user_id, email=self.requester.email)
+ self.project = self.services.project.TestAddProject(
+ 'proj',
+ project_id=789,
+ owner_ids=[self.user_1.user_id],
+ committer_ids=[self.user_2.user_id, self.service_account.user_id])
+
+ @mock.patch('services.client_config_svc.GetServiceAccountMap')
+ def testUserCreateDisplayNamesAndEmails_NonProjectMembers(
+ self, fake_account_map):
+ fake_account_map.return_value = {'service@account.com': 'Service'}
+ users = [self.user_1, self.user_2, self.user_3, self.user_4,
+ self.service_account, self.user_deleted]
+ (display_names_by_id,
+ display_emails_by_id) = framework_bizobj.CreateUserDisplayNamesAndEmails(
+ self.cnxn, self.services, self.user_auth, users)
+ expected_display_names = {
+ self.user_1.user_id: testing_helpers.ObscuredEmail(self.user_1.email),
+ self.user_2.user_id: self.user_2.email,
+ self.user_3.user_id: testing_helpers.ObscuredEmail(self.user_3.email),
+ self.user_4.user_id: self.user_4.email,
+ self.service_account.user_id: 'Service',
+ self.user_deleted.user_id: framework_constants.DELETED_USER_NAME}
+ expected_display_emails = {
+ self.user_1.user_id:
+ testing_helpers.ObscuredEmail(self.user_1.email),
+ self.user_2.user_id:
+ self.user_2.email,
+ self.user_3.user_id:
+ testing_helpers.ObscuredEmail(self.user_3.email),
+ self.user_4.user_id:
+ self.user_4.email,
+ self.service_account.user_id:
+ testing_helpers.ObscuredEmail(self.service_account.email),
+ self.user_deleted.user_id: '',
+ }
+ self.assertEqual(display_names_by_id, expected_display_names)
+ self.assertEqual(display_emails_by_id, expected_display_emails)
+
+ @mock.patch('services.client_config_svc.GetServiceAccountMap')
+ def testUserCreateDisplayNamesAndEmails_ProjectMember(self, fake_account_map):
+ fake_account_map.return_value = {'service@account.com': 'Service'}
+ users = [self.user_1, self.user_2, self.user_3, self.user_4,
+ self.service_account, self.user_deleted]
+ self.project.committer_ids.append(self.requester.user_id)
+ (display_names_by_id,
+ display_emails_by_id) = framework_bizobj.CreateUserDisplayNamesAndEmails(
+ self.cnxn, self.services, self.user_auth, users)
+ expected_display_names = {
+ self.user_1.user_id: self.user_1.email, # Project member
+ self.user_2.user_id: self.user_2.email, # Project member and unobscured
+ self.user_3.user_id: testing_helpers.ObscuredEmail(self.user_3.email),
+ self.user_4.user_id: self.user_4.email, # Unobscured email
+ self.service_account.user_id: 'Service',
+ self.user_deleted.user_id: framework_constants.DELETED_USER_NAME
+ }
+ expected_display_emails = {
+ self.user_1.user_id: self.user_1.email, # Project member
+ self.user_2.user_id: self.user_2.email, # Project member and unobscured
+ self.user_3.user_id: testing_helpers.ObscuredEmail(self.user_3.email),
+ self.user_4.user_id: self.user_4.email, # Unobscured email
+ self.service_account.user_id: self.service_account.email,
+ self.user_deleted.user_id: ''
+ }
+ self.assertEqual(display_names_by_id, expected_display_names)
+ self.assertEqual(display_emails_by_id, expected_display_emails)
+
+ @mock.patch('services.client_config_svc.GetServiceAccountMap')
+ def testUserCreateDisplayNamesAndEmails_Admin(self, fake_account_map):
+ fake_account_map.return_value = {'service@account.com': 'Service'}
+ users = [self.user_1, self.user_2, self.user_3, self.user_4,
+ self.service_account, self.user_deleted]
+ self.user_auth.user_pb.is_site_admin = True
+ (display_names_by_id,
+ display_emails_by_id) = framework_bizobj.CreateUserDisplayNamesAndEmails(
+ self.cnxn, self.services, self.user_auth, users)
+ expected_display_names = {
+ self.user_1.user_id: self.user_1.email,
+ self.user_2.user_id: self.user_2.email,
+ self.user_3.user_id: self.user_3.email,
+ self.user_4.user_id: self.user_4.email,
+ self.service_account.user_id: 'Service',
+ self.user_deleted.user_id: framework_constants.DELETED_USER_NAME}
+ expected_display_emails = {
+ self.user_1.user_id: self.user_1.email,
+ self.user_2.user_id: self.user_2.email,
+ self.user_3.user_id: self.user_3.email,
+ self.user_4.user_id: self.user_4.email,
+ self.service_account.user_id: self.service_account.email,
+ self.user_deleted.user_id: ''
+ }
+
+ self.assertEqual(display_names_by_id, expected_display_names)
+ self.assertEqual(display_emails_by_id, expected_display_emails)
+
+
+class ParseAndObscureAddressTest(unittest.TestCase):
+
+ def testParseAndObscureAddress(self):
+ email = 'sir.chicken@farm.test'
+ (username, user_domain, obscured_username,
+ obscured_email) = framework_bizobj.ParseAndObscureAddress(email)
+
+ self.assertEqual(username, 'sir.chicken')
+ self.assertEqual(user_domain, 'farm.test')
+ self.assertEqual(obscured_username, 'sir.c')
+ self.assertEqual(obscured_email, 'sir.c...@farm.test')
+
+
+class FilterViewableEmailsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.user_1 = self.services.user.TestAddUser(
+ 'user_1@test.com', 111, obscure_email=True)
+ self.user_2 = self.services.user.TestAddUser(
+ 'user_2@test.com', 222, obscure_email=False)
+ self.requester = self.services.user.TestAddUser(
+ 'user_5@test.com', 555, obscure_email=True)
+ self.user_auth = authdata.AuthData(
+ user_id=self.requester.user_id, email=self.requester.email)
+ self.user_auth.user_pb.email = self.user_auth.email
+ self.project = self.services.project.TestAddProject(
+ 'proj', project_id=789, owner_ids=[111], committer_ids=[222])
+
+ def testFilterViewableEmail_Anon(self):
+ anon = authdata.AuthData()
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, anon, other_users)
+ self.assertEqual(filtered_users, [])
+
+ def testFilterViewableEmail_Self(self):
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, [self.user_auth.user_pb])
+ self.assertEqual(filtered_users, [self.user_auth.user_pb])
+
+ def testFilterViewableEmail_SiteAdmin(self):
+ self.user_auth.user_pb.is_site_admin = True
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, other_users)
+ self.assertEqual(filtered_users, other_users)
+
+ def testFilterViewableEmail_InDisplayNameGroup(self):
+ display_name_group_id = 666
+ self.services.usergroup.TestAddGroupSettings(
+ display_name_group_id, 'display-perm-perm@email.com')
+ settings.full_emails_perm_groups = ['display-perm-perm@email.com']
+ self.user_auth.effective_ids.add(display_name_group_id)
+
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, other_users)
+ self.assertEqual(filtered_users, other_users)
+
+ def testFilterViewableEmail_NonMember(self):
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, other_users)
+ self.assertEqual(filtered_users, [])
+
+ def testFilterViewableEmail_ProjectMember(self):
+ self.project.committer_ids.append(self.requester.user_id)
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, other_users)
+ self.assertEqual(filtered_users, other_users)
+
+
+# TODO(https://crbug.com/monorail/8192): Remove deprecated tests.
+class DeprecatedShouldRevealEmailTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.user_1 = self.services.user.TestAddUser(
+ 'user_1@test.com', 111, obscure_email=True)
+ self.user_2 = self.services.user.TestAddUser(
+ 'user_2@test.com', 222, obscure_email=False)
+ self.requester = self.services.user.TestAddUser(
+ 'user_5@test.com', 555, obscure_email=True)
+ self.user_auth = authdata.AuthData(
+ user_id=self.requester.user_id, email=self.requester.email)
+ self.user_auth.user_pb.email = self.user_auth.email
+ self.project = self.services.project.TestAddProject(
+ 'proj', project_id=789, owner_ids=[111], committer_ids=[222])
+
+ def testDeprecatedShouldRevealEmail_Anon(self):
+ anon = authdata.AuthData()
+ self.assertFalse(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ anon, self.project, self.user_1.email))
+ self.assertFalse(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ anon, self.project, self.user_2.email))
+
+ def testDeprecatedShouldRevealEmail_Self(self):
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_auth.user_pb.email))
+
+ def testDeprecatedShouldRevealEmail_SiteAdmin(self):
+ self.user_auth.user_pb.is_site_admin = True
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_1.email))
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_2.email))
+
+ def testDeprecatedShouldRevealEmail_ProjectMember(self):
+ self.project.committer_ids.append(self.requester.user_id)
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_1.email))
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_2.email))
+
+ def testDeprecatedShouldRevealEmail_NonMember(self):
+ self.assertFalse(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_1.email))
+ self.assertFalse(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_2.email))
+
+
+class ArtifactTest(unittest.TestCase):
+
+ def setUp(self):
+ # No custom fields. Exclusive prefixes: Type, Priority, Milestone.
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ def testMergeLabels_Labels(self):
+ # Empty case.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ [], [], [], self.config)
+ self.assertEqual(merged_labels, [])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, [])
+
+ # No-op case.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['a', 'b'], [], [], self.config)
+ self.assertEqual(merged_labels, ['a', 'b'])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, [])
+
+ # Adding and removing at the same time.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['a', 'b', 'd'], ['c'], ['d'], self.config)
+ self.assertEqual(merged_labels, ['a', 'b', 'c'])
+ self.assertEqual(update_add, ['c'])
+ self.assertEqual(update_remove, ['d'])
+
+ # Removing a non-matching label has no effect.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['a', 'b', 'd'], ['d'], ['e'], self.config)
+ self.assertEqual(merged_labels, ['a', 'b', 'd'])
+ self.assertEqual(update_add, []) # d was already there.
+ self.assertEqual(update_remove, []) # there was no e.
+
+ # We can add and remove at the same time.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], ['Hot'], ['OpSys-OSX'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'Hot'])
+ self.assertEqual(update_add, ['Hot'])
+ self.assertEqual(update_remove, ['OpSys-OSX'])
+
+ # Adding Priority-High replaces Priority-Medium.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], ['Priority-High', 'OpSys-Win'], [],
+ self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX', 'Priority-High', 'OpSys-Win'])
+ self.assertEqual(update_add, ['Priority-High', 'OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # Adding Priority-High and Priority-Low replaces with High only.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'],
+ ['Priority-High', 'Priority-Low'], [], self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX', 'Priority-High'])
+ self.assertEqual(update_add, ['Priority-High'])
+ self.assertEqual(update_remove, [])
+
+ # Removing a mix of matching and non-matching labels only does matching.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], [], ['Priority-Medium', 'OpSys-Win'],
+ self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX'])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, ['Priority-Medium'])
+
+ # Multi-part labels work as expected.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX-11'],
+ ['Priority-Medium-Rare', 'OpSys-OSX-13'], [], self.config)
+ self.assertEqual(
+ merged_labels, ['OpSys-OSX-11', 'Priority-Medium-Rare', 'OpSys-OSX-13'])
+ self.assertEqual(update_add, ['Priority-Medium-Rare', 'OpSys-OSX-13'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part exclusive prefixes only filter labels that match whole prefix.
+ self.config.exclusive_label_prefixes.append('Branch-Name')
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Branch-Name-xyz'],
+ ['Branch-Prediction', 'Branch-Name-Beta'], [], self.config)
+ self.assertEqual(merged_labels, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_add, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_remove, [])
+
+ def testMergeLabels_SingleValuedEnums(self):
+ self.config.field_defs.append(tracker_pb2.FieldDef(
+ field_id=1, field_name='Size',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE,
+ is_multivalued=False))
+ self.config.field_defs.append(tracker_pb2.FieldDef(
+ field_id=1, field_name='Branch-Name',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE,
+ is_multivalued=False))
+
+ # We can add a label for a single-valued enum.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], ['Size-L'], [], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'OpSys-OSX', 'Size-L'])
+ self.assertEqual(update_add, ['Size-L'])
+ self.assertEqual(update_remove, [])
+
+ # Adding and removing the same label adds it.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium'], ['Size-M'], ['Size-M'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'Size-M'])
+ self.assertEqual(update_add, ['Size-M'])
+ self.assertEqual(update_remove, [])
+
+ # Adding Size-L replaces Size-M.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'Size-M'], ['Size-L', 'OpSys-Win'], [],
+ self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'Size-L', 'OpSys-Win'])
+ self.assertEqual(update_add, ['Size-L', 'OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # Adding Size-L and Size-XL replaces with L only.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Size-M', 'OpSys-OSX'], ['Size-L', 'Size-XL'], [], self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX', 'Size-L'])
+ self.assertEqual(update_add, ['Size-L'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part labels work as expected.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Size-M', 'OpSys-OSX'], ['Size-M-USA'], [], self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX', 'Size-M-USA'])
+ self.assertEqual(update_add, ['Size-M-USA'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part enum names only filter labels that match whole name.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Branch-Name-xyz'],
+ ['Branch-Prediction', 'Branch-Name-Beta'], [], self.config)
+ self.assertEqual(merged_labels, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_add, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_remove, [])
+
+ def testMergeLabels_MultiValuedEnums(self):
+ self.config.field_defs.append(tracker_pb2.FieldDef(
+ field_id=1, field_name='OpSys',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE,
+ is_multivalued=True))
+ self.config.field_defs.append(tracker_pb2.FieldDef(
+ field_id=1, field_name='Branch-Name',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE,
+ is_multivalued=True))
+
+ # We can add a label for a multi-valued enum.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium'], ['OpSys-Win'], [], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'OpSys-Win'])
+ self.assertEqual(update_add, ['OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # We can remove a matching label for a multi-valued enum.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-Win'], [], ['OpSys-Win'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium'])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, ['OpSys-Win'])
+
+ # We can remove a non-matching label and it is a no-op.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], [], ['OpSys-Win'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'OpSys-OSX'])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, [])
+
+ # Adding and removing the same label adds it.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium'], ['OpSys-Win'], ['OpSys-Win'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'OpSys-Win'])
+ self.assertEqual(update_add, ['OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # We can add a label for a multi-valued enum, even if matching exists.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], ['OpSys-Win'], [], self.config)
+ self.assertEqual(
+ merged_labels, ['Priority-Medium', 'OpSys-OSX', 'OpSys-Win'])
+ self.assertEqual(update_add, ['OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # Adding two at the same time is fine.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Size-M', 'OpSys-OSX'], ['OpSys-Win', 'OpSys-Vax'], [], self.config)
+ self.assertEqual(
+ merged_labels, ['Size-M', 'OpSys-OSX', 'OpSys-Win', 'OpSys-Vax'])
+ self.assertEqual(update_add, ['OpSys-Win', 'OpSys-Vax'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part labels work as expected.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Size-M', 'OpSys-OSX'], ['OpSys-Win-10'], [], self.config)
+ self.assertEqual(merged_labels, ['Size-M', 'OpSys-OSX', 'OpSys-Win-10'])
+ self.assertEqual(update_add, ['OpSys-Win-10'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part enum names don't mess up anything.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Branch-Name-xyz'],
+ ['Branch-Prediction', 'Branch-Name-Beta'], [], self.config)
+ self.assertEqual(
+ merged_labels,
+ ['Branch-Name-xyz', 'Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_add, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_remove, [])
+
+
+class CanonicalizeLabelTest(unittest.TestCase):
+
+ def testCanonicalizeLabel(self):
+ self.assertEqual(None, framework_bizobj.CanonicalizeLabel(None))
+ self.assertEqual('FooBar', framework_bizobj.CanonicalizeLabel('Foo Bar '))
+ self.assertEqual('Foo.Bar',
+ framework_bizobj.CanonicalizeLabel('Foo . Bar '))
+ self.assertEqual('Foo-Bar',
+ framework_bizobj.CanonicalizeLabel('Foo - Bar '))
+
+
+class UserIsInProjectTest(unittest.TestCase):
+
+ def testUserIsInProject(self):
+ p = project_pb2.Project()
+ self.assertFalse(framework_bizobj.UserIsInProject(p, {10}))
+ self.assertFalse(framework_bizobj.UserIsInProject(p, set()))
+
+ p.owner_ids.extend([1, 2, 3])
+ p.committer_ids.extend([4, 5, 6])
+ p.contributor_ids.extend([7, 8, 9])
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {1}))
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {4}))
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {7}))
+ self.assertFalse(framework_bizobj.UserIsInProject(p, {10}))
+
+ # Membership via group membership
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {10, 4}))
+
+ # Membership via several group memberships
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {1, 4}))
+
+ # Several irrelevant group memberships
+ self.assertFalse(framework_bizobj.UserIsInProject(p, {10, 11, 12}))
+
+
+class IsValidColumnSpecTest(unittest.TestCase):
+
+ def testIsValidColumnSpec(self):
+ self.assertTrue(
+ framework_bizobj.IsValidColumnSpec('some columns hey-honk hay.honk'))
+
+ self.assertTrue(framework_bizobj.IsValidColumnSpec('some'))
+
+ self.assertTrue(framework_bizobj.IsValidColumnSpec(''))
+
+ def testIsValidColumnSpec_NotValid(self):
+ self.assertFalse(
+ framework_bizobj.IsValidColumnSpec('some columns hey-honk hay.'))
+
+ self.assertFalse(framework_bizobj.IsValidColumnSpec('some columns hey-'))
+
+ self.assertFalse(framework_bizobj.IsValidColumnSpec('-some columns hey'))
+
+ self.assertFalse(framework_bizobj.IsValidColumnSpec('some .columns hey'))
+
+
+class ValidatePrefTest(unittest.TestCase):
+
+ def testUnknown(self):
+ msg = framework_bizobj.ValidatePref('shoe_size', 'true')
+ self.assertIn('shoe_size', msg)
+ self.assertIn('Unknown', msg)
+
+ msg = framework_bizobj.ValidatePref('', 'true')
+ self.assertIn('Unknown', msg)
+
+ def testTooLong(self):
+ msg = framework_bizobj.ValidatePref('code_font', 'x' * 100)
+ self.assertIn('code_font', msg)
+ self.assertIn('too long', msg)
+
+ def testKnownValid(self):
+ self.assertIsNone(framework_bizobj.ValidatePref('code_font', 'true'))
+ self.assertIsNone(framework_bizobj.ValidatePref('code_font', 'false'))
+
+ def testKnownInvalid(self):
+ msg = framework_bizobj.ValidatePref('code_font', '')
+ self.assertIn('Invalid', msg)
+
+ msg = framework_bizobj.ValidatePref('code_font', 'sometimes')
+ self.assertIn('Invalid', msg)
+
+
+class IsRestrictNewIssuesUserTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.services.user.TestAddUser('corp_user@example.com', 111)
+ self.services.user.TestAddUser('corp_group@example.com', 888)
+ self.services.usergroup.TestAddGroupSettings(888, 'corp_group@example.com')
+
+ @mock.patch(
+ 'settings.restrict_new_issues_user_groups', ['corp_group@example.com'])
+ def testNonRestrictNewIssuesUser(self):
+ """We detect when a user is not part of a corp user group."""
+ self.assertFalse(
+ framework_bizobj.IsRestrictNewIssuesUser(self.cnxn, self.services, 111))
+
+ @mock.patch(
+ 'settings.restrict_new_issues_user_groups', ['corp_group@example.com'])
+ def testRestrictNewIssuesUser(self):
+ """We detect when a user is a member of such a group."""
+ self.services.usergroup.TestAddMembers(888, [111, 222])
+ self.assertTrue(
+ framework_bizobj.IsRestrictNewIssuesUser(self.cnxn, self.services, 111))
+
+
+class IsPublicIssueNoticeUserTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(), usergroup=fake.UserGroupService())
+ self.services.user.TestAddUser('corp_user@example.com', 111)
+ self.services.user.TestAddUser('corp_group@example.com', 888)
+ self.services.usergroup.TestAddGroupSettings(888, 'corp_group@example.com')
+
+ @mock.patch(
+ 'settings.public_issue_notice_user_groups', ['corp_group@example.com'])
+ def testNonPublicIssueNoticeUser(self):
+ """We detect when a user is not part of a corp user group."""
+ self.assertFalse(
+ framework_bizobj.IsPublicIssueNoticeUser(self.cnxn, self.services, 111))
+
+ @mock.patch(
+ 'settings.public_issue_notice_user_groups', ['corp_group@example.com'])
+ def testPublicIssueNoticeUser(self):
+ """We detect when a user is a member of such a group."""
+ self.services.usergroup.TestAddMembers(888, [111, 222])
+ self.assertTrue(
+ framework_bizobj.IsPublicIssueNoticeUser(self.cnxn, self.services, 111))
+
+
+class GetEffectiveIdsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(), usergroup=fake.UserGroupService())
+ self.services.user.TestAddUser('test@example.com', 111)
+
+ def testNoMemberships(self):
+ """No user groups means effective_ids == {user_id}."""
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111}})
+
+ def testNormalMemberships(self):
+ """effective_ids should be {user_id, group_id...}."""
+ self.services.usergroup.TestAddMembers(888, [111])
+ self.services.usergroup.TestAddMembers(999, [111])
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111, 888, 999}})
+
+ def testComputedUserGroup(self):
+ """effective_ids should be {user_id, group_id...}."""
+ self.services.usergroup.TestAddGroupSettings(888, 'everyone@example.com')
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111, 888}})
+
+ def testAccountHasParent(self):
+ """The parent's effective_ids are added to child's."""
+ child = self.services.user.TestAddUser('child@example.com', 111)
+ child.linked_parent_id = 222
+ parent = self.services.user.TestAddUser('parent@example.com', 222)
+ parent.linked_child_ids = [111]
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111, 222}})
+
+ self.services.usergroup.TestAddMembers(888, [111])
+ self.services.usergroup.TestAddMembers(999, [222])
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111, 222, 888, 999}})
+
+ def testAccountHasChildren(self):
+ """All linked child effective_ids are added to parent's."""
+ child1 = self.services.user.TestAddUser('child1@example.com', 111)
+ child1.linked_parent_id = 333
+ child2 = self.services.user.TestAddUser('child3@example.com', 222)
+ child2.linked_parent_id = 333
+ parent = self.services.user.TestAddUser('parent@example.com', 333)
+ parent.linked_child_ids = [111, 222]
+
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [333])
+ self.assertEqual(effective_ids, {333: {111, 222, 333}})
+
+ self.services.usergroup.TestAddMembers(888, [111])
+ self.services.usergroup.TestAddMembers(999, [222])
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [333])
+ self.assertEqual(effective_ids, {333: {111, 222, 333, 888, 999}})
diff --git a/framework/test/framework_helpers_test.py b/framework/test/framework_helpers_test.py
new file mode 100644
index 0000000..1d0146c
--- /dev/null
+++ b/framework/test/framework_helpers_test.py
@@ -0,0 +1,563 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the framework_helpers module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+
+import mox
+import time
+
+from businesslogic import work_env
+from framework import framework_helpers
+from framework import framework_views
+from proto import features_pb2
+from proto import project_pb2
+from proto import user_pb2
+from services import service_manager
+from testing import fake
+from testing import testing_helpers
+
+
+class HelperFunctionsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.time = self.mox.CreateMock(framework_helpers.time)
+ framework_helpers.time = self.time # Point to a mocked out time module.
+
+ def tearDown(self):
+ framework_helpers.time = time # Point back to the time module.
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testRetryDecorator_ExceedFailures(self):
+ class Tracker(object):
+ func_called = 0
+ tracker = Tracker()
+
+ # Use a function that always fails.
+ @framework_helpers.retry(2, delay=1, backoff=2)
+ def testFunc(tracker):
+ tracker.func_called += 1
+ raise Exception('Failed')
+
+ self.time.sleep(1).AndReturn(None)
+ self.time.sleep(2).AndReturn(None)
+ self.mox.ReplayAll()
+ with self.assertRaises(Exception):
+ testFunc(tracker)
+ self.mox.VerifyAll()
+ self.assertEqual(3, tracker.func_called)
+
+ def testRetryDecorator_EventuallySucceed(self):
+ class Tracker(object):
+ func_called = 0
+ tracker = Tracker()
+
+ # Use a function that succeeds on the 2nd attempt.
+ @framework_helpers.retry(2, delay=1, backoff=2)
+ def testFunc(tracker):
+ tracker.func_called += 1
+ if tracker.func_called < 2:
+ raise Exception('Failed')
+
+ self.time.sleep(1).AndReturn(None)
+ self.mox.ReplayAll()
+ testFunc(tracker)
+ self.mox.VerifyAll()
+ self.assertEqual(2, tracker.func_called)
+
+ def testGetRoleName(self):
+ proj = project_pb2.Project()
+ proj.owner_ids.append(111)
+ proj.committer_ids.append(222)
+ proj.contributor_ids.append(333)
+
+ self.assertEqual(None, framework_helpers.GetRoleName(set(), proj))
+
+ self.assertEqual('Owner', framework_helpers.GetRoleName({111}, proj))
+ self.assertEqual('Committer', framework_helpers.GetRoleName({222}, proj))
+ self.assertEqual('Contributor', framework_helpers.GetRoleName({333}, proj))
+
+ self.assertEqual(
+ 'Owner', framework_helpers.GetRoleName({111, 222, 999}, proj))
+ self.assertEqual(
+ 'Committer', framework_helpers.GetRoleName({222, 333, 999}, proj))
+ self.assertEqual(
+ 'Contributor', framework_helpers.GetRoleName({333, 999}, proj))
+
+ def testGetHotlistRoleName(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+ hotlist.follower_ids.append(333)
+
+ self.assertEqual(None, framework_helpers.GetHotlistRoleName(set(), hotlist))
+
+ self.assertEqual(
+ 'Owner', framework_helpers.GetHotlistRoleName({111}, hotlist))
+ self.assertEqual(
+ 'Editor', framework_helpers.GetHotlistRoleName({222}, hotlist))
+ self.assertEqual(
+ 'Follower', framework_helpers.GetHotlistRoleName({333}, hotlist))
+
+ self.assertEqual(
+ 'Owner', framework_helpers.GetHotlistRoleName({111, 222, 999}, hotlist))
+ self.assertEqual(
+ 'Editor', framework_helpers.GetHotlistRoleName(
+ {222, 333, 999}, hotlist))
+ self.assertEqual(
+ 'Follower', framework_helpers.GetHotlistRoleName({333, 999}, hotlist))
+
+
+class UrlFormattingTest(unittest.TestCase):
+ """Tests for URL formatting."""
+
+ def setUp(self):
+ self.services = service_manager.Services(user=fake.UserService())
+
+ def testFormatMovedProjectURL(self):
+ """Project foo has been moved to bar. User is visiting /p/foo/..."""
+ mr = testing_helpers.MakeMonorailRequest()
+ mr.current_page_url = '/p/foo/'
+ self.assertEqual(
+ '/p/bar/',
+ framework_helpers.FormatMovedProjectURL(mr, 'bar'))
+
+ mr.current_page_url = '/p/foo/issues/list'
+ self.assertEqual(
+ '/p/bar/issues/list',
+ framework_helpers.FormatMovedProjectURL(mr, 'bar'))
+
+ mr.current_page_url = '/p/foo/issues/detail?id=123'
+ self.assertEqual(
+ '/p/bar/issues/detail?id=123',
+ framework_helpers.FormatMovedProjectURL(mr, 'bar'))
+
+ mr.current_page_url = '/p/foo/issues/detail?id=123#c7'
+ self.assertEqual(
+ '/p/bar/issues/detail?id=123#c7',
+ framework_helpers.FormatMovedProjectURL(mr, 'bar'))
+
+ def testFormatURL(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ path = '/dude/wheres/my/car'
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ url = framework_helpers.FormatURL(recognized_params, path)
+ self.assertEqual(path, url)
+
+ def testFormatURLWithRecognizedParams(self):
+ params = {}
+ query = []
+ for name in framework_helpers.RECOGNIZED_PARAMS:
+ params[name] = name
+ query.append('%s=%s' % (name, 123))
+ path = '/dude/wheres/my/car'
+ expected = '%s?%s' % (path, '&'.join(query))
+ mr = testing_helpers.MakeMonorailRequest(path=expected)
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ # No added params.
+ url = framework_helpers.FormatURL(recognized_params, path)
+ self.assertEqual(expected, url)
+
+ def testFormatURLWithKeywordArgs(self):
+ params = {}
+ query_pairs = []
+ for name in framework_helpers.RECOGNIZED_PARAMS:
+ params[name] = name
+ if name != 'can' and name != 'start':
+ query_pairs.append('%s=%s' % (name, 123))
+ path = '/dude/wheres/my/car'
+ mr = testing_helpers.MakeMonorailRequest(
+ path='%s?%s' % (path, '&'.join(query_pairs)))
+ query_pairs.append('can=yep')
+ query_pairs.append('start=486')
+ query_string = '&'.join(query_pairs)
+ expected = '%s?%s' % (path, query_string)
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ url = framework_helpers.FormatURL(
+ recognized_params, path, can='yep', start=486)
+ self.assertEqual(expected, url)
+
+ def testFormatURLWithKeywordArgsAndID(self):
+ params = {}
+ query_pairs = []
+ query_pairs.append('id=200') # id should be the first parameter.
+ for name in framework_helpers.RECOGNIZED_PARAMS:
+ params[name] = name
+ if name != 'can' and name != 'start':
+ query_pairs.append('%s=%s' % (name, 123))
+ path = '/dude/wheres/my/car'
+ mr = testing_helpers.MakeMonorailRequest(
+ path='%s?%s' % (path, '&'.join(query_pairs)))
+ query_pairs.append('can=yep')
+ query_pairs.append('start=486')
+ query_string = '&'.join(query_pairs)
+ expected = '%s?%s' % (path, query_string)
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ url = framework_helpers.FormatURL(
+ recognized_params, path, can='yep', start=486, id=200)
+ self.assertEqual(expected, url)
+
+ def testFormatURLWithStrangeParams(self):
+ mr = testing_helpers.MakeMonorailRequest(path='/foo?start=0')
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ url = framework_helpers.FormatURL(
+ recognized_params, '/foo',
+ r=0, path='/foo/bar', sketchy='/foo/ bar baz ')
+ self.assertEqual(
+ '/foo?start=0&path=/foo/bar&r=0&sketchy=/foo/%20bar%20baz%20',
+ url)
+
+ def testFormatAbsoluteURL(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/some-path',
+ headers={'Host': 'www.test.com'})
+ self.assertEqual(
+ 'http://www.test.com/p/proj/some/path',
+ framework_helpers.FormatAbsoluteURL(mr, '/some/path'))
+
+ def testFormatAbsoluteURL_CommonRequestParams(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/some-path?foo=bar&can=1',
+ headers={'Host': 'www.test.com'})
+ self.assertEqual(
+ 'http://www.test.com/p/proj/some/path?can=1',
+ framework_helpers.FormatAbsoluteURL(mr, '/some/path'))
+ self.assertEqual(
+ 'http://www.test.com/p/proj/some/path',
+ framework_helpers.FormatAbsoluteURL(
+ mr, '/some/path', copy_params=False))
+
+ def testFormatAbsoluteURL_NoProject(self):
+ path = '/some/path'
+ _request, mr = testing_helpers.GetRequestObjects(
+ headers={'Host': 'www.test.com'}, path=path)
+ url = framework_helpers.FormatAbsoluteURL(mr, path, include_project=False)
+ self.assertEqual(url, 'http://www.test.com/some/path')
+
+ def testGetHostPort_Local(self):
+ """We use testing-app.appspot.com when running locally."""
+ self.assertEqual('testing-app.appspot.com',
+ framework_helpers.GetHostPort())
+ self.assertEqual('testing-app.appspot.com',
+ framework_helpers.GetHostPort(project_name='proj'))
+
+ @mock.patch('settings.preferred_domains',
+ {'testing-app.appspot.com': 'example.com'})
+ def testGetHostPort_PreferredDomain(self):
+ """A prod server can have a preferred domain."""
+ self.assertEqual('example.com',
+ framework_helpers.GetHostPort())
+ self.assertEqual('example.com',
+ framework_helpers.GetHostPort(project_name='proj'))
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.com', '*': 'unbranded.com'})
+ @mock.patch('settings.preferred_domains',
+ {'testing-app.appspot.com': 'example.com'})
+ def testGetHostPort_BrandedDomain(self):
+ """A prod server can have a preferred domain."""
+ self.assertEqual('example.com',
+ framework_helpers.GetHostPort())
+ self.assertEqual('branded.com',
+ framework_helpers.GetHostPort(project_name='proj'))
+ self.assertEqual('unbranded.com',
+ framework_helpers.GetHostPort(project_name='other-proj'))
+
+ def testIssueCommentURL(self):
+ hostport = 'port.someplex.com'
+ proj = project_pb2.Project()
+ proj.project_name = 'proj'
+
+ url = 'https://port.someplex.com/p/proj/issues/detail?id=2'
+ actual_url = framework_helpers.IssueCommentURL(
+ hostport, proj, 2)
+ self.assertEqual(actual_url, url)
+
+ url = 'https://port.someplex.com/p/proj/issues/detail?id=2#c2'
+ actual_url = framework_helpers.IssueCommentURL(
+ hostport, proj, 2, seq_num=2)
+ self.assertEqual(actual_url, url)
+
+
+class WordWrapSuperLongLinesTest(unittest.TestCase):
+
+ def testEmptyLogMessage(self):
+ msg = ''
+ wrapped_msg = framework_helpers.WordWrapSuperLongLines(msg)
+ self.assertEqual(wrapped_msg, '')
+
+ def testShortLines(self):
+ msg = 'one\ntwo\nthree\n'
+ wrapped_msg = framework_helpers.WordWrapSuperLongLines(msg)
+ expected = 'one\ntwo\nthree\n'
+ self.assertEqual(wrapped_msg, expected)
+
+ def testOneLongLine(self):
+ msg = ('This is a super long line that just goes on and on '
+ 'and it seems like it will never stop because it is '
+ 'super long and it was entered by a user who had no '
+ 'familiarity with the return key.')
+ wrapped_msg = framework_helpers.WordWrapSuperLongLines(msg)
+ expected = ('This is a super long line that just goes on and on and it '
+ 'seems like it will never stop because it\n'
+ 'is super long and it was entered by a user who had no '
+ 'familiarity with the return key.')
+ self.assertEqual(wrapped_msg, expected)
+
+ msg2 = ('This is a super long line that just goes on and on '
+ 'and it seems like it will never stop because it is '
+ 'super long and it was entered by a user who had no '
+ 'familiarity with the return key. '
+ 'This is a super long line that just goes on and on '
+ 'and it seems like it will never stop because it is '
+ 'super long and it was entered by a user who had no '
+ 'familiarity with the return key.')
+ wrapped_msg2 = framework_helpers.WordWrapSuperLongLines(msg2)
+ expected2 = ('This is a super long line that just goes on and on and it '
+ 'seems like it will never stop because it\n'
+ 'is super long and it was entered by a user who had no '
+ 'familiarity with the return key. This is a\n'
+ 'super long line that just goes on and on and it seems like '
+ 'it will never stop because it is super\n'
+ 'long and it was entered by a user who had no familiarity '
+ 'with the return key.')
+ self.assertEqual(wrapped_msg2, expected2)
+
+ def testMixOfShortAndLong(self):
+ msg = ('[Author: mpcomplete]\n'
+ '\n'
+ # Description on one long line
+ 'Fix a memory leak in JsArray and JsObject for the IE and NPAPI '
+ 'ports. Each time you call GetElement* or GetProperty* to '
+ 'retrieve string or object token, the token would be leaked. '
+ 'I added a JsScopedToken to ensure that the right thing is '
+ 'done when the object leaves scope, depending on the platform.\n'
+ '\n'
+ 'R=zork\n'
+ 'CC=google-gears-eng@googlegroups.com\n'
+ 'DELTA=108 (52 added, 36 deleted, 20 changed)\n'
+ 'OCL=5932446\n'
+ 'SCL=5933728\n')
+ wrapped_msg = framework_helpers.WordWrapSuperLongLines(msg)
+ expected = (
+ '[Author: mpcomplete]\n'
+ '\n'
+ 'Fix a memory leak in JsArray and JsObject for the IE and NPAPI '
+ 'ports. Each time you call\n'
+ 'GetElement* or GetProperty* to retrieve string or object token, the '
+ 'token would be leaked. I added\n'
+ 'a JsScopedToken to ensure that the right thing is done when the '
+ 'object leaves scope, depending on\n'
+ 'the platform.\n'
+ '\n'
+ 'R=zork\n'
+ 'CC=google-gears-eng@googlegroups.com\n'
+ 'DELTA=108 (52 added, 36 deleted, 20 changed)\n'
+ 'OCL=5932446\n'
+ 'SCL=5933728\n')
+ self.assertEqual(wrapped_msg, expected)
+
+
+class ComputeListDeltasTest(unittest.TestCase):
+
+ def DoOne(self, old=None, new=None, added=None, removed=None):
+ """Run one call to the target method and check expected results."""
+ actual_added, actual_removed = framework_helpers.ComputeListDeltas(
+ old, new)
+ self.assertItemsEqual(added, actual_added)
+ self.assertItemsEqual(removed, actual_removed)
+
+ def testEmptyLists(self):
+ self.DoOne(old=[], new=[], added=[], removed=[])
+ self.DoOne(old=[1, 2], new=[], added=[], removed=[1, 2])
+ self.DoOne(old=[], new=[1, 2], added=[1, 2], removed=[])
+
+ def testUnchanged(self):
+ self.DoOne(old=[1], new=[1], added=[], removed=[])
+ self.DoOne(old=[1, 2], new=[1, 2], added=[], removed=[])
+ self.DoOne(old=[1, 2], new=[2, 1], added=[], removed=[])
+
+ def testCompleteChange(self):
+ self.DoOne(old=[1, 2], new=[3, 4], added=[3, 4], removed=[1, 2])
+
+ def testGeneralChange(self):
+ self.DoOne(old=[1, 2], new=[2], added=[], removed=[1])
+ self.DoOne(old=[1], new=[1, 2], added=[2], removed=[])
+ self.DoOne(old=[1, 2], new=[2, 3], added=[3], removed=[1])
+
+
+class UserSettingsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mr = testing_helpers.MakeMonorailRequest()
+ self.cnxn = 'cnxn'
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+
+ def testGatherUnifiedSettingsPageData(self):
+ mr = self.mr
+ mr.auth.user_view = framework_views.StuffUserView(100, 'user@invalid', True)
+ mr.auth.user_view.profile_url = '/u/profile/url'
+ userprefs = user_pb2.UserPrefs(
+ prefs=[user_pb2.UserPrefValue(name='public_issue_notice', value='true')])
+ page_data = framework_helpers.UserSettings.GatherUnifiedSettingsPageData(
+ mr.auth.user_id, mr.auth.user_view, mr.auth.user_pb, userprefs)
+
+ expected_keys = [
+ 'settings_user',
+ 'settings_user_pb',
+ 'settings_user_is_banned',
+ 'self',
+ 'profile_url_fragment',
+ 'preview_on_hover',
+ 'settings_user_prefs',
+ ]
+ self.assertItemsEqual(expected_keys, list(page_data.keys()))
+
+ self.assertEqual('profile/url', page_data['profile_url_fragment'])
+ self.assertTrue(page_data['settings_user_prefs'].public_issue_notice)
+ self.assertFalse(page_data['settings_user_prefs'].restrict_new_issues)
+
+ def testGatherUnifiedSettingsPageData_NoUserPrefs(self):
+ """If UserPrefs were not loaded, consider them all false."""
+ mr = self.mr
+ mr.auth.user_view = framework_views.StuffUserView(100, 'user@invalid', True)
+ userprefs = None
+
+ page_data = framework_helpers.UserSettings.GatherUnifiedSettingsPageData(
+ mr.auth.user_id, mr.auth.user_view, mr.auth.user_pb, userprefs)
+
+ self.assertFalse(page_data['settings_user_prefs'].public_issue_notice)
+ self.assertFalse(page_data['settings_user_prefs'].restrict_new_issues)
+
+ def testProcessBanForm(self):
+ """We can ban and unban users."""
+ user = self.services.user.TestAddUser('one@example.com', 111)
+ post_data = {'banned': 1, 'banned_reason': 'rude'}
+ framework_helpers.UserSettings.ProcessBanForm(
+ self.cnxn, self.services.user, post_data, 111, user)
+ self.assertEqual('rude', user.banned)
+
+ post_data = {} # not banned
+ framework_helpers.UserSettings.ProcessBanForm(
+ self.cnxn, self.services.user, post_data, 111, user)
+ self.assertEqual('', user.banned)
+
+ def testProcessSettingsForm_OldStylePrefs(self):
+ """We can set prefs that are stored in the User PB."""
+ user = self.services.user.TestAddUser('one@example.com', 111)
+ post_data = {'obscure_email': 1, 'notify': 1}
+ with work_env.WorkEnv(self.mr, self.services) as we:
+ framework_helpers.UserSettings.ProcessSettingsForm(
+ we, post_data, user)
+
+ self.assertTrue(user.obscure_email)
+ self.assertTrue(user.notify_issue_change)
+ self.assertFalse(user.notify_starred_ping)
+
+ def testProcessSettingsForm_NewStylePrefs(self):
+ """We can set prefs that are stored in the UserPrefs PB."""
+ user = self.services.user.TestAddUser('one@example.com', 111)
+ post_data = {'restrict_new_issues': 1}
+ with work_env.WorkEnv(self.mr, self.services) as we:
+ framework_helpers.UserSettings.ProcessSettingsForm(
+ we, post_data, user)
+ userprefs = we.GetUserPrefs(111)
+
+ actual = {upv.name: upv.value
+ for upv in userprefs.prefs}
+ expected = {
+ 'restrict_new_issues': 'true',
+ 'public_issue_notice': 'false',
+ }
+ self.assertEqual(expected, actual)
+
+
+class MurmurHash3Test(unittest.TestCase):
+
+ def testMurmurHash(self):
+ test_data = [
+ ('', 0),
+ ('agable@chromium.org', 4092810879),
+ (u'jrobbins@chromium.org', 904770043),
+ ('seanmccullough%google.com@gtempaccount.com', 1301269279),
+ ('rmistry+monorail@chromium.org', 4186878788),
+ ('jparent+foo@', 2923900874),
+ ('@example.com', 3043483168),
+ ]
+ hashes = [framework_helpers.MurmurHash3_x86_32(x)
+ for (x, _) in test_data]
+ self.assertListEqual(hashes, [e for (_, e) in test_data])
+
+ def testMurmurHashWithSeed(self):
+ test_data = [
+ ('', 1113155926, 2270882445),
+ ('agable@chromium.org', 772936925, 3995066671),
+ (u'jrobbins@chromium.org', 1519359761, 1273489513),
+ ('seanmccullough%google.com@gtempaccount.com', 49913829, 1202521153),
+ ('rmistry+monorail@chromium.org', 314860298, 3636123309),
+ ('jparent+foo@', 195791379, 332453977),
+ ('@example.com', 521490555, 257496459),
+ ]
+ hashes = [framework_helpers.MurmurHash3_x86_32(x, s)
+ for (x, s, _) in test_data]
+ self.assertListEqual(hashes, [e for (_, _, e) in test_data])
+
+
+class MakeRandomKeyTest(unittest.TestCase):
+
+ def testMakeRandomKey_Normal(self):
+ key1 = framework_helpers.MakeRandomKey()
+ key2 = framework_helpers.MakeRandomKey()
+ self.assertEqual(128, len(key1))
+ self.assertEqual(128, len(key2))
+ self.assertNotEqual(key1, key2)
+
+ def testMakeRandomKey_Length(self):
+ key = framework_helpers.MakeRandomKey()
+ self.assertEqual(128, len(key))
+ key16 = framework_helpers.MakeRandomKey(length=16)
+ self.assertEqual(16, len(key16))
+
+ def testMakeRandomKey_Chars(self):
+ key = framework_helpers.MakeRandomKey(chars='a', length=4)
+ self.assertEqual('aaaa', key)
+
+
+class IsServiceAccountTest(unittest.TestCase):
+
+ def testIsServiceAccount(self):
+ appspot = 'abc@appspot.gserviceaccount.com'
+ developer = '@developer.gserviceaccount.com'
+ bugdroid = 'bugdroid1@chromium.org'
+ user = 'test@example.com'
+
+ self.assertTrue(framework_helpers.IsServiceAccount(appspot))
+ self.assertTrue(framework_helpers.IsServiceAccount(developer))
+ self.assertTrue(framework_helpers.IsServiceAccount(bugdroid))
+ self.assertFalse(framework_helpers.IsServiceAccount(user))
+
+ client_emails = set([appspot, developer, bugdroid])
+ self.assertTrue(framework_helpers.IsServiceAccount(
+ appspot, client_emails=client_emails))
+ self.assertTrue(framework_helpers.IsServiceAccount(
+ developer, client_emails=client_emails))
+ self.assertTrue(framework_helpers.IsServiceAccount(
+ bugdroid, client_emails=client_emails))
+ self.assertFalse(framework_helpers.IsServiceAccount(
+ user, client_emails=client_emails))
diff --git a/framework/test/framework_views_test.py b/framework/test/framework_views_test.py
new file mode 100644
index 0000000..57f9fd1
--- /dev/null
+++ b/framework/test/framework_views_test.py
@@ -0,0 +1,326 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for framework_views classes."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import unittest
+
+from framework import framework_constants
+from framework import framework_views
+from framework import monorailrequest
+from proto import project_pb2
+from proto import tracker_pb2
+from proto import user_pb2
+import settings
+from services import service_manager
+from testing import fake
+
+
+LONG_STR = 'VeryLongStringThatCertainlyWillNotFit'
+LONG_PART_STR = 'OnePartThatWillNotFit-OneShort'
+
+
+class LabelViewTest(unittest.TestCase):
+
+ def testLabelView(self):
+ view = framework_views.LabelView('', None)
+ self.assertEqual('', view.name)
+
+ view = framework_views.LabelView('Priority-High', None)
+ self.assertEqual('Priority-High', view.name)
+ self.assertIsNone(view.is_restrict)
+ self.assertEqual('', view.docstring)
+ self.assertEqual('Priority', view.prefix)
+ self.assertEqual('High', view.value)
+
+ view = framework_views.LabelView('%s-%s' % (LONG_STR, LONG_STR), None)
+ self.assertEqual('%s-%s' % (LONG_STR, LONG_STR), view.name)
+ self.assertEqual('', view.docstring)
+ self.assertEqual(LONG_STR, view.prefix)
+ self.assertEqual(LONG_STR, view.value)
+
+ view = framework_views.LabelView(LONG_PART_STR, None)
+ self.assertEqual(LONG_PART_STR, view.name)
+ self.assertEqual('', view.docstring)
+ self.assertEqual('OnePartThatWillNotFit', view.prefix)
+ self.assertEqual('OneShort', view.value)
+
+ config = tracker_pb2.ProjectIssueConfig()
+ config.well_known_labels.append(tracker_pb2.LabelDef(
+ label='Priority-High', label_docstring='Must ship in this milestone'))
+
+ view = framework_views.LabelView('Priority-High', config)
+ self.assertEqual('Must ship in this milestone', view.docstring)
+
+ view = framework_views.LabelView('Priority-Foo', config)
+ self.assertEqual('', view.docstring)
+
+ view = framework_views.LabelView('Restrict-View-Commit', None)
+ self.assertTrue(view.is_restrict)
+
+
+class StatusViewTest(unittest.TestCase):
+
+ def testStatusView(self):
+ view = framework_views.StatusView('', None)
+ self.assertEqual('', view.name)
+
+ view = framework_views.StatusView('Accepted', None)
+ self.assertEqual('Accepted', view.name)
+ self.assertEqual('', view.docstring)
+ self.assertEqual('yes', view.means_open)
+
+ view = framework_views.StatusView(LONG_STR, None)
+ self.assertEqual(LONG_STR, view.name)
+ self.assertEqual('', view.docstring)
+ self.assertEqual('yes', view.means_open)
+
+ config = tracker_pb2.ProjectIssueConfig()
+ config.well_known_statuses.append(tracker_pb2.StatusDef(
+ status='SlamDunk', status_docstring='Code fixed and taught a lesson',
+ means_open=False))
+
+ view = framework_views.StatusView('SlamDunk', config)
+ self.assertEqual('Code fixed and taught a lesson', view.docstring)
+ self.assertFalse(view.means_open)
+
+ view = framework_views.StatusView('SlammedBack', config)
+ self.assertEqual('', view.docstring)
+
+
+class UserViewTest(unittest.TestCase):
+
+ def setUp(self):
+ self.user = user_pb2.User(user_id=111)
+
+ def testGetAvailablity_Anon(self):
+ self.user.user_id = 0
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual(None, user_view.avail_message)
+ self.assertEqual(None, user_view.avail_state)
+
+ def testGetAvailablity_Banned(self):
+ self.user.banned = 'spamming'
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('Banned', user_view.avail_message)
+ self.assertEqual('banned', user_view.avail_state)
+
+ def testGetAvailablity_Vacation(self):
+ self.user.vacation_message = 'gone fishing'
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('gone fishing', user_view.avail_message)
+ self.assertEqual('none', user_view.avail_state)
+
+ self.user.vacation_message = (
+ 'Gone fishing as really long time with lots of friends and reading '
+ 'a long novel by a famous author. I wont have internet access but '
+ 'If you urgently need anything you can call Alice or Bob for most '
+ 'things otherwise call Charlie. Wish me luck! ')
+ user_view = framework_views.UserView(self.user)
+ self.assertTrue(len(user_view.avail_message) >= 50)
+ self.assertTrue(len(user_view.avail_message_short) < 50)
+ self.assertEqual('none', user_view.avail_state)
+
+ def testGetAvailablity_Bouncing(self):
+ self.user.email_bounce_timestamp = 1234567890
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('Email to this user bounced', user_view.avail_message)
+ self.assertEqual(user_view.avail_message_short, user_view.avail_message)
+ self.assertEqual('none', user_view.avail_state)
+
+ def testGetAvailablity_Groups(self):
+ user_view = framework_views.UserView(self.user, is_group=True)
+ self.assertEqual(None, user_view.avail_message)
+ self.assertEqual(None, user_view.avail_state)
+
+ self.user.email = 'likely-user-group@example.com'
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual(None, user_view.avail_message)
+ self.assertEqual(None, user_view.avail_state)
+
+ def testGetAvailablity_NeverVisitied(self):
+ self.user.last_visit_timestamp = 0
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('User never visited', user_view.avail_message)
+ self.assertEqual('never', user_view.avail_state)
+
+ def testGetAvailablity_NotRecent(self):
+ now = int(time.time())
+ self.user.last_visit_timestamp = now - 20 * framework_constants.SECS_PER_DAY
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('Last visit 20 days ago', user_view.avail_message)
+ self.assertEqual('unsure', user_view.avail_state)
+
+ def testGetAvailablity_ReallyLongTime(self):
+ now = int(time.time())
+ self.user.last_visit_timestamp = now - 99 * framework_constants.SECS_PER_DAY
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('Last visit > 30 days ago', user_view.avail_message)
+ self.assertEqual('none', user_view.avail_state)
+
+ def testDeletedUser(self):
+ deleted_user = user_pb2.User(user_id=1)
+ user_view = framework_views.UserView(deleted_user)
+ self.assertEqual(
+ user_view.display_name, framework_constants.DELETED_USER_NAME)
+ self.assertEqual(user_view.email, '')
+ self.assertEqual(user_view.obscure_email, '')
+ self.assertEqual(user_view.profile_url, '')
+
+class RevealEmailsToMembersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.mr = monorailrequest.MonorailRequest(None)
+ self.mr.project = self.services.project.TestAddProject(
+ 'proj',
+ project_id=789,
+ owner_ids=[111],
+ committer_ids=[222],
+ contrib_ids=[333, 888])
+ user = self.services.user.TestAddUser('test@example.com', 1000)
+ self.mr.auth.user_pb = user
+
+ def CheckRevealAllToMember(
+ self, logged_in_user_id, expected, viewed_user_id=333, group_id=None):
+ user_view = framework_views.StuffUserView(
+ viewed_user_id, 'user@example.com', True)
+
+ if group_id:
+ pass # xxx re-implement groups
+
+ users_by_id = {333: user_view}
+ self.mr.auth.user_id = logged_in_user_id
+ self.mr.auth.effective_ids = {logged_in_user_id}
+ # Assert display name is obscured before the reveal.
+ self.assertEqual('u...@example.com', user_view.display_name)
+ # Assert profile url contains user ID before the reveal.
+ self.assertEqual('/u/%s/' % viewed_user_id, user_view.profile_url)
+ framework_views.RevealAllEmailsToMembers(
+ self.cnxn, self.services, self.mr.auth, users_by_id)
+ self.assertEqual(expected, not user_view.obscure_email)
+ if expected:
+ # Assert display name is now revealed.
+ self.assertEqual('user@example.com', user_view.display_name)
+ # Assert profile url contains the email.
+ self.assertEqual('/u/user@example.com/', user_view.profile_url)
+ else:
+ # Assert display name is still hidden.
+ self.assertEqual('u...@example.com', user_view.display_name)
+ # Assert profile url still contains user ID.
+ self.assertEqual('/u/%s/' % viewed_user_id, user_view.profile_url)
+
+ # TODO(https://crbug.com/monorail/8192): Remove this method and related test.
+ def DeprecatedCheckRevealAllToMember(
+ self, logged_in_user_id, expected, viewed_user_id=333, group_id=None):
+ user_view = framework_views.StuffUserView(
+ viewed_user_id, 'user@example.com', True)
+
+ if group_id:
+ pass # xxx re-implement groups
+
+ users_by_id = {333: user_view}
+ self.mr.auth.user_id = logged_in_user_id
+ self.mr.auth.effective_ids = {logged_in_user_id}
+ # Assert display name is obscured before the reveal.
+ self.assertEqual('u...@example.com', user_view.display_name)
+ # Assert profile url contains user ID before the reveal.
+ self.assertEqual('/u/%s/' % viewed_user_id, user_view.profile_url)
+ framework_views.RevealAllEmailsToMembers(
+ self.cnxn, self.services, self.mr.auth, users_by_id, self.mr.project)
+ self.assertEqual(expected, not user_view.obscure_email)
+ if expected:
+ # Assert display name is now revealed.
+ self.assertEqual('user@example.com', user_view.display_name)
+ # Assert profile url contains the email.
+ self.assertEqual('/u/user@example.com/', user_view.profile_url)
+ else:
+ # Assert display name is still hidden.
+ self.assertEqual('u...@example.com', user_view.display_name)
+ # Assert profile url still contains user ID.
+ self.assertEqual('/u/%s/' % viewed_user_id, user_view.profile_url)
+
+ def testDontRevealEmailsToPriviledgedDomain(self):
+ """We no longer give this advantage based on email address domain."""
+ for priviledged_user_domain in settings.priviledged_user_domains:
+ self.mr.auth.user_pb.email = 'test@' + priviledged_user_domain
+ self.CheckRevealAllToMember(100001, False)
+
+ def testRevealEmailToSelf(self):
+ logged_in_user = self.services.user.TestAddUser('user@example.com', 333)
+ self.mr.auth.user_pb = logged_in_user
+ self.CheckRevealAllToMember(333, True)
+
+ def testRevealAllEmailsToMembers_Collaborators(self):
+ self.CheckRevealAllToMember(0, False)
+ self.CheckRevealAllToMember(111, True)
+ self.CheckRevealAllToMember(222, True)
+ self.CheckRevealAllToMember(333, True)
+ self.CheckRevealAllToMember(444, False)
+
+ # Viewed user has indirect role in the project via a group.
+ self.CheckRevealAllToMember(0, False, group_id=888)
+ self.CheckRevealAllToMember(111, True, group_id=888)
+ # xxx re-implement
+ # self.CheckRevealAllToMember(
+ # 111, True, viewed_user_id=444, group_id=888)
+
+ # Logged in user has indirect role in the project via a group.
+ self.CheckRevealAllToMember(888, True)
+
+ def testDeprecatedRevealAllEmailsToMembers_Collaborators(self):
+ self.DeprecatedCheckRevealAllToMember(0, False)
+ self.DeprecatedCheckRevealAllToMember(111, True)
+ self.DeprecatedCheckRevealAllToMember(222, True)
+ self.DeprecatedCheckRevealAllToMember(333, True)
+ self.DeprecatedCheckRevealAllToMember(444, False)
+
+ # Viewed user has indirect role in the project via a group.
+ self.DeprecatedCheckRevealAllToMember(0, False, group_id=888)
+ self.DeprecatedCheckRevealAllToMember(111, True, group_id=888)
+
+ # Logged in user has indirect role in the project via a group.
+ self.DeprecatedCheckRevealAllToMember(888, True)
+
+ def testRevealAllEmailsToMembers_Admins(self):
+ self.CheckRevealAllToMember(555, False)
+ self.mr.auth.user_pb.is_site_admin = True
+ self.CheckRevealAllToMember(555, True)
+
+
+class RevealAllEmailsTest(unittest.TestCase):
+
+ def testRevealAllEmail(self):
+ users_by_id = {
+ 111: framework_views.StuffUserView(111, 'a@a.com', True),
+ 222: framework_views.StuffUserView(222, 'b@b.com', True),
+ 333: framework_views.StuffUserView(333, 'c@c.com', True),
+ 999: framework_views.StuffUserView(999, 'z@z.com', True),
+ }
+ # Assert display names are obscured before the reveal.
+ self.assertEqual('a...@a.com', users_by_id[111].display_name)
+ self.assertEqual('b...@b.com', users_by_id[222].display_name)
+ self.assertEqual('c...@c.com', users_by_id[333].display_name)
+ self.assertEqual('z...@z.com', users_by_id[999].display_name)
+
+ framework_views.RevealAllEmails(users_by_id)
+
+ self.assertFalse(users_by_id[111].obscure_email)
+ self.assertFalse(users_by_id[222].obscure_email)
+ self.assertFalse(users_by_id[333].obscure_email)
+ self.assertFalse(users_by_id[999].obscure_email)
+ # Assert display names are now revealed.
+ self.assertEqual('a@a.com', users_by_id[111].display_name)
+ self.assertEqual('b@b.com', users_by_id[222].display_name)
+ self.assertEqual('c@c.com', users_by_id[333].display_name)
+ self.assertEqual('z@z.com', users_by_id[999].display_name)
diff --git a/framework/test/gcs_helpers_test.py b/framework/test/gcs_helpers_test.py
new file mode 100644
index 0000000..3500e40
--- /dev/null
+++ b/framework/test/gcs_helpers_test.py
@@ -0,0 +1,185 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the framework_helpers module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+import uuid
+
+import mox
+
+from google.appengine.api import app_identity
+from google.appengine.api import images
+from google.appengine.api import urlfetch
+from google.appengine.ext import testbed
+from third_party import cloudstorage
+
+from framework import filecontent
+from framework import gcs_helpers
+from testing import fake
+from testing import testing_helpers
+
+
+class GcsHelpersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+ self.testbed.deactivate()
+
+ def testDeleteObjectFromGCS(self):
+ object_id = 'aaaaa'
+ bucket_name = 'test_bucket'
+ object_path = '/' + bucket_name + object_id
+
+ self.mox.StubOutWithMock(app_identity, 'get_default_gcs_bucket_name')
+ app_identity.get_default_gcs_bucket_name().AndReturn(bucket_name)
+
+ self.mox.StubOutWithMock(cloudstorage, 'delete')
+ cloudstorage.delete(object_path)
+
+ self.mox.ReplayAll()
+
+ gcs_helpers.DeleteObjectFromGCS(object_id)
+ self.mox.VerifyAll()
+
+ def testStoreObjectInGCS_ResizableMimeType(self):
+ guid = 'aaaaa'
+ project_id = 100
+ object_id = '/%s/attachments/%s' % (project_id, guid)
+ bucket_name = 'test_bucket'
+ object_path = '/' + bucket_name + object_id
+ mime_type = 'image/png'
+ content = 'content'
+ thumb_content = 'thumb_content'
+
+ self.mox.StubOutWithMock(app_identity, 'get_default_gcs_bucket_name')
+ app_identity.get_default_gcs_bucket_name().AndReturn(bucket_name)
+
+ self.mox.StubOutWithMock(uuid, 'uuid4')
+ uuid.uuid4().AndReturn(guid)
+
+ self.mox.StubOutWithMock(cloudstorage, 'open')
+ cloudstorage.open(
+ object_path, 'w', mime_type, options={}
+ ).AndReturn(fake.FakeFile())
+ cloudstorage.open(object_path + '-thumbnail', 'w', mime_type).AndReturn(
+ fake.FakeFile())
+
+ self.mox.StubOutWithMock(images, 'resize')
+ images.resize(content, gcs_helpers.DEFAULT_THUMB_WIDTH,
+ gcs_helpers.DEFAULT_THUMB_HEIGHT).AndReturn(thumb_content)
+
+ self.mox.ReplayAll()
+
+ ret_id = gcs_helpers.StoreObjectInGCS(
+ content, mime_type, project_id, gcs_helpers.DEFAULT_THUMB_WIDTH,
+ gcs_helpers.DEFAULT_THUMB_HEIGHT)
+ self.mox.VerifyAll()
+ self.assertEqual(object_id, ret_id)
+
+ def testStoreObjectInGCS_NotResizableMimeType(self):
+ guid = 'aaaaa'
+ project_id = 100
+ object_id = '/%s/attachments/%s' % (project_id, guid)
+ bucket_name = 'test_bucket'
+ object_path = '/' + bucket_name + object_id
+ mime_type = 'not_resizable_mime_type'
+ content = 'content'
+
+ self.mox.StubOutWithMock(app_identity, 'get_default_gcs_bucket_name')
+ app_identity.get_default_gcs_bucket_name().AndReturn(bucket_name)
+
+ self.mox.StubOutWithMock(uuid, 'uuid4')
+ uuid.uuid4().AndReturn(guid)
+
+ self.mox.StubOutWithMock(cloudstorage, 'open')
+ options = {'Content-Disposition': 'inline; filename="file.ext"'}
+ cloudstorage.open(
+ object_path, 'w', mime_type, options=options
+ ).AndReturn(fake.FakeFile())
+
+ self.mox.ReplayAll()
+
+ ret_id = gcs_helpers.StoreObjectInGCS(
+ content, mime_type, project_id, gcs_helpers.DEFAULT_THUMB_WIDTH,
+ gcs_helpers.DEFAULT_THUMB_HEIGHT, filename='file.ext')
+ self.mox.VerifyAll()
+ self.assertEqual(object_id, ret_id)
+
+ def testCheckMemeTypeResizable(self):
+ for resizable_mime_type in gcs_helpers.RESIZABLE_MIME_TYPES:
+ gcs_helpers.CheckMimeTypeResizable(resizable_mime_type)
+
+ with self.assertRaises(gcs_helpers.UnsupportedMimeType):
+ gcs_helpers.CheckMimeTypeResizable('not_resizable_mime_type')
+
+ def testStoreLogoInGCS(self):
+ file_name = 'test_file.png'
+ mime_type = 'image/png'
+ content = 'test content'
+ project_id = 100
+ object_id = 123
+
+ self.mox.StubOutWithMock(filecontent, 'GuessContentTypeFromFilename')
+ filecontent.GuessContentTypeFromFilename(file_name).AndReturn(mime_type)
+
+ self.mox.StubOutWithMock(gcs_helpers, 'StoreObjectInGCS')
+ gcs_helpers.StoreObjectInGCS(
+ content, mime_type, project_id,
+ thumb_width=gcs_helpers.LOGO_THUMB_WIDTH,
+ thumb_height=gcs_helpers.LOGO_THUMB_HEIGHT).AndReturn(object_id)
+
+ self.mox.ReplayAll()
+
+ ret_id = gcs_helpers.StoreLogoInGCS(file_name, content, project_id)
+ self.mox.VerifyAll()
+ self.assertEqual(object_id, ret_id)
+
+ @mock.patch('google.appengine.api.urlfetch.fetch')
+ def testFetchSignedURL_Success(self, mock_fetch):
+ mock_fetch.return_value = testing_helpers.Blank(
+ headers={'Location': 'signed url'})
+ actual = gcs_helpers._FetchSignedURL('signing req url')
+ mock_fetch.assert_called_with('signing req url', follow_redirects=False)
+ self.assertEqual('signed url', actual)
+
+ @mock.patch('google.appengine.api.urlfetch.fetch')
+ def testFetchSignedURL_UnderpopulatedResult(self, mock_fetch):
+ mock_fetch.return_value = testing_helpers.Blank(headers={})
+ self.assertRaises(
+ KeyError, gcs_helpers._FetchSignedURL, 'signing req url')
+
+ @mock.patch('google.appengine.api.urlfetch.fetch')
+ def testFetchSignedURL_DownloadError(self, mock_fetch):
+ mock_fetch.side_effect = urlfetch.DownloadError
+ self.assertRaises(
+ urlfetch.DownloadError,
+ gcs_helpers._FetchSignedURL, 'signing req url')
+
+ @mock.patch('framework.gcs_helpers._FetchSignedURL')
+ def testSignUrl_Success(self, mock_FetchSignedURL):
+ with mock.patch(
+ 'google.appengine.api.app_identity.get_access_token') as gat:
+ gat.return_value = ['token']
+ mock_FetchSignedURL.return_value = 'signed url'
+ signed_url = gcs_helpers.SignUrl('bucket', '/object')
+ self.assertEqual('signed url', signed_url)
+
+ @mock.patch('framework.gcs_helpers._FetchSignedURL')
+ def testSignUrl_DownloadError(self, mock_FetchSignedURL):
+ mock_FetchSignedURL.side_effect = urlfetch.DownloadError
+ self.assertEqual(
+ '/missing-gcs-url', gcs_helpers.SignUrl('bucket', '/object'))
diff --git a/framework/test/grid_view_helpers_test.py b/framework/test/grid_view_helpers_test.py
new file mode 100644
index 0000000..df3ecc6
--- /dev/null
+++ b/framework/test/grid_view_helpers_test.py
@@ -0,0 +1,201 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for grid_view_helpers classes and functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import framework_constants
+from framework import framework_views
+from framework import grid_view_helpers
+from proto import tracker_pb2
+from testing import fake
+from tracker import tracker_bizobj
+
+
+class GridViewHelpersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.default_cols = 'a b c'
+ self.builtin_cols = 'a b x y z'
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ self.art1 = fake.MakeTestIssue(
+ 789, 1, 'a summary', '', 0, derived_owner_id=111, star_count=12,
+ derived_labels='Priority-Medium Hot Mstone-1 Mstone-2',
+ derived_status='Overdue')
+ self.art2 = fake.MakeTestIssue(
+ 789, 1, 'a summary', 'New', 111, star_count=12, merged_into=200001,
+ labels='Priority-Medium Type-DEFECT Hot Mstone-1 Mstone-2')
+ self.users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', True),
+ }
+
+ def testSortGridHeadings(self):
+ config = fake.MakeTestConfig(
+ 789, labels=('Priority-High Priority-Medium Priority-Low Hot Cold '
+ 'Milestone-Near Milestone-Far '
+ 'Day-Sun Day-Mon Day-Tue Day-Wed Day-Thu Day-Fri Day-Sat'),
+ statuses=('New Accepted Started Fixed WontFix Invalid Duplicate'))
+ config.field_defs = [
+ tracker_pb2.FieldDef(field_id=1, project_id=789, field_name='Day',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE)]
+ asc_accessors = {
+ 'id': 'some function that is not called',
+ 'reporter': 'some function that is not called',
+ 'opened': 'some function that is not called',
+ 'modified': 'some function that is not called',
+ }
+
+ # Verify that status headings are sorted according to the status
+ # values defined in the config.
+ col_name = 'status'
+ headings = ['Duplicate', 'Limbo', 'New', 'OnHold', 'Accepted', 'Fixed']
+ sorted_headings = grid_view_helpers.SortGridHeadings(
+ col_name, headings, self.users_by_id, config, asc_accessors)
+ self.assertEqual(
+ sorted_headings,
+ ['New', 'Accepted', 'Fixed', 'Duplicate', 'Limbo', 'OnHold'])
+
+ # Verify that special columns are sorted alphabetically or numerically.
+ col_name = 'id'
+ headings = [1, 2, 5, 3, 4]
+ sorted_headings = grid_view_helpers.SortGridHeadings(
+ col_name, headings, self.users_by_id, config, asc_accessors)
+ self.assertEqual(sorted_headings,
+ [1, 2, 3, 4, 5])
+
+ # Verify that label value headings are sorted according to the labels
+ # values defined in the config.
+ col_name = 'priority'
+ headings = ['Medium', 'High', 'Low', 'dont-care']
+ sorted_headings = grid_view_helpers.SortGridHeadings(
+ col_name, headings, self.users_by_id, config, asc_accessors)
+ self.assertEqual(sorted_headings,
+ ['High', 'Medium', 'Low', 'dont-care'])
+
+ # Verify that enum headings are sorted according to the labels
+ # values defined in the config.
+ col_name = 'day'
+ headings = ['Tue', 'Fri', 'Sun', 'Dogday', 'Wed', 'Caturday', 'Low']
+ sorted_headings = grid_view_helpers.SortGridHeadings(
+ col_name, headings, self.users_by_id, config, asc_accessors)
+ self.assertEqual(sorted_headings,
+ ['Sun', 'Tue', 'Wed', 'Fri',
+ 'Caturday', 'Dogday', 'Low'])
+
+ def testGetArtifactAttr_Explicit(self):
+ label_values = grid_view_helpers.MakeLabelValuesDict(self.art2)
+
+ id_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'id', self.users_by_id, label_values, self.config, {})
+ self.assertEqual([1], id_vals)
+ summary_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'summary', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['a summary'], summary_vals)
+ status_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'status', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['New'], status_vals)
+ stars_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'stars', self.users_by_id, label_values, self.config, {})
+ self.assertEqual([12], stars_vals)
+ owner_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'owner', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['f...@example.com'], owner_vals)
+ priority_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'priority', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['Medium'], priority_vals)
+ mstone_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'mstone', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['1', '2'], mstone_vals)
+ foo_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'foo', self.users_by_id, label_values, self.config, {})
+ self.assertEqual([framework_constants.NO_VALUES], foo_vals)
+ art3 = fake.MakeTestIssue(
+ 987, 5, 'unecessary summary', 'New', 111, star_count=12,
+ issue_id=200001, project_name='other-project')
+ related_issues = {200001: art3}
+ merged_into_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'mergedinto', self.users_by_id, label_values,
+ self.config, related_issues)
+ self.assertEqual(['other-project:5'], merged_into_vals)
+
+ def testGetArtifactAttr_Derived(self):
+ label_values = grid_view_helpers.MakeLabelValuesDict(self.art1)
+ status_vals = grid_view_helpers.GetArtifactAttr(
+ self.art1, 'status', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['Overdue'], status_vals)
+ owner_vals = grid_view_helpers.GetArtifactAttr(
+ self.art1, 'owner', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['f...@example.com'], owner_vals)
+ priority_vals = grid_view_helpers.GetArtifactAttr(
+ self.art1, 'priority', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['Medium'], priority_vals)
+ mstone_vals = grid_view_helpers.GetArtifactAttr(
+ self.art1, 'mstone', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['1', '2'], mstone_vals)
+
+ def testMakeLabelValuesDict_Empty(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'a summary', '', 0, derived_owner_id=111, star_count=12)
+ label_values = grid_view_helpers.MakeLabelValuesDict(art)
+ self.assertEqual({}, label_values)
+
+ def testMakeLabelValuesDict(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'a summary', '', 0, derived_owner_id=111, star_count=12,
+ labels=['Priority-Medium', 'Hot', 'Mstone-1', 'Mstone-2'])
+ label_values = grid_view_helpers.MakeLabelValuesDict(art)
+ self.assertEqual(
+ {'priority': ['Medium'], 'mstone': ['1', '2']},
+ label_values)
+
+ art = fake.MakeTestIssue(
+ 789, 1, 'a summary', '', 0, derived_owner_id=111, star_count=12,
+ labels='Priority-Medium Hot Mstone-1'.split(),
+ derived_labels=['Mstone-2'])
+ label_values = grid_view_helpers.MakeLabelValuesDict(art)
+ self.assertEqual(
+ {'priority': ['Medium'], 'mstone': ['1', '2']},
+ label_values)
+
+ def testMakeDrillDownSearch(self):
+ self.assertEqual('-has:milestone ',
+ grid_view_helpers.MakeDrillDownSearch('milestone', '----'))
+ self.assertEqual('milestone=22 ',
+ grid_view_helpers.MakeDrillDownSearch('milestone', '22'))
+ self.assertEqual(
+ 'owner=a@example.com ',
+ grid_view_helpers.MakeDrillDownSearch('owner', 'a@example.com'))
+
+ def testAnyArtifactHasNoAttr_Empty(self):
+ artifacts = []
+ all_label_values = {}
+ self.assertFalse(grid_view_helpers.AnyArtifactHasNoAttr(
+ artifacts, 'milestone', self.users_by_id, all_label_values,
+ self.config, {}))
+
+ def testAnyArtifactHasNoAttr(self):
+ artifacts = [self.art1]
+ all_label_values = {
+ self.art1.local_id: grid_view_helpers.MakeLabelValuesDict(self.art1),
+ }
+ self.assertFalse(grid_view_helpers.AnyArtifactHasNoAttr(
+ artifacts, 'mstone', self.users_by_id, all_label_values,
+ self.config, {}))
+ self.assertTrue(grid_view_helpers.AnyArtifactHasNoAttr(
+ artifacts, 'milestone', self.users_by_id, all_label_values,
+ self.config, {}))
+
+ def testGetGridViewData(self):
+ # TODO(jojwang): write this test
+ pass
+
+ def testPrepareForMakeGridData(self):
+ # TODO(jojwang): write this test
+ pass
diff --git a/framework/test/jsonfeed_test.py b/framework/test/jsonfeed_test.py
new file mode 100644
index 0000000..0a569e2
--- /dev/null
+++ b/framework/test/jsonfeed_test.py
@@ -0,0 +1,141 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for jsonfeed module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import httplib
+import logging
+import unittest
+
+from google.appengine.api import app_identity
+
+from framework import jsonfeed
+from framework import servlet
+from framework import xsrf
+from services import service_manager
+from testing import testing_helpers
+
+
+class JsonFeedTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = 'fake cnxn'
+
+ def testGet(self):
+ """Tests handling of GET requests."""
+ feed = TestableJsonFeed()
+
+ # all expected args are present + a bonus arg that should be ignored
+ feed.mr = testing_helpers.MakeMonorailRequest(
+ path='/foo/bar/wee?sna=foo', method='POST',
+ params={'a': '123', 'z': 'zebra'})
+ feed.get()
+
+ self.assertEqual(True, feed.handle_request_called)
+ self.assertEqual(1, len(feed.json_data))
+
+ def testPost(self):
+ """Tests handling of POST requests."""
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest(
+ path='/foo/bar/wee?sna=foo', method='POST',
+ params={'a': '123', 'z': 'zebra'})
+
+ feed.post()
+
+ self.assertEqual(True, feed.handle_request_called)
+ self.assertEqual(1, len(feed.json_data))
+
+ def testSecurityTokenChecked_BadToken(self):
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest(
+ user_info={'user_id': 555})
+ # Note that feed.mr has no token set.
+ self.assertRaises(xsrf.TokenIncorrect, feed.get)
+ self.assertRaises(xsrf.TokenIncorrect, feed.post)
+
+ feed.mr.token = 'bad token'
+ self.assertRaises(xsrf.TokenIncorrect, feed.get)
+ self.assertRaises(xsrf.TokenIncorrect, feed.post)
+
+ def testSecurityTokenChecked_HandlerDoesNotNeedToken(self):
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest(
+ user_info={'user_id': 555})
+ # Note that feed.mr has no token set.
+ feed.CHECK_SECURITY_TOKEN = False
+ feed.get()
+ feed.post()
+
+ def testSecurityTokenChecked_AnonUserDoesNotNeedToken(self):
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ # Note that feed.mr has no token set, but also no auth.user_id.
+ feed.get()
+ feed.post()
+
+ def testSameAppOnly_ExternallyAccessible(self):
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ # Note that request has no X-Appengine-Inbound-Appid set.
+ feed.get()
+ feed.post()
+
+ def testSameAppOnly_InternalOnlyCalledFromSameApp(self):
+ feed = TestableJsonFeed()
+ feed.CHECK_SAME_APP = True
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ app_id = app_identity.get_application_id()
+ feed.mr.request.headers['X-Appengine-Inbound-Appid'] = app_id
+ feed.get()
+ feed.post()
+
+ def testSameAppOnly_InternalOnlyCalledExternally(self):
+ feed = TestableJsonFeed()
+ feed.CHECK_SAME_APP = True
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ # Note that request has no X-Appengine-Inbound-Appid set.
+ self.assertIsNone(feed.get())
+ self.assertFalse(feed.handle_request_called)
+ self.assertEqual(httplib.FORBIDDEN, feed.response.status)
+ self.assertIsNone(feed.post())
+ self.assertFalse(feed.handle_request_called)
+ self.assertEqual(httplib.FORBIDDEN, feed.response.status)
+
+ def testSameAppOnly_InternalOnlyCalledFromWrongApp(self):
+ feed = TestableJsonFeed()
+ feed.CHECK_SAME_APP = True
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ feed.mr.request.headers['X-Appengine-Inbound-Appid'] = 'wrong'
+ self.assertIsNone(feed.get())
+ self.assertFalse(feed.handle_request_called)
+ self.assertEqual(httplib.FORBIDDEN, feed.response.status)
+ self.assertIsNone(feed.post())
+ self.assertFalse(feed.handle_request_called)
+ self.assertEqual(httplib.FORBIDDEN, feed.response.status)
+
+
+class TestableJsonFeed(jsonfeed.JsonFeed):
+
+ def __init__(self, request=None):
+ response = testing_helpers.Blank()
+ super(TestableJsonFeed, self).__init__(
+ request or 'req', response, services=service_manager.Services())
+
+ self.response_data = None
+ self.handle_request_called = False
+ self.json_data = None
+
+ def HandleRequest(self, mr):
+ self.handle_request_called = True
+ return {'a': mr.GetParam('a')}
+
+ # The output chain is hard to double so we pass on that phase,
+ # but save the response data for inspection
+ def _RenderJsonResponse(self, json_data):
+ self.json_data = json_data
diff --git a/framework/test/monitoring_test.py b/framework/test/monitoring_test.py
new file mode 100644
index 0000000..edbd15d
--- /dev/null
+++ b/framework/test/monitoring_test.py
@@ -0,0 +1,86 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+"""Unit tests for the monitoring module."""
+
+import unittest
+from framework import monitoring
+
+COMMON_TEST_FIELDS = monitoring.GetCommonFields(200, 'monorail.v3.MethodName')
+
+
+class MonitoringTest(unittest.TestCase):
+
+ def testIncrementAPIRequestsCount(self):
+ # Non-service account email gets hidden.
+ monitoring.IncrementAPIRequestsCount(
+ 'v3', 'monorail-prod', client_email='client-email@chicken.com')
+ self.assertEqual(
+ 1,
+ monitoring.API_REQUESTS_COUNT.get(
+ fields={
+ 'client_id': 'monorail-prod',
+ 'client_email': 'user@email.com',
+ 'version': 'v3'
+ }))
+
+ # None email address gets replaced by 'anonymous'.
+ monitoring.IncrementAPIRequestsCount('v3', 'monorail-prod')
+ self.assertEqual(
+ 1,
+ monitoring.API_REQUESTS_COUNT.get(
+ fields={
+ 'client_id': 'monorail-prod',
+ 'client_email': 'anonymous',
+ 'version': 'v3'
+ }))
+
+ # Service account email is not hidden
+ monitoring.IncrementAPIRequestsCount(
+ 'endpoints',
+ 'monorail-prod',
+ client_email='123456789@developer.gserviceaccount.com')
+ self.assertEqual(
+ 1,
+ monitoring.API_REQUESTS_COUNT.get(
+ fields={
+ 'client_id': 'monorail-prod',
+ 'client_email': '123456789@developer.gserviceaccount.com',
+ 'version': 'endpoints'
+ }))
+
+ def testGetCommonFields(self):
+ fields = monitoring.GetCommonFields(200, 'monorail.v3.TestName')
+ self.assertEqual(
+ {
+ 'status': 200,
+ 'name': 'monorail.v3.TestName',
+ 'is_robot': False
+ }, fields)
+
+ def testAddServerDurations(self):
+ self.assertIsNone(
+ monitoring.SERVER_DURATIONS.get(fields=COMMON_TEST_FIELDS))
+ monitoring.AddServerDurations(500, COMMON_TEST_FIELDS)
+ self.assertIsNotNone(
+ monitoring.SERVER_DURATIONS.get(fields=COMMON_TEST_FIELDS))
+
+ def testIncrementServerResponseStatusCount(self):
+ monitoring.IncrementServerResponseStatusCount(COMMON_TEST_FIELDS)
+ self.assertEqual(
+ 1, monitoring.SERVER_RESPONSE_STATUS.get(fields=COMMON_TEST_FIELDS))
+
+ def testAddServerRequesteBytes(self):
+ self.assertIsNone(
+ monitoring.SERVER_REQUEST_BYTES.get(fields=COMMON_TEST_FIELDS))
+ monitoring.AddServerRequesteBytes(1234, COMMON_TEST_FIELDS)
+ self.assertIsNotNone(
+ monitoring.SERVER_REQUEST_BYTES.get(fields=COMMON_TEST_FIELDS))
+
+ def testAddServerResponseBytes(self):
+ self.assertIsNone(
+ monitoring.SERVER_RESPONSE_BYTES.get(fields=COMMON_TEST_FIELDS))
+ monitoring.AddServerResponseBytes(9876, COMMON_TEST_FIELDS)
+ self.assertIsNotNone(
+ monitoring.SERVER_RESPONSE_BYTES.get(fields=COMMON_TEST_FIELDS))
diff --git a/framework/test/monorailcontext_test.py b/framework/test/monorailcontext_test.py
new file mode 100644
index 0000000..ed93920
--- /dev/null
+++ b/framework/test/monorailcontext_test.py
@@ -0,0 +1,89 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for MonorailContext."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import mox
+
+from framework import authdata
+from framework import monorailcontext
+from framework import permissions
+from framework import profiler
+from framework import template_helpers
+from framework import sql
+from services import service_manager
+from testing import fake
+
+
+class MonorailContextTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService(),
+ project=fake.ProjectService())
+ self.project = self.services.project.TestAddProject(
+ 'proj', project_id=789, owner_ids=[111])
+ self.user = self.services.user.TestAddUser('owner@example.com', 111)
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testConstructor_PassingAuthAndPerms(self):
+ """We can easily make an mc for testing."""
+ auth = authdata.AuthData(user_id=111, email='owner@example.com')
+ mc = monorailcontext.MonorailContext(
+ None, cnxn=self.cnxn, auth=auth, perms=permissions.USER_PERMISSIONSET)
+ self.assertEqual(self.cnxn, mc.cnxn)
+ self.assertEqual(auth, mc.auth)
+ self.assertEqual(permissions.USER_PERMISSIONSET, mc.perms)
+ self.assertTrue(isinstance(mc.profiler, profiler.Profiler))
+ self.assertEqual([], mc.warnings)
+ self.assertTrue(isinstance(mc.errors, template_helpers.EZTError))
+
+ mc.CleanUp()
+ self.assertIsNone(mc.cnxn)
+
+ def testConstructor_AsUsedInApp(self):
+ """We can make an mc like it is done in the app or a test."""
+ self.mox.StubOutClassWithMocks(sql, 'MonorailConnection')
+ mock_cnxn = sql.MonorailConnection()
+ mock_cnxn.Close()
+ requester = 'new-user@example.com'
+ self.mox.ReplayAll()
+
+ mc = monorailcontext.MonorailContext(self.services, requester=requester)
+ mc.LookupLoggedInUserPerms(self.project)
+ self.assertEqual(mock_cnxn, mc.cnxn)
+ self.assertEqual(requester, mc.auth.email)
+ self.assertEqual(permissions.USER_PERMISSIONSET, mc.perms)
+ self.assertTrue(isinstance(mc.profiler, profiler.Profiler))
+ self.assertEqual([], mc.warnings)
+ self.assertTrue(isinstance(mc.errors, template_helpers.EZTError))
+
+ mc.CleanUp()
+ self.assertIsNone(mc.cnxn)
+
+ # Double Cleanup or Cleanup with no cnxn is not a crash.
+ mc.CleanUp()
+ self.assertIsNone(mc.cnxn)
+
+ def testRepr(self):
+ """We get nice debugging strings."""
+ auth = authdata.AuthData(user_id=111, email='owner@example.com')
+ mc = monorailcontext.MonorailContext(
+ None, cnxn=self.cnxn, auth=auth, perms=permissions.USER_PERMISSIONSET)
+ repr_str = '%r' % mc
+ self.assertTrue(repr_str.startswith('MonorailContext('))
+ self.assertIn('owner@example.com', repr_str)
+ self.assertIn('view', repr_str)
diff --git a/framework/test/monorailrequest_test.py b/framework/test/monorailrequest_test.py
new file mode 100644
index 0000000..fcd30c3
--- /dev/null
+++ b/framework/test/monorailrequest_test.py
@@ -0,0 +1,613 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the monorailrequest module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import endpoints
+import mock
+import re
+import unittest
+
+import mox
+import six
+
+from google.appengine.api import oauth
+from google.appengine.api import users
+
+import webapp2
+
+from framework import exceptions
+from framework import monorailrequest
+from framework import permissions
+from proto import project_pb2
+from proto import tracker_pb2
+from services import service_manager
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_constants
+
+
+class HostportReTest(unittest.TestCase):
+
+ def testGood(self):
+ test_data = [
+ 'localhost:8080',
+ 'app.appspot.com',
+ 'bugs-staging.chromium.org',
+ 'vers10n-h3x-dot-app-id.appspot.com',
+ ]
+ for hostport in test_data:
+ self.assertTrue(monorailrequest._HOSTPORT_RE.match(hostport),
+ msg='Incorrectly rejected %r' % hostport)
+
+ def testBad(self):
+ test_data = [
+ '',
+ ' ',
+ '\t',
+ '\n',
+ '\'',
+ '"',
+ 'version"cruft-dot-app-id.appspot.com',
+ '\nother header',
+ 'version&cruft-dot-app-id.appspot.com',
+ ]
+ for hostport in test_data:
+ self.assertFalse(monorailrequest._HOSTPORT_RE.match(hostport),
+ msg='Incorrectly accepted %r' % hostport)
+
+
+class MonorailApiRequestUnitTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = 'fake cnxn'
+ self.services = service_manager.Services(
+ config=fake.ConfigService(),
+ issue=fake.IssueService(),
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.project = self.services.project.TestAddProject(
+ 'proj', project_id=789)
+ self.services.user.TestAddUser('requester@example.com', 111)
+ self.issue = fake.MakeTestIssue(
+ 789, 1, 'sum', 'New', 111)
+ self.services.issue.TestAddIssue(self.issue)
+
+ self.patcher_1 = mock.patch('endpoints.get_current_user')
+ self.mock_endpoints_gcu = self.patcher_1.start()
+ self.mock_endpoints_gcu.return_value = None
+ self.patcher_2 = mock.patch('google.appengine.api.oauth.get_current_user')
+ self.mock_oauth_gcu = self.patcher_2.start()
+ self.mock_oauth_gcu.return_value = testing_helpers.Blank(
+ email=lambda: 'requester@example.com')
+
+ def tearDown(self):
+ mock.patch.stopall()
+
+ def testInit_NoProjectIssueOrViewedUser(self):
+ request = testing_helpers.Blank()
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertIsNone(mar.project)
+ self.assertIsNone(mar.issue)
+
+ def testInit_WithProject(self):
+ request = testing_helpers.Blank(projectId='proj')
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertEqual(self.project, mar.project)
+ self.assertIsNone(mar.issue)
+
+ def testInit_WithProjectAndIssue(self):
+ request = testing_helpers.Blank(
+ projectId='proj', issueId=1)
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertEqual(self.project, mar.project)
+ self.assertEqual(self.issue, mar.issue)
+
+ def testGetParam_Normal(self):
+ request = testing_helpers.Blank(q='owner:me')
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertEqual(None, mar.GetParam('unknown'))
+ self.assertEqual(100, mar.GetParam('num'))
+ self.assertEqual('owner:me', mar.GetParam('q'))
+
+ request = testing_helpers.Blank(q='owner:me', maxResults=200)
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertEqual(200, mar.GetParam('num'))
+
+
+class MonorailRequestUnitTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService(),
+ features=fake.FeaturesService())
+ self.project = self.services.project.TestAddProject('proj')
+ self.hotlist = self.services.features.TestAddHotlist(
+ 'TestHotlist', owner_ids=[111])
+ self.services.user.TestAddUser('jrobbins@example.com', 111)
+
+ self.mox = mox.Mox()
+ self.mox.StubOutWithMock(users, 'get_current_user')
+ users.get_current_user().AndReturn(None)
+ self.mox.ReplayAll()
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+
+ def testGetIntParam_ConvertsQueryParamToInt(self):
+ notice_id = 12345
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/foo?notice=%s' % notice_id)
+
+ value = mr.GetIntParam('notice')
+ self.assertTrue(isinstance(value, int))
+ self.assertEqual(notice_id, value)
+
+ def testGetIntParam_ConvertsQueryParamToLong(self):
+ notice_id = 12345678901234567890
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/foo?notice=%s' % notice_id)
+
+ value = mr.GetIntParam('notice')
+ self.assertTrue(isinstance(value, six.integer_types))
+ self.assertEqual(notice_id, value)
+
+ def testGetIntListParam_NoParam(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.ParseRequest(webapp2.Request.blank('servlet'), self.services)
+ self.assertEqual(mr.GetIntListParam('ids'), None)
+ self.assertEqual(mr.GetIntListParam('ids', default_value=['test']),
+ ['test'])
+
+ def testGetIntListParam_OneValue(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.ParseRequest(webapp2.Request.blank('servlet?ids=11'), self.services)
+ self.assertEqual(mr.GetIntListParam('ids'), [11])
+ self.assertEqual(mr.GetIntListParam('ids', default_value=['test']),
+ [11])
+
+ def testGetIntListParam_MultiValue(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.ParseRequest(
+ webapp2.Request.blank('servlet?ids=21,22,23'), self.services)
+ self.assertEqual(mr.GetIntListParam('ids'), [21, 22, 23])
+ self.assertEqual(mr.GetIntListParam('ids', default_value=['test']),
+ [21, 22, 23])
+
+ def testGetIntListParam_BogusValue(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ with self.assertRaises(exceptions.InputException):
+ mr.ParseRequest(
+ webapp2.Request.blank('servlet?ids=not_an_int'), self.services)
+
+ def testGetIntListParam_Malformed(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ with self.assertRaises(exceptions.InputException):
+ mr.ParseRequest(
+ webapp2.Request.blank('servlet?ids=31,32,,'), self.services)
+
+ def testDefaultValuesNoUrl(self):
+ """If request has no param, default param values should be used."""
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.ParseRequest(webapp2.Request.blank('servlet'), self.services)
+ self.assertEqual(mr.GetParam('r', 3), 3)
+ self.assertEqual(mr.GetIntParam('r', 3), 3)
+ self.assertEqual(mr.GetPositiveIntParam('r', 3), 3)
+ self.assertEqual(mr.GetIntListParam('r', [3, 4]), [3, 4])
+
+ def _MRWithMockRequest(
+ self, path, headers=None, *mr_args, **mr_kwargs):
+ request = webapp2.Request.blank(path, headers=headers)
+ mr = monorailrequest.MonorailRequest(self.services, *mr_args, **mr_kwargs)
+ mr.ParseRequest(request, self.services)
+ return mr
+
+ def testParseQueryParameters(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50')
+ self.assertEqual('foo OR bar', mr.query)
+ self.assertEqual(50, mr.num)
+
+ def testParseQueryParameters_ModeMissing(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50')
+ self.assertEqual('list', mr.mode)
+
+ def testParseQueryParameters_ModeList(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50&mode=')
+ self.assertEqual('list', mr.mode)
+
+ def testParseQueryParameters_ModeGrid(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50&mode=grid')
+ self.assertEqual('grid', mr.mode)
+
+ def testParseQueryParameters_ModeChart(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50&mode=chart')
+ self.assertEqual('chart', mr.mode)
+
+ def testParseRequest_Scheme(self):
+ mr = self._MRWithMockRequest('/p/proj/')
+ self.assertEqual('http', mr.request.scheme)
+
+ def testParseRequest_HostportAndCurrentPageURL(self):
+ mr = self._MRWithMockRequest('/p/proj/', headers={
+ 'Host': 'example.com',
+ 'Cookie': 'asdf',
+ })
+ self.assertEqual('http', mr.request.scheme)
+ self.assertEqual('example.com', mr.request.host)
+ self.assertEqual('http://example.com/p/proj/', mr.current_page_url)
+
+ def testParseRequest_ProjectFound(self):
+ mr = self._MRWithMockRequest('/p/proj/')
+ self.assertEqual(mr.project, self.project)
+
+ def testParseRequest_ProjectNotFound(self):
+ with self.assertRaises(exceptions.NoSuchProjectException):
+ self._MRWithMockRequest('/p/no-such-proj/')
+
+ def testViewedUser_WithEmail(self):
+ mr = self._MRWithMockRequest('/u/jrobbins@example.com/')
+ self.assertEqual('jrobbins@example.com', mr.viewed_username)
+ self.assertEqual(111, mr.viewed_user_auth.user_id)
+ self.assertEqual(
+ self.services.user.GetUser('fake cnxn', 111),
+ mr.viewed_user_auth.user_pb)
+
+ def testViewedUser_WithUserID(self):
+ mr = self._MRWithMockRequest('/u/111/')
+ self.assertEqual('jrobbins@example.com', mr.viewed_username)
+ self.assertEqual(111, mr.viewed_user_auth.user_id)
+ self.assertEqual(
+ self.services.user.GetUser('fake cnxn', 111),
+ mr.viewed_user_auth.user_pb)
+
+ def testViewedUser_NoSuchEmail(self):
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self._MRWithMockRequest('/u/unknownuser@example.com/')
+ self.assertEqual(404, cm.exception.code)
+
+ def testViewedUser_NoSuchUserID(self):
+ with self.assertRaises(exceptions.NoSuchUserException):
+ self._MRWithMockRequest('/u/234521111/')
+
+ def testGetParam(self):
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/foo?syn=error!&a=a&empty=',
+ params=dict(over1='over_value1', over2='over_value2'))
+
+ # test tampering
+ self.assertRaises(exceptions.InputException, mr.GetParam, 'a',
+ antitamper_re=re.compile(r'^$'))
+ self.assertRaises(exceptions.InputException, mr.GetParam,
+ 'undefined', default_value='default',
+ antitamper_re=re.compile(r'^$'))
+
+ # test empty value
+ self.assertEqual('', mr.GetParam(
+ 'empty', default_value='default', antitamper_re=re.compile(r'^$')))
+
+ # test default
+ self.assertEqual('default', mr.GetParam(
+ 'undefined', default_value='default'))
+
+ def testComputeColSpec(self):
+ # No config passed, and nothing in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123')
+ mr.ComputeColSpec(None)
+ self.assertEqual(tracker_constants.DEFAULT_COL_SPEC, mr.col_spec)
+
+ # No config passed, but set in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123&colspec=a b C')
+ mr.ComputeColSpec(None)
+ self.assertEqual('a b C', mr.col_spec)
+
+ config = tracker_pb2.ProjectIssueConfig()
+
+ # No default in the config, and nothing in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123')
+ mr.ComputeColSpec(config)
+ self.assertEqual(tracker_constants.DEFAULT_COL_SPEC, mr.col_spec)
+
+ # No default in the config, but set in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123&colspec=a b C')
+ mr.ComputeColSpec(config)
+ self.assertEqual('a b C', mr.col_spec)
+
+ config.default_col_spec = 'd e f'
+
+ # Default in the config, and nothing in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123')
+ mr.ComputeColSpec(config)
+ self.assertEqual('d e f', mr.col_spec)
+
+ # Default in the config, but overrided via URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123&colspec=a b C')
+ mr.ComputeColSpec(config)
+ self.assertEqual('a b C', mr.col_spec)
+
+ # project colspec contains hotlist columns
+ mr = testing_helpers.MakeMonorailRequest(
+ path='p/proj/issues/detail?id=123&colspec=Rank Adder Adder Owner')
+ mr.ComputeColSpec(None)
+ self.assertEqual(tracker_constants.DEFAULT_COL_SPEC, mr.col_spec)
+
+ # hotlist columns are not deleted when page is a hotlist page
+ mr = testing_helpers.MakeMonorailRequest(
+ path='u/jrobbins@example.com/hotlists/TestHotlist?colspec=Rank Adder',
+ hotlist=self.hotlist)
+ mr.ComputeColSpec(None)
+ self.assertEqual('Rank Adder', mr.col_spec)
+
+ def testComputeColSpec_XSS(self):
+ config_1 = tracker_pb2.ProjectIssueConfig()
+ config_2 = tracker_pb2.ProjectIssueConfig()
+ config_2.default_col_spec = "id '+alert(1)+'"
+ mr_1 = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123')
+ mr_2 = testing_helpers.MakeMonorailRequest(
+ path="/p/proj/issues/detail?id=123&colspec=id '+alert(1)+'")
+
+ # Normal colspec in config but malicious request
+ self.assertRaises(
+ exceptions.InputException,
+ mr_2.ComputeColSpec, config_1)
+
+ # Malicious colspec in config but normal request
+ self.assertRaises(
+ exceptions.InputException,
+ mr_1.ComputeColSpec, config_2)
+
+ # Malicious colspec in config and malicious request
+ self.assertRaises(
+ exceptions.InputException,
+ mr_2.ComputeColSpec, config_2)
+
+
+class CalcDefaultQueryTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project = project_pb2.Project()
+ self.project.project_name = 'proj'
+ self.project.owner_ids = [111]
+ self.config = tracker_pb2.ProjectIssueConfig()
+
+ def testIssueListURL_NotDefaultCan(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 1
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_NoProject(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_NoConfig(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ mr.project = self.project
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_NotCustomized(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ mr.project = self.project
+ mr.config = self.config
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_Customized_Nonmember(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ mr.project = self.project
+ mr.config = self.config
+ mr.config.member_default_query = 'owner:me'
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ mr.auth = testing_helpers.Blank(effective_ids=set())
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ mr.auth = testing_helpers.Blank(effective_ids={999})
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_Customized_Member(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ mr.project = self.project
+ mr.config = self.config
+ mr.config.member_default_query = 'owner:me'
+ mr.auth = testing_helpers.Blank(effective_ids={111})
+ self.assertEqual('owner:me', mr._CalcDefaultQuery())
+
+
+class TestMonorailRequestFunctions(unittest.TestCase):
+
+ def testExtractPathIdentifiers_ProjectOnly(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/p/proj/issues/list?q=foo+OR+bar&ts=1234')
+ self.assertIsNone(username)
+ self.assertIsNone(hotlist_id)
+ self.assertIsNone(hotlist_name)
+ self.assertEqual('proj', project_name)
+
+ def testExtractPathIdentifiers_ViewedUserOnly(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/u/jrobbins@example.com/')
+ self.assertEqual('jrobbins@example.com', username)
+ self.assertIsNone(project_name)
+ self.assertIsNone(hotlist_id)
+ self.assertIsNone(hotlist_name)
+
+ def testExtractPathIdentifiers_ViewedUserURLSpace(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/u/jrobbins@example.com/updates')
+ self.assertEqual('jrobbins@example.com', username)
+ self.assertIsNone(project_name)
+ self.assertIsNone(hotlist_id)
+ self.assertIsNone(hotlist_name)
+
+ def testExtractPathIdentifiers_ViewedGroupURLSpace(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/g/user-group@example.com/updates')
+ self.assertEqual('user-group@example.com', username)
+ self.assertIsNone(project_name)
+ self.assertIsNone(hotlist_id)
+ self.assertIsNone(hotlist_name)
+
+ def testExtractPathIdentifiers_HotlistIssuesURLSpaceById(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/u/jrobbins@example.com/hotlists/13124?q=stuff&ts=more')
+ self.assertIsNone(hotlist_name)
+ self.assertIsNone(project_name)
+ self.assertEqual('jrobbins@example.com', username)
+ self.assertEqual(13124, hotlist_id)
+
+ def testExtractPathIdentifiers_HotlistIssuesURLSpaceByName(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/u/jrobbins@example.com/hotlists/testname?q=stuff&ts=more')
+ self.assertIsNone(project_name)
+ self.assertIsNone(hotlist_id)
+ self.assertEqual('jrobbins@example.com', username)
+ self.assertEqual('testname', hotlist_name)
+
+ def testParseColSpec(self):
+ parse = monorailrequest.ParseColSpec
+ self.assertEqual(['PageName', 'Summary', 'Changed', 'ChangedBy'],
+ parse(u'PageName Summary Changed ChangedBy'))
+ self.assertEqual(['Foo-Bar', 'Foo-Bar-Baz', 'Release-1.2', 'Hey', 'There'],
+ parse('Foo-Bar Foo-Bar-Baz Release-1.2 Hey!There'))
+ self.assertEqual(
+ ['\xe7\xaa\xbf\xe8\x8b\xa5\xe7\xb9\xb9'.decode('utf-8'),
+ '\xe5\x9f\xba\xe5\x9c\xb0\xe3\x81\xaf'.decode('utf-8')],
+ parse('\xe7\xaa\xbf\xe8\x8b\xa5\xe7\xb9\xb9 '
+ '\xe5\x9f\xba\xe5\x9c\xb0\xe3\x81\xaf'.decode('utf-8')))
+
+ def testParseColSpec_Dedup(self):
+ """An attacker cannot inflate response size by repeating a column."""
+ parse = monorailrequest.ParseColSpec
+ self.assertEqual([], parse(''))
+ self.assertEqual(
+ ['Aa', 'b', 'c/d'],
+ parse(u'Aa Aa AA AA AA b Aa aa c/d d c aA b aa B C/D D/aa/c'))
+ self.assertEqual(
+ ['A', 'b', 'c/d', 'e', 'f'],
+ parse(u'A b c/d e f g h i j a/k l m/c/a n/o'))
+
+ def testParseColSpec_Huge(self):
+ """An attacker cannot inflate response size with a huge column name."""
+ parse = monorailrequest.ParseColSpec
+ self.assertEqual(
+ ['Aa', 'b', 'c/d'],
+ parse(u'Aa Aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa b c/d'))
+
+ def testParseColSpec_Ignore(self):
+ """We ignore groupby and grid axes that would be useless."""
+ parse = monorailrequest.ParseColSpec
+ self.assertEqual(
+ ['Aa', 'b', 'c/d'],
+ parse(u'Aa AllLabels alllabels Id b opened/summary c/d',
+ ignore=tracker_constants.NOT_USED_IN_GRID_AXES))
+
+
+class TestPermissionLookup(unittest.TestCase):
+ OWNER_ID = 1
+ OTHER_USER_ID = 2
+
+ def setUp(self):
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.services.user.TestAddUser('owner@gmail.com', self.OWNER_ID)
+ self.services.user.TestAddUser('user@gmail.com', self.OTHER_USER_ID)
+ self.live_project = self.services.project.TestAddProject(
+ 'live', owner_ids=[self.OWNER_ID])
+ self.archived_project = self.services.project.TestAddProject(
+ 'archived', owner_ids=[self.OWNER_ID],
+ state=project_pb2.ProjectState.ARCHIVED)
+ self.members_only_project = self.services.project.TestAddProject(
+ 'members-only', owner_ids=[self.OWNER_ID],
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY)
+
+ self.mox = mox.Mox()
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+
+ def CheckPermissions(self, perms, expect_view, expect_commit, expect_edit):
+ may_view = perms.HasPerm(permissions.VIEW, None, None)
+ self.assertEqual(expect_view, may_view)
+ may_commit = perms.HasPerm(permissions.COMMIT, None, None)
+ self.assertEqual(expect_commit, may_commit)
+ may_edit = perms.HasPerm(permissions.EDIT_PROJECT, None, None)
+ self.assertEqual(expect_edit, may_edit)
+
+ def MakeRequestAsUser(self, project_name, email):
+ self.mox.StubOutWithMock(users, 'get_current_user')
+ users.get_current_user().AndReturn(testing_helpers.Blank(
+ email=lambda: email))
+ self.mox.ReplayAll()
+
+ request = webapp2.Request.blank('/p/' + project_name)
+ mr = monorailrequest.MonorailRequest(self.services)
+ with mr.profiler.Phase('parse user info'):
+ mr.ParseRequest(request, self.services)
+ print('mr.auth is %r' % mr.auth)
+ return mr
+
+ def testOwnerPermissions_Live(self):
+ mr = self.MakeRequestAsUser('live', 'owner@gmail.com')
+ self.CheckPermissions(mr.perms, True, True, True)
+
+ def testOwnerPermissions_Archived(self):
+ mr = self.MakeRequestAsUser('archived', 'owner@gmail.com')
+ self.CheckPermissions(mr.perms, True, False, True)
+
+ def testOwnerPermissions_MembersOnly(self):
+ mr = self.MakeRequestAsUser('members-only', 'owner@gmail.com')
+ self.CheckPermissions(mr.perms, True, True, True)
+
+ def testExternalUserPermissions_Live(self):
+ mr = self.MakeRequestAsUser('live', 'user@gmail.com')
+ self.CheckPermissions(mr.perms, True, False, False)
+
+ def testExternalUserPermissions_Archived(self):
+ mr = self.MakeRequestAsUser('archived', 'user@gmail.com')
+ self.CheckPermissions(mr.perms, False, False, False)
+
+ def testExternalUserPermissions_MembersOnly(self):
+ mr = self.MakeRequestAsUser('members-only', 'user@gmail.com')
+ self.CheckPermissions(mr.perms, False, False, False)
diff --git a/framework/test/paginate_test.py b/framework/test/paginate_test.py
new file mode 100644
index 0000000..99adaa9
--- /dev/null
+++ b/framework/test/paginate_test.py
@@ -0,0 +1,145 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for pagination classes."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from google.appengine.ext import testbed
+
+from framework import exceptions
+from framework import paginate
+from testing import testing_helpers
+from proto import secrets_pb2
+
+
+class PageTokenTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def testGeneratePageToken_DiffRequests(self):
+ request_cont_1 = secrets_pb2.ListRequestContents(
+ parent='same', page_size=1, order_by='same', query='same')
+ request_cont_2 = secrets_pb2.ListRequestContents(
+ parent='same', page_size=2, order_by='same', query='same')
+ start = 10
+ self.assertNotEqual(
+ paginate.GeneratePageToken(request_cont_1, start),
+ paginate.GeneratePageToken(request_cont_2, start))
+
+ def testValidateAndParsePageToken(self):
+ request_cont_1 = secrets_pb2.ListRequestContents(
+ parent='projects/chicken', page_size=1, order_by='boks', query='hay')
+ start = 2
+ token = paginate.GeneratePageToken(request_cont_1, start)
+ self.assertEqual(
+ start,
+ paginate.ValidateAndParsePageToken(token, request_cont_1))
+
+ def testValidateAndParsePageToken_InvalidContents(self):
+ request_cont_1 = secrets_pb2.ListRequestContents(
+ parent='projects/chicken', page_size=1, order_by='boks', query='hay')
+ start = 2
+ token = paginate.GeneratePageToken(request_cont_1, start)
+
+ request_cont_diff = secrets_pb2.ListRequestContents(
+ parent='projects/goose', page_size=1, order_by='boks', query='hay')
+ with self.assertRaises(exceptions.PageTokenException):
+ paginate.ValidateAndParsePageToken(token, request_cont_diff)
+
+ def testValidateAndParsePageToken_InvalidSerializedToken(self):
+ request_cont = secrets_pb2.ListRequestContents()
+ with self.assertRaises(exceptions.PageTokenException):
+ paginate.ValidateAndParsePageToken('sldkfj87', request_cont)
+
+ def testValidateAndParsePageToken_InvalidTokenFormat(self):
+ request_cont = secrets_pb2.ListRequestContents()
+ with self.assertRaises(exceptions.PageTokenException):
+ paginate.ValidateAndParsePageToken('///sldkfj87', request_cont)
+
+
+class PaginateTest(unittest.TestCase):
+
+ def testVirtualPagination(self):
+ # Paginating 0 results on a page that can hold 100.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list')
+ total_count = 0
+ items_per_page = 100
+ start = 0
+ vp = paginate.VirtualPagination(total_count, items_per_page, start)
+ self.assertEqual(vp.num, 100)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 0)
+ self.assertFalse(vp.visible)
+
+ # Paginating 12 results on a page that can hold 100.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list')
+ vp = paginate.VirtualPagination(12, 100, 0)
+ self.assertEqual(vp.num, 100)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 12)
+ self.assertTrue(vp.visible)
+
+ # Paginating 12 results on a page that can hold 10.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list?num=10')
+ vp = paginate.VirtualPagination(12, 10, 0)
+ self.assertEqual(vp.num, 10)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 10)
+ self.assertTrue(vp.visible)
+
+ # Paginating 12 results starting at 5 on page that can hold 10.
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/issues/list?start=5&num=10')
+ vp = paginate.VirtualPagination(12, 10, 5)
+ self.assertEqual(vp.num, 10)
+ self.assertEqual(vp.start, 6)
+ self.assertEqual(vp.last, 12)
+ self.assertTrue(vp.visible)
+
+ # Paginating 123 results on a page that can hold 100.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list')
+ vp = paginate.VirtualPagination(123, 100, 0)
+ self.assertEqual(vp.num, 100)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 100)
+ self.assertTrue(vp.visible)
+
+ # Paginating 123 results on second page that can hold 100.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list?start=100')
+ vp = paginate.VirtualPagination(123, 100, 100)
+ self.assertEqual(vp.num, 100)
+ self.assertEqual(vp.start, 101)
+ self.assertEqual(vp.last, 123)
+ self.assertTrue(vp.visible)
+
+ # Paginating a huge number of objects will show at most 1000 per page.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list?num=9999')
+ vp = paginate.VirtualPagination(12345, 9999, 0)
+ self.assertEqual(vp.num, 1000)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 1000)
+ self.assertTrue(vp.visible)
+
+ # Test urls for a hotlist pagination
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/u/hotlists/17?num=5&start=4')
+ mr.hotlist_id = 17
+ mr.auth.user_id = 112
+ vp = paginate.VirtualPagination(12, 5, 4,
+ list_page_url='/u/112/hotlists/17')
+ self.assertEqual(vp.num, 5)
+ self.assertEqual(vp.start, 5)
+ self.assertEqual(vp.last, 9)
+ self.assertTrue(vp.visible)
+ self.assertEqual('/u/112/hotlists/17?num=5&start=9', vp.next_url)
+ self.assertEqual('/u/112/hotlists/17?num=5&start=0', vp.prev_url)
diff --git a/framework/test/permissions_test.py b/framework/test/permissions_test.py
new file mode 100644
index 0000000..0917b53
--- /dev/null
+++ b/framework/test/permissions_test.py
@@ -0,0 +1,1860 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for permissions.py."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import unittest
+
+import mox
+
+import settings
+from framework import authdata
+from framework import framework_constants
+from framework import framework_views
+from framework import permissions
+from proto import features_pb2
+from proto import project_pb2
+from proto import site_pb2
+from proto import tracker_pb2
+from proto import user_pb2
+from proto import usergroup_pb2
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+
+
+class PermissionSetTest(unittest.TestCase):
+
+ def setUp(self):
+ self.perms = permissions.PermissionSet(['A', 'b', 'Cc'])
+ self.proj = project_pb2.Project()
+ self.proj.contributor_ids.append(111)
+ self.proj.contributor_ids.append(222)
+ self.proj.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=111, perms=['Cc', 'D', 'e', 'Ff']))
+ self.proj.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=222, perms=['G', 'H']))
+ # user 3 used to be a member and had extra perms, but no longer in project.
+ self.proj.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=333, perms=['G', 'H']))
+
+ def testGetAttr(self):
+ self.assertTrue(self.perms.a)
+ self.assertTrue(self.perms.A)
+ self.assertTrue(self.perms.b)
+ self.assertTrue(self.perms.Cc)
+ self.assertTrue(self.perms.CC)
+
+ self.assertFalse(self.perms.z)
+ self.assertFalse(self.perms.Z)
+
+ def testCanUsePerm_Anonymous(self):
+ effective_ids = set()
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+
+ def testCanUsePerm_SignedInNoGroups(self):
+ effective_ids = {111}
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'D', effective_ids, self.proj, ['Restrict-D-A']))
+ self.assertFalse(self.perms.CanUsePerm('G', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+
+ effective_ids = {222}
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm('G', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm(
+ 'Z', effective_ids, self.proj, ['Restrict-Z-A']))
+
+ def testCanUsePerm_SignedInWithGroups(self):
+ effective_ids = {111, 222, 333}
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm('G', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'G', effective_ids, self.proj, ['Restrict-G-D']))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm(
+ 'G', effective_ids, self.proj, ['Restrict-G-Z']))
+
+ def testCanUsePerm_FormerMember(self):
+ effective_ids = {333}
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('G', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+
+ def testHasPerm_InPermSet(self):
+ self.assertTrue(self.perms.HasPerm('a', 0, None))
+ self.assertTrue(self.perms.HasPerm('a', 0, self.proj))
+ self.assertTrue(self.perms.HasPerm('A', 0, None))
+ self.assertTrue(self.perms.HasPerm('A', 0, self.proj))
+ self.assertFalse(self.perms.HasPerm('Z', 0, None))
+ self.assertFalse(self.perms.HasPerm('Z', 0, self.proj))
+
+ def testHasPerm_InExtraPerms(self):
+ self.assertTrue(self.perms.HasPerm('d', 111, self.proj))
+ self.assertTrue(self.perms.HasPerm('D', 111, self.proj))
+ self.assertTrue(self.perms.HasPerm('Cc', 111, self.proj))
+ self.assertTrue(self.perms.HasPerm('CC', 111, self.proj))
+ self.assertFalse(self.perms.HasPerm('Z', 111, self.proj))
+
+ self.assertFalse(self.perms.HasPerm('d', 222, self.proj))
+ self.assertFalse(self.perms.HasPerm('D', 222, self.proj))
+
+ # Only current members can have extra permissions
+ self.proj.contributor_ids = []
+ self.assertFalse(self.perms.HasPerm('d', 111, self.proj))
+
+ # TODO(jrobbins): also test consider_restrictions=False and
+ # restriction labels directly in this class.
+
+ def testHasPerm_OverrideExtraPerms(self):
+ # D is an extra perm for 111...
+ self.assertTrue(self.perms.HasPerm('d', 111, self.proj))
+ self.assertTrue(self.perms.HasPerm('D', 111, self.proj))
+ # ...unless we tell HasPerm it isn't.
+ self.assertFalse(self.perms.HasPerm('d', 111, self.proj, []))
+ self.assertFalse(self.perms.HasPerm('D', 111, self.proj, []))
+ # Perms in self.perms are still considered
+ self.assertTrue(self.perms.HasPerm('Cc', 111, self.proj, []))
+ self.assertTrue(self.perms.HasPerm('CC', 111, self.proj, []))
+ # Z is not an extra perm...
+ self.assertFalse(self.perms.HasPerm('Z', 111, self.proj))
+ # ...unless we tell HasPerm it is.
+ self.assertTrue(self.perms.HasPerm('Z', 111, self.proj, ['z']))
+
+ def testHasPerm_GrantedPerms(self):
+ self.assertTrue(self.perms.CanUsePerm(
+ 'A', {111}, self.proj, [], granted_perms=['z']))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'a', {111}, self.proj, [], granted_perms=['z']))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'a', {111}, self.proj, [], granted_perms=['a']))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'Z', {111}, self.proj, [], granted_perms=['y', 'z']))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'z', {111}, self.proj, [], granted_perms=['y', 'z']))
+ self.assertFalse(self.perms.CanUsePerm(
+ 'z', {111}, self.proj, [], granted_perms=['y']))
+
+ def testDebugString(self):
+ self.assertEqual('PermissionSet()',
+ permissions.PermissionSet([]).DebugString())
+ self.assertEqual('PermissionSet(a)',
+ permissions.PermissionSet(['A']).DebugString())
+ self.assertEqual('PermissionSet(a, b, cc)', self.perms.DebugString())
+
+ def testRepr(self):
+ self.assertEqual('PermissionSet(frozenset([]))',
+ permissions.PermissionSet([]).__repr__())
+ self.assertEqual('PermissionSet(frozenset([\'a\']))',
+ permissions.PermissionSet(['A']).__repr__())
+
+
+class PermissionsTest(unittest.TestCase):
+
+ NOW = 1277762224 # Any timestamp will do, we only compare it to itself +/- 1
+ COMMITTER_USER_ID = 111
+ OWNER_USER_ID = 222
+ CONTRIB_USER_ID = 333
+ SITE_ADMIN_USER_ID = 444
+
+ def MakeProject(self, project_name, state, add_members=True, access=None):
+ args = dict(project_name=project_name, state=state)
+ if add_members:
+ args.update(owner_ids=[self.OWNER_USER_ID],
+ committer_ids=[self.COMMITTER_USER_ID],
+ contributor_ids=[self.CONTRIB_USER_ID])
+
+ if access:
+ args.update(access=access)
+
+ return fake.Project(**args)
+
+ def setUp(self):
+ self.live_project = self.MakeProject('live', project_pb2.ProjectState.LIVE)
+ self.archived_project = self.MakeProject(
+ 'archived', project_pb2.ProjectState.ARCHIVED)
+ self.other_live_project = self.MakeProject(
+ 'other_live', project_pb2.ProjectState.LIVE, add_members=False)
+ self.members_only_project = self.MakeProject(
+ 's3kr3t', project_pb2.ProjectState.LIVE,
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY)
+
+ self.nonmember = user_pb2.User()
+ self.member = user_pb2.User()
+ self.owner = user_pb2.User()
+ self.contrib = user_pb2.User()
+ self.site_admin = user_pb2.User()
+ self.site_admin.is_site_admin = True
+ self.borg_user = user_pb2.User(email=settings.borg_service_account)
+
+ self.normal_artifact = tracker_pb2.Issue()
+ self.normal_artifact.labels.extend(['hot', 'Key-Value'])
+ self.normal_artifact.reporter_id = 111
+
+ # Two PermissionSets w/ permissions outside of any project.
+ self.normal_user_perms = permissions.GetPermissions(
+ None, {111}, None)
+ self.admin_perms = permissions.PermissionSet(
+ [permissions.ADMINISTER_SITE,
+ permissions.CREATE_PROJECT])
+
+ self.mox = mox.Mox()
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+
+ def testGetPermissions_Admin(self):
+ self.assertEqual(
+ permissions.ADMIN_PERMISSIONSET,
+ permissions.GetPermissions(self.site_admin, None, None))
+
+ def testGetPermissions_BorgServiceAccount(self):
+ self.assertEqual(
+ permissions.GROUP_IMPORT_BORG_PERMISSIONSET,
+ permissions.GetPermissions(self.borg_user, None, None))
+
+ def CheckPermissions(self, perms, expected_list):
+ expect_view, expect_commit, expect_edit_project = expected_list
+ self.assertEqual(
+ expect_view, perms.HasPerm(permissions.VIEW, None, None))
+ self.assertEqual(
+ expect_commit, perms.HasPerm(permissions.COMMIT, None, None))
+ self.assertEqual(
+ expect_edit_project,
+ perms.HasPerm(permissions.EDIT_PROJECT, None, None))
+
+ def testAnonPermissions(self):
+ perms = permissions.GetPermissions(None, set(), self.live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(None, set(), self.members_only_project)
+ self.CheckPermissions(perms, [False, False, False])
+
+ def testNonmemberPermissions(self):
+ perms = permissions.GetPermissions(
+ self.nonmember, {123}, self.live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.nonmember, {123}, self.members_only_project)
+ self.CheckPermissions(perms, [False, False, False])
+
+ def testMemberPermissions(self):
+ perms = permissions.GetPermissions(
+ self.member, {self.COMMITTER_USER_ID}, self.live_project)
+ self.CheckPermissions(perms, [True, True, False])
+
+ perms = permissions.GetPermissions(
+ self.member, {self.COMMITTER_USER_ID}, self.other_live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.member, {self.COMMITTER_USER_ID}, self.members_only_project)
+ self.CheckPermissions(perms, [True, True, False])
+
+ def testOwnerPermissions(self):
+ perms = permissions.GetPermissions(
+ self.owner, {self.OWNER_USER_ID}, self.live_project)
+ self.CheckPermissions(perms, [True, True, True])
+
+ perms = permissions.GetPermissions(
+ self.owner, {self.OWNER_USER_ID}, self.other_live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.owner, {self.OWNER_USER_ID}, self.members_only_project)
+ self.CheckPermissions(perms, [True, True, True])
+
+ def testContributorPermissions(self):
+ perms = permissions.GetPermissions(
+ self.contrib, {self.CONTRIB_USER_ID}, self.live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.contrib, {self.CONTRIB_USER_ID}, self.other_live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.contrib, {self.CONTRIB_USER_ID}, self.members_only_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ def testLookupPermset_ExactMatch(self):
+ self.assertEqual(
+ permissions.USER_PERMISSIONSET,
+ permissions._LookupPermset(
+ permissions.USER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE))
+
+ def testLookupPermset_WildcardAccess(self):
+ self.assertEqual(
+ permissions.OWNER_ACTIVE_PERMISSIONSET,
+ permissions._LookupPermset(
+ permissions.OWNER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.MEMBERS_ONLY))
+
+ def testGetPermissionKey_AnonUser(self):
+ self.assertEqual(
+ (permissions.ANON_ROLE, permissions.UNDEFINED_STATUS,
+ permissions.UNDEFINED_ACCESS),
+ permissions._GetPermissionKey(None, None))
+ self.assertEqual(
+ (permissions.ANON_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(None, self.live_project))
+
+ def testGetPermissionKey_ExpiredProject(self):
+ self.archived_project.delete_time = self.NOW
+ # In an expired project, the user's committe role does not count.
+ self.assertEqual(
+ (permissions.USER_ROLE, project_pb2.ProjectState.ARCHIVED,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.COMMITTER_USER_ID, self.archived_project,
+ expired_before=self.NOW + 1))
+ # If not expired yet, the user's committe role still counts.
+ self.assertEqual(
+ (permissions.COMMITTER_ROLE, project_pb2.ProjectState.ARCHIVED,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.COMMITTER_USER_ID, self.archived_project,
+ expired_before=self.NOW - 1))
+
+ def testGetPermissionKey_DefinedRoles(self):
+ self.assertEqual(
+ (permissions.OWNER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.OWNER_USER_ID, self.live_project))
+ self.assertEqual(
+ (permissions.COMMITTER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.COMMITTER_USER_ID, self.live_project))
+ self.assertEqual(
+ (permissions.CONTRIBUTOR_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.CONTRIB_USER_ID, self.live_project))
+
+ def testGetPermissionKey_Nonmember(self):
+ self.assertEqual(
+ (permissions.USER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ 999, self.live_project))
+
+ def testPermissionsImmutable(self):
+ self.assertTrue(isinstance(
+ permissions.EMPTY_PERMISSIONSET.perm_names, frozenset))
+ self.assertTrue(isinstance(
+ permissions.READ_ONLY_PERMISSIONSET.perm_names, frozenset))
+ self.assertTrue(isinstance(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET.perm_names, frozenset))
+ self.assertTrue(isinstance(
+ permissions.OWNER_ACTIVE_PERMISSIONSET.perm_names, frozenset))
+
+ def testGetExtraPerms(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(222)
+ # User 1 is a former member with left-over extra perms that don't count.
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=111, perms=['a', 'b', 'c']))
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=222, perms=['a', 'b', 'c']))
+
+ self.assertListEqual(
+ [],
+ permissions.GetExtraPerms(project, 111))
+ self.assertListEqual(
+ ['a', 'b', 'c'],
+ permissions.GetExtraPerms(project, 222))
+ self.assertListEqual(
+ [],
+ permissions.GetExtraPerms(project, 333))
+
+ def testCanDeleteComment_NoPermissionSet(self):
+ """Test that if no PermissionSet is given, we can't delete comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+ # If no PermissionSet is given, the user cannot delete the comment.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111, None))
+ # Same, with no user specified.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, framework_constants.NO_USER_SPECIFIED, None))
+
+ def testCanDeleteComment_AnonUsersCannotDelete(self):
+ """Test that anon users can't delete comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+ perms = permissions.PermissionSet([permissions.DELETE_ANY])
+
+ # No logged in user, even with perms from somewhere.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, framework_constants.NO_USER_SPECIFIED, perms))
+
+ # No logged in user, even if artifact was already deleted.
+ comment.deleted_by = 111
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, framework_constants.NO_USER_SPECIFIED, perms))
+
+ def testCanDeleteComment_DeleteAny(self):
+ """Test that users with DeleteAny permission can delete any comment.
+
+ Except for spam comments or comments by banned users.
+ """
+ comment = tracker_pb2.IssueComment(user_id=111)
+ commenter = user_pb2.User()
+ perms = permissions.PermissionSet([permissions.DELETE_ANY])
+
+ # Users with DeleteAny permission can delete their own comments.
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # And also comments by other users
+ comment.user_id = 999
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # As well as undelete comments they deleted.
+ comment.deleted_by = 111
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # Or that other users deleted.
+ comment.deleted_by = 222
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ def testCanDeleteComment_DeleteOwn(self):
+ """Test that users with DeleteOwn permission can delete any comment.
+
+ Except for spam comments or comments by banned users.
+ """
+ comment = tracker_pb2.IssueComment(user_id=111)
+ commenter = user_pb2.User()
+ perms = permissions.PermissionSet([permissions.DELETE_OWN])
+
+ # Users with DeleteOwn permission can delete their own comments.
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # But not comments by other users
+ comment.user_id = 999
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # They can undelete comments they deleted.
+ comment.user_id = 111
+ comment.deleted_by = 111
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # But not comments that other users deleted.
+ comment.deleted_by = 222
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ def testCanDeleteComment_CannotDeleteSpamComments(self):
+ """Test that nobody can (un)delete comments marked as spam."""
+ comment = tracker_pb2.IssueComment(user_id=111, is_spam=True)
+ commenter = user_pb2.User()
+
+ # Nobody can delete comments marked as spam.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ # Nobody can undelete comments marked as spam.
+ comment.deleted_by = 222
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ def testCanDeleteComment_CannotDeleteCommentsByBannedUser(self):
+ """Test that nobody can (un)delete comments by banned users."""
+ comment = tracker_pb2.IssueComment(user_id=111)
+ commenter = user_pb2.User(banned='Some reason')
+
+ # Nobody can delete comments by banned users.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ # Nobody can undelete comments by banned users.
+ comment.deleted_by = 222
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ def testCanFlagComment_FlagSpamCanReport(self):
+ """Test that users with FlagSpam permissions can report comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([permissions.FLAG_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_FlagSpamCanUnReportOwn(self):
+ """Test that users with FlagSpam permission can un-report comments they
+ previously reported."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [111], 111,
+ permissions.PermissionSet([permissions.FLAG_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_FlagSpamCannotUnReportOthers(self):
+ """Test that users with FlagSpam permission doesn't know if other users have
+ reported a comment as spam."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [222], 111,
+ permissions.PermissionSet([permissions.FLAG_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_FlagSpamCannotUnFlag(self):
+ comment = tracker_pb2.IssueComment(is_spam=True)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [111], 111,
+ permissions.PermissionSet([permissions.FLAG_SPAM]))
+
+ self.assertFalse(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_VerdictSpamCanFlag(self):
+ """Test that users with FlagSpam permissions can flag comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([permissions.VERDICT_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_VerdictSpamCanUnFlag(self):
+ """Test that users with FlagSpam permissions can un-flag comments."""
+ comment = tracker_pb2.IssueComment(is_spam=True)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([permissions.VERDICT_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_CannotFlagNoPermission(self):
+ """Test that users without permission cannot flag comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([permissions.DELETE_ANY]))
+
+ self.assertFalse(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_CannotUnFlagNoPermission(self):
+ """Test that users without permission cannot un-flag comments."""
+ comment = tracker_pb2.IssueComment(is_spam=True)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ # Users need the VerdictSpam permission to be able to un-flag comments.
+ permissions.PermissionSet([
+ permissions.DELETE_ANY, permissions.FLAG_SPAM]))
+
+ self.assertFalse(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_CannotFlagCommentByBannedUser(self):
+ """Test that nobady can flag comments by banned users."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User(banned='Some reason')
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([
+ permissions.FLAG_SPAM, permissions.VERDICT_SPAM]))
+
+ self.assertFalse(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_CannotUnFlagCommentByBannedUser(self):
+ """Test that nobady can un-flag comments by banned users."""
+ comment = tracker_pb2.IssueComment(is_spam=True)
+ commenter = user_pb2.User(banned='Some reason')
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([
+ permissions.FLAG_SPAM, permissions.VERDICT_SPAM]))
+
+ self.assertFalse(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_CanUnFlagDeletedSpamComment(self):
+ """Test that we can un-flag a deleted comment that is spam."""
+ comment = tracker_pb2.IssueComment(is_spam=True, deleted_by=111)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 222,
+ permissions.PermissionSet([permissions.VERDICT_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_CannotFlagDeletedComment(self):
+ """Test that nobody can flag a deleted comment that is not spam."""
+ comment = tracker_pb2.IssueComment(deleted_by=111)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([
+ permissions.FLAG_SPAM, permissions.VERDICT_SPAM,
+ permissions.DELETE_ANY, permissions.DELETE_OWN]))
+
+ self.assertFalse(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanViewComment_Normal(self):
+ """Test that we can view comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+ # We assume that CanViewIssue was already called. There are no further
+ # restrictions to view this comment.
+ self.assertTrue(permissions.CanViewComment(
+ comment, commenter, 111, None))
+
+ def testCanViewComment_CannotViewCommentsByBannedUser(self):
+ """Test that nobody can view comments by banned users."""
+ comment = tracker_pb2.IssueComment(user_id=111)
+ commenter = user_pb2.User(banned='Some reason')
+
+ # Nobody can view comments by banned users.
+ self.assertFalse(permissions.CanViewComment(
+ comment, commenter, 111, permissions.ADMIN_PERMISSIONSET))
+
+ def testCanViewComment_OnlyModeratorsCanViewSpamComments(self):
+ """Test that only users with VerdictSpam can view spam comments."""
+ comment = tracker_pb2.IssueComment(user_id=111, is_spam=True)
+ commenter = user_pb2.User()
+
+ # Users with VerdictSpam permission can view comments marked as spam.
+ self.assertTrue(permissions.CanViewComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.VERDICT_SPAM])))
+
+ # Other users cannot view comments marked as spam, even if it is their own
+ # comment.
+ self.assertFalse(permissions.CanViewComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([
+ permissions.FLAG_SPAM, permissions.DELETE_ANY,
+ permissions.DELETE_OWN])))
+
+ def testCanViewComment_DeletedComment(self):
+ """Test that for deleted comments, only the users that can undelete it can
+ view it.
+ """
+ comment = tracker_pb2.IssueComment(user_id=111, deleted_by=222)
+ commenter = user_pb2.User()
+
+ # Users with DeleteAny permission can view all deleted comments.
+ self.assertTrue(permissions.CanViewComment(
+ comment, commenter, 333,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ # Users with DeleteOwn permissions can only see their own comments if they
+ # deleted them.
+ comment.user_id = comment.deleted_by = 333
+ self.assertTrue(permissions.CanViewComment(
+ comment, commenter, 333,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+
+ # But not comments they didn't delete.
+ comment.deleted_by = 111
+ self.assertFalse(permissions.CanViewComment(
+ comment, commenter, 333,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+
+ def testCanViewInboundMessage(self):
+ comment = tracker_pb2.IssueComment(user_id=111)
+
+ # Users can view their own inbound messages
+ self.assertTrue(permissions.CanViewInboundMessage(
+ comment, 111, permissions.EMPTY_PERMISSIONSET))
+
+ # Users with the ViewInboundMessages permissions can view inbound messages.
+ self.assertTrue(permissions.CanViewInboundMessage(
+ comment, 333,
+ permissions.PermissionSet([permissions.VIEW_INBOUND_MESSAGES])))
+
+ # Other users cannot view inbound messages.
+ self.assertFalse(permissions.CanViewInboundMessage(
+ comment, 333,
+ permissions.PermissionSet([permissions.VIEW])))
+
+ def testCanViewNormalArifact(self):
+ # Anyone can view a non-restricted artifact.
+ self.assertTrue(permissions.CanView(
+ {111}, permissions.READ_ONLY_PERMISSIONSET,
+ self.live_project, []))
+
+ def testCanCreateProject_NoPerms(self):
+ """Signed out users cannot create projects."""
+ self.assertFalse(permissions.CanCreateProject(
+ permissions.EMPTY_PERMISSIONSET))
+
+ self.assertFalse(permissions.CanCreateProject(
+ permissions.READ_ONLY_PERMISSIONSET))
+
+ def testCanCreateProject_Admin(self):
+ """Site admins can create projects."""
+ self.assertTrue(permissions.CanCreateProject(
+ permissions.ADMIN_PERMISSIONSET))
+
+ def testCanCreateProject_RegularUser(self):
+ """Signed in non-admins can create a project if settings allow ANYONE."""
+ try:
+ orig_restriction = settings.project_creation_restriction
+ ANYONE = site_pb2.UserTypeRestriction.ANYONE
+ ADMIN_ONLY = site_pb2.UserTypeRestriction.ADMIN_ONLY
+ NO_ONE = site_pb2.UserTypeRestriction.NO_ONE
+ perms = permissions.PermissionSet([permissions.CREATE_PROJECT])
+
+ settings.project_creation_restriction = ANYONE
+ self.assertTrue(permissions.CanCreateProject(perms))
+
+ settings.project_creation_restriction = ADMIN_ONLY
+ self.assertFalse(permissions.CanCreateProject(perms))
+
+ settings.project_creation_restriction = NO_ONE
+ self.assertFalse(permissions.CanCreateProject(perms))
+ self.assertFalse(permissions.CanCreateProject(
+ permissions.ADMIN_PERMISSIONSET))
+ finally:
+ settings.project_creation_restriction = orig_restriction
+
+ def testCanCreateGroup_AnyoneWithCreateGroup(self):
+ orig_setting = settings.group_creation_restriction
+ try:
+ settings.group_creation_restriction = site_pb2.UserTypeRestriction.ANYONE
+ self.assertTrue(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.CREATE_GROUP])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([])))
+ finally:
+ settings.group_creation_restriction = orig_setting
+
+ def testCanCreateGroup_AdminOnly(self):
+ orig_setting = settings.group_creation_restriction
+ try:
+ ADMIN_ONLY = site_pb2.UserTypeRestriction.ADMIN_ONLY
+ settings.group_creation_restriction = ADMIN_ONLY
+ self.assertTrue(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.ADMINISTER_SITE])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.CREATE_GROUP])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([])))
+ finally:
+ settings.group_creation_restriction = orig_setting
+
+ def testCanCreateGroup_UnspecifiedSetting(self):
+ orig_setting = settings.group_creation_restriction
+ try:
+ settings.group_creation_restriction = None
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.ADMINISTER_SITE])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.CREATE_GROUP])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([])))
+ finally:
+ settings.group_creation_restriction = orig_setting
+
+ def testCanEditGroup_HasPerm(self):
+ self.assertTrue(permissions.CanEditGroup(
+ permissions.PermissionSet([permissions.EDIT_GROUP]), None, None))
+
+ def testCanEditGroup_IsOwner(self):
+ self.assertTrue(permissions.CanEditGroup(
+ permissions.PermissionSet([]), {111}, {111}))
+
+ def testCanEditGroup_Otherwise(self):
+ self.assertFalse(permissions.CanEditGroup(
+ permissions.PermissionSet([]), {111}, {222}))
+
+ def testCanViewGroupMembers_HasPerm(self):
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([permissions.VIEW_GROUP]),
+ None, None, None, None, None))
+
+ def testCanViewGroupMembers_IsMemberOfFriendProject(self):
+ group_settings = usergroup_pb2.MakeSettings('owners', friend_projects=[890])
+ self.assertFalse(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789}))
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789, 890}))
+
+ def testCanViewGroupMembers_VisibleToOwner(self):
+ group_settings = usergroup_pb2.MakeSettings('owners')
+ self.assertFalse(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789}))
+ self.assertFalse(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {222}, group_settings, {222}, {333}, {789}))
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {333}, group_settings, {222}, {333}, {789}))
+
+ def testCanViewGroupMembers_IsVisibleToMember(self):
+ group_settings = usergroup_pb2.MakeSettings('members')
+ self.assertFalse(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789}))
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {222}, group_settings, {222}, {333}, {789}))
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {333}, group_settings, {222}, {333}, {789}))
+
+ def testCanViewGroupMembers_AnyoneCanView(self):
+ group_settings = usergroup_pb2.MakeSettings('anyone')
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789}))
+
+ def testIsBanned_AnonUser(self):
+ user_view = framework_views.StuffUserView(None, None, True)
+ self.assertFalse(permissions.IsBanned(None, user_view))
+
+ def testIsBanned_NormalUser(self):
+ user = user_pb2.User()
+ user_view = framework_views.StuffUserView(None, None, True)
+ self.assertFalse(permissions.IsBanned(user, user_view))
+
+ def testIsBanned_BannedUser(self):
+ user = user_pb2.User()
+ user.banned = 'spammer'
+ user_view = framework_views.StuffUserView(None, None, True)
+ self.assertTrue(permissions.IsBanned(user, user_view))
+
+ def testIsBanned_BadDomainUser(self):
+ user = user_pb2.User()
+ self.assertFalse(permissions.IsBanned(user, None))
+
+ user_view = framework_views.StuffUserView(None, None, True)
+ user_view.domain = 'spammer.com'
+ self.assertFalse(permissions.IsBanned(user, user_view))
+
+ orig_banned_user_domains = settings.banned_user_domains
+ settings.banned_user_domains = ['spammer.com', 'phisher.com']
+ self.assertTrue(permissions.IsBanned(user, user_view))
+ settings.banned_user_domains = orig_banned_user_domains
+
+ def testIsBanned_PlusAddressUser(self):
+ """We don't allow users who have + in their email address."""
+ user = user_pb2.User(email='user@example.com')
+ self.assertFalse(permissions.IsBanned(user, None))
+
+ user.email = 'user+shadystuff@example.com'
+ self.assertTrue(permissions.IsBanned(user, None))
+
+ def testCanExpungeUser_Admin(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ mr.perms = permissions.ADMIN_PERMISSIONSET
+ self.assertTrue(permissions.CanExpungeUsers(mr))
+
+ def testGetCustomPermissions(self):
+ project = project_pb2.Project()
+ self.assertListEqual([], permissions.GetCustomPermissions(project))
+
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ perms=['Core', 'Elite', 'Gold']))
+ self.assertListEqual(['Core', 'Elite', 'Gold'],
+ permissions.GetCustomPermissions(project))
+
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ perms=['Silver', 'Gold', 'Bronze']))
+ self.assertListEqual(['Bronze', 'Core', 'Elite', 'Gold', 'Silver'],
+ permissions.GetCustomPermissions(project))
+
+ # View is not returned because it is a starndard permission.
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ perms=['Bronze', permissions.VIEW]))
+ self.assertListEqual(['Bronze', 'Core', 'Elite', 'Gold', 'Silver'],
+ permissions.GetCustomPermissions(project))
+
+ def testUserCanViewProject(self):
+ self.mox.StubOutWithMock(time, 'time')
+ for _ in range(8):
+ time.time().AndReturn(self.NOW)
+ self.mox.ReplayAll()
+
+ self.assertTrue(permissions.UserCanViewProject(
+ self.member, {self.COMMITTER_USER_ID}, self.live_project))
+ self.assertTrue(permissions.UserCanViewProject(
+ None, None, self.live_project))
+
+ self.archived_project.delete_time = self.NOW + 1
+ self.assertFalse(permissions.UserCanViewProject(
+ None, None, self.archived_project))
+ self.assertTrue(permissions.UserCanViewProject(
+ self.owner, {self.OWNER_USER_ID}, self.archived_project))
+ self.assertTrue(permissions.UserCanViewProject(
+ self.site_admin, {self.SITE_ADMIN_USER_ID},
+ self.archived_project))
+
+ self.archived_project.delete_time = self.NOW - 1
+ self.assertFalse(permissions.UserCanViewProject(
+ None, None, self.archived_project))
+ self.assertFalse(permissions.UserCanViewProject(
+ self.owner, {self.OWNER_USER_ID}, self.archived_project))
+ self.assertTrue(permissions.UserCanViewProject(
+ self.site_admin, {self.SITE_ADMIN_USER_ID},
+ self.archived_project))
+
+ self.mox.VerifyAll()
+
+ def CheckExpired(self, state, expected_to_be_reapable):
+ proj = project_pb2.Project()
+ proj.state = state
+ proj.delete_time = self.NOW + 1
+ self.assertFalse(permissions.IsExpired(proj))
+
+ proj.delete_time = self.NOW - 1
+ self.assertEqual(expected_to_be_reapable, permissions.IsExpired(proj))
+
+ proj.delete_time = self.NOW - 1
+ self.assertFalse(permissions.IsExpired(proj, expired_before=self.NOW - 2))
+
+ def testIsExpired_Live(self):
+ self.CheckExpired(project_pb2.ProjectState.LIVE, False)
+
+ def testIsExpired_Archived(self):
+ self.mox.StubOutWithMock(time, 'time')
+ for _ in range(2):
+ time.time().AndReturn(self.NOW)
+ self.mox.ReplayAll()
+
+ self.CheckExpired(project_pb2.ProjectState.ARCHIVED, True)
+
+ self.mox.VerifyAll()
+
+
+class PermissionsCheckTest(unittest.TestCase):
+
+ def setUp(self):
+ self.perms = permissions.PermissionSet(['a', 'b', 'c'])
+
+ self.proj = project_pb2.Project()
+ self.proj.committer_ids.append(111)
+ self.proj.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=111, perms=['d']))
+
+ # Note: z is an example of a perm that the user does not have.
+ # Note: q is an example of an irrelevant perm that the user does not have.
+
+ def DoCanUsePerm(self, perm, project='default', user_id=None, restrict=''):
+ """Wrapper function to call CanUsePerm()."""
+ if project == 'default':
+ project = self.proj
+ return self.perms.CanUsePerm(
+ perm, {user_id or 111}, project, restrict.split())
+
+ def testHasPermNoRestrictions(self):
+ self.assertTrue(self.DoCanUsePerm('a'))
+ self.assertTrue(self.DoCanUsePerm('A'))
+ self.assertFalse(self.DoCanUsePerm('z'))
+ self.assertTrue(self.DoCanUsePerm('d'))
+ self.assertFalse(self.DoCanUsePerm('d', user_id=222))
+ self.assertFalse(self.DoCanUsePerm('d', project=project_pb2.Project()))
+
+ def testHasPermOperationRestrictions(self):
+ self.assertTrue(self.DoCanUsePerm('a', restrict='Restrict-a-b'))
+ self.assertTrue(self.DoCanUsePerm('a', restrict='Restrict-b-z'))
+ self.assertTrue(self.DoCanUsePerm('a', restrict='Restrict-a-d'))
+ self.assertTrue(self.DoCanUsePerm('d', restrict='Restrict-d-a'))
+ self.assertTrue(self.DoCanUsePerm(
+ 'd', restrict='Restrict-q-z Restrict-q-d Restrict-d-a'))
+
+ self.assertFalse(self.DoCanUsePerm('a', restrict='Restrict-a-z'))
+ self.assertFalse(self.DoCanUsePerm('d', restrict='Restrict-d-z'))
+ self.assertFalse(self.DoCanUsePerm(
+ 'd', restrict='Restrict-d-a Restrict-d-z'))
+
+ def testHasPermOutsideProjectScope(self):
+ self.assertTrue(self.DoCanUsePerm('a', project=None))
+ self.assertTrue(self.DoCanUsePerm(
+ 'a', project=None, restrict='Restrict-a-c'))
+ self.assertTrue(self.DoCanUsePerm(
+ 'a', project=None, restrict='Restrict-q-z'))
+
+ self.assertFalse(self.DoCanUsePerm('z', project=None))
+ self.assertFalse(self.DoCanUsePerm(
+ 'a', project=None, restrict='Restrict-a-d'))
+
+
+class CanViewProjectContributorListTest(unittest.TestCase):
+
+ def testCanViewProjectContributorList_NoProject(self):
+ mr = testing_helpers.MakeMonorailRequest(path='/')
+ self.assertFalse(permissions.CanViewContributorList(mr, mr.project))
+
+ def testCanViewProjectContributorList_NormalProject(self):
+ project = project_pb2.Project()
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/', project=project)
+ self.assertTrue(permissions.CanViewContributorList(mr, mr.project))
+
+ def testCanViewProjectContributorList_ProjectWithOptionSet(self):
+ project = project_pb2.Project()
+ project.only_owners_see_contributors = True
+
+ for perms in [permissions.READ_ONLY_PERMISSIONSET,
+ permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ permissions.CONTRIBUTOR_INACTIVE_PERMISSIONSET]:
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/', project=project, perms=perms)
+ self.assertFalse(permissions.CanViewContributorList(mr, mr.project))
+
+ for perms in [permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ permissions.COMMITTER_INACTIVE_PERMISSIONSET,
+ permissions.OWNER_ACTIVE_PERMISSIONSET,
+ permissions.OWNER_INACTIVE_PERMISSIONSET,
+ permissions.ADMIN_PERMISSIONSET]:
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/', project=project, perms=perms)
+ self.assertTrue(permissions.CanViewContributorList(mr, mr.project))
+
+
+class ShouldCheckForAbandonmentTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mr = testing_helpers.Blank(
+ project=project_pb2.Project(),
+ auth=authdata.AuthData())
+
+ def testOwner(self):
+ self.mr.auth.effective_ids = {111}
+ self.mr.perms = permissions.OWNER_ACTIVE_PERMISSIONSET
+ self.assertTrue(permissions.ShouldCheckForAbandonment(self.mr))
+
+ def testNonOwner(self):
+ self.mr.auth.effective_ids = {222}
+ self.mr.perms = permissions.COMMITTER_ACTIVE_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+ self.mr.perms = permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+ self.mr.perms = permissions.USER_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+ self.mr.perms = permissions.EMPTY_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+
+ def testSiteAdmin(self):
+ self.mr.auth.effective_ids = {111}
+ self.mr.perms = permissions.ADMIN_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+
+
+class RestrictionLabelsTest(unittest.TestCase):
+
+ ORIG_SUMMARY = 'this is the orginal summary'
+ ORIG_LABELS = ['one', 'two']
+
+ def testIsRestrictLabel(self):
+ self.assertFalse(permissions.IsRestrictLabel('Usability'))
+ self.assertTrue(permissions.IsRestrictLabel('Restrict-View-CoreTeam'))
+ # Doing it again will test the cached results.
+ self.assertFalse(permissions.IsRestrictLabel('Usability'))
+ self.assertTrue(permissions.IsRestrictLabel('Restrict-View-CoreTeam'))
+
+ self.assertFalse(permissions.IsRestrictLabel('Usability', perm='View'))
+ self.assertTrue(permissions.IsRestrictLabel(
+ 'Restrict-View-CoreTeam', perm='View'))
+
+ # This one is a restriction label, but not the kind that we want.
+ self.assertFalse(permissions.IsRestrictLabel(
+ 'Restrict-View-CoreTeam', perm='Delete'))
+
+ def testGetRestrictions_NoIssue(self):
+ self.assertEqual([], permissions.GetRestrictions(None))
+
+ def testGetRestrictions_PermSpecified(self):
+ """We can return restiction labels related to the given perm."""
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0, labels=self.ORIG_LABELS)
+ self.assertEqual([], permissions.GetRestrictions(art, perm='view'))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['Restrict-View-Core', 'Hot',
+ 'Restrict-EditIssue-Commit', 'Restrict-EditIssue-Core'])
+ self.assertEqual(
+ ['restrict-view-core'],
+ permissions.GetRestrictions(art, perm='view'))
+ self.assertEqual(
+ ['restrict-view-core'],
+ permissions.GetRestrictions(art, perm='View'))
+ self.assertEqual(
+ ['restrict-editissue-commit', 'restrict-editissue-core'],
+ permissions.GetRestrictions(art, perm='EditIssue'))
+
+ def testGetRestrictions_NoPerm(self):
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0, labels=self.ORIG_LABELS)
+ self.assertEqual([], permissions.GetRestrictions(art))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['Restrict-MissingThirdPart', 'Hot'])
+ self.assertEqual([], permissions.GetRestrictions(art))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['Restrict-View-Core', 'Hot'])
+ self.assertEqual(['restrict-view-core'], permissions.GetRestrictions(art))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['Restrict-View-Core', 'Hot'],
+ derived_labels=['Color-Red', 'Restrict-EditIssue-GoldMembers'])
+ self.assertEqual(
+ ['restrict-view-core', 'restrict-editissue-goldmembers'],
+ permissions.GetRestrictions(art))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['restrict-view-core', 'hot'],
+ derived_labels=['Color-Red', 'RESTRICT-EDITISSUE-GOLDMEMBERS'])
+ self.assertEqual(
+ ['restrict-view-core', 'restrict-editissue-goldmembers'],
+ permissions.GetRestrictions(art))
+
+
+REPORTER_ID = 111
+OWNER_ID = 222
+CC_ID = 333
+OTHER_ID = 444
+APPROVER_ID = 555
+
+
+class IssuePermissionsTest(unittest.TestCase):
+
+ REGULAR_ISSUE = tracker_pb2.Issue()
+ REGULAR_ISSUE.reporter_id = REPORTER_ID
+
+ DELETED_ISSUE = tracker_pb2.Issue()
+ DELETED_ISSUE.deleted = True
+ DELETED_ISSUE.reporter_id = REPORTER_ID
+
+ RESTRICTED_ISSUE = tracker_pb2.Issue()
+ RESTRICTED_ISSUE.reporter_id = REPORTER_ID
+ RESTRICTED_ISSUE.owner_id = OWNER_ID
+ RESTRICTED_ISSUE.cc_ids.append(CC_ID)
+ RESTRICTED_ISSUE.approval_values.append(
+ tracker_pb2.ApprovalValue(approver_ids=[APPROVER_ID])
+ )
+ RESTRICTED_ISSUE.labels.append('Restrict-View-Commit')
+
+ RESTRICTED_ISSUE2 = tracker_pb2.Issue()
+ RESTRICTED_ISSUE2.reporter_id = REPORTER_ID
+ # RESTRICTED_ISSUE2 has no owner
+ RESTRICTED_ISSUE2.cc_ids.append(CC_ID)
+ RESTRICTED_ISSUE2.labels.append('Restrict-View-Commit')
+
+ RESTRICTED_ISSUE3 = tracker_pb2.Issue()
+ RESTRICTED_ISSUE3.reporter_id = REPORTER_ID
+ RESTRICTED_ISSUE3.owner_id = OWNER_ID
+ # Restrict to a permission that no one has.
+ RESTRICTED_ISSUE3.labels.append('Restrict-EditIssue-Foo')
+
+ PROJECT = project_pb2.Project()
+
+ ADMIN_PERMS = permissions.ADMIN_PERMISSIONSET
+ PERMS = permissions.EMPTY_PERMISSIONSET
+
+ def testUpdateIssuePermissions_Normal(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.REGULAR_ISSUE, {})
+
+ self.assertEqual(
+ ['addissuecomment',
+ 'commit',
+ 'createissue',
+ 'deleteown',
+ 'editissue',
+ 'flagspam',
+ 'setstar',
+ 'verdictspam',
+ 'view',
+ 'viewcontributorlist',
+ 'viewinboundmessages',
+ 'viewquota'],
+ sorted(perms.perm_names))
+
+ def testUpdateIssuePermissions_FromConfig(self):
+ config = tracker_pb2.ProjectIssueConfig(
+ field_defs=[tracker_pb2.FieldDef(field_id=123, grants_perm='Granted')])
+ issue = tracker_pb2.Issue(
+ field_values=[tracker_pb2.FieldValue(field_id=123, user_id=111)])
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, issue, {111},
+ config=config)
+ self.assertIn('granted', perms.perm_names)
+
+ def testUpdateIssuePermissions_ExtraPerms(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(999)
+ project.extra_perms.append(
+ project_pb2.Project.ExtraPerms(member_id=999, perms=['EditIssue']))
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, project,
+ self.REGULAR_ISSUE, {999})
+ self.assertIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_ExtraPermsAreSubjectToRestrictions(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(999)
+ project.extra_perms.append(
+ project_pb2.Project.ExtraPerms(member_id=999, perms=['EditIssue']))
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, project,
+ self.RESTRICTED_ISSUE3, {999})
+ self.assertNotIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_GrantedPermsAreNotSubjectToRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE3,
+ {}, granted_perms=['EditIssue'])
+ self.assertIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_RespectConsiderRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.ADMIN_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE3,
+ {})
+ self.assertIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_RestrictionsAreConsideredIndividually(self):
+ issue = tracker_pb2.Issue(
+ labels=[
+ 'Restrict-Perm1-Perm2',
+ 'Restrict-Perm2-Perm3'])
+ perms = permissions.UpdateIssuePermissions(
+ permissions.PermissionSet(['Perm1', 'Perm2', 'View']),
+ self.PROJECT, issue, {})
+ self.assertIn('perm1', perms.perm_names)
+ self.assertNotIn('perm2', perms.perm_names)
+
+ def testUpdateIssuePermissions_DeletedNoPermissions(self):
+ issue = tracker_pb2.Issue(
+ labels=['Restrict-View-Foo'],
+ deleted=True)
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT, issue, {})
+ self.assertEqual([], sorted(perms.perm_names))
+
+ def testUpdateIssuePermissions_ViewDeleted(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.DELETED_ISSUE, {})
+ self.assertEqual(['view'], sorted(perms.perm_names))
+
+ def testUpdateIssuePermissions_ViewAndDeleteDeleted(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.OWNER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.DELETED_ISSUE, {})
+ self.assertEqual(['deleteissue', 'view'], sorted(perms.perm_names))
+
+ def testUpdateIssuePermissions_ViewRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE, {})
+ self.assertNotIn('view', perms.perm_names)
+
+ def testUpdateIssuePermissions_RolesBypassViewRestrictions(self):
+ for role in {OWNER_ID, REPORTER_ID, CC_ID, APPROVER_ID}:
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE,
+ {role})
+ self.assertIn('view', perms.perm_names)
+
+ def testUpdateIssuePermissions_RolesAllowViewingDeleted(self):
+ issue = tracker_pb2.Issue(
+ reporter_id=REPORTER_ID,
+ owner_id=OWNER_ID,
+ cc_ids=[CC_ID],
+ approval_values=[tracker_pb2.ApprovalValue(approver_ids=[APPROVER_ID])],
+ labels=['Restrict-View-Foo'],
+ deleted=True)
+ for role in {OWNER_ID, REPORTER_ID, CC_ID, APPROVER_ID}:
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, issue, {role})
+ self.assertIn('view', perms.perm_names)
+
+ def testUpdateIssuePermissions_GrantedViewPermission(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE,
+ {}, ['commit'])
+ self.assertIn('view', perms.perm_names)
+
+ def testUpdateIssuePermissions_EditRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.RESTRICTED_ISSUE3, {REPORTER_ID, CC_ID, APPROVER_ID})
+ self.assertNotIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_OwnerBypassEditRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.RESTRICTED_ISSUE3, {OWNER_ID})
+ self.assertIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_CustomPermissionGrantsEditPermission(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(999)
+ project.extra_perms.append(
+ project_pb2.Project.ExtraPerms(member_id=999, perms=['Foo']))
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, project,
+ self.RESTRICTED_ISSUE3, {999})
+ self.assertIn('editissue', perms.perm_names)
+
+ def testCanViewIssue_Deleted(self):
+ self.assertFalse(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.DELETED_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.DELETED_ISSUE, allow_viewing_deleted=True))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ def testCanViewIssue_Regular(self):
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID},
+ permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.USER_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ set(), permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ def testCanViewIssue_Restricted(self):
+ # Project owner can always view issue.
+ self.assertTrue(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Member can view because they have Commit perm.
+ self.assertTrue(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Contributors normally do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Non-members do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.USER_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Anon user's do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+
+ def testCanViewIssue_RestrictedParticipants(self):
+ # Reporter can always view issue
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Issue owner can always view issue
+ self.assertTrue(permissions.CanViewIssue(
+ {OWNER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # CC'd user can always view issue
+ self.assertTrue(permissions.CanViewIssue(
+ {CC_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Non-participants cannot view issue if they don't have the needed perm.
+ self.assertFalse(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Anon user's do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Anon user's cannot match owner 0.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE2))
+ # Approvers can always view issue
+ self.assertTrue(permissions.CanViewIssue(
+ {APPROVER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+
+ def testCannotViewIssueIfCannotViewProject(self):
+ """Cross-project search should not be a backdoor to viewing issues."""
+ # Reporter cannot view issue if they not long have access to the project.
+ self.assertFalse(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.EMPTY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ # Issue owner cannot always view issue
+ self.assertFalse(permissions.CanViewIssue(
+ {OWNER_ID}, permissions.EMPTY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ # CC'd user cannot always view issue
+ self.assertFalse(permissions.CanViewIssue(
+ {CC_ID}, permissions.EMPTY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ # Non-participants cannot view issue if they don't have the needed perm.
+ self.assertFalse(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.EMPTY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ # Anon user's do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.EMPTY_PERMISSIONSET, self.PROJECT,
+ self.REGULAR_ISSUE))
+ # Anon user's cannot match owner 0.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.EMPTY_PERMISSIONSET, self.PROJECT,
+ self.REGULAR_ISSUE))
+
+ def testCanEditIssue(self):
+ # Anon users cannot edit issues.
+ self.assertFalse(permissions.CanEditIssue(
+ {}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ # Non-members and contributors cannot edit issues,
+ # even if they reported them.
+ self.assertFalse(permissions.CanEditIssue(
+ {REPORTER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertFalse(permissions.CanEditIssue(
+ {REPORTER_ID}, permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ # Project committers and project owners can edit issues, regardless
+ # of their role in the issue.
+ self.assertTrue(permissions.CanEditIssue(
+ {REPORTER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {OWNER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {OWNER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ def testCanEditIssue_Restricted(self):
+ # Anon users cannot edit restricted issues.
+ self.assertFalse(permissions.CanEditIssue(
+ {}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE3))
+
+ # Project committers cannot edit issues with a restriction to a custom
+ # permission that they don't have.
+ self.assertFalse(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE3))
+
+ # *Issue* owners can always edit the issues that they own, even if
+ # those issues are restricted to perms that they don't have.
+ self.assertTrue(permissions.CanEditIssue(
+ {OWNER_ID}, permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE3))
+
+ # Project owners can always edit, they cannot lock themselves out.
+ self.assertTrue(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE3))
+
+ # A committer with edit permission but not view permission
+ # should not be able to edit the issue.
+ self.assertFalse(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE2))
+
+ def testCanCommentIssue_HasPerm(self):
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([permissions.ADD_ISSUE_COMMENT]),
+ None, None))
+ self.assertFalse(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, None))
+
+ def testCanCommentIssue_HasExtraPerm(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(111)
+ extra_perm = project_pb2.Project.ExtraPerms(
+ member_id=111, perms=[permissions.ADD_ISSUE_COMMENT])
+ project.extra_perms.append(extra_perm)
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ project, None))
+ self.assertFalse(permissions.CanCommentIssue(
+ {222}, permissions.PermissionSet([]),
+ project, None))
+
+ def testCanCommentIssue_Restricted(self):
+ issue = tracker_pb2.Issue(labels=['Restrict-AddIssueComment-CoreTeam'])
+ # User is granted exactly the perm they need specifically in this issue.
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, issue, granted_perms=['addissuecomment']))
+ # User is granted CoreTeam, which satifies the restriction, and allows
+ # them to use the AddIssueComment permission that they have and would
+ # normally be able to use in an unrestricted issue.
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([permissions.ADD_ISSUE_COMMENT]),
+ None, issue, granted_perms=['coreteam']))
+ # User was granted CoreTeam, but never had AddIssueComment.
+ self.assertFalse(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, issue, granted_perms=['coreteam']))
+ # User has AddIssueComment, but cannot satisfy restriction.
+ self.assertFalse(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([permissions.ADD_ISSUE_COMMENT]),
+ None, issue))
+
+ def testCanCommentIssue_Granted(self):
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, None, granted_perms=['addissuecomment']))
+ self.assertFalse(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, None))
+
+ def testCanUpdateApprovalStatus_Approver(self):
+ # restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [222], tracker_pb2.ApprovalStatus.APPROVED))
+
+ # non-restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [222], tracker_pb2.ApprovalStatus.NEEDS_REVIEW))
+
+ def testCanUpdateApprovalStatus_SiteAdmin(self):
+ # restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {444}, permissions.PermissionSet([permissions.EDIT_ISSUE_APPROVAL]),
+ self.PROJECT, [222], tracker_pb2.ApprovalStatus.NOT_APPROVED))
+
+ # non-restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {444}, permissions.PermissionSet([permissions.EDIT_ISSUE_APPROVAL]),
+ self.PROJECT, [222], tracker_pb2.ApprovalStatus.NEEDS_REVIEW))
+
+ def testCanUpdateApprovalStatus_NonApprover(self):
+ # non-restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [333], tracker_pb2.ApprovalStatus.NEED_INFO))
+
+ # restricted status
+ self.assertFalse(permissions.CanUpdateApprovalStatus(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [333], tracker_pb2.ApprovalStatus.NA))
+
+ def testCanUpdateApprovers_Approver(self):
+ self.assertTrue(permissions.CanUpdateApprovers(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [222]))
+
+ def testCanUpdateApprovers_SiteAdmins(self):
+ self.assertTrue(permissions.CanUpdateApprovers(
+ {444}, permissions.PermissionSet([permissions.EDIT_ISSUE_APPROVAL]),
+ self.PROJECT, [222]))
+
+ def testCanUpdateApprovers_NonApprover(self):
+ self.assertFalse(permissions.CanUpdateApprovers(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [333]))
+
+ def testCanViewComponentDef_ComponentAdmin(self):
+ cd = tracker_pb2.ComponentDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanViewComponentDef(
+ {111}, perms, None, cd))
+ self.assertFalse(permissions.CanViewComponentDef(
+ {999}, perms, None, cd))
+
+ def testCanViewComponentDef_NormalUser(self):
+ cd = tracker_pb2.ComponentDef()
+ self.assertTrue(permissions.CanViewComponentDef(
+ {111}, permissions.PermissionSet([permissions.VIEW]),
+ None, cd))
+ self.assertFalse(permissions.CanViewComponentDef(
+ {111}, permissions.PermissionSet([]),
+ None, cd))
+
+ def testCanEditComponentDef_ComponentAdmin(self):
+ cd = tracker_pb2.ComponentDef(admin_ids=[111], path='Whole')
+ sub_cd = tracker_pb2.ComponentDef(admin_ids=[222], path='Whole>Part')
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.component_defs.append(cd)
+ config.component_defs.append(sub_cd)
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditComponentDef(
+ {111}, perms, None, cd, config))
+ self.assertFalse(permissions.CanEditComponentDef(
+ {222}, perms, None, cd, config))
+ self.assertFalse(permissions.CanEditComponentDef(
+ {999}, perms, None, cd, config))
+ self.assertTrue(permissions.CanEditComponentDef(
+ {111}, perms, None, sub_cd, config))
+ self.assertTrue(permissions.CanEditComponentDef(
+ {222}, perms, None, sub_cd, config))
+ self.assertFalse(permissions.CanEditComponentDef(
+ {999}, perms, None, sub_cd, config))
+
+ def testCanEditComponentDef_ProjectOwners(self):
+ cd = tracker_pb2.ComponentDef(path='Whole')
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.component_defs.append(cd)
+ self.assertTrue(permissions.CanEditComponentDef(
+ {111}, permissions.PermissionSet([permissions.EDIT_PROJECT]),
+ None, cd, config))
+ self.assertFalse(permissions.CanEditComponentDef(
+ {111}, permissions.PermissionSet([]),
+ None, cd, config))
+
+ def testCanViewFieldDef_FieldAdmin(self):
+ fd = tracker_pb2.FieldDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanViewFieldDef(
+ {111}, perms, None, fd))
+ self.assertFalse(permissions.CanViewFieldDef(
+ {999}, perms, None, fd))
+
+ def testCanViewFieldDef_NormalUser(self):
+ fd = tracker_pb2.FieldDef()
+ self.assertTrue(permissions.CanViewFieldDef(
+ {111}, permissions.PermissionSet([permissions.VIEW]),
+ None, fd))
+ self.assertFalse(permissions.CanViewFieldDef(
+ {111}, permissions.PermissionSet([]),
+ None, fd))
+
+ def testCanEditFieldDef_FieldAdmin(self):
+ fd = tracker_pb2.FieldDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditFieldDef(
+ {111}, perms, None, fd))
+ self.assertFalse(permissions.CanEditFieldDef(
+ {999}, perms, None, fd))
+
+ def testCanEditFieldDef_ProjectOwners(self):
+ fd = tracker_pb2.FieldDef()
+ self.assertTrue(permissions.CanEditFieldDef(
+ {111}, permissions.PermissionSet([permissions.EDIT_PROJECT]),
+ None, fd))
+ self.assertFalse(permissions.CanEditFieldDef(
+ {111}, permissions.PermissionSet([]),
+ None, fd))
+
+ def testCanEditValueForFieldDef_NotRestrictedField(self):
+ fd = tracker_pb2.FieldDef()
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditValueForFieldDef({111}, perms, None, fd))
+
+ def testCanEditValueForFieldDef_RestrictedFieldEditor(self):
+ fd = tracker_pb2.FieldDef(is_restricted_field=True, editor_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditValueForFieldDef({111}, perms, None, fd))
+ self.assertFalse(
+ permissions.CanEditValueForFieldDef({999}, perms, None, fd))
+
+ def testCanEditValueForFieldDef_RestrictedFieldAdmin(self):
+ fd = tracker_pb2.FieldDef(is_restricted_field=True, admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditValueForFieldDef({111}, perms, None, fd))
+ self.assertFalse(
+ permissions.CanEditValueForFieldDef({999}, perms, None, fd))
+
+ def testCanEditValueForFieldDef_ProjectOwners(self):
+ fd = tracker_pb2.FieldDef(is_restricted_field=True)
+ self.assertTrue(
+ permissions.CanEditValueForFieldDef(
+ {111}, permissions.PermissionSet([permissions.EDIT_PROJECT]), None,
+ fd))
+ self.assertFalse(
+ permissions.CanEditValueForFieldDef(
+ {111}, permissions.PermissionSet([]), None, fd))
+
+ def testCanViewTemplate_TemplateAdmin(self):
+ td = tracker_pb2.TemplateDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanViewTemplate(
+ {111}, perms, None, td))
+ self.assertFalse(permissions.CanViewTemplate(
+ {999}, perms, None, td))
+
+ def testCanViewTemplate_MembersOnly(self):
+ td = tracker_pb2.TemplateDef(members_only=True)
+ project = project_pb2.Project(committer_ids=[111])
+ self.assertTrue(permissions.CanViewTemplate(
+ {111}, permissions.PermissionSet([]),
+ project, td))
+ self.assertFalse(permissions.CanViewTemplate(
+ {999}, permissions.PermissionSet([]),
+ project, td))
+
+ def testCanViewTemplate_AnyoneWhoCanViewProject(self):
+ td = tracker_pb2.TemplateDef()
+ self.assertTrue(permissions.CanViewTemplate(
+ {111}, permissions.PermissionSet([permissions.VIEW]),
+ None, td))
+ self.assertFalse(permissions.CanViewTemplate(
+ {111}, permissions.PermissionSet([]),
+ None, td))
+
+ def testCanEditTemplate_TemplateAdmin(self):
+ td = tracker_pb2.TemplateDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditTemplate(
+ {111}, perms, None, td))
+ self.assertFalse(permissions.CanEditTemplate(
+ {999}, perms, None, td))
+
+ def testCanEditTemplate_ProjectOwners(self):
+ td = tracker_pb2.TemplateDef()
+ self.assertTrue(permissions.CanEditTemplate(
+ {111}, permissions.PermissionSet([permissions.EDIT_PROJECT]),
+ None, td))
+ self.assertFalse(permissions.CanEditTemplate(
+ {111}, permissions.PermissionSet([]),
+ None, td))
+
+ def testCanViewHotlist_Private(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.is_private = True
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+
+ self.assertTrue(permissions.CanViewHotlist({222}, self.PERMS, hotlist))
+ self.assertTrue(permissions.CanViewHotlist({111, 333}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanViewHotlist({111, 333}, self.ADMIN_PERMS, hotlist))
+ self.assertFalse(
+ permissions.CanViewHotlist({333, 444}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanViewHotlist({333, 444}, self.ADMIN_PERMS, hotlist))
+
+ def testCanViewHotlist_Public(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.is_private = False
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+
+ self.assertTrue(permissions.CanViewHotlist({222}, self.PERMS, hotlist))
+ self.assertTrue(permissions.CanViewHotlist({111, 333}, self.PERMS, hotlist))
+ self.assertTrue(permissions.CanViewHotlist({333, 444}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanViewHotlist({333, 444}, self.ADMIN_PERMS, hotlist))
+
+ def testCanEditHotlist(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+
+ self.assertTrue(permissions.CanEditHotlist({222}, self.PERMS, hotlist))
+ self.assertTrue(permissions.CanEditHotlist({111, 333}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanEditHotlist({111, 333}, self.ADMIN_PERMS, hotlist))
+ self.assertFalse(
+ permissions.CanEditHotlist({333, 444}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanEditHotlist({333, 444}, self.ADMIN_PERMS, hotlist))
+
+ def testCanAdministerHotlist(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+
+ self.assertFalse(
+ permissions.CanAdministerHotlist({222}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanAdministerHotlist({111, 333}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanAdministerHotlist({111, 333}, self.ADMIN_PERMS, hotlist))
+ self.assertFalse(
+ permissions.CanAdministerHotlist({333, 444}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanAdministerHotlist({333, 444}, self.ADMIN_PERMS, hotlist))
diff --git a/framework/test/profiler_test.py b/framework/test/profiler_test.py
new file mode 100644
index 0000000..3cc7e85
--- /dev/null
+++ b/framework/test/profiler_test.py
@@ -0,0 +1,138 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Test for monorail.framework.profiler."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import profiler
+
+
+class MockPatchResponse(object):
+ def execute(self):
+ pass
+
+
+class MockCloudTraceProjects(object):
+ def __init__(self):
+ self.patch_response = MockPatchResponse()
+ self.project_id = None
+ self.body = None
+
+ def patchTraces(self, projectId, body):
+ self.project_id = projectId
+ self.body = body
+ return self.patch_response
+
+
+class MockCloudTraceApi(object):
+ def __init__(self):
+ self.mock_projects = MockCloudTraceProjects()
+
+ def projects(self):
+ return self.mock_projects
+
+
+class ProfilerTest(unittest.TestCase):
+
+ def testTopLevelPhase(self):
+ prof = profiler.Profiler()
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.current_phase.parent, None)
+ self.assertEqual(prof.current_phase, prof.top_phase)
+ self.assertEqual(prof.next_color, 0)
+
+ def testSinglePhase(self):
+ prof = profiler.Profiler()
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ with prof.Phase('test'):
+ self.assertEqual(prof.current_phase.name, 'test')
+ self.assertEqual(prof.current_phase.parent.name, 'overall profile')
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.next_color, 1)
+
+ def testSinglePhase_SuperLongName(self):
+ prof = profiler.Profiler()
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ long_name = 'x' * 1000
+ with prof.Phase(long_name):
+ self.assertEqual(
+ 'x' * profiler.MAX_PHASE_NAME_LENGTH, prof.current_phase.name)
+
+ def testSubphaseExecption(self):
+ prof = profiler.Profiler()
+ try:
+ with prof.Phase('foo'):
+ with prof.Phase('bar'):
+ pass
+ with prof.Phase('baz'):
+ raise Exception('whoops')
+ except Exception as e:
+ self.assertEqual(e.message, 'whoops')
+ finally:
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.top_phase.subphases[0].subphases[1].name, 'baz')
+
+ def testSpanJson(self):
+ mock_trace_api = MockCloudTraceApi()
+ mock_trace_context = '1234/5678;xxxxx'
+
+ prof = profiler.Profiler(mock_trace_context, mock_trace_api)
+ with prof.Phase('foo'):
+ with prof.Phase('bar'):
+ pass
+ with prof.Phase('baz'):
+ pass
+
+ # Shouldn't this be automatic?
+ prof.current_phase.End()
+
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.top_phase.subphases[0].subphases[1].name, 'baz')
+ span_json = prof.top_phase.SpanJson()
+ self.assertEqual(len(span_json), 4)
+
+ for span in span_json:
+ self.assertTrue(span['endTime'] > span['startTime'])
+
+ # pylint: disable=unbalanced-tuple-unpacking
+ span1, span2, span3, span4 = span_json
+
+ self.assertEqual(span1['name'], 'overall profile')
+ self.assertEqual(span2['name'], 'foo')
+ self.assertEqual(span3['name'], 'bar')
+ self.assertEqual(span4['name'], 'baz')
+
+ self.assertTrue(span1['startTime'] < span2['startTime'])
+ self.assertTrue(span1['startTime'] < span3['startTime'])
+ self.assertTrue(span1['startTime'] < span4['startTime'])
+
+ self.assertTrue(span1['endTime'] > span2['endTime'])
+ self.assertTrue(span1['endTime'] > span3['endTime'])
+ self.assertTrue(span1['endTime'] > span4['endTime'])
+
+
+ def testReportCloudTrace(self):
+ mock_trace_api = MockCloudTraceApi()
+ mock_trace_context = '1234/5678;xxxxx'
+
+ prof = profiler.Profiler(mock_trace_context, mock_trace_api)
+ with prof.Phase('foo'):
+ with prof.Phase('bar'):
+ pass
+ with prof.Phase('baz'):
+ pass
+
+ # Shouldn't this be automatic?
+ prof.current_phase.End()
+
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.top_phase.subphases[0].subphases[1].name, 'baz')
+
+ prof.ReportTrace()
+ self.assertEqual(mock_trace_api.mock_projects.project_id, 'testing-app')
diff --git a/framework/test/ratelimiter_test.py b/framework/test/ratelimiter_test.py
new file mode 100644
index 0000000..b351f8c
--- /dev/null
+++ b/framework/test/ratelimiter_test.py
@@ -0,0 +1,398 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for RateLimiter.
+"""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import unittest
+
+from google.appengine.api import memcache
+from google.appengine.ext import testbed
+
+import mox
+import os
+import settings
+
+from framework import ratelimiter
+from services import service_manager
+from services import client_config_svc
+from testing import fake
+from testing import testing_helpers
+
+
+class RateLimiterTest(unittest.TestCase):
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_user_stub()
+
+ self.mox = mox.Mox()
+ self.services = service_manager.Services(
+ config=fake.ConfigService(),
+ issue=fake.IssueService(),
+ user=fake.UserService(),
+ project=fake.ProjectService(),
+ )
+ self.project = self.services.project.TestAddProject('proj', project_id=987)
+
+ self.ratelimiter = ratelimiter.RateLimiter()
+ ratelimiter.COUNTRY_LIMITS = {}
+ os.environ['USER_EMAIL'] = ''
+ settings.ratelimiting_enabled = True
+ ratelimiter.DEFAULT_LIMIT = 10
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+ # settings.ratelimiting_enabled = True
+
+ def testCheckStart_pass(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ self.ratelimiter.CheckStart(request)
+ # Should not throw an exception.
+
+ def testCheckStart_fail(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ now = 0.0
+ cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
+ values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
+ cachekeys in cachekeysets]
+ for value in values:
+ memcache.add_multi(value)
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ self.ratelimiter.CheckStart(request, now)
+
+ def testCheckStart_expiredEntries(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ now = 0.0
+ cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
+ values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
+ cachekeys in cachekeysets]
+ for value in values:
+ memcache.add_multi(value)
+
+ now = now + 2 * ratelimiter.EXPIRE_AFTER_SECS
+ self.ratelimiter.CheckStart(request, now)
+ # Should not throw an exception.
+
+ def testCheckStart_repeatedCalls(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ now = 0.0
+
+ # Call CheckStart once every minute. Should be ok.
+ for _ in range(ratelimiter.N_MINUTES):
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 120.0
+
+ # Call CheckStart more than DEFAULT_LIMIT times in the same minute.
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for _ in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
+ now = now + 0.001
+ self.ratelimiter.CheckStart(request, now)
+
+ def testCheckStart_differentIPs(self):
+ now = 0.0
+
+ ratelimiter.COUNTRY_LIMITS = {}
+ # Exceed DEFAULT_LIMIT calls, but vary remote_addr so different
+ # remote addresses aren't ratelimited together.
+ for m in range(ratelimiter.DEFAULT_LIMIT * 2):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.%d' % (m % 16)
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ # Exceed the limit, but only for one IP address. The
+ # others should be fine.
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ # Now proceed to make requests for all of the other IP
+ # addresses besides .0.
+ for m in range(ratelimiter.DEFAULT_LIMIT * 2):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ # Skip .0 since it's already exceeded the limit.
+ request.remote_addr = '192.168.1.%d' % (m + 1)
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ def testCheckStart_sameIPDifferentUserIDs(self):
+ # Behind a NAT, e.g.
+ now = 0.0
+
+ # Exceed DEFAULT_LIMIT calls, but vary user_id so different
+ # users behind the same IP aren't ratelimited together.
+ for m in range(ratelimiter.DEFAULT_LIMIT * 2):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.remote_addr = '192.168.1.0'
+ os.environ['USER_EMAIL'] = '%s@example.com' % m
+ request.headers['X-AppEngine-Country'] = 'US'
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ # Exceed the limit, but only for one userID+IP address. The
+ # others should be fine.
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ os.environ['USER_EMAIL'] = '42@example.com'
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ # Now proceed to make requests for other user IDs
+ # besides 42.
+ for m in range(ratelimiter.DEFAULT_LIMIT * 2):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ # Skip .0 since it's already exceeded the limit.
+ request.remote_addr = '192.168.1.0'
+ os.environ['USER_EMAIL'] = '%s@example.com' % (43 + m)
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ def testCheckStart_ratelimitingDisabled(self):
+ settings.ratelimiting_enabled = False
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ now = 0.0
+
+ # Call CheckStart a lot. Should be ok.
+ for _ in range(ratelimiter.DEFAULT_LIMIT):
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ def testCheckStart_perCountryLoggedOutLimit(self):
+ ratelimiter.COUNTRY_LIMITS['US'] = 10
+
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
+ request.remote_addr = '192.168.1.1'
+ now = 0.0
+
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
+ self.ratelimiter.CheckStart(request, now)
+ # Vary remote address to make sure the limit covers
+ # the whole country, regardless of IP.
+ request.remote_addr = '192.168.1.%d' % m
+ now = now + 0.001
+
+ # CheckStart for a country that isn't covered by a country-specific limit.
+ request.headers['X-AppEngine-Country'] = 'UK'
+ for m in range(11):
+ self.ratelimiter.CheckStart(request, now)
+ # Vary remote address to make sure the limit covers
+ # the whole country, regardless of IP.
+ request.remote_addr = '192.168.1.%d' % m
+ now = now + 0.001
+
+ # And regular rate limits work per-IP.
+ request.remote_addr = '192.168.1.1'
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
+ self.ratelimiter.CheckStart(request, now)
+ # Vary remote address to make sure the limit covers
+ # the whole country, regardless of IP.
+ now = now + 0.001
+
+ def testCheckEnd_SlowRequest(self):
+ """We count one request for each 1000ms."""
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
+ request.remote_addr = '192.168.1.1'
+ start_time = 0.0
+
+ # Send some requests, all under the limit.
+ for _ in range(ratelimiter.DEFAULT_LIMIT-1):
+ start_time = start_time + 0.001
+ self.ratelimiter.CheckStart(request, start_time)
+ now = start_time + 0.010
+ self.ratelimiter.CheckEnd(request, now, start_time)
+
+ # Now issue some more request, this time taking long
+ # enough to get the cost threshold penalty.
+ # Fast forward enough to impact a later bucket than the
+ # previous requests.
+ start_time = now + 120.0
+ self.ratelimiter.CheckStart(request, start_time)
+
+ # Take longer than the threshold to process the request.
+ elapsed_ms = settings.ratelimiting_ms_per_count * 2
+ now = start_time + elapsed_ms / 1000
+
+ # The request finished, taking long enough to count as two.
+ self.ratelimiter.CheckEnd(request, now, start_time)
+
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ # One more request after the expensive query should
+ # throw an excpetion.
+ self.ratelimiter.CheckStart(request, start_time)
+
+ def testCheckEnd_FastRequest(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers[ratelimiter.COUNTRY_HEADER] = 'asdasd'
+ request.remote_addr = '192.168.1.1'
+ start_time = 0.0
+
+ # Send some requests, all under the limit.
+ for _ in range(ratelimiter.DEFAULT_LIMIT):
+ self.ratelimiter.CheckStart(request, start_time)
+ now = start_time + 0.01
+ self.ratelimiter.CheckEnd(request, now, start_time)
+ start_time = now + 0.01
+
+
+class ApiRateLimiterTest(unittest.TestCase):
+
+ def setUp(self):
+ settings.ratelimiting_enabled = True
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.services = service_manager.Services(
+ config=fake.ConfigService(),
+ issue=fake.IssueService(),
+ user=fake.UserService(),
+ project=fake.ProjectService(),
+ )
+
+ self.client_id = '123456789'
+ self.client_email = 'test@example.com'
+
+ self.ratelimiter = ratelimiter.ApiRateLimiter()
+ settings.api_ratelimiting_enabled = True
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testCheckStart_Allowed(self):
+ now = 0.0
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+ self.ratelimiter.CheckStart(self.client_id, None, now)
+ self.ratelimiter.CheckStart(None, None, now)
+ self.ratelimiter.CheckStart('anonymous', None, now)
+
+ def testCheckStart_Rejected(self):
+ now = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, now)
+ values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
+ keyset in keysets]
+ for value in values:
+ memcache.add_multi(value)
+ with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+
+ def testCheckStart_Allowed_HigherQPMSpecified(self):
+ """Client goes over the default, but has a higher QPM set."""
+ now = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, now)
+ qpm_dict = client_config_svc.GetQPMDict()
+ qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM + 10
+ # The client used 1 request more than the default limit in each of the
+ # 5 minutes in our 5 minute sample window, so 5 over to the total.
+ values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
+ keyset in keysets]
+ for value in values:
+ memcache.add_multi(value)
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+ del qpm_dict[self.client_email]
+
+ def testCheckStart_Allowed_LowQPMIgnored(self):
+ """Client specifies a QPM lower than the default and default is used."""
+ now = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, now)
+ qpm_dict = client_config_svc.GetQPMDict()
+ qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
+ values = [{key: ratelimiter.DEFAULT_API_QPM for key in keyset} for
+ keyset in keysets]
+ for value in values:
+ memcache.add_multi(value)
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+ del qpm_dict[self.client_email]
+
+ def testCheckStart_Rejected_LowQPMIgnored(self):
+ """Client specifies a QPM lower than the default and default is used."""
+ now = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, now)
+ qpm_dict = client_config_svc.GetQPMDict()
+ qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
+ values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
+ keyset in keysets]
+ for value in values:
+ memcache.add_multi(value)
+ with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+ del qpm_dict[self.client_email]
+
+ def testCheckEnd(self):
+ start_time = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, start_time)
+
+ now = 0.1
+ self.ratelimiter.CheckEnd(
+ self.client_id, self.client_email, now, start_time)
+ counters = memcache.get_multi(keysets[0])
+ count = sum(counters.values())
+ # No extra cost charged
+ self.assertEqual(0, count)
+
+ elapsed_ms = settings.ratelimiting_ms_per_count * 2
+ now = start_time + elapsed_ms / 1000
+ self.ratelimiter.CheckEnd(
+ self.client_id, self.client_email, now, start_time)
+ counters = memcache.get_multi(keysets[0])
+ count = sum(counters.values())
+ # Extra cost charged
+ self.assertEqual(1, count)
diff --git a/framework/test/reap_test.py b/framework/test/reap_test.py
new file mode 100644
index 0000000..f1a907d
--- /dev/null
+++ b/framework/test/reap_test.py
@@ -0,0 +1,131 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the reap module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import mock
+import mox
+
+from mock import Mock
+
+from framework import reap
+from framework import sql
+from proto import project_pb2
+from services import service_manager
+from services import template_svc
+from testing import fake
+from testing import testing_helpers
+
+
+class ReapTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project_service = fake.ProjectService()
+ self.issue_service = fake.IssueService()
+ self.issue_star_service = fake.IssueStarService()
+ self.config_service = fake.ConfigService()
+ self.features_service = fake.FeaturesService()
+ self.project_star_service = fake.ProjectStarService()
+ self.services = service_manager.Services(
+ project=self.project_service,
+ issue=self.issue_service,
+ issue_star=self.issue_star_service,
+ config=self.config_service,
+ features=self.features_service,
+ project_star=self.project_star_service,
+ template=Mock(spec=template_svc.TemplateService),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+
+ self.proj1_id = 1001
+ self.proj1_issue_id = 111
+ self.proj1 = self.project_service.TestAddProject(
+ name='proj1', project_id=self.proj1_id)
+ self.proj2_id = 1002
+ self.proj2_issue_id = 112
+ self.proj2 = self.project_service.TestAddProject(
+ name='proj2', project_id=self.proj2_id)
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.project_service.project_tbl = self.mox.CreateMock(sql.SQLTableManager)
+ self.issue_service.issue_tbl = self.mox.CreateMock(sql.SQLTableManager)
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def setUpMarkDoomedProjects(self):
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'], limit=1000, state='archived',
+ where=mox.IgnoreArg()).AndReturn([[self.proj1_id]])
+
+ def testMarkDoomedProjects(self):
+ self.setUpMarkDoomedProjects()
+ reaper = reap.Reap('req', 'resp', services=self.services)
+
+ self.mox.ReplayAll()
+ doomed_project_ids = reaper._MarkDoomedProjects(self.cnxn)
+ self.mox.VerifyAll()
+
+ self.assertEqual([self.proj1_id], doomed_project_ids)
+ self.assertEqual(project_pb2.ProjectState.DELETABLE, self.proj1.state)
+ self.assertEqual('DELETABLE_%s' % self.proj1_id, self.proj1.project_name)
+
+ def setUpExpungeParts(self):
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'], limit=100,
+ state='deletable').AndReturn([[self.proj1_id], [self.proj2_id]])
+ self.issue_service.issue_tbl.Select(
+ self.cnxn, cols=['id'], limit=1000,
+ project_id=self.proj1_id).AndReturn([[self.proj1_issue_id]])
+ self.issue_service.issue_tbl.Select(
+ self.cnxn, cols=['id'], limit=1000,
+ project_id=self.proj2_id).AndReturn([[self.proj2_issue_id]])
+
+ def testExpungeDeletableProjects(self):
+ self.setUpExpungeParts()
+ reaper = reap.Reap('req', 'resp', services=self.services)
+
+ self.mox.ReplayAll()
+ expunged_project_ids = reaper._ExpungeDeletableProjects(self.cnxn)
+ self.mox.VerifyAll()
+
+ self.assertEqual([self.proj1_id, self.proj2_id], expunged_project_ids)
+ # Verify all expected expunge methods were called.
+ self.assertEqual(
+ [self.proj1_issue_id, self.proj2_issue_id],
+ self.services.issue_star.expunged_item_ids)
+ self.assertEqual(
+ [self.proj1_issue_id, self.proj2_issue_id],
+ self.services.issue.expunged_issues)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id], self.services.config.expunged_configs)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.features.expunged_saved_queries)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.features.expunged_filter_rules)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.issue.expunged_former_locations)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id], self.services.issue.expunged_local_ids)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.features.expunged_quick_edit)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.project_star.expunged_item_ids)
+ self.assertEqual(0, len(self.services.project.test_projects))
+ self.services.template.ExpungeProjectTemplates.assert_has_calls([
+ mock.call(self.cnxn, 1001),
+ mock.call(self.cnxn, 1002)])
diff --git a/framework/test/redis_utils_test.py b/framework/test/redis_utils_test.py
new file mode 100644
index 0000000..a4128ce
--- /dev/null
+++ b/framework/test/redis_utils_test.py
@@ -0,0 +1,64 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""Tests for the Redis utility module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import fakeredis
+import unittest
+
+from framework import redis_utils
+from proto import features_pb2
+
+
+class RedisHelperTest(unittest.TestCase):
+
+ def testFormatRedisKey(self):
+ redis_key = redis_utils.FormatRedisKey(111)
+ self.assertEqual('111', redis_key)
+ redis_key = redis_utils.FormatRedisKey(222, prefix='foo:')
+ self.assertEqual('foo:222', redis_key)
+ redis_key = redis_utils.FormatRedisKey(333, prefix='bar')
+ self.assertEqual('bar:333', redis_key)
+
+ def testCreateRedisClient(self):
+ self.assertIsNone(redis_utils.connection_pool)
+ redis_client_1 = redis_utils.CreateRedisClient()
+ self.assertIsNotNone(redis_client_1)
+ self.assertIsNotNone(redis_utils.connection_pool)
+ redis_client_2 = redis_utils.CreateRedisClient()
+ self.assertIsNotNone(redis_client_2)
+ self.assertIsNot(redis_client_1, redis_client_2)
+
+ def testConnectionVerification(self):
+ server = fakeredis.FakeServer()
+ client = None
+ self.assertFalse(redis_utils.VerifyRedisConnection(client))
+ server.connected = True
+ client = fakeredis.FakeRedis(server=server)
+ self.assertTrue(redis_utils.VerifyRedisConnection(client))
+ server.connected = False
+ self.assertFalse(redis_utils.VerifyRedisConnection(client))
+
+ def testSerializeDeserializeInt(self):
+ serialized_int = redis_utils.SerializeValue(123)
+ self.assertEqual('123', serialized_int)
+ self.assertEquals(123, redis_utils.DeserializeValue(serialized_int))
+
+ def testSerializeDeserializeStr(self):
+ serialized = redis_utils.SerializeValue('123')
+ self.assertEqual('"123"', serialized)
+ self.assertEquals('123', redis_utils.DeserializeValue(serialized))
+
+ def testSerializeDeserializePB(self):
+ features = features_pb2.Hotlist.HotlistItem(
+ issue_id=7949, rank=0, adder_id=333, date_added=1525)
+ serialized = redis_utils.SerializeValue(
+ features, pb_class=features_pb2.Hotlist.HotlistItem)
+ self.assertIsInstance(serialized, str)
+ deserialized = redis_utils.DeserializeValue(
+ serialized, pb_class=features_pb2.Hotlist.HotlistItem)
+ self.assertIsInstance(deserialized, features_pb2.Hotlist.HotlistItem)
+ self.assertEquals(deserialized, features)
diff --git a/framework/test/registerpages_helpers_test.py b/framework/test/registerpages_helpers_test.py
new file mode 100644
index 0000000..61c489e
--- /dev/null
+++ b/framework/test/registerpages_helpers_test.py
@@ -0,0 +1,59 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for URL handler registration helper functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import webapp2
+
+from framework import registerpages_helpers
+
+
+class SendRedirectInScopeTest(unittest.TestCase):
+
+ def testMakeRedirectInScope_Error(self):
+ self.assertRaises(
+ AssertionError,
+ registerpages_helpers.MakeRedirectInScope, 'no/initial/slash', 'p')
+ self.assertRaises(
+ AssertionError,
+ registerpages_helpers.MakeRedirectInScope, '', 'p')
+
+ def testMakeRedirectInScope_Normal(self):
+ factory = registerpages_helpers.MakeRedirectInScope('/', 'p')
+ # Non-dasher, normal case
+ request = webapp2.Request.blank(
+ path='/p/foo', headers={'Host': 'example.com'})
+ response = webapp2.Response()
+ redirector = factory(request, response)
+ redirector.get()
+ self.assertEqual(response.location, '//example.com/p/foo/')
+ self.assertEqual(response.status, '301 Moved Permanently')
+
+ def testMakeRedirectInScope_Temporary(self):
+ factory = registerpages_helpers.MakeRedirectInScope(
+ '/', 'p', permanent=False)
+ request = webapp2.Request.blank(
+ path='/p/foo', headers={'Host': 'example.com'})
+ response = webapp2.Response()
+ redirector = factory(request, response)
+ redirector.get()
+ self.assertEqual(response.location, '//example.com/p/foo/')
+ self.assertEqual(response.status, '302 Moved Temporarily')
+
+ def testMakeRedirectInScope_KeepQueryString(self):
+ factory = registerpages_helpers.MakeRedirectInScope(
+ '/', 'p', keep_qs=True)
+ request = webapp2.Request.blank(
+ path='/p/foo?q=1', headers={'Host': 'example.com'})
+ response = webapp2.Response()
+ redirector = factory(request, response)
+ redirector.get()
+ self.assertEqual(response.location, '//example.com/p/foo/?q=1')
+ self.assertEqual(response.status, '302 Moved Temporarily')
diff --git a/framework/test/servlet_helpers_test.py b/framework/test/servlet_helpers_test.py
new file mode 100644
index 0000000..a2fe687
--- /dev/null
+++ b/framework/test/servlet_helpers_test.py
@@ -0,0 +1,168 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for servlet base class helper functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from google.appengine.ext import testbed
+
+
+from framework import permissions
+from framework import servlet_helpers
+from proto import project_pb2
+from proto import tracker_pb2
+from testing import testing_helpers
+
+
+class EztDataTest(unittest.TestCase):
+
+ def testGetBannerTime(self):
+ """Tests GetBannerTime method."""
+ timestamp = [2019, 6, 13, 18, 30]
+
+ banner_time = servlet_helpers.GetBannerTime(timestamp)
+ self.assertEqual(1560450600, banner_time)
+
+
+class AssertBasePermissionTest(unittest.TestCase):
+
+ def testAccessGranted(self):
+ _, mr = testing_helpers.GetRequestObjects(path='/hosting')
+ # No exceptions should be raised.
+ servlet_helpers.AssertBasePermission(mr)
+
+ mr.auth.user_id = 123
+ # No exceptions should be raised.
+ servlet_helpers.AssertBasePermission(mr)
+ servlet_helpers.AssertBasePermissionForUser(
+ mr.auth.user_pb, mr.auth.user_view)
+
+ def testBanned(self):
+ _, mr = testing_helpers.GetRequestObjects(path='/hosting')
+ mr.auth.user_pb.banned = 'spammer'
+ self.assertRaises(
+ permissions.BannedUserException,
+ servlet_helpers.AssertBasePermissionForUser,
+ mr.auth.user_pb, mr.auth.user_view)
+ self.assertRaises(
+ permissions.BannedUserException,
+ servlet_helpers.AssertBasePermission, mr)
+
+ def testPlusAddressAccount(self):
+ _, mr = testing_helpers.GetRequestObjects(path='/hosting')
+ mr.auth.user_pb.email = 'mailinglist+spammer@chromium.org'
+ self.assertRaises(
+ permissions.BannedUserException,
+ servlet_helpers.AssertBasePermissionForUser,
+ mr.auth.user_pb, mr.auth.user_view)
+ self.assertRaises(
+ permissions.BannedUserException,
+ servlet_helpers.AssertBasePermission, mr)
+
+ def testNoAccessToProject(self):
+ project = project_pb2.Project()
+ project.project_name = 'proj'
+ project.access = project_pb2.ProjectAccess.MEMBERS_ONLY
+ _, mr = testing_helpers.GetRequestObjects(path='/p/proj/', project=project)
+ mr.perms = permissions.EMPTY_PERMISSIONSET
+ self.assertRaises(
+ permissions.PermissionException,
+ servlet_helpers.AssertBasePermission, mr)
+
+
+FORM_URL = 'http://example.com/issues/form.php'
+
+
+class ComputeIssueEntryURLTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project = project_pb2.Project()
+ self.project.project_name = 'proj'
+ self.config = tracker_pb2.ProjectIssueConfig()
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_user_stub()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testComputeIssueEntryURL_Normal(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues/detail?id=123&q=term',
+ project=self.project)
+
+ url = servlet_helpers.ComputeIssueEntryURL(mr, self.config)
+ self.assertEqual('/p/proj/issues/entry', url)
+
+ def testComputeIssueEntryURL_Customized(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues/detail?id=123&q=term',
+ project=self.project)
+ mr.auth.user_id = 111
+ self.config.custom_issue_entry_url = FORM_URL
+
+ url = servlet_helpers.ComputeIssueEntryURL(mr, self.config)
+ self.assertTrue(url.startswith(FORM_URL))
+ self.assertIn('token=', url)
+ self.assertIn('role=', url)
+ self.assertIn('continue=', url)
+
+class IssueListURLTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project = project_pb2.Project()
+ self.project.project_name = 'proj'
+ self.project.owner_ids = [111]
+ self.config = tracker_pb2.ProjectIssueConfig()
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_user_stub()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testIssueListURL_NotCustomized(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues', project=self.project)
+
+ url = servlet_helpers.IssueListURL(mr, self.config)
+ self.assertEqual('/p/proj/issues/list', url)
+
+ def testIssueListURL_Customized_Nonmember(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues', project=self.project)
+ self.config.member_default_query = 'owner:me'
+
+ url = servlet_helpers.IssueListURL(mr, self.config)
+ self.assertEqual('/p/proj/issues/list', url)
+
+ def testIssueListURL_Customized_Member(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues', project=self.project,
+ user_info={'effective_ids': {111}})
+ self.config.member_default_query = 'owner:me'
+
+ url = servlet_helpers.IssueListURL(mr, self.config)
+ self.assertEqual('/p/proj/issues/list?q=owner%3Ame', url)
+
+ def testIssueListURL_Customized_RetainQS(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues', project=self.project,
+ user_info={'effective_ids': {111}})
+ self.config.member_default_query = 'owner:me'
+
+ url = servlet_helpers.IssueListURL(mr, self.config, query_string='')
+ self.assertEqual('/p/proj/issues/list?q=owner%3Ame', url)
+
+ url = servlet_helpers.IssueListURL(mr, self.config, query_string='q=Pri=1')
+ self.assertEqual('/p/proj/issues/list?q=Pri=1', url)
diff --git a/framework/test/servlet_test.py b/framework/test/servlet_test.py
new file mode 100644
index 0000000..40d5ed2
--- /dev/null
+++ b/framework/test/servlet_test.py
@@ -0,0 +1,474 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for servlet base class module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import mock
+import unittest
+
+from google.appengine.api import app_identity
+from google.appengine.ext import testbed
+
+import webapp2
+
+from framework import framework_constants
+from framework import servlet
+from framework import xsrf
+from proto import project_pb2
+from proto import tracker_pb2
+from proto import user_pb2
+from services import service_manager
+from testing import fake
+from testing import testing_helpers
+
+
+class TestableServlet(servlet.Servlet):
+ """A tiny concrete subclass of abstract class Servlet."""
+
+ def __init__(self, request, response, services=None, do_post_redirect=True):
+ super(TestableServlet, self).__init__(request, response, services=services)
+ self.do_post_redirect = do_post_redirect
+ self.seen_post_data = None
+
+ def ProcessFormData(self, _mr, post_data):
+ self.seen_post_data = post_data
+ if self.do_post_redirect:
+ return '/This/Is?The=Next#Page'
+ else:
+ self.response.write('sending raw data to browser')
+
+
+class ServletTest(unittest.TestCase):
+
+ def setUp(self):
+ services = service_manager.Services(
+ project=fake.ProjectService(),
+ project_star=fake.ProjectStarService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ services.user.TestAddUser('user@example.com', 111)
+ self.page_class = TestableServlet(
+ webapp2.Request.blank('/'), webapp2.Response(), services=services)
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_user_stub()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testDefaultValues(self):
+ self.assertEqual(None, self.page_class._MAIN_TAB_MODE)
+ self.assertTrue(self.page_class._TEMPLATE_PATH.endswith('/templates/'))
+ self.assertEqual(None, self.page_class._PAGE_TEMPLATE)
+
+ def testGatherBaseData(self):
+ project = self.page_class.services.project.TestAddProject(
+ 'testproj', state=project_pb2.ProjectState.LIVE)
+ project.cached_content_timestamp = 12345
+
+ (_request, mr) = testing_helpers.GetRequestObjects(
+ path='/p/testproj/feeds', project=project)
+ nonce = '1a2b3c4d5e6f7g'
+
+ base_data = self.page_class.GatherBaseData(mr, nonce)
+
+ self.assertEqual(base_data['nonce'], nonce)
+ self.assertEqual(base_data['projectname'], 'testproj')
+ self.assertEqual(base_data['project'].cached_content_timestamp, 12345)
+ self.assertEqual(base_data['project_alert'], None)
+
+ self.assertTrue(base_data['currentPageURL'].endswith('/p/testproj/feeds'))
+ self.assertTrue(
+ base_data['currentPageURLEncoded'].endswith('%2Fp%2Ftestproj%2Ffeeds'))
+
+ def testFormHandlerURL(self):
+ self.assertEqual('/edit.do', self.page_class._FormHandlerURL('/'))
+ self.assertEqual(
+ '/something/edit.do',
+ self.page_class._FormHandlerURL('/something/'))
+ self.assertEqual(
+ '/something/edit.do',
+ self.page_class._FormHandlerURL('/something/edit.do'))
+ self.assertEqual(
+ '/something/detail_ezt.do',
+ self.page_class._FormHandlerURL('/something/detail_ezt'))
+
+ def testProcessForm_BadToken(self):
+ user_id = 111
+ token = 'no soup for you'
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ },
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ self.assertRaises(
+ xsrf.TokenIncorrect, self.page_class._DoFormProcessing, request, mr)
+ self.assertEqual(None, self.page_class.seen_post_data)
+
+ def testProcessForm_XhrAllowed_BadToken(self):
+ user_id = 111
+ token = 'no soup for you'
+
+ self.page_class.ALLOW_XHR = True
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ },
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ self.assertRaises(
+ xsrf.TokenIncorrect, self.page_class._DoFormProcessing, request, mr)
+ self.assertEqual(None, self.page_class.seen_post_data)
+
+ def testProcessForm_XhrAllowed_AcceptsPathToken(self):
+ user_id = 111
+ token = xsrf.GenerateToken(user_id, '/we/we/we')
+
+ self.page_class.ALLOW_XHR = True
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ },
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._DoFormProcessing(request, mr)
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+
+ self.assertDictEqual(
+ {
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ }, dict(self.page_class.seen_post_data))
+
+ def testProcessForm_XhrAllowed_AcceptsXhrToken(self):
+ user_id = 111
+ token = xsrf.GenerateToken(user_id, 'xhr')
+
+ self.page_class.ALLOW_XHR = True
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={'yesterday': 'thursday', 'today': 'friday', 'token': token},
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._DoFormProcessing(request, mr)
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+
+ self.assertDictEqual(
+ {
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ }, dict(self.page_class.seen_post_data))
+
+ def testProcessForm_RawResponse(self):
+ user_id = 111
+ token = xsrf.GenerateToken(user_id, '/we/we/we')
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={'yesterday': 'thursday', 'today': 'friday', 'token': token},
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ self.page_class.do_post_redirect = False
+ self.page_class._DoFormProcessing(request, mr)
+ self.assertEqual(
+ 'sending raw data to browser',
+ self.page_class.response.body)
+
+ def testProcessForm_Normal(self):
+ user_id = 111
+ token = xsrf.GenerateToken(user_id, '/we/we/we')
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={'yesterday': 'thursday', 'today': 'friday', 'token': token},
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._DoFormProcessing(request, mr)
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+
+ self.assertDictEqual(
+ {'yesterday': 'thursday', 'today': 'friday', 'token': token},
+ dict(self.page_class.seen_post_data))
+
+ def testCalcProjectAlert(self):
+ project = fake.Project(
+ project_name='alerttest', state=project_pb2.ProjectState.LIVE)
+
+ project_alert = servlet._CalcProjectAlert(project)
+ self.assertEqual(project_alert, None)
+
+ project.state = project_pb2.ProjectState.ARCHIVED
+ project_alert = servlet._CalcProjectAlert(project)
+ self.assertEqual(
+ project_alert,
+ 'Project is archived: read-only by members only.')
+
+ delete_time = int(time.time() + framework_constants.SECS_PER_DAY * 1.5)
+ project.delete_time = delete_time
+ project_alert = servlet._CalcProjectAlert(project)
+ self.assertEqual(project_alert, 'Scheduled for deletion in 1 day.')
+
+ delete_time = int(time.time() + framework_constants.SECS_PER_DAY * 2.5)
+ project.delete_time = delete_time
+ project_alert = servlet._CalcProjectAlert(project)
+ self.assertEqual(project_alert, 'Scheduled for deletion in 2 days.')
+
+ def testCheckForMovedProject_NoRedirect(self):
+ project = fake.Project(
+ project_name='proj', state=project_pb2.ProjectState.LIVE)
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/source/browse/p/adminAdvanced', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ def testCheckForMovedProject_Redirect(self):
+ project = fake.Project(project_name='proj', moved_to='http://example.com')
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._CheckForMovedProject(mr, request)
+ self.assertEqual(302, cm.exception.code) # redirect because project moved
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/source/browse/p/adminAdvanced', project=project)
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._CheckForMovedProject(mr, request)
+ self.assertEqual(302, cm.exception.code) # redirect because project moved
+
+ def testCheckForMovedProject_AdminAdvanced(self):
+ """We do not redirect away from the page that edits project state."""
+ project = fake.Project(project_name='proj', moved_to='http://example.com')
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/adminAdvanced', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/adminAdvanced?ts=123234', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/adminAdvanced.do', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_RedirBrandedProject(self):
+ """We redirect for a branded project if the user typed a different host."""
+ project = fake.Project(project_name='proj')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path', project=project)
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+ self.assertEqual('https://branded.example.com/p/proj/path?redir=1',
+ cm.exception.location)
+
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path?query', project=project)
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+ self.assertEqual('https://branded.example.com/p/proj/path?query&redir=1',
+ cm.exception.location)
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_AvoidRedirLoops(self):
+ """Don't redirect for a branded project if already redirected."""
+ project = fake.Project(project_name='proj')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path?redir=1', project=project)
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_NonProjectPage(self):
+ """Don't redirect for a branded project if not in any project."""
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/u/user@example.com')
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, None)
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_AlreadyOnBrandedHost(self):
+ """Don't redirect for a branded project if already on branded domain."""
+ project = fake.Project(project_name='proj')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path', project=project)
+ request.host = 'branded.example.com'
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_Localhost(self):
+ """Don't redirect for a branded project on localhost."""
+ project = fake.Project(project_name='proj')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path', project=project)
+ request.host = 'localhost:8080'
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+
+ request.host = '0.0.0.0:8080'
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_NotBranded(self):
+ """Don't redirect for a non-branded project."""
+ project = fake.Project(project_name='other')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/other/path?query', project=project)
+ request.host = 'branded.example.com' # But other project is unbranded.
+
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'other')
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+ self.assertEqual('https://bugs.chromium.org/p/other/path?query&redir=1',
+ cm.exception.location)
+
+ def testGatherHelpData_Normal(self):
+ project = fake.Project(project_name='proj')
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual(None, help_data['cue'])
+ self.assertEqual(None, help_data['account_cue'])
+
+ def testGatherHelpData_VacationReminder(self):
+ project = fake.Project(project_name='proj')
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ mr.auth.user_id = 111
+ mr.auth.user_pb.vacation_message = 'Gone skiing'
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual('you_are_on_vacation', help_data['cue'])
+
+ self.page_class.services.user.SetUserPrefs(
+ 'cnxn', 111,
+ [user_pb2.UserPrefValue(name='you_are_on_vacation', value='true')])
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual(None, help_data['cue'])
+ self.assertEqual(None, help_data['account_cue'])
+
+ def testGatherHelpData_YouAreBouncing(self):
+ project = fake.Project(project_name='proj')
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ mr.auth.user_id = 111
+ mr.auth.user_pb.email_bounce_timestamp = 1497647529
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual('your_email_bounced', help_data['cue'])
+
+ self.page_class.services.user.SetUserPrefs(
+ 'cnxn', 111,
+ [user_pb2.UserPrefValue(name='your_email_bounced', value='true')])
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual(None, help_data['cue'])
+ self.assertEqual(None, help_data['account_cue'])
+
+ def testGatherHelpData_ChildAccount(self):
+ """Display a warning when user is signed in to a child account."""
+ project = fake.Project(project_name='proj')
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ mr.auth.user_pb.linked_parent_id = 111
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual(None, help_data['cue'])
+ self.assertEqual('switch_to_parent_account', help_data['account_cue'])
+ self.assertEqual('user@example.com', help_data['parent_email'])
+
+ def testGatherDebugData_Visibility(self):
+ project = fake.Project(
+ project_name='testtest', state=project_pb2.ProjectState.LIVE)
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/foo/servlet_path', project=project)
+ debug_data = self.page_class.GatherDebugData(mr, {})
+ self.assertEqual('off', debug_data['dbg'])
+
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/foo/servlet_path?debug=1', project=project)
+ debug_data = self.page_class.GatherDebugData(mr, {})
+ self.assertEqual('on', debug_data['dbg'])
+
+
+class ProjectIsRestrictedTest(unittest.TestCase):
+
+ def testNonRestrictedProject(self):
+ proj = project_pb2.Project()
+ mr = testing_helpers.MakeMonorailRequest()
+ mr.project = proj
+
+ proj.access = project_pb2.ProjectAccess.ANYONE
+ proj.state = project_pb2.ProjectState.LIVE
+ self.assertFalse(servlet._ProjectIsRestricted(mr))
+
+ proj.state = project_pb2.ProjectState.ARCHIVED
+ self.assertFalse(servlet._ProjectIsRestricted(mr))
+
+ def testRestrictedProject(self):
+ proj = project_pb2.Project()
+ mr = testing_helpers.MakeMonorailRequest()
+ mr.project = proj
+
+ proj.state = project_pb2.ProjectState.LIVE
+ proj.access = project_pb2.ProjectAccess.MEMBERS_ONLY
+ self.assertTrue(servlet._ProjectIsRestricted(mr))
+
+class VersionBaseTest(unittest.TestCase):
+
+ @mock.patch('settings.local_mode', True)
+ def testLocalhost(self):
+ request = webapp2.Request.blank('/', base_url='http://localhost:8080')
+ actual = servlet._VersionBaseURL(request)
+ expected = 'http://localhost:8080'
+ self.assertEqual(expected, actual)
+
+ @mock.patch('settings.local_mode', False)
+ @mock.patch('google.appengine.api.app_identity.get_default_version_hostname')
+ def testProd(self, mock_gdvh):
+ mock_gdvh.return_value = 'monorail-prod.appspot.com'
+ request = webapp2.Request.blank('/', base_url='https://bugs.chromium.org')
+ actual = servlet._VersionBaseURL(request)
+ expected = 'https://test-dot-monorail-prod.appspot.com'
+ self.assertEqual(expected, actual)
diff --git a/framework/test/sorting_test.py b/framework/test/sorting_test.py
new file mode 100644
index 0000000..4b1feb3
--- /dev/null
+++ b/framework/test/sorting_test.py
@@ -0,0 +1,360 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for sorting.py functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+# For convenient debugging
+import logging
+
+import mox
+
+from framework import sorting
+from framework import framework_views
+from proto import tracker_pb2
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+
+
+def MakeDescending(accessor):
+ return sorting._MaybeMakeDescending(accessor, True)
+
+
+class DescendingValueTest(unittest.TestCase):
+
+ def testMinString(self):
+ """When sorting desc, a min string will sort last instead of first."""
+ actual = sorting.DescendingValue.MakeDescendingValue(sorting.MIN_STRING)
+ self.assertEqual(sorting.MAX_STRING, actual)
+
+ def testMaxString(self):
+ """When sorting desc, a max string will sort first instead of last."""
+ actual = sorting.DescendingValue.MakeDescendingValue(sorting.MAX_STRING)
+ self.assertEqual(sorting.MIN_STRING, actual)
+
+ def testDescValues(self):
+ """The point of DescendingValue is to reverse the sort order."""
+ anti_a = sorting.DescendingValue.MakeDescendingValue('a')
+ anti_b = sorting.DescendingValue.MakeDescendingValue('b')
+ self.assertTrue(anti_a > anti_b)
+
+ def testMaybeMakeDescending(self):
+ """It returns an accessor that makes DescendingValue iff arg is True."""
+ asc_accessor = sorting._MaybeMakeDescending(lambda issue: 'a', False)
+ asc_value = asc_accessor('fake issue')
+ self.assertTrue(asc_value is 'a')
+
+ desc_accessor = sorting._MaybeMakeDescending(lambda issue: 'a', True)
+ print(desc_accessor)
+ desc_value = desc_accessor('fake issue')
+ self.assertTrue(isinstance(desc_value, sorting.DescendingValue))
+
+
+class SortingTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.default_cols = 'a b c'
+ self.builtin_cols = 'a b x y z'
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ self.config.component_defs.append(tracker_bizobj.MakeComponentDef(
+ 11, 789, 'Database', 'doc', False, [], [], 0, 0))
+ self.config.component_defs.append(tracker_bizobj.MakeComponentDef(
+ 22, 789, 'User Interface', 'doc', True, [], [], 0, 0))
+ self.config.component_defs.append(tracker_bizobj.MakeComponentDef(
+ 33, 789, 'Installer', 'doc', False, [], [], 0, 0))
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testMakeSingleSortKeyAccessor_Status(self):
+ """Sorting by status should create an accessor for that column."""
+ self.mox.StubOutWithMock(sorting, '_IndexOrLexical')
+ status_names = [wks.status for wks in self.config.well_known_statuses]
+ sorting._IndexOrLexical(status_names, 'status accessor')
+ self.mox.ReplayAll()
+
+ sorting._MakeSingleSortKeyAccessor(
+ 'status', self.config, {'status': 'status accessor'}, [], {}, [])
+ self.mox.VerifyAll()
+
+ def testMakeSingleSortKeyAccessor_Component(self):
+ """Sorting by component should create an accessor for that column."""
+ self.mox.StubOutWithMock(sorting, '_IndexListAccessor')
+ component_ids = [11, 33, 22]
+ sorting._IndexListAccessor(component_ids, 'component accessor')
+ self.mox.ReplayAll()
+
+ sorting._MakeSingleSortKeyAccessor(
+ 'component', self.config, {'component': 'component accessor'}, [], {}, [])
+ self.mox.VerifyAll()
+
+ def testMakeSingleSortKeyAccessor_OtherBuiltInColunms(self):
+ """Sorting a built-in column should create an accessor for that column."""
+ accessor = sorting._MakeSingleSortKeyAccessor(
+ 'buildincol', self.config, {'buildincol': 'accessor'}, [], {}, [])
+ self.assertEqual('accessor', accessor)
+
+ def testMakeSingleSortKeyAccessor_WithPostProcessor(self):
+ """Sorting a built-in user column should create a user accessor."""
+ self.mox.StubOutWithMock(sorting, '_MakeAccessorWithPostProcessor')
+ users_by_id = {111: 'fake user'}
+ sorting._MakeAccessorWithPostProcessor(
+ users_by_id, 'mock owner accessor', 'mock postprocessor')
+ self.mox.ReplayAll()
+
+ sorting._MakeSingleSortKeyAccessor(
+ 'owner', self.config, {'owner': 'mock owner accessor'},
+ {'owner': 'mock postprocessor'}, users_by_id, [])
+ self.mox.VerifyAll()
+
+ def testIndexOrLexical(self):
+ well_known_values = ['x-a', 'x-b', 'x-c', 'x-d']
+ art = 'this is a fake artifact'
+
+ # Case 1: accessor generates no values.
+ base_accessor = lambda art: None
+ accessor = sorting._IndexOrLexical(well_known_values, base_accessor)
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.DescendingValue(sorting.MAX_STRING),
+ neg_accessor(art))
+
+ # Case 2: accessor generates a value, but it is an empty value.
+ base_accessor = lambda art: ''
+ accessor = sorting._IndexOrLexical(well_known_values, base_accessor)
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.DescendingValue(sorting.MAX_STRING),
+ neg_accessor(art))
+
+ # Case 3: A single well-known value
+ base_accessor = lambda art: 'x-c'
+ accessor = sorting._IndexOrLexical(well_known_values, base_accessor)
+ self.assertEqual(2, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(-2, neg_accessor(art))
+
+ # Case 4: A single odd-ball value
+ base_accessor = lambda art: 'x-zzz'
+ accessor = sorting._IndexOrLexical(well_known_values, base_accessor)
+ self.assertEqual('x-zzz', accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(
+ sorting.DescendingValue('x-zzz'), neg_accessor(art))
+
+ def testIndexListAccessor_SomeWellKnownValues(self):
+ """Values sort according to their position in the well-known list."""
+ well_known_values = [11, 33, 22] # These represent component IDs.
+ art = fake.MakeTestIssue(789, 1, 'sum 1', 'New', 111)
+ base_accessor = lambda issue: issue.component_ids
+ accessor = sorting._IndexListAccessor(well_known_values, base_accessor)
+
+ # Case 1: accessor generates no values.
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.MAX_STRING, neg_accessor(art))
+
+ # Case 2: A single well-known value
+ art.component_ids = [33]
+ self.assertEqual([1], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([-1], neg_accessor(art))
+
+ # Case 3: Multiple well-known and odd-ball values
+ art.component_ids = [33, 11, 99]
+ self.assertEqual([0, 1, sorting.MAX_STRING], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.MAX_STRING, -1, 0],
+ neg_accessor(art))
+
+ def testIndexListAccessor_NoWellKnownValues(self):
+ """When there are no well-known values, all values sort last."""
+ well_known_values = [] # Nothing pre-defined, so everything is oddball
+ art = fake.MakeTestIssue(789, 1, 'sum 1', 'New', 111)
+ base_accessor = lambda issue: issue.component_ids
+ accessor = sorting._IndexListAccessor(well_known_values, base_accessor)
+
+ # Case 1: accessor generates no values.
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.MAX_STRING, neg_accessor(art))
+
+ # Case 2: A single oddball value
+ art.component_ids = [33]
+ self.assertEqual([sorting.MAX_STRING], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.MAX_STRING], neg_accessor(art))
+
+ # Case 3: Multiple odd-ball values
+ art.component_ids = [33, 11, 99]
+ self.assertEqual(
+ [sorting.MAX_STRING, sorting.MAX_STRING, sorting.MAX_STRING],
+ accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(
+ [sorting.MAX_STRING, sorting.MAX_STRING, sorting.MAX_STRING],
+ neg_accessor(art))
+
+ def testIndexOrLexicalList(self):
+ well_known_values = ['Pri-High', 'Pri-Med', 'Pri-Low']
+ art = fake.MakeTestIssue(789, 1, 'sum 1', 'New', 111, merged_into=200001)
+
+ # Case 1: accessor generates no values.
+ accessor = sorting._IndexOrLexicalList(well_known_values, [], 'pri', {})
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.MAX_STRING, neg_accessor(art))
+
+ # Case 2: A single well-known value
+ art.labels = ['Pri-Med']
+ accessor = sorting._IndexOrLexicalList(well_known_values, [], 'pri', {})
+ self.assertEqual([1], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([-1], neg_accessor(art))
+
+ # Case 3: Multiple well-known and odd-ball values
+ art.labels = ['Pri-zzz', 'Pri-Med', 'yyy', 'Pri-High']
+ accessor = sorting._IndexOrLexicalList(well_known_values, [], 'pri', {})
+ self.assertEqual([0, 1, 'zzz'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.DescendingValue('zzz'), -1, 0],
+ neg_accessor(art))
+
+ # Case 4: Multi-part prefix.
+ well_known_values.extend(['X-Y-Header', 'X-Y-Footer'])
+ art.labels = ['X-Y-Footer', 'X-Y-Zone', 'X-Y-Header', 'X-Y-Area']
+ accessor = sorting._IndexOrLexicalList(well_known_values, [], 'x-y', {})
+ self.assertEqual([3, 4, 'area', 'zone'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.DescendingValue('zone'),
+ sorting.DescendingValue('area'), -4, -3],
+ neg_accessor(art))
+
+ def testIndexOrLexicalList_CustomFields(self):
+ art = fake.MakeTestIssue(789, 1, 'sum 2', 'New', 111)
+ art.labels = ['samename-value1']
+ art.field_values = [tracker_bizobj.MakeFieldValue(
+ 3, 6078, None, None, None, None, False)]
+
+ all_field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'samename', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'samename', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'notsamename', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'should get filtered out', False)
+ ]
+
+ accessor = sorting._IndexOrLexicalList([], all_field_defs, 'samename', {})
+ self.assertEqual([6078, 'value1'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(
+ [sorting.DescendingValue('value1'), -6078], neg_accessor(art))
+
+ def testIndexOrLexicalList_PhaseCustomFields(self):
+ art = fake.MakeTestIssue(789, 1, 'sum 2', 'New', 111)
+ art.labels = ['summer.goats-value1']
+ art.field_values = [
+ tracker_bizobj.MakeFieldValue(
+ 3, 33, None, None, None, None, False, phase_id=77),
+ tracker_bizobj.MakeFieldValue(
+ 3, 34, None, None, None, None, False, phase_id=77),
+ tracker_bizobj.MakeFieldValue(
+ 3, 1000, None, None, None, None, False, phase_id=78)]
+ art.phases = [tracker_pb2.Phase(phase_id=77, name='summer'),
+ tracker_pb2.Phase(phase_id=78, name='winter')]
+
+ all_field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'goats', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, True, None, None, None, False, None,
+ None, None, None, 'goats love mineral', False, is_phase_field=True),
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'boo', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'ahh', False),
+ ]
+
+ accessor = sorting._IndexOrLexicalList(
+ [], all_field_defs, 'summer.goats', {})
+ self.assertEqual([33, 34, 'value1'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(
+ [sorting.DescendingValue('value1'), -34, -33], neg_accessor(art))
+
+ def testIndexOrLexicalList_ApprovalStatus(self):
+ art = fake.MakeTestIssue(789, 1, 'sum 2', 'New', 111)
+ art.labels = ['samename-value1']
+ art.approval_values = [tracker_pb2.ApprovalValue(approval_id=4)]
+
+ all_field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'samename', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'samename', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False)
+ ]
+
+ accessor = sorting._IndexOrLexicalList([], all_field_defs, 'samename', {})
+ self.assertEqual([0, 'value1'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.DescendingValue('value1'),
+ sorting.DescendingValue(0)],
+ neg_accessor(art))
+
+ def testIndexOrLexicalList_ApprovalApprover(self):
+ art = art = fake.MakeTestIssue(789, 1, 'sum 2', 'New', 111)
+ art.labels = ['samename-approver-value1']
+ art.approval_values = [
+ tracker_pb2.ApprovalValue(approval_id=4, approver_ids=[333])]
+
+ all_field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'samename', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False)
+ ]
+ users_by_id = {333: framework_views.StuffUserView(333, 'a@test.com', True)}
+
+ accessor = sorting._IndexOrLexicalList(
+ [], all_field_defs, 'samename-approver', users_by_id)
+ self.assertEqual(['a@test.com', 'value1'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.DescendingValue('value1'),
+ sorting.DescendingValue('a@test.com')],
+ neg_accessor(art))
+
+ def testComputeSortDirectives(self):
+ config = tracker_pb2.ProjectIssueConfig()
+ self.assertEqual(
+ ['project', 'id'], sorting.ComputeSortDirectives(config, '', ''))
+
+ self.assertEqual(
+ ['a', 'b', 'c', 'project', 'id'],
+ sorting.ComputeSortDirectives(config, '', 'a b C'))
+
+ config.default_sort_spec = 'id -reporter Owner'
+ self.assertEqual(
+ ['id', '-reporter', 'owner', 'project'],
+ sorting.ComputeSortDirectives(config, '', ''))
+
+ self.assertEqual(
+ ['x', '-b', 'a', 'c', '-owner', 'id', '-reporter', 'project'],
+ sorting.ComputeSortDirectives(config, 'x -b', 'A -b c -owner'))
diff --git a/framework/test/sql_test.py b/framework/test/sql_test.py
new file mode 100644
index 0000000..f073e24
--- /dev/null
+++ b/framework/test/sql_test.py
@@ -0,0 +1,681 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the sql module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import mock
+import time
+import unittest
+
+import settings
+from framework import exceptions
+from framework import sql
+
+
+class MockSQLCnxn(object):
+ """This class mocks the connection and cursor classes."""
+
+ def __init__(self, instance, database):
+ self.instance = instance
+ self.database = database
+ self.last_executed = None
+ self.last_executed_args = None
+ self.result_rows = None
+ self.rowcount = 0
+ self.lastrowid = None
+ self.pool_key = instance + '/' + database
+ self.is_bad = False
+ self.has_uncommitted = False
+
+ def execute(self, stmt_str, args=None):
+ self.last_executed = stmt_str % tuple(args or [])
+ if not stmt_str.startswith(('SET', 'SELECT')):
+ self.has_uncommitted = True
+
+ def executemany(self, stmt_str, args):
+ # We cannot format the string because args has many values for each %s.
+ self.last_executed = stmt_str
+ self.last_executed_args = tuple(args)
+
+ # sql.py only calls executemany() for INSERT.
+ assert stmt_str.startswith('INSERT')
+ self.lastrowid = 123
+
+ def fetchall(self):
+ return self.result_rows
+
+ def cursor(self):
+ return self
+
+ def commit(self):
+ self.has_uncommitted = False
+
+ def close(self):
+ assert not self.has_uncommitted
+
+ def rollback(self):
+ self.has_uncommitted = False
+
+ def ping(self):
+ if self.is_bad:
+ raise BaseException('connection error!')
+
+
+sql.cnxn_ctor = MockSQLCnxn
+
+
+class ConnectionPoolingTest(unittest.TestCase):
+
+ def testGet(self):
+ pool_size = 2
+ num_dbs = 2
+ p = sql.ConnectionPool(pool_size)
+
+ for i in range(num_dbs):
+ for _ in range(pool_size):
+ c = p.get('test', 'db%d' % i)
+ self.assertIsNotNone(c)
+ p.release(c)
+
+ cnxn1 = p.get('test', 'db0')
+ q = p.queues[cnxn1.pool_key]
+ self.assertIs(q.qsize(), 0)
+
+ p.release(cnxn1)
+ self.assertIs(q.qsize(), pool_size - 1)
+ self.assertIs(q.full(), False)
+ self.assertIs(q.empty(), False)
+
+ cnxn2 = p.get('test', 'db0')
+ q = p.queues[cnxn2.pool_key]
+ self.assertIs(q.qsize(), 0)
+ self.assertIs(q.full(), False)
+ self.assertIs(q.empty(), True)
+
+ def testGetAndReturnPooledCnxn(self):
+ p = sql.ConnectionPool(2)
+
+ cnxn1 = p.get('test', 'db1')
+ self.assertIs(len(p.queues), 1)
+
+ cnxn2 = p.get('test', 'db2')
+ self.assertIs(len(p.queues), 2)
+
+ # Should use the existing pool.
+ cnxn3 = p.get('test', 'db1')
+ self.assertIs(len(p.queues), 2)
+
+ p.release(cnxn3)
+ p.release(cnxn2)
+
+ cnxn1.is_bad = True
+ p.release(cnxn1)
+ # cnxn1 should not be returned from the pool if we
+ # ask for a connection to its database.
+
+ cnxn4 = p.get('test', 'db1')
+
+ self.assertIsNot(cnxn1, cnxn4)
+ self.assertIs(len(p.queues), 2)
+ self.assertIs(cnxn4.is_bad, False)
+
+ def testGetAndReturnPooledCnxn_badCnxn(self):
+ p = sql.ConnectionPool(2)
+
+ cnxn1 = p.get('test', 'db1')
+ cnxn2 = p.get('test', 'db2')
+ cnxn3 = p.get('test', 'db1')
+
+ cnxn3.is_bad = True
+
+ p.release(cnxn3)
+ q = p.queues[cnxn3.pool_key]
+ self.assertIs(q.qsize(), 1)
+
+ with self.assertRaises(BaseException):
+ cnxn3 = p.get('test', 'db1')
+
+ q = p.queues[cnxn2.pool_key]
+ self.assertIs(q.qsize(), 0)
+ p.release(cnxn2)
+ self.assertIs(q.qsize(), 1)
+
+ p.release(cnxn1)
+ q = p.queues[cnxn1.pool_key]
+ self.assertIs(q.qsize(), 1)
+
+
+class MonorailConnectionTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = sql.MonorailConnection()
+ self.orig_local_mode = settings.local_mode
+ self.orig_num_logical_shards = settings.num_logical_shards
+ settings.local_mode = False
+
+ def tearDown(self):
+ settings.local_mode = self.orig_local_mode
+ settings.num_logical_shards = self.orig_num_logical_shards
+
+ def testGetPrimaryConnection(self):
+ sql_cnxn = self.cnxn.GetPrimaryConnection()
+ self.assertEqual(settings.db_instance, sql_cnxn.instance)
+ self.assertEqual(settings.db_database_name, sql_cnxn.database)
+
+ sql_cnxn2 = self.cnxn.GetPrimaryConnection()
+ self.assertIs(sql_cnxn2, sql_cnxn)
+
+ def testGetConnectionForShard(self):
+ sql_cnxn = self.cnxn.GetConnectionForShard(1)
+ replica_name = settings.db_replica_names[
+ 1 % len(settings.db_replica_names)]
+ self.assertEqual(settings.physical_db_name_format % replica_name,
+ sql_cnxn.instance)
+ self.assertEqual(settings.db_database_name, sql_cnxn.database)
+
+ sql_cnxn2 = self.cnxn.GetConnectionForShard(1)
+ self.assertIs(sql_cnxn2, sql_cnxn)
+
+ def testClose(self):
+ sql_cnxn = self.cnxn.GetPrimaryConnection()
+ self.cnxn.Close()
+ self.assertFalse(sql_cnxn.has_uncommitted)
+
+ def testExecute_Primary(self):
+ """Execute() with no shard passes the statement to the primary sql cnxn."""
+ sql_cnxn = self.cnxn.GetPrimaryConnection()
+ with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = self.cnxn.Execute('statement', [])
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(sql_cnxn, 'statement', [], commit=True)
+
+ def testExecute_Shard(self):
+ """Execute() with a shard passes the statement to the shard sql cnxn."""
+ shard_id = 1
+ sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
+ with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
+
+ def testExecute_Shard_Unavailable(self):
+ """If a shard is unavailable, we try the next one."""
+ shard_id = 1
+ sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
+ sql_cnxn_2 = self.cnxn.GetConnectionForShard(shard_id + 1)
+
+ # Simulate a recent failure on shard 1.
+ self.cnxn.unavailable_shards[1] = int(time.time()) - 3
+
+ with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(sql_cnxn_2, 'statement', [], commit=True)
+
+ # Even a new MonorailConnection instance shares the same state.
+ other_cnxn = sql.MonorailConnection()
+ other_sql_cnxn_2 = other_cnxn.GetConnectionForShard(shard_id + 1)
+
+ with mock.patch.object(other_cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = other_cnxn.Execute('statement', [], shard_id=shard_id)
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(
+ other_sql_cnxn_2, 'statement', [], commit=True)
+
+ # Simulate an old failure on shard 1, allowing us to try using it again.
+ self.cnxn.unavailable_shards[1] = (
+ int(time.time()) - sql.BAD_SHARD_AVOIDANCE_SEC - 2)
+
+ with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
+
+
+class TableManagerTest(unittest.TestCase):
+
+ def setUp(self):
+ self.emp_tbl = sql.SQLTableManager('Employee')
+ self.cnxn = sql.MonorailConnection()
+ self.primary_cnxn = self.cnxn.GetPrimaryConnection()
+
+ def testSelect_Trivial(self):
+ self.primary_cnxn.result_rows = [(111, True), (222, False)]
+ rows = self.emp_tbl.Select(self.cnxn)
+ self.assertEqual('SELECT * FROM Employee', self.primary_cnxn.last_executed)
+ self.assertEqual([(111, True), (222, False)], rows)
+
+ def testSelect_Conditions(self):
+ self.primary_cnxn.result_rows = [(111,)]
+ rows = self.emp_tbl.Select(
+ self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
+ self.assertEqual(
+ 'SELECT emp_id FROM Employee'
+ '\nWHERE dept_id IN (10,20)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual([(111,)], rows)
+
+ def testSelectRow(self):
+ self.primary_cnxn.result_rows = [(111,)]
+ row = self.emp_tbl.SelectRow(
+ self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id FROM Employee'
+ '\nWHERE dept_id IN (10,20)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual((111,), row)
+
+ def testSelectRow_NoMatches(self):
+ self.primary_cnxn.result_rows = []
+ row = self.emp_tbl.SelectRow(
+ self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99])
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id FROM Employee'
+ '\nWHERE dept_id IN (99)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual(None, row)
+
+ row = self.emp_tbl.SelectRow(
+ self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99],
+ default=(-1,))
+ self.assertEqual((-1,), row)
+
+ def testSelectValue(self):
+ self.primary_cnxn.result_rows = [(111,)]
+ val = self.emp_tbl.SelectValue(
+ self.cnxn, 'emp_id', fulltime=True, dept_id=[10, 20])
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id FROM Employee'
+ '\nWHERE dept_id IN (10,20)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual(111, val)
+
+ def testSelectValue_NoMatches(self):
+ self.primary_cnxn.result_rows = []
+ val = self.emp_tbl.SelectValue(
+ self.cnxn, 'emp_id', fulltime=True, dept_id=[99])
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id FROM Employee'
+ '\nWHERE dept_id IN (99)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual(None, val)
+
+ val = self.emp_tbl.SelectValue(
+ self.cnxn, 'emp_id', fulltime=True, dept_id=[99],
+ default=-1)
+ self.assertEqual(-1, val)
+
+ def testInsertRow(self):
+ self.primary_cnxn.rowcount = 1
+ generated_id = self.emp_tbl.InsertRow(self.cnxn, emp_id=111, fulltime=True)
+ self.assertEqual(
+ 'INSERT INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
+ self.assertEqual(([111, 1],), self.primary_cnxn.last_executed_args)
+ self.assertEqual(123, generated_id)
+
+ def testInsertRows_Empty(self):
+ generated_id = self.emp_tbl.InsertRows(
+ self.cnxn, ['emp_id', 'fulltime'], [])
+ self.assertIsNone(self.primary_cnxn.last_executed)
+ self.assertIsNone(self.primary_cnxn.last_executed_args)
+ self.assertEqual(None, generated_id)
+
+ def testInsertRows(self):
+ self.primary_cnxn.rowcount = 2
+ generated_ids = self.emp_tbl.InsertRows(
+ self.cnxn, ['emp_id', 'fulltime'], [(111, True), (222, False)])
+ self.assertEqual(
+ 'INSERT INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
+ self.assertEqual(([111, 1], [222, 0]), self.primary_cnxn.last_executed_args)
+ self.assertEqual([], generated_ids)
+
+ def testUpdate(self):
+ self.primary_cnxn.rowcount = 2
+ rowcount = self.emp_tbl.Update(
+ self.cnxn, {'fulltime': True}, emp_id=[111, 222])
+ self.assertEqual(
+ 'UPDATE Employee SET fulltime=1'
+ '\nWHERE emp_id IN (111,222)', self.primary_cnxn.last_executed)
+ self.assertEqual(2, rowcount)
+
+ def testUpdate_Limit(self):
+ self.emp_tbl.Update(
+ self.cnxn, {'fulltime': True}, limit=8, emp_id=[111, 222])
+ self.assertEqual(
+ 'UPDATE Employee SET fulltime=1'
+ '\nWHERE emp_id IN (111,222)'
+ '\nLIMIT 8', self.primary_cnxn.last_executed)
+
+ def testIncrementCounterValue(self):
+ self.primary_cnxn.rowcount = 1
+ self.primary_cnxn.lastrowid = 9
+ new_counter_val = self.emp_tbl.IncrementCounterValue(
+ self.cnxn, 'years_worked', emp_id=111)
+ self.assertEqual(
+ 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + 1)'
+ '\nWHERE emp_id = 111', self.primary_cnxn.last_executed)
+ self.assertEqual(9, new_counter_val)
+
+ def testDelete(self):
+ self.primary_cnxn.rowcount = 1
+ rowcount = self.emp_tbl.Delete(self.cnxn, fulltime=True)
+ self.assertEqual(
+ 'DELETE FROM Employee'
+ '\nWHERE fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual(1, rowcount)
+
+ def testDelete_Limit(self):
+ self.emp_tbl.Delete(self.cnxn, fulltime=True, limit=3)
+ self.assertEqual(
+ 'DELETE FROM Employee'
+ '\nWHERE fulltime = 1'
+ '\nLIMIT 3', self.primary_cnxn.last_executed)
+
+
+class StatementTest(unittest.TestCase):
+
+ def testMakeSelect(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ stmt = sql.Statement.MakeSelect(
+ 'Employee', ['emp_id', 'fulltime'], distinct=True)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testMakeInsert(self):
+ stmt = sql.Statement.MakeInsert(
+ 'Employee', ['emp_id', 'fulltime'], [(111, True), (222, False)])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'INSERT INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)',
+ stmt_str)
+ self.assertEqual([[111, 1], [222, 0]], args)
+
+ stmt = sql.Statement.MakeInsert(
+ 'Employee', ['emp_id', 'fulltime'], [(111, False)], replace=True)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'INSERT INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)'
+ '\nON DUPLICATE KEY UPDATE '
+ 'emp_id=VALUES(emp_id), fulltime=VALUES(fulltime)',
+ stmt_str)
+ self.assertEqual([[111, 0]], args)
+
+ stmt = sql.Statement.MakeInsert(
+ 'Employee', ['emp_id', 'fulltime'], [(111, False)], ignore=True)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'INSERT IGNORE INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)',
+ stmt_str)
+ self.assertEqual([[111, 0]], args)
+
+ def testMakeInsert_InvalidString(self):
+ with self.assertRaises(exceptions.InputException):
+ sql.Statement.MakeInsert(
+ 'Employee', ['emp_id', 'name'], [(111, 'First \x00 Last')])
+
+ def testMakeUpdate(self):
+ stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'UPDATE Employee SET fulltime=%s',
+ stmt_str)
+ self.assertEqual([1], args)
+
+ def testMakeUpdate_InvalidString(self):
+ with self.assertRaises(exceptions.InputException):
+ sql.Statement.MakeUpdate('Employee', {'name': 'First \x00 Last'})
+
+ def testMakeIncrement(self):
+ stmt = sql.Statement.MakeIncrement('Employee', 'years_worked')
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
+ stmt_str)
+ self.assertEqual([1], args)
+
+ stmt = sql.Statement.MakeIncrement('Employee', 'years_worked', step=5)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
+ stmt_str)
+ self.assertEqual([5], args)
+
+ def testMakeDelete(self):
+ stmt = sql.Statement.MakeDelete('Employee')
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'DELETE FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddUseClause(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddUseClause('USE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)')
+ stmt.AddOrderByTerms([('emp_id', [])])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nUSE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)'
+ '\nORDER BY emp_id',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddJoinClause_Empty(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddJoinClauses([])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddJoinClause(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddJoinClauses([('CorporateHoliday', [])])
+ stmt.AddJoinClauses(
+ [('Product ON Project.inventor_id = emp_id', [])], left=True)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\n JOIN CorporateHoliday'
+ '\n LEFT JOIN Product ON Project.inventor_id = emp_id',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddGroupByTerms_Empty(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddGroupByTerms([])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddGroupByTerms(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddGroupByTerms(['dept_id', 'location_id'])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nGROUP BY dept_id, location_id',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddOrderByTerms_Empty(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddOrderByTerms([])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddOrderByTerms(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddOrderByTerms([('dept_id', []), ('emp_id DESC', [])])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nORDER BY dept_id, emp_id DESC',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testSetLimitAndOffset(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.SetLimitAndOffset(100, 0)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nLIMIT 100',
+ stmt_str)
+ self.assertEqual([], args)
+
+ stmt.SetLimitAndOffset(100, 500)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nLIMIT 100 OFFSET 500',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddWhereTerms_Select(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddWhereTerms([], emp_id=[111, 222])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nWHERE emp_id IN (%s,%s)',
+ stmt_str)
+ self.assertEqual([111, 222], args)
+
+ def testAddWhereTerms_Update(self):
+ stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
+ stmt.AddWhereTerms([], emp_id=[111, 222])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'UPDATE Employee SET fulltime=%s'
+ '\nWHERE emp_id IN (%s,%s)',
+ stmt_str)
+ self.assertEqual([1, 111, 222], args)
+
+ def testAddWhereTerms_Delete(self):
+ stmt = sql.Statement.MakeDelete('Employee')
+ stmt.AddWhereTerms([], emp_id=[111, 222])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'DELETE FROM Employee'
+ '\nWHERE emp_id IN (%s,%s)',
+ stmt_str)
+ self.assertEqual([111, 222], args)
+
+ def testAddWhereTerms_Empty(self):
+ """Add empty terms should have no effect."""
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddWhereTerms([])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddWhereTerms_UpdateEmptyArray(self):
+ """Add empty array should throw an exception."""
+ stmt = sql.Statement.MakeUpdate('SpamVerdict', {'user_id': 1})
+ # See https://crbug.com/monorail/6735.
+ with self.assertRaises(exceptions.InputException):
+ stmt.AddWhereTerms([], user_id=[])
+ mock_log.assert_called_once_with('Invalid update DB value %r', 'user_id')
+
+ def testAddWhereTerms_MulitpleTerms(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddWhereTerms(
+ [('emp_id %% %s = %s', [2, 0])], fulltime=True, emp_id_not=222)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nWHERE emp_id %% %s = %s'
+ '\n AND emp_id != %s'
+ '\n AND fulltime = %s',
+ stmt_str)
+ self.assertEqual([2, 0, 222, 1], args)
+
+ def testAddHavingTerms_NoGroupBy(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
+ self.assertRaises(AssertionError, stmt.Generate)
+
+ def testAddHavingTerms_WithGroupBy(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddGroupByTerms(['dept_id', 'location_id'])
+ stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nGROUP BY dept_id, location_id'
+ '\nHAVING COUNT(*) > %s',
+ stmt_str)
+ self.assertEqual([10], args)
+
+
+class FunctionsTest(unittest.TestCase):
+
+ def testIsValidDBValue_NonString(self):
+ self.assertTrue(sql._IsValidDBValue(12))
+ self.assertTrue(sql._IsValidDBValue(True))
+ self.assertTrue(sql._IsValidDBValue(False))
+ self.assertTrue(sql._IsValidDBValue(None))
+
+ def testIsValidDBValue_String(self):
+ self.assertTrue(sql._IsValidDBValue(''))
+ self.assertTrue(sql._IsValidDBValue('hello'))
+ self.assertTrue(sql._IsValidDBValue(u'hello'))
+ self.assertFalse(sql._IsValidDBValue('null \x00 byte'))
+
+ def testBoolsToInts_NoChanges(self):
+ self.assertEqual(['hello'], sql._BoolsToInts(['hello']))
+ self.assertEqual([['hello']], sql._BoolsToInts([['hello']]))
+ self.assertEqual([['hello']], sql._BoolsToInts([('hello',)]))
+ self.assertEqual([12], sql._BoolsToInts([12]))
+ self.assertEqual([[12]], sql._BoolsToInts([[12]]))
+ self.assertEqual([[12]], sql._BoolsToInts([(12,)]))
+ self.assertEqual(
+ [12, 13, 'hi', [99, 'yo']],
+ sql._BoolsToInts([12, 13, 'hi', [99, 'yo']]))
+
+ def testBoolsToInts_WithChanges(self):
+ self.assertEqual([1, 0], sql._BoolsToInts([True, False]))
+ self.assertEqual([[1, 0]], sql._BoolsToInts([[True, False]]))
+ self.assertEqual([[1, 0]], sql._BoolsToInts([(True, False)]))
+ self.assertEqual(
+ [12, 1, 'hi', [0, 'yo']],
+ sql._BoolsToInts([12, True, 'hi', [False, 'yo']]))
+
+ def testRandomShardID(self):
+ """A random shard ID must always be a valid shard ID."""
+ shard_id = sql.RandomShardID()
+ self.assertTrue(0 <= shard_id < settings.num_logical_shards)
diff --git a/framework/test/table_view_helpers_test.py b/framework/test/table_view_helpers_test.py
new file mode 100644
index 0000000..0260308
--- /dev/null
+++ b/framework/test/table_view_helpers_test.py
@@ -0,0 +1,753 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for table_view_helpers classes and functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import unittest
+import logging
+
+from framework import framework_views
+from framework import table_view_helpers
+from proto import tracker_pb2
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+
+
+EMPTY_SEARCH_RESULTS = []
+
+SEARCH_RESULTS_WITH_LABELS = [
+ fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111, labels='Priority-High Mstone-1',
+ merged_into=200001, star_count=1),
+ fake.MakeTestIssue(
+ 789, 2, 'sum 2', 'New', 111, labels='Priority-High Mstone-1',
+ merged_into=1, star_count=1),
+ fake.MakeTestIssue(
+ 789, 3, 'sum 3', 'New', 111, labels='Priority-Low Mstone-1.1',
+ merged_into=1, star_count=1),
+ # 'Visibility-Super-High' tests that only first dash counts
+ fake.MakeTestIssue(
+ 789, 4, 'sum 4', 'New', 111, labels='Visibility-Super-High',
+ star_count=1),
+ ]
+
+
+def MakeTestIssue(local_id, issue_id, summary):
+ issue = tracker_pb2.Issue()
+ issue.local_id = local_id
+ issue.issue_id = issue_id
+ issue.summary = summary
+ return issue
+
+
+class TableCellTest(unittest.TestCase):
+
+ USERS_BY_ID = {}
+
+ def setUp(self):
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ self.config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 1, 789, 'Goats', tracker_pb2.FieldTypes.INT_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'Num of Goats in the season', False, is_phase_field=True),
+ tracker_bizobj.MakeFieldDef(
+ 2, 789, 'DogNames', tracker_pb2.FieldTypes.STR_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'good dog names', False),
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'Approval', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'Tracks review from cows', False)
+ ]
+ self.config.approval_defs = [tracker_pb2.ApprovalDef(approval_id=3)]
+ self.issue = MakeTestIssue(
+ local_id=1, issue_id=100001, summary='One')
+ self.issue.field_values = [
+ tracker_bizobj.MakeFieldValue(
+ 1, 34, None, None, None, None, False, phase_id=23),
+ tracker_bizobj.MakeFieldValue(
+ 1, 35, None, None, None, None, False, phase_id=24),
+ tracker_bizobj.MakeFieldValue(
+ 2, None, 'Waffles', None, None, None, False),
+ ]
+ self.issue.phases = [
+ tracker_pb2.Phase(phase_id=23, name='winter'),
+ tracker_pb2.Phase(phase_id=24, name='summer')]
+ self.issue.approval_values = [
+ tracker_pb2.ApprovalValue(
+ approval_id=3, approver_ids=[111, 222, 333])]
+ self.users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', False),
+ 222: framework_views.StuffUserView(222, 'foo2@example.com', True),
+ }
+
+ self.summary_table_cell_kws = {
+ 'col': None,
+ 'users_by_id': {},
+ 'non_col_labels': [('lab', False)],
+ 'label_values': {},
+ 'related_issues': {},
+ 'config': 'fake_config',
+ }
+
+ def testTableCellSummary(self):
+ """TableCellSummary stores the data given to it."""
+ cell = table_view_helpers.TableCellSummary(
+ MakeTestIssue(4, 4, 'Lame default summary.'),
+ **self.summary_table_cell_kws)
+ self.assertEqual(cell.type, table_view_helpers.CELL_TYPE_SUMMARY)
+ self.assertEqual(cell.values[0].item, 'Lame default summary.')
+ self.assertEqual(cell.non_column_labels[0].value, 'lab')
+
+ def testTableCellSummary_NoPythonEscaping(self):
+ """TableCellSummary stores the summary without escaping it in python."""
+ cell = table_view_helpers.TableCellSummary(
+ MakeTestIssue(4, 4, '<b>bold</b> "summary".'),
+ **self.summary_table_cell_kws)
+ self.assertEqual(cell.values[0].item,'<b>bold</b> "summary".')
+
+ def testTableCellCustom_normal(self):
+ """TableCellCustom stores the value of a custom FieldValue."""
+ cell_dognames = table_view_helpers.TableCellCustom(
+ self.issue, col='dognames', config=self.config)
+ self.assertEqual(cell_dognames.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(cell_dognames.values[0].item, 'Waffles')
+
+ def testTableCellCustom_phasefields(self):
+ """TableCellCustom stores the value of a custom FieldValue."""
+ cell_winter = table_view_helpers.TableCellCustom(
+ self.issue, col='winter.goats', config=self.config)
+ self.assertEqual(cell_winter.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(cell_winter.values[0].item, 34)
+
+ cell_summer = table_view_helpers.TableCellCustom(
+ self.issue, col='summer.goats', config=self.config)
+ self.assertEqual(cell_summer.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(cell_summer.values[0].item, 35)
+
+ def testTableCellApprovalStatus(self):
+ """TableCellApprovalStatus stores the status of an ApprovalValue."""
+ cell = table_view_helpers.TableCellApprovalStatus(
+ self.issue, col='Approval', config=self.config)
+ self.assertEqual(cell.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(cell.values[0].item, 'NOT_SET')
+
+ def testTableCellApprovalApprover(self):
+ """TableCellApprovalApprover stores the approvers of an ApprovalValue."""
+ cell = table_view_helpers.TableCellApprovalApprover(
+ self.issue, col='Approval-approver', config=self.config,
+ users_by_id=self.users_by_id)
+ self.assertEqual(cell.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(len(cell.values), 2)
+ self.assertItemsEqual([cell.values[0].item, cell.values[1].item],
+ ['foo@example.com', 'f...@example.com'])
+
+ # TODO(jrobbins): TableCellProject, TableCellStars
+
+
+
+class TableViewHelpersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.default_cols = 'a b c'
+ self.builtin_cols = ['a', 'b', 'x', 'y', 'z']
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ def testComputeUnshownColumns_CommonCase(self):
+ shown_cols = ['a', 'b', 'c']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(
+ unshown, ['Mstone', 'Priority', 'Visibility', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_MoreBuiltins(self):
+ shown_cols = ['a', 'b', 'c', 'x', 'y']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['Mstone', 'Priority', 'Visibility', 'z'])
+
+ def testComputeUnshownColumns_NotAllDefaults(self):
+ shown_cols = ['a', 'b']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['c', 'x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(
+ unshown, ['Mstone', 'Priority', 'Visibility', 'c', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_ExtraNonDefaults(self):
+ shown_cols = ['a', 'b', 'c', 'd', 'e', 'f']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(
+ unshown, ['Mstone', 'Priority', 'Visibility', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_UserColumnsShown(self):
+ shown_cols = ['a', 'b', 'c', 'Priority']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['Mstone', 'Visibility', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_EverythingShown(self):
+ shown_cols = [
+ 'a', 'b', 'c', 'x', 'y', 'z', 'Priority', 'Mstone', 'Visibility']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, [])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, [])
+
+ def testComputeUnshownColumns_NothingShown(self):
+ shown_cols = []
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['a', 'b', 'c', 'x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(
+ unshown,
+ ['Mstone', 'Priority', 'Visibility', 'a', 'b', 'c', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_NoBuiltins(self):
+ shown_cols = ['a', 'b', 'c']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = ''
+ config.well_known_labels = []
+ builtin_cols = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, builtin_cols)
+ self.assertEqual(unshown, [])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, builtin_cols)
+ self.assertEqual(unshown, ['Mstone', 'Priority', 'Visibility'])
+
+ def testComputeUnshownColumns_FieldDefs(self):
+ search_results = [
+ fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111,
+ field_values=[
+ tracker_bizobj.MakeFieldValue(
+ 5, 74, None, None, None, None, False, phase_id=4),
+ tracker_bizobj.MakeFieldValue(
+ 6, 78, None, None, None, None, False, phase_id=5)],
+ phases=[
+ tracker_pb2.Phase(phase_id=4, name='goats'),
+ tracker_pb2.Phase(phase_id=5, name='sheep')]),
+ fake.MakeTestIssue(
+ 789, 2, 'sum 2', 'New', 111,
+ field_values=[
+ tracker_bizobj.MakeFieldValue(
+ 5, 74, None, None, None, None, False, phase_id=3),
+ tracker_bizobj.MakeFieldValue(
+ 6, 77, None, None, None, None, False, phase_id=3)],
+ phases=[
+ tracker_pb2.Phase(phase_id=3, name='Goats'),
+ tracker_pb2.Phase(phase_id=3, name='Goats-Exp')]),
+ ]
+
+ shown_cols = ['a', 'b', 'a1', 'a2-approver', 'f3', 'goats.g1', 'sheep.g2']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = ''
+ config.well_known_labels = []
+ config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 1, 789, 'a1', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'Tracks review from cows', False),
+ tracker_bizobj.MakeFieldDef(
+ 2, 789, 'a2', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'Tracks review from chickens', False),
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'f3', tracker_pb2.FieldTypes.STR_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow names', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 789, 'f4', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'chicken gobbles', False),
+ tracker_bizobj.MakeFieldDef(
+ 5, 789, 'g1', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'fluff', False, is_phase_field=True),
+ tracker_bizobj.MakeFieldDef(
+ 6, 789, 'g2', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'poof', False, is_phase_field=True),
+ ]
+ builtin_cols = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ search_results, shown_cols, config, builtin_cols)
+ self.assertEqual(unshown, [
+ 'a1-approver', 'a2', 'f4',
+ 'goats-exp.g1', 'goats-exp.g2', 'goats.g2', 'sheep.g1'])
+
+ def testExtractUniqueValues_NoColumns(self):
+ column_values = table_view_helpers.ExtractUniqueValues(
+ [], SEARCH_RESULTS_WITH_LABELS, {}, self.config, {})
+ self.assertEqual([], column_values)
+
+ def testExtractUniqueValues_NoResults(self):
+ cols = ['type', 'priority', 'owner', 'status', 'stars', 'attachments']
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, EMPTY_SEARCH_RESULTS, {}, self.config, {})
+ self.assertEqual(6, len(column_values))
+ for index, col in enumerate(cols):
+ self.assertEqual(index, column_values[index].col_index)
+ self.assertEqual(col, column_values[index].column_name)
+ self.assertEqual([], column_values[index].filter_values)
+
+ def testExtractUniqueValues_ExplicitResults(self):
+ cols = ['priority', 'owner', 'status', 'stars', 'mstone', 'foo']
+ users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', True),
+ }
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, SEARCH_RESULTS_WITH_LABELS, users_by_id, self.config, {})
+ self.assertEqual(len(cols), len(column_values))
+
+ self.assertEqual('priority', column_values[0].column_name)
+ self.assertEqual(['High', 'Low'], column_values[0].filter_values)
+
+ self.assertEqual('owner', column_values[1].column_name)
+ self.assertEqual(['f...@example.com'], column_values[1].filter_values)
+
+ self.assertEqual('status', column_values[2].column_name)
+ self.assertEqual(['New'], column_values[2].filter_values)
+
+ self.assertEqual('stars', column_values[3].column_name)
+ self.assertEqual([1], column_values[3].filter_values)
+
+ self.assertEqual('mstone', column_values[4].column_name)
+ self.assertEqual(['1', '1.1'], column_values[4].filter_values)
+
+ self.assertEqual('foo', column_values[5].column_name)
+ self.assertEqual([], column_values[5].filter_values)
+
+ # self.assertEquals('mergedinto', column_values[6].column_name)
+ # self.assertEquals(
+ # ['1', 'other-project:1'], column_values[6].filter_values)
+
+ def testExtractUniqueValues_CombinedColumns(self):
+ cols = ['priority/pri', 'owner', 'status', 'stars', 'mstone/milestone']
+ users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', True),
+ }
+ issue = fake.MakeTestIssue(
+ 789, 5, 'sum 5', 'New', 111, merged_into=200001,
+ labels='Priority-High Pri-0 Milestone-1.0 mstone-1',
+ star_count=15)
+
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, SEARCH_RESULTS_WITH_LABELS + [issue], users_by_id,
+ self.config, {})
+ self.assertEqual(5, len(column_values))
+
+ self.assertEqual('priority/pri', column_values[0].column_name)
+ self.assertEqual(['0', 'High', 'Low'], column_values[0].filter_values)
+
+ self.assertEqual('owner', column_values[1].column_name)
+ self.assertEqual(['f...@example.com'], column_values[1].filter_values)
+
+ self.assertEqual('status', column_values[2].column_name)
+ self.assertEqual(['New'], column_values[2].filter_values)
+
+ self.assertEqual('stars', column_values[3].column_name)
+ self.assertEqual([1, 15], column_values[3].filter_values)
+
+ self.assertEqual('mstone/milestone', column_values[4].column_name)
+ self.assertEqual(['1', '1.0', '1.1'], column_values[4].filter_values)
+
+ def testExtractUniqueValues_DerivedValues(self):
+ cols = ['priority', 'milestone', 'owner', 'status']
+ users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', True),
+ 222: framework_views.StuffUserView(222, 'bar@example.com', True),
+ 333: framework_views.StuffUserView(333, 'lol@example.com', True),
+ }
+ search_results = [
+ fake.MakeTestIssue(
+ 789, 1, 'sum 1', '', 111, labels='Priority-High Milestone-1.0',
+ derived_labels='Milestone-2.0 Foo', derived_status='Started'),
+ fake.MakeTestIssue(
+ 789, 2, 'sum 2', 'New', 111, labels='Priority-High Milestone-1.0',
+ derived_owner_id=333), # Not seen because of owner_id
+ fake.MakeTestIssue(
+ 789, 3, 'sum 3', 'New', 0, labels='Priority-Low Milestone-1.1',
+ derived_owner_id=222),
+ ]
+
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, search_results, users_by_id, self.config, {})
+ self.assertEqual(4, len(column_values))
+
+ self.assertEqual('priority', column_values[0].column_name)
+ self.assertEqual(['High', 'Low'], column_values[0].filter_values)
+
+ self.assertEqual('milestone', column_values[1].column_name)
+ self.assertEqual(['1.0', '1.1', '2.0'], column_values[1].filter_values)
+
+ self.assertEqual('owner', column_values[2].column_name)
+ self.assertEqual(
+ ['b...@example.com', 'f...@example.com'],
+ column_values[2].filter_values)
+
+ self.assertEqual('status', column_values[3].column_name)
+ self.assertEqual(['New', 'Started'], column_values[3].filter_values)
+
+ def testExtractUniqueValues_ColumnsRobustness(self):
+ cols = ['reporter', 'cc', 'owner', 'status', 'attachments']
+ search_results = [
+ tracker_pb2.Issue(),
+ ]
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, search_results, {}, self.config, {})
+
+ self.assertEqual(5, len(column_values))
+ for col_val in column_values:
+ if col_val.column_name == 'attachments':
+ self.assertEqual([0], col_val.filter_values)
+ else:
+ self.assertEqual([], col_val.filter_values)
+
+ def testMakeTableData_Empty(self):
+ visible_results = []
+ lower_columns = []
+ cell_factories = {}
+ table_data = table_view_helpers.MakeTableData(
+ visible_results, [], lower_columns, lower_columns,
+ cell_factories, [], 'unused function', {}, set(), self.config)
+ self.assertEqual([], table_data)
+
+ lower_columns = ['type', 'priority', 'summary', 'stars']
+ cell_factories = {
+ 'summary': table_view_helpers.TableCellSummary,
+ 'stars': table_view_helpers.TableCellStars,
+ }
+
+ table_data = table_view_helpers.MakeTableData(
+ visible_results, [], lower_columns, [], {},
+ cell_factories, 'unused function', {}, set(), self.config)
+ self.assertEqual([], table_data)
+
+ def testMakeTableData_Normal(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111, labels='Type-Defect Priority-Medium')
+ visible_results = [art]
+ lower_columns = ['type', 'priority', 'summary', 'stars']
+ cell_factories = {
+ 'summary': table_view_helpers.TableCellSummary,
+ 'stars': table_view_helpers.TableCellStars,
+ }
+
+ table_data = table_view_helpers.MakeTableData(
+ visible_results, [], lower_columns, lower_columns, {},
+ cell_factories, lambda art: 'id', {}, set(), self.config)
+ self.assertEqual(1, len(table_data))
+ row = table_data[0]
+ self.assertEqual(4, len(row.cells))
+ self.assertEqual('Defect', row.cells[0].values[0].item)
+
+ def testMakeTableData_Groups(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111, labels='Type-Defect Priority-Medium')
+ visible_results = [art]
+ lower_columns = ['type', 'priority', 'summary', 'stars']
+ lower_group_by = ['priority']
+ cell_factories = {
+ 'summary': table_view_helpers.TableCellSummary,
+ 'stars': table_view_helpers.TableCellStars,
+ }
+
+ table_data = table_view_helpers.MakeTableData(
+ visible_results, [], lower_columns, lower_group_by, {},
+ cell_factories, lambda art: 'id', {}, set(), self.config)
+ self.assertEqual(1, len(table_data))
+ row = table_data[0]
+ self.assertEqual(1, len(row.group.cells))
+ self.assertEqual('Medium', row.group.cells[0].values[0].item)
+
+ def testMakeRowData(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111, labels='Type-Defect Priority-Medium',
+ star_count=1)
+ columns = ['type', 'priority', 'summary', 'stars']
+
+ cell_factories = [table_view_helpers.TableCellKeyLabels,
+ table_view_helpers.TableCellKeyLabels,
+ table_view_helpers.TableCellSummary,
+ table_view_helpers.TableCellStars]
+
+ # a result is an table_view_helpers.TableRow object with a "cells" field
+ # containing a list of table_view_helpers.TableCell objects.
+ result = table_view_helpers.MakeRowData(
+ art, columns, {}, cell_factories, {}, set(), self.config, {})
+
+ self.assertEqual(len(columns), len(result.cells))
+
+ for i in range(len(columns)):
+ cell = result.cells[i]
+ self.assertEqual(i, cell.col_index)
+
+ self.assertEqual(table_view_helpers.CELL_TYPE_ATTR, result.cells[0].type)
+ self.assertEqual('Defect', result.cells[0].values[0].item)
+ self.assertFalse(result.cells[0].values[0].is_derived)
+
+ self.assertEqual(table_view_helpers.CELL_TYPE_ATTR, result.cells[1].type)
+ self.assertEqual('Medium', result.cells[1].values[0].item)
+ self.assertFalse(result.cells[1].values[0].is_derived)
+
+ self.assertEqual(
+ table_view_helpers.CELL_TYPE_SUMMARY, result.cells[2].type)
+ self.assertEqual('sum 1', result.cells[2].values[0].item)
+ self.assertFalse(result.cells[2].values[0].is_derived)
+
+ self.assertEqual(table_view_helpers.CELL_TYPE_ATTR, result.cells[3].type)
+ self.assertEqual(1, result.cells[3].values[0].item)
+ self.assertFalse(result.cells[3].values[0].is_derived)
+
+ def testAccumulateLabelValues_Empty(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ [], [], label_values, non_col_labels)
+ self.assertEqual({}, label_values)
+ self.assertEqual([], non_col_labels)
+
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ [], ['Type', 'Priority'], label_values, non_col_labels)
+ self.assertEqual({}, label_values)
+ self.assertEqual([], non_col_labels)
+
+ def testAccumulateLabelValues_OneWordLabels(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['HelloThere'], [], label_values, non_col_labels)
+ self.assertEqual({}, label_values)
+ self.assertEqual([('HelloThere', False)], non_col_labels)
+
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['HelloThere'], [], label_values, non_col_labels, is_derived=True)
+ self.assertEqual({}, label_values)
+ self.assertEqual([('HelloThere', True)], non_col_labels)
+
+ def testAccumulateLabelValues_KeyValueLabels(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['Type-Defect', 'Milestone-Soon'], ['type', 'milestone'],
+ label_values, non_col_labels)
+ self.assertEqual(
+ {'type': [('Defect', False)],
+ 'milestone': [('Soon', False)]},
+ label_values)
+ self.assertEqual([], non_col_labels)
+
+ def testAccumulateLabelValues_MultiValueLabels(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['OS-Mac', 'OS-Linux'], ['os', 'arch'],
+ label_values, non_col_labels)
+ self.assertEqual(
+ {'os': [('Mac', False), ('Linux', False)]},
+ label_values)
+ self.assertEqual([], non_col_labels)
+
+ def testAccumulateLabelValues_MultiPartLabels(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['OS-Mac-Server', 'OS-Mac-Laptop'], ['os', 'os-mac'],
+ label_values, non_col_labels)
+ self.assertEqual(
+ {'os': [('Mac-Server', False), ('Mac-Laptop', False)],
+ 'os-mac': [('Server', False), ('Laptop', False)],
+ },
+ label_values)
+ self.assertEqual([], non_col_labels)
+
+ def testChooseCellFactory(self):
+ """We choose the right kind of table cell for the specified column."""
+ cell_factories = {
+ 'summary': table_view_helpers.TableCellSummary,
+ 'stars': table_view_helpers.TableCellStars,
+ }
+ os_fd = tracker_bizobj.MakeFieldDef(
+ 1, 789, 'os', tracker_pb2.FieldTypes.ENUM_TYPE, None, None, False,
+ False, False, None, None, None, False, None, None, None, None,
+ 'Operating system', False)
+ deadline_fd = tracker_bizobj.MakeFieldDef(
+ 2, 789, 'deadline', tracker_pb2.FieldTypes.DATE_TYPE, None, None, False,
+ False, False, None, None, None, False, None, None, None, None,
+ 'Deadline to resolve issue', False)
+ approval_fd = tracker_bizobj.MakeFieldDef(
+ 3, 789, 'CowApproval', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False,
+ False, False, None, None, None, False, None, None, None, None,
+ 'Tracks reviews from cows', False)
+ goats_fd = tracker_bizobj.MakeFieldDef(
+ 4, 789, 'goats', tracker_pb2.FieldTypes.INT_TYPE, None, None, False,
+ False, False, None, None, None, False, None, None, None, None,
+ 'Num goats in each phase', False, is_phase_field=True)
+ self.config.field_defs = [os_fd, deadline_fd, approval_fd, goats_fd]
+
+ # The column is defined in cell_factories.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'summary', cell_factories, self.config)
+ self.assertEqual(table_view_helpers.TableCellSummary, actual)
+
+ # The column is a composite column.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'summary/stars', cell_factories, self.config)
+ self.assertEqual('FactoryClass', actual.__name__)
+
+ # The column is a enum custom field, so it is treated like a label.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'os', cell_factories, self.config)
+ self.assertEqual(table_view_helpers.TableCellKeyLabels, actual)
+
+ # The column is a non-enum custom field.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'deadline', cell_factories, self.config)
+ self.assertEqual(
+ [(table_view_helpers.TableCellCustom, 'deadline'),
+ (table_view_helpers.TableCellKeyLabels, 'deadline')],
+ actual.factory_col_list)
+
+ # The column is an approval custom field.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'CowApproval', cell_factories, self.config)
+ self.assertEqual(
+ [(table_view_helpers.TableCellApprovalStatus, 'CowApproval'),
+ (table_view_helpers.TableCellKeyLabels, 'CowApproval')],
+ actual.factory_col_list)
+
+ # The column is an approval custom field with '-approver'.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'CowApproval-approver', cell_factories, self.config)
+ self.assertEqual(
+ [(table_view_helpers.TableCellApprovalApprover, 'CowApproval-approver'),
+ (table_view_helpers.TableCellKeyLabels, 'CowApproval-approver')],
+ actual.factory_col_list)
+
+ # The column specifies a phase custom field.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'winter.goats', cell_factories, self.config)
+ self.assertEqual(
+ [(table_view_helpers.TableCellCustom, 'winter.goats'),
+ (table_view_helpers.TableCellKeyLabels, 'winter.goats')],
+ actual.factory_col_list)
+
+
+ # Column that don't match one of the other cases is assumed to be a label.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'reward', cell_factories, self.config)
+ self.assertEqual(table_view_helpers.TableCellKeyLabels, actual)
+
+ def testCompositeFactoryTableCell_Empty(self):
+ """If we made a composite of zero columns, it would have no values."""
+ composite = table_view_helpers.CompositeFactoryTableCell([])
+ cell = composite('artifact')
+ self.assertEqual([], cell.values)
+
+ def testCompositeFactoryTableCell_Normal(self):
+ """If we make a composite, it has values from each of the sub cells."""
+ composite = table_view_helpers.CompositeFactoryTableCell(
+ [(sub_factory_1, 'col1'),
+ (sub_factory_2, 'col2')])
+
+ cell = composite('artifact')
+ self.assertEqual(
+ ['sub_cell_1_col1',
+ 'sub_cell_2_col2'],
+ cell.values)
+
+ def testCompositeColTableCell_Empty(self):
+ """If we made a composite of zero columns, it would have no values."""
+ composite = table_view_helpers.CompositeColTableCell([], {}, self.config)
+ cell = composite('artifact')
+ self.assertEqual([], cell.values)
+
+
+ def testCompositeColTableCell_Normal(self):
+ """If we make a composite, it has values from each of the sub cells."""
+ composite = table_view_helpers.CompositeColTableCell(
+ ['col1', 'col2'],
+ {'col1': sub_factory_1, 'col2': sub_factory_2},
+ self.config)
+ cell = composite('artifact')
+ self.assertEqual(
+ ['sub_cell_1_col1',
+ 'sub_cell_2_col2'],
+ cell.values)
+
+
+def sub_factory_1(_art, **kw):
+ return testing_helpers.Blank(
+ values=['sub_cell_1_%s' % kw['col']],
+ non_column_labels=[])
+
+
+def sub_factory_2(_art, **kw):
+ return testing_helpers.Blank(
+ values=['sub_cell_2_%s' % kw['col']],
+ non_column_labels=[])
diff --git a/framework/test/template_helpers_test.py b/framework/test/template_helpers_test.py
new file mode 100644
index 0000000..85296fa
--- /dev/null
+++ b/framework/test/template_helpers_test.py
@@ -0,0 +1,216 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for template_helpers module."""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import unittest
+
+from framework import pbproxy_test_pb2
+from framework import template_helpers
+
+
+class HelpersUnitTest(unittest.TestCase):
+
+ def testDictionaryProxy(self):
+
+ # basic in 'n out test
+ item = template_helpers.EZTItem(label='foo', group_name='bar')
+
+ self.assertEqual('foo', item.label)
+ self.assertEqual('bar', item.group_name)
+
+ # be sure the __str__ returns the fields
+ self.assertEqual(
+ "EZTItem({'group_name': 'bar', 'label': 'foo'})", str(item))
+
+ def testPBProxy(self):
+ """Checks that PBProxy wraps protobuf objects as expected."""
+ # check that protobuf fields are accessible in ".attribute" form
+ pbe = pbproxy_test_pb2.PBProxyExample()
+ pbe.nickname = 'foo'
+ pbe.invited = False
+ pbep = template_helpers.PBProxy(pbe)
+ self.assertEqual(pbep.nickname, 'foo')
+ # _bool suffix converts protobuf field 'bar' to None (EZT boolean false)
+ self.assertEqual(pbep.invited_bool, None)
+
+ # check that a new field can be added to the PBProxy
+ pbep.baz = 'bif'
+ self.assertEqual(pbep.baz, 'bif')
+
+ # check that a PBProxy-local field can hide a protobuf field
+ pbep.nickname = 'local foo'
+ self.assertEqual(pbep.nickname, 'local foo')
+
+ # check that a nested protobuf is recursively wrapped with a PBProxy
+ pbn = pbproxy_test_pb2.PBProxyNested()
+ pbn.nested = pbproxy_test_pb2.PBProxyExample()
+ pbn.nested.nickname = 'bar'
+ pbn.nested.invited = True
+ pbnp = template_helpers.PBProxy(pbn)
+ self.assertEqual(pbnp.nested.nickname, 'bar')
+ # _bool suffix converts protobuf field 'bar' to 'yes' (EZT boolean true)
+ self.assertEqual(pbnp.nested.invited_bool, 'yes')
+
+ # check that 'repeated' lists of items produce a list of strings
+ pbn.multiple_strings.append('1')
+ pbn.multiple_strings.append('2')
+ self.assertEqual(pbnp.multiple_strings, ['1', '2'])
+
+ # check that 'repeated' messages produce lists of PBProxy instances
+ pbe1 = pbproxy_test_pb2.PBProxyExample()
+ pbn.multiple_pbes.append(pbe1)
+ pbe1.nickname = '1'
+ pbe1.invited = True
+ pbe2 = pbproxy_test_pb2.PBProxyExample()
+ pbn.multiple_pbes.append(pbe2)
+ pbe2.nickname = '2'
+ pbe2.invited = False
+ self.assertEqual(pbnp.multiple_pbes[0].nickname, '1')
+ self.assertEqual(pbnp.multiple_pbes[0].invited_bool, 'yes')
+ self.assertEqual(pbnp.multiple_pbes[1].nickname, '2')
+ self.assertEqual(pbnp.multiple_pbes[1].invited_bool, None)
+
+ def testFitTextMethods(self):
+ """Tests both FitUnsafeText with an eye on i18n."""
+ # pylint: disable=anomalous-unicode-escape-in-string
+ test_data = (
+ u'This is a short string.',
+
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. ',
+
+ # This is a short escaped i18n string
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab'.decode('utf-8'),
+
+ # This is a longer i18n string
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab '
+ '\xe6\x88\x91\xe8\x83\xbd\xe5\x90\x9e '
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab '
+ '\xe6\x88\x91\xe8\x83\xbd\xe5\x90\x9e '
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab '
+ '\xe6\x88\x91\xe8\x83\xbd\xe5\x90\x9e '
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab '
+ '\xe6\x88\x91\xe8\x83\xbd\xe5\x90\x9e '.decode('utf-8'),
+
+ # This is a longer i18n string that was causing trouble.
+ '\u041d\u0430 \u0431\u0435\u0440\u0435\u0433\u0443'
+ ' \u043f\u0443\u0441\u0442\u044b\u043d\u043d\u044b\u0445'
+ ' \u0432\u043e\u043b\u043d \u0421\u0442\u043e\u044f\u043b'
+ ' \u043e\u043d, \u0434\u0443\u043c'
+ ' \u0432\u0435\u043b\u0438\u043a\u0438\u0445'
+ ' \u043f\u043e\u043b\u043d, \u0418'
+ ' \u0432\u0434\u0430\u043b\u044c'
+ ' \u0433\u043b\u044f\u0434\u0435\u043b.'
+ ' \u041f\u0440\u0435\u0434 \u043d\u0438\u043c'
+ ' \u0448\u0438\u0440\u043e\u043a\u043e'
+ ' \u0420\u0435\u043a\u0430'
+ ' \u043d\u0435\u0441\u043b\u0430\u0441\u044f;'
+ ' \u0431\u0435\u0434\u043d\u044b\u0439'
+ ' \u0447\u0451\u043b\u043d \u041f\u043e'
+ ' \u043d\u0435\u0439'
+ ' \u0441\u0442\u0440\u0435\u043c\u0438\u043b\u0441\u044f'
+ ' \u043e\u0434\u0438\u043d\u043e\u043a\u043e.'
+ ' \u041f\u043e \u043c\u0448\u0438\u0441\u0442\u044b\u043c,'
+ ' \u0442\u043e\u043f\u043a\u0438\u043c'
+ ' \u0431\u0435\u0440\u0435\u0433\u0430\u043c'
+ ' \u0427\u0435\u0440\u043d\u0435\u043b\u0438'
+ ' \u0438\u0437\u0431\u044b \u0437\u0434\u0435\u0441\u044c'
+ ' \u0438 \u0442\u0430\u043c, \u041f\u0440\u0438\u044e\u0442'
+ ' \u0443\u0431\u043e\u0433\u043e\u0433\u043e'
+ ' \u0447\u0443\u0445\u043e\u043d\u0446\u0430;'
+ ' \u0418 \u043b\u0435\u0441,'
+ ' \u043d\u0435\u0432\u0435\u0434\u043e\u043c\u044b\u0439'
+ ' \u043b\u0443\u0447\u0430\u043c \u0412'
+ ' \u0442\u0443\u043c\u0430\u043d\u0435'
+ ' \u0441\u043f\u0440\u044f\u0442\u0430\u043d\u043d\u043e'
+ '\u0433\u043e \u0441\u043e\u043b\u043d\u0446\u0430,'
+ ' \u041a\u0440\u0443\u0433\u043e\u043c'
+ ' \u0448\u0443\u043c\u0435\u043b.'.decode('utf-8'))
+
+ for unicode_s in test_data:
+ # Get the length in characters, not bytes.
+ length = len(unicode_s)
+
+ # Test the FitUnsafeText method at the length boundary.
+ fitted_unsafe_text = template_helpers.FitUnsafeText(unicode_s, length)
+ self.assertEqual(fitted_unsafe_text, unicode_s)
+
+ # Set some values that test FitString well.
+ available_space = length // 2
+ max_trailing = length // 4
+ # Break the string at various places - symmetric range around 0
+ for i in range(1-max_trailing, max_trailing):
+ # Test the FitUnsafeText method.
+ fitted_unsafe_text = template_helpers.FitUnsafeText(
+ unicode_s, available_space - i)
+ self.assertEqual(fitted_unsafe_text[:available_space - i],
+ unicode_s[:available_space - i])
+
+ # Test a string that is already unicode
+ u_string = u'This is already unicode'
+ fitted_unsafe_text = template_helpers.FitUnsafeText(u_string, 100)
+ self.assertEqual(u_string, fitted_unsafe_text)
+
+ # Test a string that is already unicode, and has non-ascii in it.
+ u_string = u'This is already unicode este\\u0301tico'
+ fitted_unsafe_text = template_helpers.FitUnsafeText(u_string, 100)
+ self.assertEqual(u_string, fitted_unsafe_text)
+
+ def testEZTError(self):
+ errors = template_helpers.EZTError()
+ self.assertFalse(errors.AnyErrors())
+
+ errors.error_a = 'A'
+ self.assertTrue(errors.AnyErrors())
+ self.assertEqual('A', errors.error_a)
+
+ errors.SetError('error_b', 'B')
+ self.assertTrue(errors.AnyErrors())
+ self.assertEqual('A', errors.error_a)
+ self.assertEqual('B', errors.error_b)
+
+ def testBytesKbOrMb(self):
+ self.assertEqual('1023 bytes', template_helpers.BytesKbOrMb(1023))
+ self.assertEqual('1.0 KB', template_helpers.BytesKbOrMb(1024))
+ self.assertEqual('1023 KB', template_helpers.BytesKbOrMb(1024 * 1023))
+ self.assertEqual('1.0 MB', template_helpers.BytesKbOrMb(1024 * 1024))
+ self.assertEqual('98.0 MB', template_helpers.BytesKbOrMb(98 * 1024 * 1024))
+ self.assertEqual('99 MB', template_helpers.BytesKbOrMb(99 * 1024 * 1024))
+
+
+class TextRunTest(unittest.TestCase):
+
+ def testLink(self):
+ run = template_helpers.TextRun(
+ 'content', tag='a', href='http://example.com')
+ expected = '<a href="http://example.com">content</a>'
+ self.assertEqual(expected, run.FormatForHTMLEmail())
+
+ run = template_helpers.TextRun(
+ 'con<tent>', tag='a', href='http://exa"mple.com')
+ expected = '<a href="http://exa"mple.com">con<tent></a>'
+ self.assertEqual(expected, run.FormatForHTMLEmail())
+
+ def testText(self):
+ run = template_helpers.TextRun('content')
+ expected = 'content'
+ self.assertEqual(expected, run.FormatForHTMLEmail())
+
+ run = template_helpers.TextRun('con<tent>')
+ expected = 'con<tent>'
+ self.assertEqual(expected, run.FormatForHTMLEmail())
diff --git a/framework/test/timestr_test.py b/framework/test/timestr_test.py
new file mode 100644
index 0000000..ad11249
--- /dev/null
+++ b/framework/test/timestr_test.py
@@ -0,0 +1,95 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unittest for timestr module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import calendar
+import datetime
+import time
+import unittest
+
+from framework import timestr
+
+
+class TimeStrTest(unittest.TestCase):
+ """Unit tests for timestr routines."""
+
+ def testFormatAbsoluteDate(self):
+ now = datetime.datetime(2008, 1, 1)
+
+ def GetDate(*args):
+ date = datetime.datetime(*args)
+ return timestr.FormatAbsoluteDate(
+ calendar.timegm(date.utctimetuple()), clock=lambda: now)
+
+ self.assertEqual(GetDate(2008, 1, 1), 'Today')
+ self.assertEqual(GetDate(2007, 12, 31), 'Yesterday')
+ self.assertEqual(GetDate(2007, 12, 30), 'Dec 30')
+ self.assertEqual(GetDate(2007, 1, 1), 'Jan 2007')
+ self.assertEqual(GetDate(2007, 1, 2), 'Jan 2007')
+ self.assertEqual(GetDate(2007, 12, 31), 'Yesterday')
+ self.assertEqual(GetDate(2006, 12, 31), 'Dec 2006')
+ self.assertEqual(GetDate(2007, 7, 1), 'Jul 1')
+ self.assertEqual(GetDate(2007, 6, 30), 'Jun 2007')
+ self.assertEqual(GetDate(2008, 1, 3), 'Jan 2008')
+
+ # Leap year fun
+ now = datetime.datetime(2008, 3, 1)
+ self.assertEqual(GetDate(2008, 2, 29), 'Yesterday')
+
+ # Clock skew
+ now = datetime.datetime(2008, 1, 1, 23, 59, 59)
+ self.assertEqual(GetDate(2008, 1, 2), 'Today')
+ now = datetime.datetime(2007, 12, 31, 23, 59, 59)
+ self.assertEqual(GetDate(2008, 1, 1), 'Today')
+ self.assertEqual(GetDate(2008, 1, 2), 'Jan 2008')
+
+ def testFormatRelativeDate(self):
+ now = time.mktime(datetime.datetime(2008, 1, 1).timetuple())
+
+ def TestSecsAgo(secs_ago, expected, expected_days_only):
+ test_time = now - secs_ago
+ actual = timestr.FormatRelativeDate(
+ test_time, clock=lambda: now)
+ self.assertEqual(actual, expected)
+ actual_days_only = timestr.FormatRelativeDate(
+ test_time, clock=lambda: now, days_only=True)
+ self.assertEqual(actual_days_only, expected_days_only)
+
+ TestSecsAgo(10 * 24 * 60 * 60, '', '10 days ago')
+ TestSecsAgo(5 * 24 * 60 * 60 - 1, '4 days ago', '4 days ago')
+ TestSecsAgo(5 * 60 * 60 - 1, '4 hours ago', '')
+ TestSecsAgo(5 * 60 - 1, '4 minutes ago', '')
+ TestSecsAgo(2 * 60 - 1, '1 minute ago', '')
+ TestSecsAgo(60 - 1, 'moments ago', '')
+ TestSecsAgo(0, 'moments ago', '')
+ TestSecsAgo(-10, 'moments ago', '')
+ TestSecsAgo(-100, '', '')
+
+ def testGetHumanScaleDate(self):
+ """Tests GetHumanScaleDate()."""
+ now = time.mktime(datetime.datetime(2008, 4, 10, 20, 50, 30).timetuple())
+
+ def GetDate(*args):
+ date = datetime.datetime(*args)
+ timestamp = time.mktime(date.timetuple())
+ return timestr.GetHumanScaleDate(timestamp, now=now)
+
+ self.assertEqual(GetDate(2008, 4, 10, 15), ('Today', '5 hours ago'))
+ self.assertEqual(GetDate(2008, 4, 10, 19, 55), ('Today', '55 min ago'))
+ self.assertEqual(GetDate(2008, 4, 10, 20, 48, 35), ('Today', '1 min ago'))
+ self.assertEqual(GetDate(2008, 4, 10, 20, 49, 35), ('Today', 'moments ago'))
+ self.assertEqual(GetDate(2008, 4, 10, 20, 50, 55), ('Today', 'moments ago'))
+ self.assertEqual(GetDate(2008, 4, 9, 15), ('Yesterday', '29 hours ago'))
+ self.assertEqual(GetDate(2008, 4, 5, 15), ('Last 7 days', 'Apr 05, 2008'))
+ self.assertEqual(GetDate(2008, 3, 22, 15), ('Last 30 days', 'Mar 22, 2008'))
+ self.assertEqual(
+ GetDate(2008, 1, 2, 15), ('Earlier this year', 'Jan 02, 2008'))
+ self.assertEqual(
+ GetDate(2007, 12, 31, 15), ('Before this year', 'Dec 31, 2007'))
+ self.assertEqual(GetDate(2008, 4, 11, 20, 49, 35), ('Future', 'Later'))
diff --git a/framework/test/ts_mon_js_test.py b/framework/test/ts_mon_js_test.py
new file mode 100644
index 0000000..bcd4060
--- /dev/null
+++ b/framework/test/ts_mon_js_test.py
@@ -0,0 +1,73 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for MonorailTSMonJSHandler."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import json
+import unittest
+from mock import patch
+
+import webapp2
+from google.appengine.ext import testbed
+
+from framework.ts_mon_js import MonorailTSMonJSHandler
+from services import service_manager
+
+
+class MonorailTSMonJSHandlerTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_user_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ @patch('framework.xsrf.ValidateToken')
+ @patch('time.time')
+ def testSubmitMetrics(self, _mockTime, _mockValidateToken):
+ """Test normal case POSTing metrics."""
+ _mockTime.return_value = 1537821859
+ req = webapp2.Request.blank('/_/ts_mon_js')
+ req.body = json.dumps({
+ 'metrics': [{
+ 'MetricInfo': {
+ 'Name': 'monorail/frontend/issue_update_latency',
+ 'ValueType': 2,
+ },
+ 'Cells': [{
+ 'value': {
+ 'sum': 1234,
+ 'count': 4321,
+ 'buckets': {
+ 0: 123,
+ 1: 321,
+ 2: 213,
+ },
+ },
+ 'fields': {
+ 'client_id': '789',
+ 'host_name': 'rutabaga',
+ 'document_visible': True,
+ },
+ 'start_time': 1537821859 - 60,
+ }],
+ }],
+ })
+ res = webapp2.Response()
+ ts_mon_handler = MonorailTSMonJSHandler(request=req, response=res)
+ class MockApp(object):
+ def __init__(self):
+ self.config = {'services': service_manager.Services()}
+ ts_mon_handler.app = MockApp()
+
+ ts_mon_handler.post()
+
+ self.assertEqual(res.status_int, 201)
+ self.assertEqual(res.body, 'Ok.')
diff --git a/framework/test/validate_test.py b/framework/test/validate_test.py
new file mode 100644
index 0000000..9ea17fe
--- /dev/null
+++ b/framework/test/validate_test.py
@@ -0,0 +1,128 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""This file provides unit tests for Validate functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import validate
+
+
+class ValidateUnitTest(unittest.TestCase):
+ """Set of unit tests for validation functions."""
+
+ GOOD_EMAIL_ADDRESSES = [
+ 'user@example.com',
+ 'user@e.com',
+ 'user+tag@example.com',
+ 'u.ser@example.com',
+ 'us.er@example.com',
+ 'u.s.e.r@example.com',
+ 'user@ex-ample.com',
+ 'user@ex.ample.com',
+ 'user@e.x.ample.com',
+ 'user@exampl.e.com',
+ 'user@e-x-ample.com',
+ 'user@e-x-a-m-p-l-e.com',
+ 'user@e-x.am-ple.com',
+ 'user@e--xample.com',
+ ]
+
+ BAD_EMAIL_ADDRESSES = [
+ ' leading.whitespace@example.com',
+ 'trailing.whitespace@example.com ',
+ '(paren.quoted@example.com)',
+ '<angle.quoted@example.com>',
+ 'trailing.@example.com',
+ 'trailing.dot.@example.com',
+ '.leading@example.com',
+ '.leading.dot@example.com',
+ 'user@example.com.',
+ 'us..er@example.com',
+ 'user@ex..ample.com',
+ 'user@example..com',
+ 'user@ex-.ample.com',
+ 'user@-example.com',
+ 'user@.example.com',
+ 'user@example-.com',
+ 'user@example',
+ 'user@example.',
+ 'user@example.c',
+ 'user@example.comcomcomc',
+ 'user@example.co-m',
+ 'user@exa_mple.com',
+ 'user@exa-_mple.com',
+ 'user@example.c0m',
+ ]
+
+ def testIsValidEmail(self):
+ """Tests the Email validator class."""
+ for email in self.GOOD_EMAIL_ADDRESSES:
+ self.assertTrue(validate.IsValidEmail(email), msg='Rejected:%r' % email)
+
+ for email in self.BAD_EMAIL_ADDRESSES:
+ self.assertFalse(validate.IsValidEmail(email), msg='Accepted:%r' % email)
+
+ def testIsValidMailTo(self):
+ for email in self.GOOD_EMAIL_ADDRESSES:
+ self.assertTrue(
+ validate.IsValidMailTo('mailto:' + email),
+ msg='Rejected:%r' % ('mailto:' + email))
+
+ for email in self.BAD_EMAIL_ADDRESSES:
+ self.assertFalse(
+ validate.IsValidMailTo('mailto:' + email),
+ msg='Accepted:%r' % ('mailto:' + email))
+
+ GOOD_URLS = [
+ 'http://google.com',
+ 'http://maps.google.com/',
+ 'https://secure.protocol.com',
+ 'https://dash-domain.com',
+ 'http://www.google.com/search?q=foo&hl=en',
+ 'https://a.very.long.domain.name.net/with/a/long/path/inf0/too',
+ 'http://funny.ws/',
+ 'http://we.love.anchors.info/page.html#anchor',
+ 'http://redundant-slashes.com//in/path//info',
+ 'http://trailingslashe.com/in/path/info/',
+ 'http://domain.with.port.com:8080',
+ 'http://domain.with.port.com:8080/path/info',
+ 'ftp://ftp.gnu.org',
+ 'ftp://some.server.some.place.com',
+ 'http://b/123456',
+ 'http://cl/123456/',
+ ]
+
+ BAD_URLS = [
+ ' http://leading.whitespace.com',
+ 'http://trailing.domain.whitespace.com ',
+ 'http://trailing.whitespace.com/after/path/info ',
+ 'http://underscore_domain.com/',
+ 'http://space in domain.com',
+ 'http://user@example.com', # standard, but we purposely don't accept it.
+ 'http://user:pass@ex.com', # standard, but we purposely don't accept it.
+ 'http://:password@ex.com', # standard, but we purposely don't accept it.
+ 'missing-http.com',
+ 'http:missing-slashes.com',
+ 'http:/only-one-slash.com',
+ 'http://trailing.dot.',
+ 'mailto:bad.scheme',
+ 'javascript:attempt-to-inject',
+ 'http://short-with-no-final-slash',
+ 'http:///',
+ 'http:///no.host.name',
+ 'http://:8080/',
+ 'http://badport.com:808a0/ ',
+ ]
+
+ def testURL(self):
+ for url in self.GOOD_URLS:
+ self.assertTrue(validate.IsValidURL(url), msg='Rejected:%r' % url)
+
+ for url in self.BAD_URLS:
+ self.assertFalse(validate.IsValidURL(url), msg='Accepted:%r' % url)
diff --git a/framework/test/warmup_test.py b/framework/test/warmup_test.py
new file mode 100644
index 0000000..d8ddb65
--- /dev/null
+++ b/framework/test/warmup_test.py
@@ -0,0 +1,36 @@
+# Copyright 2017 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the warmup servlet."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from testing import testing_helpers
+
+from framework import sql
+from framework import warmup
+from services import service_manager
+
+
+class WarmupTest(unittest.TestCase):
+
+ def setUp(self):
+ #self.cache_manager = cachemanager_svc.CacheManager()
+ #self.services = service_manager.Services(
+ # cache_manager=self.cache_manager)
+ self.services = service_manager.Services()
+ self.servlet = warmup.Warmup(
+ 'req', 'res', services=self.services)
+
+
+ def testHandleRequest_NothingToDo(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ actual_json_data = self.servlet.HandleRequest(mr)
+ self.assertEqual(
+ {'success': 1},
+ actual_json_data)
diff --git a/framework/test/xsrf_test.py b/framework/test/xsrf_test.py
new file mode 100644
index 0000000..aa04570
--- /dev/null
+++ b/framework/test/xsrf_test.py
@@ -0,0 +1,113 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for XSRF utility functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import unittest
+
+from mock import patch
+
+from google.appengine.ext import testbed
+
+import settings
+from framework import xsrf
+
+
+class XsrfTest(unittest.TestCase):
+ """Set of unit tests for blocking XSRF attacks."""
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testGenerateToken_AnonUserGetsAToken(self):
+ self.assertNotEqual('', xsrf.GenerateToken(0, '/path'))
+
+ def testGenerateToken_DifferentUsersGetDifferentTokens(self):
+ self.assertNotEqual(
+ xsrf.GenerateToken(111, '/path'),
+ xsrf.GenerateToken(222, '/path'))
+
+ self.assertNotEqual(
+ xsrf.GenerateToken(111, '/path'),
+ xsrf.GenerateToken(0, '/path'))
+
+ def testGenerateToken_DifferentPathsGetDifferentTokens(self):
+ self.assertNotEqual(
+ xsrf.GenerateToken(111, '/path/one'),
+ xsrf.GenerateToken(111, '/path/two'))
+
+ def testValidToken(self):
+ token = xsrf.GenerateToken(111, '/path')
+ xsrf.ValidateToken(token, 111, '/path') # no exception raised
+
+ def testMalformedToken(self):
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, 'bad', 111, '/path')
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, '', 111, '/path')
+
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, '098a08fe08b08c08a05e:9721973123', 111, '/path')
+
+ def testWrongUser(self):
+ token = xsrf.GenerateToken(111, '/path')
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, token, 222, '/path')
+
+ def testWrongPath(self):
+ token = xsrf.GenerateToken(111, '/path/one')
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, token, 111, '/path/two')
+
+ @patch('time.time')
+ def testValidateToken_Expiration(self, mockTime):
+ test_time = 1526671379
+ mockTime.return_value = test_time
+ token = xsrf.GenerateToken(111, '/path')
+ xsrf.ValidateToken(token, 111, '/path')
+
+ mockTime.return_value = test_time + 1
+ xsrf.ValidateToken(token, 111, '/path')
+
+ mockTime.return_value = test_time + xsrf.TOKEN_TIMEOUT_SEC
+ xsrf.ValidateToken(token, 111, '/path')
+
+ mockTime.return_value = test_time + xsrf.TOKEN_TIMEOUT_SEC + 1
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, token, 11, '/path')
+
+ @patch('time.time')
+ def testValidateToken_Future(self, mockTime):
+ """We reject tokens from the future."""
+ test_time = 1526671379
+ mockTime.return_value = test_time
+ token = xsrf.GenerateToken(111, '/path')
+ xsrf.ValidateToken(token, 111, '/path')
+
+ # The clock of the GAE instance doing the checking might be slightly slow.
+ mockTime.return_value = test_time - 1
+ xsrf.ValidateToken(token, 111, '/path')
+
+ # But, if the difference is too much, someone is trying to fake a token.
+ mockTime.return_value = test_time - xsrf.CLOCK_SKEW_SEC - 1
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, token, 111, '/path')