Project import generated by Copybara.
GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/framework/__init__.py b/framework/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/framework/__init__.py
@@ -0,0 +1 @@
+
diff --git a/framework/alerts.py b/framework/alerts.py
new file mode 100644
index 0000000..1d24f77
--- /dev/null
+++ b/framework/alerts.py
@@ -0,0 +1,57 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Helpers for showing alerts at the top of the page.
+
+These alerts are then displayed by alerts.ezt.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+
+import ezt
+
+# Expiration time for special features of timestamped links.
+# This is not for security, just for informational messages that
+# make sense in the context of a user session, but that should
+# not appear days later if the user follows a bookmarked link.
+_LINK_EXPIRATION_SEC = 8
+
+
+class AlertsView(object):
+ """EZT object for showing alerts at the top of the page."""
+
+ def __init__(self, mr):
+ # Used to show message confirming item was updated
+ self.updated = mr.GetIntParam('updated')
+
+ # Used to show message confirming item was moved and the location of the new
+ # item.
+ self.moved_to_project = mr.GetParam('moved_to_project')
+ self.moved_to_id = mr.GetIntParam('moved_to_id')
+ self.moved = self.moved_to_project and self.moved_to_id
+
+ # Used to show message confirming item was copied and the location of the
+ # new item.
+ self.copied_from_id = mr.GetIntParam('copied_from_id')
+ self.copied_to_project = mr.GetParam('copied_to_project')
+ self.copied_to_id = mr.GetIntParam('copied_to_id')
+ self.copied = self.copied_to_project and self.copied_to_id
+
+ # Used to show message confirming items deleted
+ self.deleted = mr.GetParam('deleted')
+
+ # If present, we will show message confirming that data was saved
+ self.saved = mr.GetParam('saved')
+
+ link_generation_timestamp = mr.GetIntParam('ts', default_value=0)
+ now = int(time.time())
+ ts_links_are_valid = now - link_generation_timestamp < _LINK_EXPIRATION_SEC
+
+ show_alert = ts_links_are_valid and (
+ self.updated or self.moved or self.copied or self.deleted or self.saved)
+ self.show = ezt.boolean(show_alert)
diff --git a/framework/authdata.py b/framework/authdata.py
new file mode 100644
index 0000000..3c1bee9
--- /dev/null
+++ b/framework/authdata.py
@@ -0,0 +1,145 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Classes to hold information parsed from a request.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+from google.appengine.api import users
+
+from proto import user_pb2
+from framework import framework_bizobj
+from framework import framework_views
+
+
+class AuthData(object):
+ """This object holds authentication data about a user.
+
+ This is used by MonorailRequest as it determines which user the
+ requester is authenticated as and fetches the user's data. It can
+ also be used to lookup perms for user IDs specified in issue fields.
+
+ Attributes:
+ user_id: The user ID of the user (or 0 if not signed in).
+ effective_ids: A set of user IDs that includes the signed in user's
+ direct user ID and the user IDs of all their user groups.
+ This set will be empty for anonymous users.
+ user_view: UserView object for the signed-in user.
+ user_pb: User object for the signed-in user.
+ email: email address for the user, or None.
+ """
+
+ def __init__(self, user_id=0, email=None):
+ self.user_id = user_id
+ self.effective_ids = {user_id} if user_id else set()
+ self.user_view = None
+ self.user_pb = user_pb2.MakeUser(user_id)
+ self.email = email
+
+ @classmethod
+ def FromRequest(cls, cnxn, services):
+ """Determine auth information from the request and fetches user data.
+
+ If everything works and the user is signed in, then all of the public
+ attributes of the AuthData instance will be filled in appropriately.
+
+ Args:
+ cnxn: connection to the SQL database.
+ services: Interface to all persistence storage backends.
+
+ Returns:
+ A new AuthData object.
+ """
+ user = users.get_current_user()
+ if user is None:
+ return cls()
+ else:
+ # We create a User row for each user who visits the site.
+ # TODO(jrobbins): we should really only do it when they take action.
+ return cls.FromEmail(cnxn, user.email(), services, autocreate=True)
+
+ @classmethod
+ def FromEmail(cls, cnxn, email, services, autocreate=False):
+ """Determine auth information for the given user email address.
+
+ Args:
+ cnxn: monorail connection to the database.
+ email: string email address of the user.
+ services: connections to backend servers.
+ autocreate: set to True to create a new row in the Users table if needed.
+
+ Returns:
+ A new AuthData object.
+
+ Raises:
+ execptions.NoSuchUserException: If the user of the email does not exist.
+ """
+ auth = cls()
+ auth.email = email
+ if email:
+ auth.user_id = services.user.LookupUserID(
+ cnxn, email, autocreate=autocreate)
+ assert auth.user_id
+ cls._FinishInitialization(cnxn, auth, services, user_pb=None)
+
+ return auth
+
+ @classmethod
+ def FromUserID(cls, cnxn, user_id, services):
+ """Determine auth information for the given user ID.
+
+ Args:
+ cnxn: monorail connection to the database.
+ user_id: int user ID of the user.
+ services: connections to backend servers.
+
+ Returns:
+ A new AuthData object.
+ """
+ auth = cls()
+ auth.user_id = user_id
+ if auth.user_id:
+ auth.email = services.user.LookupUserEmail(cnxn, user_id)
+ cls._FinishInitialization(cnxn, auth, services, user_pb=None)
+
+ return auth
+
+ @classmethod
+ def FromUser(cls, cnxn, user, services):
+ """Determine auth information for the given user.
+
+ Args:
+ cnxn: monorail connection to the database.
+ user: user protobuf.
+ services: connections to backend servers.
+
+ Returns:
+ A new AuthData object.
+ """
+ auth = cls()
+ auth.user_id = user.user_id
+ if auth.user_id:
+ auth.email = user.email
+ cls._FinishInitialization(cnxn, auth, services, user)
+
+ return auth
+
+
+ @classmethod
+ def _FinishInitialization(cls, cnxn, auth, services, user_pb=None):
+ """Fill in the test of the fields based on the user_id."""
+ effective_ids_dict = framework_bizobj.GetEffectiveIds(
+ cnxn, services, [auth.user_id])
+ auth.effective_ids = effective_ids_dict[auth.user_id]
+ auth.user_pb = user_pb or services.user.GetUser(cnxn, auth.user_id)
+ if auth.user_pb:
+ auth.user_view = framework_views.UserView(auth.user_pb)
+
+ def __repr__(self):
+ """Return a string more useful for debugging."""
+ return 'AuthData(email=%r, user_id=%r, effective_ids=%r)' % (
+ self.email, self.user_id, self.effective_ids)
diff --git a/framework/banned.py b/framework/banned.py
new file mode 100644
index 0000000..cb0e220
--- /dev/null
+++ b/framework/banned.py
@@ -0,0 +1,54 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A class to display the a message explaining that the user has been banned.
+
+We can ban a user for anti-social behavior. We indicate that the user is
+banned by adding a 'banned' field to their User PB in the DB. Whenever
+a user with a banned indicator visits any page, AssertBasePermission()
+checks has_banned and redirects to this page.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+
+import ezt
+
+from framework import permissions
+from framework import servlet
+
+
+class Banned(servlet.Servlet):
+ """The Banned page shows a message explaining that the user is banned."""
+
+ _PAGE_TEMPLATE = 'framework/banned-page.ezt'
+
+ def AssertBasePermission(self, mr):
+ """Allow banned users to see this page, and prevent non-banned users."""
+ # Note, we do not call Servlet.AssertBasePermission because
+ # that would redirect banned users here again in an endless loop.
+
+ # We only show this page to users who are banned. If a non-banned user
+ # follows a link to this URL, don't show the banned message, because that
+ # would lead to a big misunderstanding.
+ if not permissions.IsBanned(mr.auth.user_pb, mr.auth.user_view):
+ logging.info('non-banned user: %s', mr.auth.user_pb)
+ self.abort(404)
+
+ def GatherPageData(self, mr):
+ """Build up a dictionary of data values to use when rendering the page."""
+ # Aside from plus-addresses, we do not display the specific
+ # reason for banning.
+ is_plus_address = '+' in (mr.auth.user_pb.email or '')
+
+ return {
+ 'is_plus_address': ezt.boolean(is_plus_address),
+
+ # Make the "Sign Out" link just sign out, don't try to bring the
+ # user back to this page after they sign out.
+ 'currentPageURLEncoded': None,
+ }
diff --git a/framework/clientmon.py b/framework/clientmon.py
new file mode 100644
index 0000000..cc4917c
--- /dev/null
+++ b/framework/clientmon.py
@@ -0,0 +1,52 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A class to log client-side javascript error reports.
+
+Updates frontend/js_errors ts_mon metric.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import json
+import logging
+
+from framework import jsonfeed
+
+from infra_libs import ts_mon
+
+class ClientMonitor(jsonfeed.JsonFeed):
+ """JSON feed to track client side js errors in ts_mon."""
+
+ js_errors = ts_mon.CounterMetric('frontend/js_errors',
+ 'Number of uncaught client-side JS errors.',
+ None)
+
+ def HandleRequest(self, mr):
+ """Build up a dictionary of data values to use when rendering the page.
+
+ Args:
+ mr: commonly used info parsed from the request.
+
+ Returns:
+ Dict of values used by EZT for rendering the page.
+ """
+
+ post_data = mr.request.POST
+ errors = post_data.get('errors')
+ try:
+ errors = json.loads(errors)
+
+ total_errors = 0
+ for error_key in errors:
+ total_errors += errors[error_key]
+ logging.error('client monitor report (%d): %s', total_errors,
+ post_data.get('errors'))
+ self.js_errors.increment_by(total_errors)
+ except Exception as e:
+ logging.error('Problem processing client monitor report: %r', e)
+
+ return {}
diff --git a/framework/cloud_tasks_helpers.py b/framework/cloud_tasks_helpers.py
new file mode 100644
index 0000000..a00fa0d
--- /dev/null
+++ b/framework/cloud_tasks_helpers.py
@@ -0,0 +1,99 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""A helper module for interfacing with google cloud tasks.
+
+This module wraps Gooogle Cloud Tasks, link to its documentation:
+https://googleapis.dev/python/cloudtasks/1.3.0/gapic/v2/api.html
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import urllib
+
+from google.api_core import exceptions
+from google.api_core import retry
+
+import settings
+
+if not settings.unit_test_mode:
+ import grpc
+ from google.cloud import tasks
+
+_client = None
+# Default exponential backoff retry config for enqueueing, not to be confused
+# with retry config for dispatching, which exists per queue.
+_DEFAULT_RETRY = retry.Retry(initial=.1, maximum=1.6, multiplier=2, deadline=10)
+
+
+def _get_client():
+ # type: () -> tasks.CloudTasksClient
+ """Returns a cloud tasks client."""
+ global _client
+ if not _client:
+ if settings.local_mode:
+ _client = tasks.CloudTasksClient(
+ channel=grpc.insecure_channel(settings.CLOUD_TASKS_EMULATOR_ADDRESS))
+ else:
+ _client = tasks.CloudTasksClient()
+ return _client
+
+
+def create_task(task, queue='default', **kwargs):
+ # type: (Union[dict, tasks.types.Task], str, **Any) ->
+ # tasks.types.Task
+ """Tries and catches creating a cloud task.
+
+ This exposes a simplied task creation interface by wrapping
+ tasks.CloudTasksClient.create_task; see its documentation:
+ https://googleapis.dev/python/cloudtasks/1.5.0/gapic/v2/api.html#google.cloud.tasks_v2.CloudTasksClient.create_task
+
+ Args:
+ task: A dict or Task describing the task to add.
+ queue: A string indicating name of the queue to add task to.
+ kwargs: Additional arguments to pass to cloud task client's create_task
+
+ Returns:
+ Successfully created Task object.
+
+ Raises:
+ AttributeError: If input task is malformed or missing attributes.
+ google.api_core.exceptions.GoogleAPICallError: If the request failed for any
+ reason.
+ google.api_core.exceptions.RetryError: If the request failed due to a
+ retryable error and retry attempts failed.
+ ValueError: If the parameters are invalid.
+ """
+ client = _get_client()
+
+ parent = client.queue_path(
+ settings.app_id, settings.CLOUD_TASKS_REGION, queue)
+ target = task.get('app_engine_http_request').get('relative_uri')
+ kwargs.setdefault('retry', _DEFAULT_RETRY)
+ logging.info('Enqueueing %s task to %s', target, parent)
+ return client.create_task(parent, task, **kwargs)
+
+
+def generate_simple_task(url, params):
+ # type: (str, dict) -> dict
+ """Construct a basic cloud tasks Task for an appengine handler.
+ Args:
+ url: Url path that handles the task.
+ params: Url query parameters dict.
+
+ Returns:
+ Dict representing a cloud tasks Task object.
+ """
+ return {
+ 'app_engine_http_request':
+ {
+ 'relative_uri': url,
+ 'body': urllib.urlencode(params),
+ 'headers': {
+ 'Content-type': 'application/x-www-form-urlencoded'
+ }
+ }
+ }
diff --git a/framework/csp_report.py b/framework/csp_report.py
new file mode 100644
index 0000000..83e3126
--- /dev/null
+++ b/framework/csp_report.py
@@ -0,0 +1,22 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Servlet for Content Security Policy violation reporting.
+See http://www.html5rocks.com/en/tutorials/security/content-security-policy/
+for more information on how this mechanism works.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import webapp2
+import logging
+
+
+class CSPReportPage(webapp2.RequestHandler):
+ """CSPReportPage serves CSP violation reports."""
+
+ def post(self):
+ logging.error('CSP Violation: %s' % self.request.body)
diff --git a/framework/csv_helpers.py b/framework/csv_helpers.py
new file mode 100644
index 0000000..3dd10c7
--- /dev/null
+++ b/framework/csv_helpers.py
@@ -0,0 +1,74 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Helper functions for creating CSV pagedata."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import types
+
+from framework import framework_helpers
+
+
+# Whenever the user request one of these columns, we replace it with the
+# list of alternate columns. In effect, we split the requested column
+# into two CSV columns.
+_CSV_COLS_TO_REPLACE = {
+ 'summary': ['Summary', 'AllLabels'],
+ 'opened': ['Opened', 'OpenedTimestamp'],
+ 'closed': ['Closed', 'ClosedTimestamp'],
+ 'modified': ['Modified', 'ModifiedTimestamp'],
+ 'ownermodified': ['OwnerModified', 'OwnerModifiedTimestamp'],
+ 'statusmodified': ['StatusModified', 'StatusModifiedTimestamp'],
+ 'componentmodified': ['ComponentModified', 'ComponentModifiedTimestamp'],
+ 'ownerlastvisit': ['OwnerLastVisit', 'OwnerLastVisitDaysAgo'],
+ }
+
+
+def RewriteColspec(col_spec):
+ """Rewrite the given colspec to expand special CSV columns."""
+ new_cols = []
+
+ for col in col_spec.split():
+ rewriten_cols = _CSV_COLS_TO_REPLACE.get(col.lower(), [col])
+ new_cols.extend(rewriten_cols)
+
+ return ' '.join(new_cols)
+
+
+def ReformatRowsForCSV(mr, page_data, url_path):
+ """Rewrites/adds to the given page_data so the CSV templates can use it."""
+ # CSV files are at risk for the PDF content sniffing by Acrobat Reader
+ page_data['prevent_sniffing'] = True
+
+ # If we're truncating the results, add a URL to the next page of results
+ page_data['next_csv_link'] = None
+ pagination = page_data['pagination']
+ if pagination.next_url:
+ page_data['next_csv_link'] = framework_helpers.FormatAbsoluteURL(
+ mr, url_path, start=pagination.last)
+ page_data['item_count'] = pagination.last - pagination.start + 1
+
+ for row in page_data['table_data']:
+ for cell in row.cells:
+ for value in cell.values:
+ value.item = EscapeCSV(value.item)
+ return page_data
+
+
+def EscapeCSV(s):
+ """Return a version of string S that is safe as part of a CSV file."""
+ if s is None:
+ return ''
+ if isinstance(s, types.StringTypes):
+ s = s.strip().replace('"', '""')
+ # Prefix any formula cells because some spreadsheets have built-in
+ # formila functions that can actually have side-effects on the user's
+ # computer.
+ if s.startswith(('=', '-', '+', '@')):
+ s = "'" + s
+
+ return s
diff --git a/framework/deleteusers.py b/framework/deleteusers.py
new file mode 100644
index 0000000..0c23ac5
--- /dev/null
+++ b/framework/deleteusers.py
@@ -0,0 +1,139 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+"""Cron and task handlers for syncing with wipeoute-lite and deleting users."""
+
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import json
+import logging
+import httplib2
+
+from google.appengine.api import app_identity
+
+from businesslogic import work_env
+from framework import cloud_tasks_helpers
+from framework import framework_constants
+from framework import jsonfeed
+from framework import urls
+from oauth2client.client import GoogleCredentials
+
+WIPEOUT_ENDPOINT = 'https://emporia-pa.googleapis.com/v1/apps/%s'
+MAX_BATCH_SIZE = 10000
+MAX_DELETE_USERS_SIZE = 1000
+
+
+def authorize():
+ credentials = GoogleCredentials.get_application_default()
+ credentials = credentials.create_scoped(framework_constants.OAUTH_SCOPE)
+ return credentials.authorize(httplib2.Http(timeout=60))
+
+
+class WipeoutSyncCron(jsonfeed.InternalTask):
+ """Enqueue tasks for sending user lists to wipeout-lite and deleting deleted
+ users fetched from wipeout-lite."""
+
+ def HandleRequest(self, mr):
+ batch_param = mr.GetIntParam('batchsize', default_value=MAX_BATCH_SIZE)
+ # Use batch_param as batch_size unless it is None or 0.
+ batch_size = min(batch_param, MAX_BATCH_SIZE)
+ total_users = self.services.user.TotalUsersCount(mr.cnxn)
+ total_batches = int(total_users / batch_size)
+ # Add an extra batch to process remainder user emails.
+ if total_users % batch_size:
+ total_batches += 1
+ if not total_batches:
+ logging.info('No users to report.')
+ return
+
+ for i in range(total_batches):
+ params = dict(limit=batch_size, offset=i * batch_size)
+ task = cloud_tasks_helpers.generate_simple_task(
+ urls.SEND_WIPEOUT_USER_LISTS_TASK + '.do', params)
+ cloud_tasks_helpers.create_task(
+ task, queue=framework_constants.QUEUE_SEND_WIPEOUT_USER_LISTS)
+
+ task = cloud_tasks_helpers.generate_simple_task(
+ urls.DELETE_WIPEOUT_USERS_TASK + '.do', {})
+ cloud_tasks_helpers.create_task(
+ task, queue=framework_constants.QUEUE_FETCH_WIPEOUT_DELETED_USERS)
+
+
+class SendWipeoutUserListsTask(jsonfeed.InternalTask):
+ """Sends a batch of monorail users to wipeout-lite."""
+
+ def HandleRequest(self, mr):
+ limit = mr.GetIntParam('limit')
+ assert limit != None, 'Missing param limit'
+ offset = mr.GetIntParam('offset')
+ assert offset != None, 'Missing param offset'
+ emails = self.services.user.GetAllUserEmailsBatch(
+ mr.cnxn, limit=limit, offset=offset)
+ accounts = [{'id': email} for email in emails]
+ service = authorize()
+ self.sendUserLists(service, accounts)
+
+ def sendUserLists(self, service, accounts):
+ app_id = app_identity.get_application_id()
+ endpoint = WIPEOUT_ENDPOINT % app_id
+ resp, data = service.request(
+ '%s/verifiedaccounts' % endpoint,
+ method='POST',
+ headers={'Content-Type': 'application/json; charset=UTF-8'},
+ body=json.dumps(accounts))
+ logging.info(
+ 'Received response, %s with contents, %s', resp, data)
+
+
+class DeleteWipeoutUsersTask(jsonfeed.InternalTask):
+ """Fetches deleted users from wipeout-lite and enqueues tasks to delete
+ those users from Monorail's DB."""
+
+ def HandleRequest(self, mr):
+ limit = mr.GetIntParam('limit', MAX_DELETE_USERS_SIZE)
+ limit = min(limit, MAX_DELETE_USERS_SIZE)
+ service = authorize()
+ deleted_user_data = self.fetchDeletedUsers(service)
+ deleted_emails = [user_object['id'] for user_object in deleted_user_data]
+ total_batches = int(len(deleted_emails) / limit)
+ if len(deleted_emails) % limit:
+ total_batches += 1
+
+ for i in range(total_batches):
+ start = i * limit
+ end = start + limit
+ params = dict(emails=','.join(deleted_emails[start:end]))
+ task = cloud_tasks_helpers.generate_simple_task(
+ urls.DELETE_USERS_TASK + '.do', params)
+ cloud_tasks_helpers.create_task(
+ task, queue=framework_constants.QUEUE_DELETE_USERS)
+
+ def fetchDeletedUsers(self, service):
+ app_id = app_identity.get_application_id()
+ endpoint = WIPEOUT_ENDPOINT % app_id
+ resp, data = service.request(
+ '%s/deletedaccounts' % endpoint,
+ method='GET',
+ headers={'Content-Type': 'application/json; charset=UTF-8'})
+ logging.info(
+ 'Received response, %s with contents, %s', resp, data)
+ return json.loads(data)
+
+
+class DeleteUsersTask(jsonfeed.InternalTask):
+ """Deletes users from Monorail's DB."""
+
+ def HandleRequest(self, mr):
+ """Delete users with the emails given in the 'emails' param."""
+ emails = mr.GetListParam('emails', default_value=[])
+ assert len(emails) <= MAX_DELETE_USERS_SIZE, (
+ 'We cannot delete more than %d users at once, current users: %d' %
+ (MAX_DELETE_USERS_SIZE, len(emails)))
+ if len(emails) == 0:
+ logging.info("No user emails found in deletion request")
+ return
+ with work_env.WorkEnv(mr, self.services) as we:
+ we.ExpungeUsers(emails, check_perms=False)
diff --git a/framework/emailfmt.py b/framework/emailfmt.py
new file mode 100644
index 0000000..2933fea
--- /dev/null
+++ b/framework/emailfmt.py
@@ -0,0 +1,422 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Functions that format or parse email messages in Monorail.
+
+Specifically, this module has the logic for generating various email
+header lines that help match inbound and outbound email to the project
+and artifact that generated it.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import hmac
+import logging
+import re
+import rfc822
+
+import six
+
+from google.appengine.api import app_identity
+
+import settings
+from framework import framework_constants
+from services import client_config_svc
+from services import secrets_svc
+
+# TODO(jrobbins): Parsing very large messages is slow, and we are not going
+# to handle attachments at first, so there is no reason to consider large
+# emails.
+MAX_BODY_SIZE = 100 * 1024
+MAX_HEADER_CHARS_CONSIDERED = 255
+
+
+def _checkEmailHeaderPrefix(key):
+ """Ensures that a given email header starts with X-Alert2Monorail prefix."""
+ # this is to catch typos in the email header prefix and raises an exception
+ # during package loading time.
+ assert key.startswith('X-Alert2Monorail')
+ return key
+
+
+class AlertEmailHeader(object):
+ """A list of the email header keys supported by Alert2Monorail."""
+ # pylint: disable=bad-whitespace
+ #
+ # The prefix has been hard-coded without string substitution to make them
+ # searchable with the header keys.
+ INCIDENT_ID = 'X-Incident-Id'
+ OWNER = _checkEmailHeaderPrefix('X-Alert2Monorail-owner')
+ CC = _checkEmailHeaderPrefix('X-Alert2Monorail-cc')
+ PRIORITY = _checkEmailHeaderPrefix('X-Alert2Monorail-priority')
+ STATUS = _checkEmailHeaderPrefix('X-Alert2Monorail-status')
+ COMPONENT = _checkEmailHeaderPrefix('X-Alert2Monorail-component')
+ OS = _checkEmailHeaderPrefix('X-Alert2Monorail-os')
+ TYPE = _checkEmailHeaderPrefix('X-Alert2Monorail-type')
+ LABEL = _checkEmailHeaderPrefix('X-Alert2Monorail-label')
+
+
+def IsBodyTooBigToParse(body):
+ """Return True if the email message body is too big to process."""
+ return len(body) > MAX_BODY_SIZE
+
+
+def IsProjectAddressOnToLine(project_addr, to_addrs):
+ """Return True if an email was explicitly sent directly to us."""
+ return project_addr in to_addrs
+
+
+def ParseEmailMessage(msg):
+ """Parse the given MessageRouterMessage and return relevant fields.
+
+ Args:
+ msg: email.message.Message object for the email message sent to us.
+
+ Returns:
+ A tuple: from_addr, to_addrs, cc_addrs, references,
+ incident_id, subject, body.
+ """
+ # Ignore messages that are probably not from humans, see:
+ # http://google.com/search?q=precedence+bulk+junk
+ precedence = msg.get('precedence', '')
+ if precedence.lower() in ['bulk', 'junk']:
+ logging.info('Precedence: %r indicates an autoresponder', precedence)
+ return '', [], [], '', '', '', ''
+
+ from_addrs = _ExtractAddrs(msg.get('from', ''))
+ if from_addrs:
+ from_addr = from_addrs[0]
+ else:
+ from_addr = ''
+
+ to_addrs = _ExtractAddrs(msg.get('to', ''))
+ cc_addrs = _ExtractAddrs(msg.get('cc', ''))
+
+ in_reply_to = msg.get('in-reply-to', '')
+ incident_id = msg.get(AlertEmailHeader.INCIDENT_ID, '')
+ references = msg.get('references', '').split()
+ references = list({ref for ref in [in_reply_to] + references if ref})
+ subject = _StripSubjectPrefixes(msg.get('subject', ''))
+
+ body = u''
+ for part in msg.walk():
+ # We only process plain text emails.
+ if part.get_content_type() == 'text/plain':
+ body = part.get_payload(decode=True)
+ if not isinstance(body, six.text_type):
+ body = body.decode('utf-8')
+ break # Only consider the first text part.
+
+ return (from_addr, to_addrs, cc_addrs, references, incident_id, subject,
+ body)
+
+
+def _ExtractAddrs(header_value):
+ """Given a message header value, return email address found there."""
+ friendly_addr_pairs = list(rfc822.AddressList(header_value))
+ return [addr for _friendly, addr in friendly_addr_pairs]
+
+
+def _StripSubjectPrefixes(subject):
+ """Strip off any 'Re:', 'Fwd:', etc. subject line prefixes."""
+ prefix = _FindSubjectPrefix(subject)
+ while prefix:
+ subject = subject[len(prefix):].strip()
+ prefix = _FindSubjectPrefix(subject)
+
+ return subject
+
+
+def _FindSubjectPrefix(subject):
+ """If the given subject starts with a prefix, return that prefix."""
+ for prefix in ['re:', 'aw:', 'fwd:', 'fw:']:
+ if subject.lower().startswith(prefix):
+ return prefix
+
+ return None
+
+
+def MailDomain():
+ """Return the domain name where this app can recieve email."""
+ if settings.unit_test_mode:
+ return 'testbed-test.appspotmail.com'
+
+ # If running on a GAFYD domain, you must define an app alias on the
+ # Application Settings admin web page. If you cannot reserve the matching
+ # APP_ID for the alias, then specify it in settings.mail_domain.
+ if settings.mail_domain:
+ return settings.mail_domain
+
+ app_id = app_identity.get_application_id()
+ if ':' in app_id:
+ app_id = app_id.split(':')[-1]
+
+ return '%s.appspotmail.com' % app_id
+
+
+def FormatFriendly(commenter_view, sender, reveal_addr):
+ """Format the From: line to include the commenter's friendly name if given."""
+ if commenter_view:
+ site_name = settings.site_name.lower()
+ if commenter_view.email in client_config_svc.GetServiceAccountMap():
+ friendly = commenter_view.display_name
+ elif reveal_addr:
+ friendly = commenter_view.email
+ else:
+ friendly = u'%s\u2026@%s' % (
+ commenter_view.obscured_username, commenter_view.domain)
+ if '@' in sender:
+ sender_username, sender_domain = sender.split('@', 1)
+ sender = '%s+v2.%d@%s' % (
+ sender_username, commenter_view.user_id, sender_domain)
+ friendly = friendly.split('@')[0]
+ return '%s via %s <%s>' % (friendly, site_name, sender)
+ else:
+ return sender
+
+
+def NoReplyAddress(commenter_view=None, reveal_addr=False):
+ """Return an address that ignores all messages sent to it."""
+ # Note: We use "no_reply" with an underscore to avoid potential conflict
+ # with any project name. Project names cannot have underscores.
+ # Note: This does not take branded domains into account, but this address
+ # is only used for email error messages and in the reply-to address
+ # when the user is not allowed to reply.
+ sender = 'no_reply@%s' % MailDomain()
+ return FormatFriendly(commenter_view, sender, reveal_addr)
+
+
+def FormatFromAddr(project, commenter_view=None, reveal_addr=False,
+ can_reply_to=True):
+ """Return a string to be used on the email From: line.
+
+ Args:
+ project: Project PB for the project that the email is sent from.
+ commenter_view: Optional UserView of the user who made a comment. We use
+ the user's (potentially obscured) email address as their friendly name.
+ reveal_addr: Optional bool. If False then the address is obscured.
+ can_reply_to: Optional bool. If True then settings.send_email_as is used,
+ otherwise settings.send_noreply_email_as is used.
+
+ Returns:
+ A string that should be used in the From: line of outbound email
+ notifications for the given project.
+ """
+ addr_format = (settings.send_email_as_format if can_reply_to
+ else settings.send_noreply_email_as_format)
+ domain = settings.branded_domains.get(
+ project.project_name, settings.branded_domains.get('*'))
+ domain = domain or 'chromium.org'
+ if domain.count('.') > 1:
+ domain = '.'.join(domain.split('.')[-2:])
+ addr = addr_format % {'domain': domain}
+ return FormatFriendly(commenter_view, addr, reveal_addr)
+
+
+def NormalizeHeader(s):
+ """Make our message-ids robust against mail client spacing and truncation."""
+ words = _StripSubjectPrefixes(s).split() # Split on any runs of whitespace.
+ normalized = ' '.join(words)
+ truncated = normalized[:MAX_HEADER_CHARS_CONSIDERED]
+ return truncated
+
+
+def MakeMessageID(to_addr, subject, from_addr):
+ """Make a unique (but deterministic) email Message-Id: value."""
+ normalized_subject = NormalizeHeader(subject)
+ if isinstance(normalized_subject, six.text_type):
+ normalized_subject = normalized_subject.encode('utf-8')
+ mail_hmac_key = secrets_svc.GetEmailKey()
+ return '<0=%s=%s=%s@%s>' % (
+ hmac.new(mail_hmac_key, to_addr).hexdigest(),
+ hmac.new(mail_hmac_key, normalized_subject).hexdigest(),
+ from_addr.split('@')[0],
+ MailDomain())
+
+
+def GetReferences(to_addr, subject, seq_num, project_from_addr):
+ """Make a References: header to make this message thread properly.
+
+ Args:
+ to_addr: address that email message will be sent to.
+ subject: subject line of email message.
+ seq_num: sequence number of message in thread, e.g., 0, 1, 2, ...,
+ or None if the message is not part of a thread.
+ project_from_addr: address that the message will be sent from.
+
+ Returns:
+ A string Message-ID that does not correspond to any actual email
+ message that was ever sent, but it does serve to unite all the
+ messages that belong togther in a thread.
+ """
+ if seq_num is not None:
+ return MakeMessageID(to_addr, subject, project_from_addr)
+ else:
+ return ''
+
+
+def ValidateReferencesHeader(message_ref, project, from_addr, subject):
+ """Check that the References header is one that we could have sent.
+
+ Args:
+ message_ref: one of the References header values from the inbound email.
+ project: Project PB for the affected project.
+ from_addr: string email address that inbound email was sent from.
+ subject: string base subject line of inbound email.
+
+ Returns:
+ True if it looks like this is a reply to a message that we sent
+ to the same address that replied. Otherwise, False.
+ """
+ sender = '%s@%s' % (project.project_name, MailDomain())
+ expected_ref = MakeMessageID(from_addr, subject, sender)
+
+ # TODO(jrobbins): project option to not check from_addr.
+ # TODO(jrobbins): project inbound auth token.
+ return expected_ref == message_ref
+
+
+PROJECT_EMAIL_RE = re.compile(
+ r'(?P<project>[-a-z0-9]+)'
+ r'(\+(?P<verb>[a-z0-9]+)(\+(?P<label>[a-z0-9-]+))?)?'
+ r'@(?P<domain>[-a-z0-9.]+)')
+
+ISSUE_CHANGE_SUBJECT_RE = re.compile(
+ r'Issue (?P<local_id>[0-9]+) in '
+ r'(?P<project>[-a-z0-9]+): '
+ r'(?P<summary>.+)')
+
+ISSUE_CHANGE_COMPACT_SUBJECT_RE = re.compile(
+ r'(?P<project>[-a-z0-9]+):'
+ r'(?P<local_id>[0-9]+): '
+ r'(?P<summary>.+)')
+
+
+def IdentifyIssue(project_name, subject):
+ """Parse the artifact id from a reply and verify it is a valid issue.
+
+ Args:
+ project_name: string the project to search for the issue in.
+ subject: string email subject line received, it must match the one
+ sent. Leading prefixes like "Re:" should already have been stripped.
+
+ Returns:
+ An int local_id for the id of the issue. None if no id is found or the id
+ is not valid.
+ """
+
+ issue_project_name, local_id_str = _MatchSubject(subject)
+
+ if project_name != issue_project_name:
+ # Something is wrong with the project name.
+ return None
+
+ logging.info('project_name = %r', project_name)
+ logging.info('local_id_str = %r', local_id_str)
+
+ try:
+ local_id = int(local_id_str)
+ except (ValueError, TypeError):
+ local_id = None
+
+ return local_id
+
+
+def IdentifyProjectVerbAndLabel(project_addr):
+ # Ignore any inbound email sent to a "no_reply@" address.
+ if project_addr.startswith('no_reply@'):
+ return None, None, None
+
+ project_name = None
+ verb = None
+ label = None
+ m = PROJECT_EMAIL_RE.match(project_addr.lower())
+ if m:
+ project_name = m.group('project')
+ verb = m.group('verb')
+ label = m.group('label')
+
+ return project_name, verb, label
+
+
+def _MatchSubject(subject):
+ """Parse the project, artifact type, and artifact id from a subject line."""
+ m = (ISSUE_CHANGE_SUBJECT_RE.match(subject) or
+ ISSUE_CHANGE_COMPACT_SUBJECT_RE.match(subject))
+ if m:
+ return m.group('project'), m.group('local_id')
+
+ return None, None
+
+
+# TODO(jrobbins): For now, we strip out lines that look like quoted
+# text and then will give the user the option to see the whole email.
+# For 2.0 of this feature, we should change the Comment PB to have
+# runs of text with different properties so that the UI can present
+# "- Show quoted text -" and expand it in-line.
+
+# TODO(jrobbins): For now, we look for lines that indicate quoted
+# text (e.g., they start with ">"). But, we should also collapse
+# multiple lines that are identical to other lines in previous
+# non-deleted comments on the same issue, regardless of quote markers.
+
+
+# We cut off the message if we see something that looks like a signature and
+# it is near the bottom of the message.
+SIGNATURE_BOUNDARY_RE = re.compile(
+ r'^(([-_=]+ ?)+|'
+ r'cheers|(best |warm |kind )?regards|thx|thanks|thank you|'
+ r'Sent from my i?Phone|Sent from my iPod)'
+ r',? *$', re.I)
+
+MAX_SIGNATURE_LINES = 8
+
+FORWARD_OR_EXPLICIT_SIG_PATS = [
+ r'[^0-9a-z]+(forwarded|original) message[^0-9a-z]+\s*$',
+ r'Updates:\s*$',
+ r'Comment #\d+ on issue \d+ by \S+:',
+ # If we see this anywhere in the message, treat the rest as a signature.
+ r'--\s*$',
+ ]
+FORWARD_OR_EXPLICIT_SIG_PATS_AND_REST_RE = re.compile(
+ r'^(%s)(.|\n)*' % '|'.join(FORWARD_OR_EXPLICIT_SIG_PATS),
+ flags=re.MULTILINE | re.IGNORECASE)
+
+# This handles gmail well, and it's pretty broad without seeming like
+# it would cause false positives.
+QUOTE_PATS = [
+ r'^On .*\s+<\s*\S+?@[-a-z0-9.]+>\s*wrote:\s*$',
+ r'^On .* \S+?@[-a-z0-9.]+\s*wrote:\s*$',
+ r'^\S+?@[-a-z0-9.]+ \(\S+?@[-a-z0-9.]+\)\s*wrote:\s*$',
+ r'\S+?@[-a-z0-9]+.appspotmail.com\s.*wrote:\s*$',
+ r'\S+?@[-a-z0-9]+.appspotmail.com\s+.*a\s+\xc3\xa9crit\s*:\s*$',
+ r'^\d+/\d+/\d+ +<\S+@[-a-z0-9.]+>:?\s*$',
+ r'^>.*$',
+ ]
+QUOTED_BLOCKS_RE = re.compile(
+ r'(^\s*\n)*((%s)\n?)+(^\s*\n)*' % '|'.join(QUOTE_PATS),
+ flags=re.MULTILINE | re.IGNORECASE)
+
+
+def StripQuotedText(description):
+ """Strip all quoted text lines out of the given comment text."""
+ # If the rest of message is forwared text, we're done.
+ description = FORWARD_OR_EXPLICIT_SIG_PATS_AND_REST_RE.sub('', description)
+ # Replace each quoted block of lines and surrounding blank lines with at
+ # most one blank line.
+ description = QUOTED_BLOCKS_RE.sub('\n', description)
+
+ new_lines = description.strip().split('\n')
+ # Make another pass over the last few lines to strip out signatures.
+ sig_zone_start = max(0, len(new_lines) - MAX_SIGNATURE_LINES)
+ for idx in range(sig_zone_start, len(new_lines)):
+ line = new_lines[idx]
+ if SIGNATURE_BOUNDARY_RE.match(line):
+ # We found the likely start of a signature, just keep the lines above it.
+ new_lines = new_lines[:idx]
+ break
+
+ return '\n'.join(new_lines).strip()
diff --git a/framework/exceptions.py b/framework/exceptions.py
new file mode 100644
index 0000000..51c9951
--- /dev/null
+++ b/framework/exceptions.py
@@ -0,0 +1,184 @@
+# Copyright 2017 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Exception classes used throughout monorail.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+
+class ErrorAggregator():
+ """Class for holding errors and raising an exception for many."""
+
+ def __init__(self, exc_type):
+ # type: (type) -> None
+ self.exc_type = exc_type
+ self.error_messages = []
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ # If no exceptions were raised within the context, we check
+ # if any error messages were accumulated that we should raise
+ # an exception for.
+ if exc_type == None:
+ self.RaiseIfErrors()
+ # If there were exceptions raised within the context, we do
+ # nothing to suppress them.
+
+ def AddErrorMessage(self, message, *args, **kwargs):
+ # type: (str, *Any, **Any) -> None
+ """Add a new error message.
+
+ Args:
+ message: An error message, to be formatted using *args and **kwargs.
+ *args: passed in to str.format.
+ **kwargs: passed in to str.format.
+ """
+ self.error_messages.append(message.format(*args, **kwargs))
+
+ def RaiseIfErrors(self):
+ # type: () -> None
+ """If there are errors, raise one exception."""
+ if self.error_messages:
+ raise self.exc_type("\n".join(self.error_messages))
+
+
+class Error(Exception):
+ """Base class for errors from this module."""
+ pass
+
+
+class ActionNotSupported(Error):
+ """The user is trying to do something we do not support."""
+ pass
+
+
+class InputException(Error):
+ """Error in user input processing."""
+ pass
+
+
+class ProjectAlreadyExists(Error):
+ """Tried to create a project that already exists."""
+
+
+class FieldDefAlreadyExists(Error):
+ """Tried to create a custom field that already exists."""
+
+
+class ComponentDefAlreadyExists(Error):
+ """Tried to create a component that already exists."""
+
+
+class NoSuchProjectException(Error):
+ """No project with the specified name exists."""
+ pass
+
+
+class NoSuchTemplateException(Error):
+ """No template with the specified name exists."""
+ pass
+
+
+class NoSuchUserException(Error):
+ """No user with the specified name exists."""
+ pass
+
+
+class NoSuchIssueException(Error):
+ """The requested issue was not found."""
+ pass
+
+
+class NoSuchAttachmentException(Error):
+ """The requested attachment was not found."""
+ pass
+
+
+class NoSuchCommentException(Error):
+ """The requested comment was not found."""
+ pass
+
+
+class NoSuchAmendmentException(Error):
+ """The requested amendment was not found."""
+ pass
+
+
+class NoSuchComponentException(Error):
+ """No component with the specified name exists."""
+ pass
+
+
+class InvalidComponentNameException(Error):
+ """The component name is invalid."""
+ pass
+
+
+class InvalidHotlistException(Error):
+ """The specified hotlist is invalid."""
+ pass
+
+
+class NoSuchFieldDefException(Error):
+ """No field def for specified project exists."""
+ pass
+
+
+class InvalidFieldTypeException(Error):
+ """Expected field type and actual field type do not match."""
+ pass
+
+
+class NoSuchIssueApprovalException(Error):
+ """The requested approval for the issue was not found."""
+ pass
+
+
+class CircularGroupException(Error):
+ """Circular nested group exception."""
+ pass
+
+
+class GroupExistsException(Error):
+ """Group already exists exception."""
+ pass
+
+
+class NoSuchGroupException(Error):
+ """Requested group was not found exception."""
+ pass
+
+
+class InvalidExternalIssueReference(Error):
+ """Improperly formatted external issue reference.
+
+ External issue references must be of the form:
+
+ $tracker_shortname/$tracker_specific_id
+
+ For example, issuetracker.google.com issues:
+
+ b/123456789
+ """
+ pass
+
+
+class PageTokenException(Error):
+ """Incorrect page tokens."""
+ pass
+
+
+class FilterRuleException(Error):
+ """Violates a filter rule that should show error."""
+ pass
+
+
+class OverAttachmentQuota(Error):
+ """Project will exceed quota if the current operation is allowed."""
+ pass
diff --git a/framework/excessiveactivity.py b/framework/excessiveactivity.py
new file mode 100644
index 0000000..3737e9a
--- /dev/null
+++ b/framework/excessiveactivity.py
@@ -0,0 +1,25 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A class to display the an error page for excessive activity.
+
+This page is shown when the user performs a given type of action
+too many times in a 24-hour period or exceeds a lifetime limit.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+from framework import servlet
+
+
+class ExcessiveActivity(servlet.Servlet):
+ """ExcessiveActivity page shows an error message."""
+
+ _PAGE_TEMPLATE = 'framework/excessive-activity-page.ezt'
+
+ def GatherPageData(self, _mr):
+ """Build up a dictionary of data values to use when rendering the page."""
+ return {}
diff --git a/framework/filecontent.py b/framework/filecontent.py
new file mode 100644
index 0000000..15d2940
--- /dev/null
+++ b/framework/filecontent.py
@@ -0,0 +1,204 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Utility routines for dealing with MIME types and decoding text files."""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import itertools
+import logging
+
+from framework import framework_constants
+
+
+_EXTENSION_TO_CTYPE_TABLE = {
+ # These are images/PDFs that we trust the browser to display.
+ 'gif': 'image/gif',
+ 'jpg': 'image/jpeg',
+ 'jpeg': 'image/jpeg',
+ 'png': 'image/png',
+ 'webp': 'image/webp',
+ 'ico': 'image/x-icon',
+ 'svg': 'image/svg+xml',
+ 'pdf': 'application/pdf',
+ 'ogv': 'video/ogg',
+ 'mov': 'video/quicktime',
+ 'mp4': 'video/mp4',
+ 'mpg': 'video/mp4',
+ 'mpeg': 'video/mp4',
+ 'webm': 'video/webm',
+
+ # We do not serve mimetypes that cause the brower to launch a local
+ # app because that is not required for issue tracking and it is a
+ # potential security risk.
+}
+
+
+def GuessContentTypeFromFilename(filename):
+ """Guess a file's content type based on the filename extension.
+
+ Args:
+ filename: String name of a file.
+
+ Returns:
+ MIME type string to use when serving this file. We only use text/plain for
+ text files, appropriate image content-types, or application/octet-stream
+ for virtually all binary files. This limits the richness of the user's
+ experience, e.g., the user cannot open an MS Office application directly
+ by clicking on an attachment, but it is safer.
+ """
+ ext = filename.split('.')[-1] if ('.' in filename) else ''
+ ext = ext.lower()
+ if ext in COMMON_TEXT_FILE_EXTENSIONS:
+ return 'text/plain'
+ return _EXTENSION_TO_CTYPE_TABLE.get(ext.lower(), 'application/octet-stream')
+
+
+# Constants used in detecting if a file has binary content.
+# All line lengths must be below the upper limit, and there must be a spefic
+# ratio below the lower limit.
+_MAX_SOURCE_LINE_LEN_LOWER = 350
+_MAX_SOURCE_LINE_LEN_UPPER = 800
+_SOURCE_LINE_LEN_LOWER_RATIO = 0.9
+
+# Message to display for undecodable commit log or author values.
+UNDECODABLE_LOG_CONTENT = '[Cannot be displayed]'
+
+# How large a repository file is in bytes before we don't try to display it
+SOURCE_FILE_MAX_SIZE = 1000 * 1024
+SOURCE_FILE_MAX_LINES = 50000
+
+# The source code browser will not attempt to display any filename ending
+# with one of these extensions.
+COMMON_BINARY_FILE_EXTENSIONS = {
+ 'gif', 'jpg', 'jpeg', 'psd', 'ico', 'icon', 'xbm', 'xpm', 'xwd', 'pcx',
+ 'bmp', 'png', 'vsd,' 'mpg', 'mpeg', 'wmv', 'wmf', 'avi', 'flv', 'snd',
+ 'mp3', 'wma', 'exe', 'dll', 'bin', 'class', 'o', 'so', 'lib', 'dylib',
+ 'jar', 'ear', 'war', 'par', 'msi', 'tar', 'zip', 'rar', 'cab', 'z', 'gz',
+ 'bz2', 'dmg', 'iso', 'rpm', 'pdf', 'eps', 'tif', 'tiff', 'xls', 'ppt',
+ 'graffie', 'violet', 'webm', 'webp',
+ }
+
+# The source code browser will display file contents as text data for files
+# with the following extensions or exact filenames (assuming they decode
+# correctly).
+COMMON_TEXT_FILE_EXTENSIONS = (
+ set(framework_constants.PRETTIFY_CLASS_MAP.keys()) | {
+ '',
+ 'ada',
+ 'asan',
+ 'asm',
+ 'asp',
+ 'bat',
+ 'cgi',
+ 'csv',
+ 'diff',
+ 'el',
+ 'emacs',
+ 'jsp',
+ 'log',
+ 'markdown',
+ 'md',
+ 'mf',
+ 'patch',
+ 'plist',
+ 'properties',
+ 'r',
+ 'rc',
+ 'txt',
+ 'vim',
+ 'wiki',
+ 'xemacs',
+ 'yacc',
+ })
+COMMON_TEXT_FILENAMES = (
+ set(framework_constants.PRETTIFY_FILENAME_CLASS_MAP.keys()) |
+ {'authors', 'install', 'readme'})
+
+
+def DecodeFileContents(file_contents, path=None):
+ """Try converting file contents to unicode using utf-8 or latin-1.
+
+ This is applicable to untrusted maybe-text from vcs files or inbound emails.
+
+ We try decoding the file as utf-8, then fall back on latin-1. In the former
+ case, we call the file a text file; in the latter case, we guess whether
+ the file is text or binary based on line length.
+
+ If we guess text when the file is binary, the user sees safely encoded
+ gibberish. If the other way around, the user sees a message that we will
+ not display the file.
+
+ TODO(jrobbins): we could try the user-supplied encoding, iff it
+ is one of the encodings that we know that we can handle.
+
+ Args:
+ file_contents: byte string from uploaded file. It could be text in almost
+ any encoding, or binary. We cannot trust the user-supplied encoding
+ in the mime-type property.
+ path: string pathname of file.
+
+ Returns:
+ The tuple (unicode_string, is_binary, is_long):
+ - The unicode version of the string.
+ - is_binary is true if the string could not be decoded as text.
+ - is_long is true if the file has more than SOURCE_FILE_MAX_LINES lines.
+ """
+ # If the filename is one that typically identifies a binary file, then
+ # just treat it as binary without any further analysis.
+ ext = None
+ if path and '.' in path:
+ ext = path.split('.')[-1]
+ if ext.lower() in COMMON_BINARY_FILE_EXTENSIONS:
+ # If the file is binary, we don't care about the length, since we don't
+ # show or diff it.
+ return u'', True, False
+
+ # If the string can be decoded as utf-8, we treat it as textual.
+ try:
+ u_str = file_contents.decode('utf-8', 'strict')
+ is_long = len(u_str.split('\n')) > SOURCE_FILE_MAX_LINES
+ return u_str, False, is_long
+ except UnicodeDecodeError:
+ logging.info('not a utf-8 file: %s bytes', len(file_contents))
+
+ # Fall back on latin-1. This will always succeed, since every byte maps to
+ # something in latin-1, even if that something is gibberish.
+ u_str = file_contents.decode('latin-1', 'strict')
+
+ lines = u_str.split('\n')
+ is_long = len(lines) > SOURCE_FILE_MAX_LINES
+ # Treat decodable files with certain filenames and/or extensions as text
+ # files. This avoids problems with common file types using our text/binary
+ # heuristic rules below.
+ if path:
+ name = path.split('/')[-1]
+ if (name.lower() in COMMON_TEXT_FILENAMES or
+ (ext and ext.lower() in COMMON_TEXT_FILE_EXTENSIONS)):
+ return u_str, False, is_long
+
+ # HEURISTIC: Binary files can qualify as latin-1, so we need to
+ # check further. Any real source code is going to be divided into
+ # reasonably sized lines. All lines must be below an upper character limit,
+ # and most lines must be below a lower limit. This allows some exceptions
+ # to the lower limit, but is more restrictive than just using a single
+ # large character limit.
+ is_binary = False
+ lower_count = 0
+ for line in itertools.islice(lines, SOURCE_FILE_MAX_LINES):
+ size = len(line)
+ if size <= _MAX_SOURCE_LINE_LEN_LOWER:
+ lower_count += 1
+ elif size > _MAX_SOURCE_LINE_LEN_UPPER:
+ is_binary = True
+ break
+
+ ratio = lower_count / float(max(1, len(lines)))
+ if ratio < _SOURCE_LINE_LEN_LOWER_RATIO:
+ is_binary = True
+
+ return u_str, is_binary, is_long
diff --git a/framework/framework_bizobj.py b/framework/framework_bizobj.py
new file mode 100644
index 0000000..bacaec5
--- /dev/null
+++ b/framework/framework_bizobj.py
@@ -0,0 +1,512 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Business objects for Monorail's framework.
+
+These are classes and functions that operate on the objects that
+users care about in Monorail but that are not part of just one specific
+component: e.g., projects, users, and labels.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import functools
+import itertools
+import re
+
+import six
+
+import settings
+from framework import exceptions
+from framework import framework_constants
+from proto import tracker_pb2
+from services import client_config_svc
+
+
+# Pattern to match a valid column header name.
+RE_COLUMN_NAME = r'\w+[\w+-.]*\w+'
+
+# Compiled regexp to match a valid column specification.
+RE_COLUMN_SPEC = re.compile('(%s(\s%s)*)*$' % (RE_COLUMN_NAME, RE_COLUMN_NAME))
+
+
+def WhichUsersShareAProject(cnxn, services, user_effective_ids, other_users):
+ # type: (MonorailConnection, Services, Sequence[int],
+ # Collection[user_pb2.User]) -> Collection[user_pb2.User]
+ """Returns a list of users that share a project with given user_effective_ids.
+
+ Args:
+ cnxn: MonorailConnection to the database.
+ services: Services object for connections to backend services.
+ user_effective_ids: The user's set of effective_ids.
+ other_users: The list of users to be filtered for email visibility.
+
+ Returns:
+ Collection of users that share a project with at least one effective_id.
+ """
+
+ projects_by_user_effective_id = services.project.GetProjectMemberships(
+ cnxn, user_effective_ids)
+ authed_user_projects = set(
+ itertools.chain.from_iterable(projects_by_user_effective_id.values()))
+
+ other_user_ids = [other_user.user_id for other_user in other_users]
+ all_other_user_effective_ids = GetEffectiveIds(cnxn, services, other_user_ids)
+ users_that_share_project = []
+ for other_user in other_users:
+ other_user_effective_ids = all_other_user_effective_ids[other_user.user_id]
+
+ # Do not filter yourself.
+ if any(uid in user_effective_ids for uid in other_user_effective_ids):
+ users_that_share_project.append(other_user)
+ continue
+
+ other_user_proj_by_effective_ids = services.project.GetProjectMemberships(
+ cnxn, other_user_effective_ids)
+ other_user_projects = itertools.chain.from_iterable(
+ other_user_proj_by_effective_ids.values())
+ if any(project in authed_user_projects for project in other_user_projects):
+ users_that_share_project.append(other_user)
+ return users_that_share_project
+
+
+def FilterViewableEmails(cnxn, services, user_auth, other_users):
+ # type: (MonorailConnection, Services, AuthData,
+ # Collection[user_pb2.User]) -> Collection[user_pb2.User]
+ """Returns a list of users with emails visible to `user_auth`.
+
+ Args:
+ cnxn: MonorailConnection to the database.
+ services: Services object for connections to backend services.
+ user_auth: The AuthData of the user viewing the email addresses.
+ other_users: The list of users to be filtered for email visibility.
+
+ Returns:
+ Collection of user that should reveal their emails.
+ """
+ # Case 1: Anon users don't see anything revealed.
+ if user_auth.user_pb is None:
+ return []
+
+ # Case 2: site admins always see unobscured email addresses.
+ if user_auth.user_pb.is_site_admin:
+ return other_users
+
+ # Case 3: Members of any groups in settings.full_emails_perm_groups
+ # can view unobscured email addresses.
+ for group_email in settings.full_emails_perm_groups:
+ if services.usergroup.LookupUserGroupID(
+ cnxn, group_email) in user_auth.effective_ids:
+ return other_users
+
+ # Case 4: Users see unobscured emails as long as they share a common Project.
+ return WhichUsersShareAProject(
+ cnxn, services, user_auth.effective_ids, other_users)
+
+
+def DoUsersShareAProject(cnxn, services, user_effective_ids, other_user_id):
+ # type: (MonorailConnection, Services, Sequence[int], int) -> bool
+ """Determine whether two users share at least one Project.
+
+ The user_effective_ids may include group ids or the other_user_id may be a
+ member of a group that results in transitive Project ownership.
+
+ Args:
+ cnxn: MonorailConnection to the database.
+ services: Services object for connections to backend services.
+ user_effective_ids: The effective ids of the authorized User.
+ other_user_id: The other user's user_id to compare against.
+
+ Returns:
+ True if one or more Projects are shared between the Users.
+ """
+ projects_by_user_effective_id = services.project.GetProjectMemberships(
+ cnxn, user_effective_ids)
+ authed_user_projects = itertools.chain.from_iterable(
+ projects_by_user_effective_id.values())
+
+ # Get effective ids for other user to handle transitive Project membership.
+ other_user_effective_ids = GetEffectiveIds(cnxn, services, other_user_id)
+ projects_by_other_user_effective_ids = services.project.GetProjectMemberships(
+ cnxn, other_user_effective_ids)
+ other_user_projects = itertools.chain.from_iterable(
+ projects_by_other_user_effective_ids.values())
+
+ return any(project in authed_user_projects for project in other_user_projects)
+
+
+# TODO(https://crbug.com/monorail/8192): Remove this method.
+def DeprecatedShouldRevealEmail(user_auth, project, viewed_email):
+ # type: (AuthData, Project, str) -> bool
+ """
+ Deprecated V1 API logic to decide whether to publish a user's email
+ address. Avoid updating this method.
+
+ Args:
+ user_auth: The AuthData of the user viewing the email addresses.
+ project: The Project PB to which the viewed user belongs.
+ viewed_email: The email of the viewed user.
+
+ Returns:
+ True if email addresses should be published to the logged-in user.
+ """
+ # Case 1: Anon users don't see anything revealed.
+ if user_auth.user_pb is None:
+ return False
+
+ # Case 2: site admins always see unobscured email addresses.
+ if user_auth.user_pb.is_site_admin:
+ return True
+
+ # Case 3: Project members see the unobscured email of everyone in a project.
+ if project and UserIsInProject(project, user_auth.effective_ids):
+ return True
+
+ # Case 4: Do not obscure your own email.
+ if viewed_email and user_auth.user_pb.email == viewed_email:
+ return True
+
+ return False
+
+
+def ParseAndObscureAddress(email):
+ # type: str -> str
+ """Break the given email into username and domain, and obscure.
+
+ Args:
+ email: string email address to process
+
+ Returns:
+ A 4-tuple (username, domain, obscured_username, obscured_email).
+ The obscured_username is truncated more aggressively than how Google Groups
+ does it: it truncates at 5 characters or truncates OFF 3 characters,
+ whichever results in a shorter obscured_username.
+ """
+ if '@' in email:
+ username, user_domain = email.split('@', 1)
+ else: # don't fail if User table has unexpected email address format.
+ username, user_domain = email, ''
+
+ base_username = username.split('+')[0]
+ cutoff_point = min(5, max(1, len(base_username) - 3))
+ obscured_username = base_username[:cutoff_point]
+ obscured_email = '%s...@%s' %(obscured_username, user_domain)
+
+ return username, user_domain, obscured_username, obscured_email
+
+
+def CreateUserDisplayNamesAndEmails(cnxn, services, user_auth, users):
+ # type: (MonorailConnection, Services, AuthData,
+ # Collection[user_pb2.User]) ->
+ # Tuple[Mapping[int, str], Mapping[int, str]]
+ """Create the display names and emails of the given users based on the
+ current user.
+
+ Args:
+ cnxn: MonorailConnection to the database.
+ services: Services object for connections to backend services.
+ user_auth: AuthData object that identifies the logged in user.
+ users: Collection of User PB objects.
+
+ Returns:
+ A Tuple containing two Dicts of user_ids to display names and user_ids to
+ emails. If a given User does not have an email, there will be an empty
+ string in both.
+ """
+ # NOTE: Currently only service accounts can have display_names set. For all
+ # other users and service accounts with no display_names specified, we use the
+ # obscured or unobscured emails for both `display_names` and `emails`.
+ # See crbug.com/monorail/8510.
+ display_names = {}
+ emails = {}
+
+ # Do a pass on simple display cases.
+ maybe_revealed_users = []
+ for user in users:
+ if user.user_id == framework_constants.DELETED_USER_ID:
+ display_names[user.user_id] = framework_constants.DELETED_USER_NAME
+ emails[user.user_id] = ''
+ elif not user.email:
+ display_names[user.user_id] = ''
+ emails[user.user_id] = ''
+ elif not user.obscure_email:
+ display_names[user.user_id] = user.email
+ emails[user.user_id] = user.email
+ else:
+ # Default to hiding user email.
+ (_username, _domain, _obs_username,
+ obs_email) = ParseAndObscureAddress(user.email)
+ display_names[user.user_id] = obs_email
+ emails[user.user_id] = obs_email
+ maybe_revealed_users.append(user)
+
+ # Reveal viewable emails.
+ viewable_users = FilterViewableEmails(
+ cnxn, services, user_auth, maybe_revealed_users)
+ for user in viewable_users:
+ display_names[user.user_id] = user.email
+ emails[user.user_id] = user.email
+
+ # Use Client.display_names for service accounts that have one specified.
+ for user in users:
+ if user.email in client_config_svc.GetServiceAccountMap():
+ display_names[user.user_id] = client_config_svc.GetServiceAccountMap()[
+ user.email]
+
+ return display_names, emails
+
+
+def UserOwnsProject(project, effective_ids):
+ """Return True if any of the effective_ids is a project owner."""
+ return not effective_ids.isdisjoint(project.owner_ids or set())
+
+
+def UserIsInProject(project, effective_ids):
+ """Return True if any of the effective_ids is a project member.
+
+ Args:
+ project: Project PB for the current project.
+ effective_ids: set of int user IDs for the current user (including all
+ user groups). This will be an empty set for anonymous users.
+
+ Returns:
+ True if the user has any direct or indirect role in the project. The value
+ will actually be a set(), but it will have an ID in it if the user is in
+ the project, or it will be an empty set which is considered False.
+ """
+ return (UserOwnsProject(project, effective_ids) or
+ not effective_ids.isdisjoint(project.committer_ids or set()) or
+ not effective_ids.isdisjoint(project.contributor_ids or set()))
+
+
+def IsPriviledgedDomainUser(email):
+ """Return True if the user's account is from a priviledged domain."""
+ if email and '@' in email:
+ _, user_domain = email.split('@', 1)
+ return user_domain in settings.priviledged_user_domains
+
+ return False
+
+
+def IsValidColumnSpec(col_spec):
+ # type: str -> bool
+ """Return true if the given column specification is valid."""
+ return re.match(RE_COLUMN_SPEC, col_spec)
+
+
+# String translation table to catch a common typos in label names.
+_CANONICALIZATION_TRANSLATION_TABLE = {
+ ord(delete_u_char): None
+ for delete_u_char in u'!"#$%&\'()*+,/:;<>?@[\\]^`{|}~\t\n\x0b\x0c\r '
+ }
+_CANONICALIZATION_TRANSLATION_TABLE.update({ord(u'='): ord(u'-')})
+
+
+def CanonicalizeLabel(user_input):
+ """Canonicalize a given label or status value.
+
+ When the user enters a string that represents a label or an enum,
+ convert it a canonical form that makes it more likely to match
+ existing values.
+
+ Args:
+ user_input: string that the user typed for a label.
+
+ Returns:
+ Canonical form of that label as a unicode string.
+ """
+ if user_input is None:
+ return user_input
+
+ if not isinstance(user_input, six.text_type):
+ user_input = user_input.decode('utf-8')
+
+ canon_str = user_input.translate(_CANONICALIZATION_TRANSLATION_TABLE)
+ return canon_str
+
+
+def MergeLabels(labels_list, labels_add, labels_remove, config):
+ """Update a list of labels with the given add and remove label lists.
+
+ Args:
+ labels_list: list of current labels.
+ labels_add: labels that the user wants to add.
+ labels_remove: labels that the user wants to remove.
+ config: ProjectIssueConfig with info about exclusive prefixes and
+ enum fields.
+
+ Returns:
+ (merged_labels, update_labels_add, update_labels_remove):
+ A new list of labels with the given labels added and removed, and
+ any exclusive label prefixes taken into account. Then two
+ lists of update strings to explain the changes that were actually
+ made.
+ """
+ old_lower_labels = [lab.lower() for lab in labels_list]
+ labels_add = [lab for lab in labels_add
+ if lab.lower() not in old_lower_labels]
+ labels_remove = [lab for lab in labels_remove
+ if lab.lower() in old_lower_labels]
+ labels_remove_lower = [lab.lower() for lab in labels_remove]
+ exclusive_prefixes = [
+ lab.lower() + '-' for lab in config.exclusive_label_prefixes]
+ for fd in config.field_defs:
+ if (fd.field_type == tracker_pb2.FieldTypes.ENUM_TYPE and
+ not fd.is_multivalued):
+ exclusive_prefixes.append(fd.field_name.lower() + '-')
+
+ # We match prefix strings rather than splitting on dash because
+ # an exclusive-prefix or field name may contain dashes.
+ def MatchPrefix(lab, prefixes):
+ for prefix_dash in prefixes:
+ if lab.lower().startswith(prefix_dash):
+ return prefix_dash
+ return False
+
+ # Dedup any added labels. E.g., ignore attempts to add Priority twice.
+ excl_add = []
+ dedupped_labels_add = []
+ for lab in labels_add:
+ matched_prefix_dash = MatchPrefix(lab, exclusive_prefixes)
+ if matched_prefix_dash:
+ if matched_prefix_dash not in excl_add:
+ excl_add.append(matched_prefix_dash)
+ dedupped_labels_add.append(lab)
+ else:
+ dedupped_labels_add.append(lab)
+
+ # "Old minus exclusive" is the set of old label values minus any
+ # that should be overwritten by newly set exclusive labels.
+ old_minus_excl = []
+ for lab in labels_list:
+ matched_prefix_dash = MatchPrefix(lab, excl_add)
+ if not matched_prefix_dash:
+ old_minus_excl.append(lab)
+
+ merged_labels = [lab for lab in old_minus_excl + dedupped_labels_add
+ if lab.lower() not in labels_remove_lower]
+
+ return merged_labels, dedupped_labels_add, labels_remove
+
+
+# Pattern to match a valid hotlist name.
+RE_HOTLIST_NAME_PATTERN = r"[a-zA-Z][-0-9a-zA-Z\.]*"
+
+# Compiled regexp to match the hotlist name and nothing more before or after.
+RE_HOTLIST_NAME = re.compile(
+ '^%s$' % RE_HOTLIST_NAME_PATTERN, re.VERBOSE)
+
+
+def IsValidHotlistName(s):
+ """Return true if the given string is a valid hotlist name."""
+ return (RE_HOTLIST_NAME.match(s) and
+ len(s) <= framework_constants.MAX_HOTLIST_NAME_LENGTH)
+
+
+USER_PREF_DEFS = {
+ 'code_font': re.compile('(true|false)'),
+ 'render_markdown': re.compile('(true|false)'),
+
+ # The are for dismissible cues. True means the user has dismissed them.
+ 'privacy_click_through': re.compile('(true|false)'),
+ 'corp_mode_click_through': re.compile('(true|false)'),
+ 'code_of_conduct': re.compile('(true|false)'),
+ 'dit_keystrokes': re.compile('(true|false)'),
+ 'italics_mean_derived': re.compile('(true|false)'),
+ 'availability_msgs': re.compile('(true|false)'),
+ 'your_email_bounced': re.compile('(true|false)'),
+ 'search_for_numbers': re.compile('(true|false)'),
+ 'restrict_new_issues': re.compile('(true|false)'),
+ 'public_issue_notice': re.compile('(true|false)'),
+ 'you_are_on_vacation': re.compile('(true|false)'),
+ 'how_to_join_project': re.compile('(true|false)'),
+ 'document_team_duties': re.compile('(true|false)'),
+ 'showing_ids_instead_of_tiles': re.compile('(true|false)'),
+ 'issue_timestamps': re.compile('(true|false)'),
+ 'stale_fulltext': re.compile('(true|false)'),
+ }
+MAX_PREF_VALUE_LENGTH = 80
+
+
+def ValidatePref(name, value):
+ """Return an error message if the server does not support a pref value."""
+ if name not in USER_PREF_DEFS:
+ return 'Unknown pref name: %r' % name
+ if len(value) > MAX_PREF_VALUE_LENGTH:
+ return 'Value for pref name %r is too long' % name
+ if not USER_PREF_DEFS[name].match(value):
+ return 'Invalid pref value %r for %r' % (value, name)
+ return None
+
+
+def IsRestrictNewIssuesUser(cnxn, services, user_id):
+ # type: (MonorailConnection, Services, int) -> bool)
+ """Returns true iff user's new issues should be restricted by default."""
+ user_group_ids = services.usergroup.LookupMemberships(cnxn, user_id)
+ restrict_new_issues_groups_dict = services.user.LookupUserIDs(
+ cnxn, settings.restrict_new_issues_user_groups, autocreate=True)
+ restrict_new_issues_group_ids = set(restrict_new_issues_groups_dict.values())
+ return any(gid in restrict_new_issues_group_ids for gid in user_group_ids)
+
+
+def IsPublicIssueNoticeUser(cnxn, services, user_id):
+ # type: (MonorailConnection, Services, int) -> bool)
+ """Returns true iff user should see a public issue notice by default."""
+ user_group_ids = services.usergroup.LookupMemberships(cnxn, user_id)
+ public_issue_notice_groups_dict = services.user.LookupUserIDs(
+ cnxn, settings.public_issue_notice_user_groups, autocreate=True)
+ public_issue_notice_group_ids = set(public_issue_notice_groups_dict.values())
+ return any(gid in public_issue_notice_group_ids for gid in user_group_ids)
+
+
+def GetEffectiveIds(cnxn, services, user_ids):
+ # type: (MonorailConnection, Services, Collection[int]) ->
+ # Mapping[int, Collection[int]]
+ """
+ Given a set of user IDs, it returns a mapping of user_id to a set of effective
+ IDs that include the user's ID and all of their user groups. This mapping
+ will be contain only the user_id anonymous users.
+ """
+ # Get direct memberships for user_ids.
+ effective_ids_by_user_id = services.usergroup.LookupAllMemberships(
+ cnxn, user_ids)
+ # Add user_id to list of effective_ids.
+ for user_id, effective_ids in effective_ids_by_user_id.items():
+ effective_ids.add(user_id)
+ # Get User objects for user_ids.
+ users_by_id = services.user.GetUsersByIDs(cnxn, user_ids)
+ for user_id, user in users_by_id.items():
+ if user and user.email:
+ effective_ids_by_user_id[user_id].update(
+ _ComputeMembershipsByEmail(cnxn, services, user.email))
+
+ # Add related parent and child ids.
+ related_ids = []
+ if user.linked_parent_id:
+ related_ids.append(user.linked_parent_id)
+ if user.linked_child_ids:
+ related_ids.extend(user.linked_child_ids)
+
+ # Add any related efective_ids.
+ if related_ids:
+ effective_ids_by_user_id[user_id].update(related_ids)
+ effective_ids_by_related_id = services.usergroup.LookupAllMemberships(
+ cnxn, related_ids)
+ related_effective_ids = functools.reduce(
+ set.union, effective_ids_by_related_id.values(), set())
+ effective_ids_by_user_id[user_id].update(related_effective_ids)
+ return effective_ids_by_user_id
+
+
+def _ComputeMembershipsByEmail(cnxn, services, email):
+ # type: (MonorailConnection, Services, str) -> Collection[int]
+ """
+ Given an user email, it returns a list [group_id] of computed user groups.
+ """
+ # Get the user email domain to compute memberships of the user.
+ (_username, user_email_domain, _obs_username,
+ _obs_email) = ParseAndObscureAddress(email)
+ return services.usergroup.LookupComputedMemberships(cnxn, user_email_domain)
diff --git a/framework/framework_constants.py b/framework/framework_constants.py
new file mode 100644
index 0000000..1490135
--- /dev/null
+++ b/framework/framework_constants.py
@@ -0,0 +1,184 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Some constants used throughout Monorail."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import os
+import re
+
+
+# Number of seconds in various periods.
+SECS_PER_MINUTE = 60
+SECS_PER_HOUR = SECS_PER_MINUTE * 60
+SECS_PER_DAY = SECS_PER_HOUR * 24
+SECS_PER_MONTH = SECS_PER_DAY * 30
+SECS_PER_YEAR = SECS_PER_DAY * 365
+
+# When we write to memcache, let the values expire so that we don't
+# get any unexpected super-old values as we make code changes over the
+# years. Also, searches can contain date terms like [opened<today-1]
+# that would become wrong if cached for a long time.
+CACHE_EXPIRATION = 6 * SECS_PER_HOUR
+
+# Fulltext indexing happens asynchronously and we get no notification
+# when the indexing operation has completed. So, when we cache searches
+# that use fulltext terms, the results might be stale. We still do
+# cache them and use the cached values, but we expire them so that the
+# results cannot be stale for a long period of time.
+FULLTEXT_MEMCACHE_EXPIRATION = 3 * SECS_PER_MINUTE
+
+# Size in bytes of the largest form submission that we will accept
+MAX_POST_BODY_SIZE = 10 * 1024 * 1024 # = 10 MB
+
+# Special issue ID to use when an issue is explicitly not specified.
+NO_ISSUE_SPECIFIED = 0
+
+# Special user ID and name to use when no user was specified.
+NO_USER_SPECIFIED = 0
+NO_SESSION_SPECIFIED = 0
+NO_USER_NAME = '----'
+DELETED_USER_NAME = 'a_deleted_user'
+DELETED_USER_ID = 1
+USER_NOT_FOUND_NAME = 'user_not_found'
+
+# Queues for deleting users tasks.
+QUEUE_SEND_WIPEOUT_USER_LISTS = 'wipeoutsendusers'
+QUEUE_FETCH_WIPEOUT_DELETED_USERS = 'wipeoutdeleteusers'
+QUEUE_DELETE_USERS = 'deleteusers'
+
+# We remember the time of each user's last page view, but to reduce the
+# number of database writes, we only update it if it is newer by an hour.
+VISIT_RESOLUTION = 1 * SECS_PER_HOUR
+
+# String to display when some field has no value.
+NO_VALUES = '----'
+
+# If the user enters one or more dashes, that means "no value". This is useful
+# in bulk edit, inbound email, and commit log command where a blank field
+# means "keep what was there" or is ignored.
+NO_VALUE_RE = re.compile(r'^-+$')
+
+# Used to loosely validate column spec. Mainly guards against malicious input.
+COLSPEC_RE = re.compile(r'^[-.\w\s/]*$', re.UNICODE)
+COLSPEC_COL_RE = re.compile(r'[-.\w/]+', re.UNICODE)
+MAX_COL_PARTS = 25
+MAX_COL_LEN = 50
+
+# Used to loosely validate sort spec. Mainly guards against malicious input.
+SORTSPEC_RE = re.compile(r'^[-.\w\s/]*$', re.UNICODE)
+MAX_SORT_PARTS = 6
+
+# For the artifact search box autosizing when the user types a long query.
+MIN_ARTIFACT_SEARCH_FIELD_SIZE = 38
+MAX_ARTIFACT_SEARCH_FIELD_SIZE = 75
+AUTOSIZE_STEP = 3
+
+# Regular expressions used in parsing label and status configuration text
+IDENTIFIER_REGEX = r'[-.\w]+'
+IDENTIFIER_RE = re.compile(IDENTIFIER_REGEX, re.UNICODE)
+# Labels and status values that are prefixed by a pound-sign are not displayed
+# in autocomplete menus.
+IDENTIFIER_DOCSTRING_RE = re.compile(
+ r'^(#?%s)[ \t]*=?[ \t]*(.*)$' % IDENTIFIER_REGEX,
+ re.MULTILINE | re.UNICODE)
+
+# Number of label text fields that we can display on a web form for issues.
+MAX_LABELS = 24
+
+# Default number of comments to display on an artifact detail page at one time.
+# Other comments will be paginated.
+DEFAULT_COMMENTS_PER_PAGE = 100
+
+# Content type to use when serving JSON.
+CONTENT_TYPE_JSON = 'application/json; charset=UTF-8'
+CONTENT_TYPE_JSON_OPTIONS = 'nosniff'
+
+# Maximum comments to index to keep the search index from choking. E.g., if an
+# artifact had 1200 comments, only 0..99 and 701..1200 would be indexed.
+# This mainly affects advocacy issues which are highly redundant anyway.
+INITIAL_COMMENTS_TO_INDEX = 100
+FINAL_COMMENTS_TO_INDEX = 500
+
+# This is the longest string that GAE search will accept in one field.
+# The entire search document is also limited to 1MB, so our limit is 200 * 1024
+# chars so that each can be 4 bytes and the comments leave room for metadata.
+# https://cloud.google.com/appengine/docs/standard/python/search/#documents
+MAX_FTS_FIELD_SIZE = 200 * 1024
+
+# Base path to EZT templates.
+this_dir = os.path.dirname(__file__)
+TEMPLATE_PATH = this_dir[:this_dir.rindex('/')] + '/templates/'
+
+# Defaults for dooming a project.
+DEFAULT_DOOM_REASON = 'No longer needed'
+DEFAULT_DOOM_PERIOD = SECS_PER_DAY * 90
+
+MAX_PROJECT_PEOPLE = 1000
+
+MAX_HOTLIST_NAME_LENGTH = 80
+
+# When logging potentially long debugging strings, only show this many chars.
+LOGGING_MAX_LENGTH = 2000
+
+# Maps languages supported by google-code-prettify
+# to the class name that should be added to code blocks in that language.
+# This list should be kept in sync with the handlers registered
+# in lang-*.js and prettify.js from the prettify project.
+PRETTIFY_CLASS_MAP = {
+ ext: 'lang-' + ext
+ for ext in [
+ # Supported in lang-*.js
+ 'apollo', 'agc', 'aea', 'lisp', 'el', 'cl', 'scm',
+ 'css', 'go', 'hs', 'lua', 'fs', 'ml', 'proto', 'scala', 'sql', 'vb',
+ 'vbs', 'vhdl', 'vhd', 'wiki', 'yaml', 'yml', 'clj',
+ # Supported in prettify.js
+ 'htm', 'html', 'mxml', 'xhtml', 'xml', 'xsl',
+ 'c', 'cc', 'cpp', 'cxx', 'cyc', 'm',
+ 'json', 'cs', 'java', 'bsh', 'csh', 'sh', 'cv', 'py', 'perl', 'pl',
+ 'pm', 'rb', 'js', 'coffee',
+ ]}
+
+# Languages which are not specifically mentioned in prettify.js
+# but which render intelligibly with the default handler.
+PRETTIFY_CLASS_MAP.update(
+ (ext, '') for ext in [
+ 'hpp', 'hxx', 'hh', 'h', 'inl', 'idl', 'swig', 'd',
+ 'php', 'tcl', 'aspx', 'cfc', 'cfm',
+ 'ent', 'mod', 'as',
+ 'y', 'lex', 'awk', 'n', 'pde',
+ ])
+
+# Languages which are not specifically mentioned in prettify.js
+# but which should be rendered using a certain prettify module.
+PRETTIFY_CLASS_MAP.update({
+ 'docbook': 'lang-xml',
+ 'dtd': 'lang-xml',
+ 'duby': 'lang-rb',
+ 'mk': 'lang-sh',
+ 'mak': 'lang-sh',
+ 'make': 'lang-sh',
+ 'mirah': 'lang-rb',
+ 'ss': 'lang-lisp',
+ 'vcproj': 'lang-xml',
+ 'xsd': 'lang-xml',
+ 'xslt': 'lang-xml',
+})
+
+PRETTIFY_FILENAME_CLASS_MAP = {
+ 'makefile': 'lang-sh',
+ 'makefile.in': 'lang-sh',
+ 'doxyfile': 'lang-sh', # Key-value pairs with hash comments
+ '.checkstyle': 'lang-xml',
+ '.classpath': 'lang-xml',
+ '.project': 'lang-xml',
+}
+
+OAUTH_SCOPE = 'https://www.googleapis.com/auth/userinfo.email'
+MONORAIL_SCOPE = 'https://www.googleapis.com/auth/monorail'
+
+FILENAME_RE = re.compile('^[-_.a-zA-Z0-9 #+()]+$')
diff --git a/framework/framework_helpers.py b/framework/framework_helpers.py
new file mode 100644
index 0000000..b7199b1
--- /dev/null
+++ b/framework/framework_helpers.py
@@ -0,0 +1,660 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Helper functions and classes used throughout Monorail."""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import collections
+import logging
+import random
+import string
+import textwrap
+import threading
+import time
+import traceback
+import urllib
+import urlparse
+
+from google.appengine.api import app_identity
+
+import ezt
+import six
+
+import settings
+from framework import framework_bizobj
+from framework import framework_constants
+from framework import template_helpers
+from framework import timestr
+from framework import urls
+from proto import user_pb2
+from services import client_config_svc
+
+# AttachmentUpload holds the information of an incoming uploaded
+# attachment before it gets saved as a gcs file and saved to the DB.
+AttachmentUpload = collections.namedtuple(
+ 'AttachmentUpload', ['filename', 'contents', 'mimetype'])
+# type: (str, str, str) -> None
+
+# For random key generation
+RANDOM_KEY_LENGTH = 128
+RANDOM_KEY_CHARACTERS = string.ascii_letters + string.digits
+
+# params recognized by FormatURL, in the order they will appear in the url
+RECOGNIZED_PARAMS = ['can', 'start', 'num', 'q', 'colspec', 'groupby', 'sort',
+ 'show', 'format', 'me', 'table_title', 'projects',
+ 'hotlist_id']
+
+
+def retry(tries, delay=1, backoff=2):
+ """A retry decorator with exponential backoff.
+
+ Functions are retried when Exceptions occur.
+
+ Args:
+ tries: int Number of times to retry, set to 0 to disable retry.
+ delay: float Initial sleep time in seconds.
+ backoff: float Must be greater than 1, further failures would sleep
+ delay*=backoff seconds.
+ """
+ if backoff <= 1:
+ raise ValueError("backoff must be greater than 1")
+ if tries < 0:
+ raise ValueError("tries must be 0 or greater")
+ if delay <= 0:
+ raise ValueError("delay must be greater than 0")
+
+ def decorator(func):
+ def wrapper(*args, **kwargs):
+ _tries, _delay = tries, delay
+ _tries += 1 # Ensure we call func at least once.
+ while _tries > 0:
+ try:
+ ret = func(*args, **kwargs)
+ return ret
+ except Exception:
+ _tries -= 1
+ if _tries == 0:
+ logging.error('Exceeded maximum number of retries for %s.',
+ func.__name__)
+ raise
+ trace_str = traceback.format_exc()
+ logging.warning('Retrying %s due to Exception: %s',
+ func.__name__, trace_str)
+ time.sleep(_delay)
+ _delay *= backoff # Wait longer the next time we fail.
+ return wrapper
+ return decorator
+
+
+class PromiseCallback(object):
+ """Executes the work of a Promise and then dereferences everything."""
+
+ def __init__(self, promise, callback, *args, **kwargs):
+ self.promise = promise
+ self.callback = callback
+ self.args = args
+ self.kwargs = kwargs
+
+ def __call__(self):
+ try:
+ self.promise._WorkOnPromise(self.callback, *self.args, **self.kwargs)
+ finally:
+ # Make sure we no longer hold onto references to anything.
+ self.promise = self.callback = self.args = self.kwargs = None
+
+
+class Promise(object):
+ """Class for promises to deliver a value in the future.
+
+ A thread is started to run callback(args), that thread
+ should return the value that it generates, or raise an expception.
+ p.WaitAndGetValue() will block until a value is available.
+ If an exception was raised, p.WaitAndGetValue() will re-raise the
+ same exception.
+ """
+
+ def __init__(self, callback, *args, **kwargs):
+ """Initialize the promise and immediately call the supplied function.
+
+ Args:
+ callback: Function that takes the args and returns the promise value.
+ *args: Any arguments to the target function.
+ **kwargs: Any keyword args for the target function.
+ """
+
+ self.has_value = False
+ self.value = None
+ self.event = threading.Event()
+ self.exception = None
+
+ promise_callback = PromiseCallback(self, callback, *args, **kwargs)
+
+ # Execute the callback in another thread.
+ promise_thread = threading.Thread(target=promise_callback)
+ promise_thread.start()
+
+ def _WorkOnPromise(self, callback, *args, **kwargs):
+ """Run callback to compute the promised value. Save any exceptions."""
+ try:
+ self.value = callback(*args, **kwargs)
+ except Exception as e:
+ trace_str = traceback.format_exc()
+ logging.info('Exception while working on promise: %s\n', trace_str)
+ # Add the stack trace at this point to the exception. That way, in the
+ # logs, we can see what happened further up in the call stack
+ # than WaitAndGetValue(), which re-raises exceptions.
+ e.pre_promise_trace = trace_str
+ self.exception = e
+ finally:
+ self.has_value = True
+ self.event.set()
+
+ def WaitAndGetValue(self):
+ """Block until my value is available, then return it or raise exception."""
+ self.event.wait()
+ if self.exception:
+ raise self.exception # pylint: disable=raising-bad-type
+ return self.value
+
+
+def FormatAbsoluteURLForDomain(
+ host, project_name, servlet_name, scheme='https', **kwargs):
+ """A variant of FormatAbsoluteURL for when request objects are not available.
+
+ Args:
+ host: string with hostname and optional port, e.g. 'localhost:8080'.
+ project_name: the destination project name, if any.
+ servlet_name: site or project-local url fragement of dest page.
+ scheme: url scheme, e.g., 'http' or 'https'.
+ **kwargs: additional query string parameters may be specified as named
+ arguments to this function.
+
+ Returns:
+ A full url beginning with 'http[s]://'.
+ """
+ path_and_args = FormatURL(None, servlet_name, **kwargs)
+
+ if host:
+ domain_port = host.split(':')
+ domain_port[0] = GetPreferredDomain(domain_port[0])
+ host = ':'.join(domain_port)
+
+ absolute_domain_url = '%s://%s' % (scheme, host)
+ if project_name:
+ return '%s/p/%s%s' % (absolute_domain_url, project_name, path_and_args)
+ return absolute_domain_url + path_and_args
+
+
+def FormatAbsoluteURL(
+ mr, servlet_name, include_project=True, project_name=None,
+ scheme=None, copy_params=True, **kwargs):
+ """Return an absolute URL to a servlet with old and new params.
+
+ Args:
+ mr: info parsed from the current request.
+ servlet_name: site or project-local url fragement of dest page.
+ include_project: if True, include the project home url as part of the
+ destination URL (as long as it is specified either in mr
+ or as the project_name param.)
+ project_name: the destination project name, to override
+ mr.project_name if include_project is True.
+ scheme: either 'http' or 'https', to override mr.request.scheme.
+ copy_params: if True, copy well-known parameters from the existing request.
+ **kwargs: additional query string parameters may be specified as named
+ arguments to this function.
+
+ Returns:
+ A full url beginning with 'http[s]://'. The destination URL will be in
+ the same domain as the current request.
+ """
+ path_and_args = FormatURL(
+ [(name, mr.GetParam(name)) for name in RECOGNIZED_PARAMS]
+ if copy_params else None,
+ servlet_name, **kwargs)
+ scheme = scheme or mr.request.scheme
+
+ project_base = ''
+ if include_project:
+ project_base = '/p/%s' % (project_name or mr.project_name)
+
+ return '%s://%s%s%s' % (scheme, mr.request.host, project_base, path_and_args)
+
+
+def FormatMovedProjectURL(mr, moved_to):
+ """Return a transformation of the given url into the given project.
+
+ Args:
+ mr: common information parsed from the HTTP request.
+ moved_to: A string from a project's moved_to field that matches
+ project_constants.RE_PROJECT_NAME.
+
+ Returns:
+ The url transposed into the given destination project.
+ """
+ project_name = moved_to
+ _, _, path, parameters, query, fragment_identifier = urlparse.urlparse(
+ mr.current_page_url)
+ # Strip off leading "/p/<moved from project>"
+ path = '/' + path.split('/', 3)[3]
+ rest_of_url = urlparse.urlunparse(
+ ('', '', path, parameters, query, fragment_identifier))
+ return '/p/%s%s' % (project_name, rest_of_url)
+
+
+def GetNeededDomain(project_name, current_domain):
+ """Return the branded domain for the project iff not on current_domain."""
+ if (not current_domain or
+ '.appspot.com' in current_domain or
+ ':' in current_domain):
+ return None
+ desired_domain = settings.branded_domains.get(
+ project_name, settings.branded_domains.get('*'))
+ if desired_domain == current_domain:
+ return None
+ return desired_domain
+
+
+def FormatURL(recognized_params, url, **kwargs):
+ # type: (Sequence[Tuple(str, str)], str, **Any) -> str
+ """Return a project relative URL to a servlet with old and new params.
+
+ Args:
+ recognized_params: Default query parameters to include.
+ url: Base URL. Could be a relative path for an EZT Servlet or an
+ absolute path for a separate service (ie: besearch).
+ **kwargs: Additional query parameters to add.
+
+ Returns:
+ A URL with the specified query parameters.
+ """
+ # Standard params not overridden in **kwargs come first, followed by kwargs.
+ # The exception is the 'id' param. If present then the 'id' param always comes
+ # first. See bugs.chromium.org/p/monorail/issues/detail?id=374
+ all_params = []
+ if kwargs.get('id'):
+ all_params.append(('id', kwargs['id']))
+ # TODO(jojwang): update all calls to FormatURL to only include non-None
+ # recognized_params
+ if recognized_params:
+ all_params.extend(
+ param for param in recognized_params if param[0] not in kwargs)
+
+ all_params.extend(
+ # Ignore the 'id' param since we already added it above.
+ sorted([kwarg for kwarg in kwargs.items() if kwarg[0] != 'id']))
+ return _FormatQueryString(url, all_params)
+
+
+def _FormatQueryString(url, params):
+ # type: (str, Sequence[Tuple(str, str)]) -> str
+ """URLencode a list of parameters and attach them to the end of a URL.
+
+ Args:
+ url: URL to append the querystring to.
+ params: List of query parameters to append.
+
+ Returns:
+ A URL with the specified query parameters.
+ """
+ param_string = '&'.join(
+ '%s=%s' % (name, urllib.quote(six.text_type(value).encode('utf-8')))
+ for name, value in params if value is not None)
+ if not param_string:
+ qs_start_char = ''
+ elif '?' in url:
+ qs_start_char = '&'
+ else:
+ qs_start_char = '?'
+ return '%s%s%s' % (url, qs_start_char, param_string)
+
+
+def WordWrapSuperLongLines(s, max_cols=100):
+ """Reformat input that was not word-wrapped by the browser.
+
+ Args:
+ s: the string to be word-wrapped, it may have embedded newlines.
+ max_cols: int maximum line length.
+
+ Returns:
+ Wrapped text string.
+
+ Rather than wrap the whole thing, we only wrap super-long lines and keep
+ all the reasonable lines formated as-is.
+ """
+ lines = [textwrap.fill(line, max_cols) for line in s.splitlines()]
+ wrapped_text = '\n'.join(lines)
+
+ # The split/join logic above can lose one final blank line.
+ if s.endswith('\n') or s.endswith('\r'):
+ wrapped_text += '\n'
+
+ return wrapped_text
+
+
+def StaticCacheHeaders():
+ """Returns HTTP headers for static content, based on the current time."""
+ year_from_now = int(time.time()) + framework_constants.SECS_PER_YEAR
+ headers = [
+ ('Cache-Control',
+ 'max-age=%d, private' % framework_constants.SECS_PER_YEAR),
+ ('Last-Modified', timestr.TimeForHTMLHeader()),
+ ('Expires', timestr.TimeForHTMLHeader(when=year_from_now)),
+ ]
+ logging.info('static headers are %r', headers)
+ return headers
+
+
+def ComputeListDeltas(old_list, new_list):
+ """Given an old and new list, return the items added and removed.
+
+ Args:
+ old_list: old list of values for comparison.
+ new_list: new list of values for comparison.
+
+ Returns:
+ Two lists: one with all the values added (in new_list but was not
+ in old_list), and one with all the values removed (not in new_list
+ but was in old_lit).
+ """
+ if old_list == new_list:
+ return [], [] # A common case: nothing was added or removed.
+
+ added = set(new_list)
+ added.difference_update(old_list)
+ removed = set(old_list)
+ removed.difference_update(new_list)
+ return list(added), list(removed)
+
+
+def GetRoleName(effective_ids, project):
+ """Determines the name of the role a member has for a given project.
+
+ Args:
+ effective_ids: set of user IDs to get the role name for.
+ project: Project PB containing the different the different member lists.
+
+ Returns:
+ The name of the role.
+ """
+ if not effective_ids.isdisjoint(project.owner_ids):
+ return 'Owner'
+ if not effective_ids.isdisjoint(project.committer_ids):
+ return 'Committer'
+ if not effective_ids.isdisjoint(project.contributor_ids):
+ return 'Contributor'
+ return None
+
+
+def GetHotlistRoleName(effective_ids, hotlist):
+ """Determines the name of the role a member has for a given hotlist."""
+ if not effective_ids.isdisjoint(hotlist.owner_ids):
+ return 'Owner'
+ if not effective_ids.isdisjoint(hotlist.editor_ids):
+ return 'Editor'
+ if not effective_ids.isdisjoint(hotlist.follower_ids):
+ return 'Follower'
+ return None
+
+
+class UserSettings(object):
+ """Abstract class providing static methods for user settings forms."""
+
+ @classmethod
+ def GatherUnifiedSettingsPageData(
+ cls, logged_in_user_id, settings_user_view, settings_user,
+ settings_user_prefs):
+ """Gather EZT variables needed for the unified user settings form.
+
+ Args:
+ logged_in_user_id: The user ID of the acting user.
+ settings_user_view: The UserView of the target user.
+ settings_user: The User PB of the target user.
+ settings_user_prefs: UserPrefs object for the view user.
+
+ Returns:
+ A dictionary giving the names and values of all the variables to
+ be exported to EZT to support the unified user settings form template.
+ """
+
+ settings_user_prefs_view = template_helpers.EZTItem(
+ **{name: None for name in framework_bizobj.USER_PREF_DEFS})
+ if settings_user_prefs:
+ for upv in settings_user_prefs.prefs:
+ if upv.value == 'true':
+ setattr(settings_user_prefs_view, upv.name, True)
+ elif upv.value == 'false':
+ setattr(settings_user_prefs_view, upv.name, None)
+
+ logging.info('settings_user_prefs_view is %r' % settings_user_prefs_view)
+ return {
+ 'settings_user': settings_user_view,
+ 'settings_user_pb': template_helpers.PBProxy(settings_user),
+ 'settings_user_is_banned': ezt.boolean(settings_user.banned),
+ 'self': ezt.boolean(logged_in_user_id == settings_user_view.user_id),
+ 'profile_url_fragment': (
+ settings_user_view.profile_url[len('/u/'):]),
+ 'preview_on_hover': ezt.boolean(settings_user.preview_on_hover),
+ 'settings_user_prefs': settings_user_prefs_view,
+ }
+
+ @classmethod
+ def ProcessBanForm(
+ cls, cnxn, user_service, post_data, user_id, user):
+ """Process the posted form data from the ban user form.
+
+ Args:
+ cnxn: connection to the SQL database.
+ user_service: An instance of UserService for saving changes.
+ post_data: The parsed post data from the form submission request.
+ user_id: The user id of the target user.
+ user: The user PB of the target user.
+ """
+ user_service.UpdateUserBan(
+ cnxn, user_id, user, is_banned='banned' in post_data,
+ banned_reason=post_data.get('banned_reason', ''))
+
+ @classmethod
+ def ProcessSettingsForm(
+ cls, we, post_data, user, admin=False):
+ """Process the posted form data from the unified user settings form.
+
+ Args:
+ we: A WorkEnvironment with cnxn and services.
+ post_data: The parsed post data from the form submission request.
+ user: The user PB of the target user.
+ admin: Whether settings reserved for admins are supported.
+ """
+ obscure_email = 'obscure_email' in post_data
+
+ kwargs = {}
+ if admin:
+ kwargs.update(is_site_admin='site_admin' in post_data)
+ kwargs.update(is_banned='banned' in post_data,
+ banned_reason=post_data.get('banned_reason', ''))
+
+ we.UpdateUserSettings(
+ user, notify='notify' in post_data,
+ notify_starred='notify_starred' in post_data,
+ email_compact_subject='email_compact_subject' in post_data,
+ email_view_widget='email_view_widget' in post_data,
+ notify_starred_ping='notify_starred_ping' in post_data,
+ preview_on_hover='preview_on_hover' in post_data,
+ obscure_email=obscure_email,
+ vacation_message=post_data.get('vacation_message', ''),
+ **kwargs)
+
+ user_prefs = []
+ for pref_name in ['restrict_new_issues', 'public_issue_notice']:
+ user_prefs.append(user_pb2.UserPrefValue(
+ name=pref_name,
+ value=('true' if pref_name in post_data else 'false')))
+ we.SetUserPrefs(user.user_id, user_prefs)
+
+
+def GetHostPort(project_name=None):
+ """Get string domain name and port number."""
+
+ app_id = app_identity.get_application_id()
+ if ':' in app_id:
+ domain, app_id = app_id.split(':')
+ else:
+ domain = ''
+
+ if domain.startswith('google'):
+ hostport = '%s.googleplex.com' % app_id
+ else:
+ hostport = '%s.appspot.com' % app_id
+
+ live_site_domain = GetPreferredDomain(hostport)
+ if project_name:
+ project_needed_domain = GetNeededDomain(project_name, live_site_domain)
+ if project_needed_domain:
+ return project_needed_domain
+
+ return live_site_domain
+
+
+def IssueCommentURL(
+ hostport, project, local_id, seq_num=None):
+ """Return a URL pointing directly to the specified comment."""
+ servlet_name = urls.ISSUE_DETAIL
+ detail_url = FormatAbsoluteURLForDomain(
+ hostport, project.project_name, servlet_name, id=local_id)
+ if seq_num:
+ detail_url += '#c%d' % seq_num
+
+ return detail_url
+
+
+def MurmurHash3_x86_32(key, seed=0x0):
+ """Implements the x86/32-bit version of Murmur Hash 3.0.
+
+ MurmurHash3 is written by Austin Appleby, and is placed in the public
+ domain. See https://code.google.com/p/smhasher/ for details.
+
+ This pure python implementation of the x86/32 bit version of MurmurHash3 is
+ written by Fredrik Kihlander and also placed in the public domain.
+ See https://github.com/wc-duck/pymmh3 for details.
+
+ The MurmurHash3 algorithm is chosen for these reasons:
+ * It is fast, even when implemented in pure python.
+ * It is remarkably well distributed, and unlikely to cause collisions.
+ * It is stable and unchanging (any improvements will be in MurmurHash4).
+ * It is well-tested, and easily usable in other contexts (such as bulk
+ data imports).
+
+ Args:
+ key (string): the data that you want hashed
+ seed (int): An offset, treated as essentially part of the key.
+
+ Returns:
+ A 32-bit integer (can be interpreted as either signed or unsigned).
+ """
+ key = bytearray(key.encode('utf-8'))
+
+ def fmix(h):
+ h ^= h >> 16
+ h = (h * 0x85ebca6b) & 0xFFFFFFFF
+ h ^= h >> 13
+ h = (h * 0xc2b2ae35) & 0xFFFFFFFF
+ h ^= h >> 16
+ return h;
+
+ length = len(key)
+ nblocks = int(length // 4)
+
+ h1 = seed;
+
+ c1 = 0xcc9e2d51
+ c2 = 0x1b873593
+
+ # body
+ for block_start in range(0, nblocks * 4, 4):
+ k1 = key[ block_start + 3 ] << 24 | \
+ key[ block_start + 2 ] << 16 | \
+ key[ block_start + 1 ] << 8 | \
+ key[ block_start + 0 ]
+
+ k1 = c1 * k1 & 0xFFFFFFFF
+ k1 = (k1 << 15 | k1 >> 17) & 0xFFFFFFFF
+ k1 = (c2 * k1) & 0xFFFFFFFF;
+
+ h1 ^= k1
+ h1 = ( h1 << 13 | h1 >> 19 ) & 0xFFFFFFFF
+ h1 = ( h1 * 5 + 0xe6546b64 ) & 0xFFFFFFFF
+
+ # tail
+ tail_index = nblocks * 4
+ k1 = 0
+ tail_size = length & 3
+
+ if tail_size >= 3:
+ k1 ^= key[ tail_index + 2 ] << 16
+ if tail_size >= 2:
+ k1 ^= key[ tail_index + 1 ] << 8
+ if tail_size >= 1:
+ k1 ^= key[ tail_index + 0 ]
+
+ if tail_size != 0:
+ k1 = ( k1 * c1 ) & 0xFFFFFFFF
+ k1 = ( k1 << 15 | k1 >> 17 ) & 0xFFFFFFFF
+ k1 = ( k1 * c2 ) & 0xFFFFFFFF
+ h1 ^= k1
+
+ return fmix( h1 ^ length )
+
+
+def MakeRandomKey(length=RANDOM_KEY_LENGTH, chars=RANDOM_KEY_CHARACTERS):
+ """Return a string with lots of random characters."""
+ chars = [random.choice(chars) for _ in range(length)]
+ return ''.join(chars)
+
+
+def IsServiceAccount(email, client_emails=None):
+ """Return a boolean value whether this email is a service account."""
+ if email.endswith('gserviceaccount.com'):
+ return True
+ if client_emails is None:
+ _, client_emails = (
+ client_config_svc.GetClientConfigSvc().GetClientIDEmails())
+ return email in client_emails
+
+
+def GetPreferredDomain(domain):
+ """Get preferred domain to display.
+
+ The preferred domain replaces app_id for default version of monorail-prod
+ and monorail-staging.
+ """
+ return settings.preferred_domains.get(domain, domain)
+
+
+def GetUserAvailability(user, is_group=False):
+ """Return (str, str) that explains why the user might not be available."""
+ if not user.user_id:
+ return None, None
+ if user.banned:
+ return 'Banned', 'banned'
+ if user.vacation_message:
+ return user.vacation_message, 'none'
+ if user.email_bounce_timestamp:
+ return 'Email to this user bounced', 'none'
+ # No availability shown for user groups, or addresses that are
+ # likely to be mailing lists.
+ if is_group or (user.email and '-' in user.email):
+ return None, None
+ if not user.last_visit_timestamp:
+ return 'User never visited', 'never'
+ secs_ago = int(time.time()) - user.last_visit_timestamp
+ last_visit_str = timestr.FormatRelativeDate(
+ user.last_visit_timestamp, days_only=True)
+ if secs_ago > 30 * framework_constants.SECS_PER_DAY:
+ return 'Last visit > 30 days ago', 'none'
+ if secs_ago > 15 * framework_constants.SECS_PER_DAY:
+ return ('Last visit %s' % last_visit_str), 'unsure'
+ return None, None
diff --git a/framework/framework_views.py b/framework/framework_views.py
new file mode 100644
index 0000000..17dead8
--- /dev/null
+++ b/framework/framework_views.py
@@ -0,0 +1,215 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""View classes to make it easy to display framework objects in EZT."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import time
+
+import ezt
+
+from framework import framework_bizobj
+from framework import framework_constants
+from framework import framework_helpers
+from framework import permissions
+from framework import template_helpers
+from framework import timestr
+from proto import user_pb2
+from services import client_config_svc
+import settings
+
+
+_LABEL_DISPLAY_CHARS = 30
+_LABEL_PART_DISPLAY_CHARS = 15
+
+
+class LabelView(object):
+ """Wrapper class that makes it easier to display a label via EZT."""
+
+ def __init__(self, label, config):
+ """Make several values related to this label available as attrs.
+
+ Args:
+ label: artifact label string. E.g., 'Priority-High' or 'Frontend'.
+ config: PB with a well_known_labels list, or None.
+ """
+ self.name = label
+ self.is_restrict = ezt.boolean(permissions.IsRestrictLabel(label))
+
+ self.docstring = ''
+ if config:
+ for wkl in config.well_known_labels:
+ if label.lower() == wkl.label.lower():
+ self.docstring = wkl.label_docstring
+
+ if '-' in label:
+ self.prefix, self.value = label.split('-', 1)
+ else:
+ self.prefix, self.value = '', label
+
+
+class StatusView(object):
+ """Wrapper class that makes it easier to display a status via EZT."""
+
+ def __init__(self, status, config):
+ """Make several values related to this status available as attrs.
+
+ Args:
+ status: artifact status string. E.g., 'New' or 'Accepted'.
+ config: PB with a well_known_statuses list, or None.
+ """
+
+ self.name = status
+
+ self.docstring = ''
+ self.means_open = ezt.boolean(True)
+ if config:
+ for wks in config.well_known_statuses:
+ if status.lower() == wks.status.lower():
+ self.docstring = wks.status_docstring
+ self.means_open = ezt.boolean(wks.means_open)
+
+
+class UserView(object):
+ """Wrapper class to easily display basic user information in a template."""
+
+ def __init__(self, user, is_group=False):
+ self.user = user
+ self.is_group = is_group
+ email = user.email or ''
+ self.user_id = user.user_id
+ self.email = email
+ if user.obscure_email:
+ self.profile_url = '/u/%s/' % user.user_id
+ else:
+ self.profile_url = '/u/%s/' % email
+ self.obscure_email = user.obscure_email
+ self.banned = ''
+
+ (self.username, self.domain, self.obscured_username,
+ obscured_email) = framework_bizobj.ParseAndObscureAddress(email)
+ # No need to obfuscate or reveal client email.
+ # Instead display a human-readable username.
+ if self.user_id == framework_constants.DELETED_USER_ID:
+ self.display_name = framework_constants.DELETED_USER_NAME
+ self.obscure_email = ''
+ self.profile_url = ''
+ elif self.email in client_config_svc.GetServiceAccountMap():
+ self.display_name = client_config_svc.GetServiceAccountMap()[self.email]
+ elif not self.obscure_email:
+ self.display_name = email
+ else:
+ self.display_name = obscured_email
+
+ self.avail_message, self.avail_state = (
+ framework_helpers.GetUserAvailability(user, is_group))
+ self.avail_message_short = template_helpers.FitUnsafeText(
+ self.avail_message, 35)
+
+ def RevealEmail(self):
+ if not self.email:
+ return
+ if self.email not in client_config_svc.GetServiceAccountMap():
+ self.obscure_email = False
+ self.display_name = self.email
+ self.profile_url = '/u/%s/' % self.email
+
+
+def MakeAllUserViews(
+ cnxn, user_service, *list_of_user_id_lists, **kw):
+ """Make a dict {user_id: user_view, ...} for all user IDs given."""
+ distinct_user_ids = set()
+ distinct_user_ids.update(*list_of_user_id_lists)
+ if None in distinct_user_ids:
+ distinct_user_ids.remove(None)
+ group_ids = kw.get('group_ids', [])
+ user_dict = user_service.GetUsersByIDs(cnxn, distinct_user_ids)
+ return {user_id: UserView(user_pb, is_group=user_id in group_ids)
+ for user_id, user_pb in user_dict.items()}
+
+
+def MakeUserView(cnxn, user_service, user_id):
+ """Make a UserView for the given user ID."""
+ user = user_service.GetUser(cnxn, user_id)
+ return UserView(user)
+
+
+def StuffUserView(user_id, email, obscure_email):
+ """Construct a UserView with the given parameters for testing."""
+ user = user_pb2.MakeUser(user_id, email=email, obscure_email=obscure_email)
+ return UserView(user)
+
+
+# TODO(https://crbug.com/monorail/8192): Remove optional project.
+def RevealAllEmailsToMembers(cnxn, services, auth, users_by_id, project=None):
+ # type: (MonorailConnection, Services, AuthData, Collection[user_pb2.User],
+ # Optional[project_pb2.Project] -> None)
+ """Reveal emails based on the authenticated user.
+
+ The actual behavior can be determined by looking into
+ framework_bizobj.ShouldRevealEmail. Look at https://crbug.com/monorail/8030
+ for context.
+ This method should be deleted when endpoints and ezt pages are deprecated.
+
+ Args:
+ cnxn: MonorailConnection to the database.
+ services: Services object for connections to backend services.
+ auth: AuthData object that identifies the logged in user.
+ users_by_id: dictionary of UserView's that might be displayed.
+ project: Optional Project PB for the current project.
+
+ Returns:
+ Nothing, but the UserViews in users_by_id may be modified to
+ publish email address.
+ """
+ if project:
+ for user_view in users_by_id.values():
+ if framework_bizobj.DeprecatedShouldRevealEmail(auth, project,
+ user_view.email):
+ user_view.RevealEmail()
+ else:
+ viewable_users = framework_bizobj.FilterViewableEmails(
+ cnxn, services, auth, users_by_id.values())
+ for user_view in viewable_users:
+ user_view.RevealEmail()
+
+
+def RevealAllEmails(users_by_id):
+ """Allow anyone to see unobscured email addresses of project members.
+
+ The modified view objects should only be used to generate views for other
+ project members.
+
+ Args:
+ users_by_id: dictionary of UserViews that will be displayed.
+
+ Returns:
+ Nothing, but the UserViews in users_by_id may be modified to
+ publish email address.
+ """
+ for user_view in users_by_id.values():
+ user_view.RevealEmail()
+
+
+def GetViewedUserDisplayName(mr):
+ """Get display name of the viewed user given the logged-in user."""
+ # Do not obscure email if current user is a site admin. Do not obscure
+ # email if current user is viewing their own profile. For all other
+ # cases do whatever obscure_email setting for the user is.
+ viewing_self = mr.auth.user_id == mr.viewed_user_auth.user_id
+ email_obscured = (not(mr.auth.user_pb.is_site_admin or viewing_self)
+ and mr.viewed_user_auth.user_view.obscure_email)
+ if email_obscured:
+ (_username, _domain, _obscured_username,
+ obscured_email) = framework_bizobj.ParseAndObscureAddress(
+ mr.viewed_user_auth.email)
+ viewed_user_display_name = obscured_email
+ else:
+ viewed_user_display_name = mr.viewed_user_auth.email
+
+ return viewed_user_display_name
diff --git a/framework/gcs_helpers.py b/framework/gcs_helpers.py
new file mode 100644
index 0000000..a01b565
--- /dev/null
+++ b/framework/gcs_helpers.py
@@ -0,0 +1,207 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Set of helpers for interacting with Google Cloud Storage."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import base64
+import logging
+import os
+import time
+import urllib
+import uuid
+
+from datetime import datetime, timedelta
+
+from google.appengine.api import app_identity
+from google.appengine.api import images
+from google.appengine.api import memcache
+from google.appengine.api import urlfetch
+from third_party import cloudstorage
+from third_party.cloudstorage import errors
+
+from framework import filecontent
+from framework import framework_constants
+from framework import framework_helpers
+
+
+ATTACHMENT_TTL = timedelta(seconds=30)
+
+IS_DEV_APPSERVER = (
+ 'development' in os.environ.get('SERVER_SOFTWARE', '').lower())
+
+RESIZABLE_MIME_TYPES = [
+ 'image/png', 'image/jpg', 'image/jpeg', 'image/gif', 'image/webp',
+ ]
+
+DEFAULT_THUMB_WIDTH = 250
+DEFAULT_THUMB_HEIGHT = 200
+LOGO_THUMB_WIDTH = 110
+LOGO_THUMB_HEIGHT = 30
+MAX_ATTACH_SIZE_TO_COPY = 10 * 1024 * 1024 # 10 MB
+# GCS signatures are valid for 10 minutes by default, but cache them for
+# 5 minutes just to be on the safe side.
+GCS_SIG_TTL = 60 * 5
+
+
+def _Now():
+ return datetime.utcnow()
+
+
+class UnsupportedMimeType(Exception):
+ pass
+
+
+def DeleteObjectFromGCS(object_id):
+ object_path = ('/' + app_identity.get_default_gcs_bucket_name() + object_id)
+ cloudstorage.delete(object_path)
+
+
+def StoreObjectInGCS(
+ content, mime_type, project_id, thumb_width=DEFAULT_THUMB_WIDTH,
+ thumb_height=DEFAULT_THUMB_HEIGHT, filename=None):
+ bucket_name = app_identity.get_default_gcs_bucket_name()
+ guid = uuid.uuid4()
+ object_id = '/%s/attachments/%s' % (project_id, guid)
+ object_path = '/' + bucket_name + object_id
+ options = {}
+ if filename:
+ if not framework_constants.FILENAME_RE.match(filename):
+ logging.info('bad file name: %s' % filename)
+ filename = 'attachment.dat'
+ options['Content-Disposition'] = 'inline; filename="%s"' % filename
+ logging.info('Writing with options %r', options)
+ with cloudstorage.open(object_path, 'w', mime_type, options=options) as f:
+ f.write(content)
+
+ if mime_type in RESIZABLE_MIME_TYPES:
+ # Create and save a thumbnail too.
+ thumb_content = None
+ try:
+ thumb_content = images.resize(content, thumb_width, thumb_height)
+ except images.LargeImageError:
+ # Don't log the whole exception because we don't need to see
+ # this on the Cloud Error Reporting page.
+ logging.info('Got LargeImageError on image with %d bytes', len(content))
+ except Exception, e:
+ # Do not raise exception for incorrectly formed images.
+ # See https://bugs.chromium.org/p/monorail/issues/detail?id=597 for more
+ # detail.
+ logging.exception(e)
+ if thumb_content:
+ thumb_path = '%s-thumbnail' % object_path
+ with cloudstorage.open(thumb_path, 'w', 'image/png') as f:
+ f.write(thumb_content)
+
+ return object_id
+
+
+def CheckMimeTypeResizable(mime_type):
+ if mime_type not in RESIZABLE_MIME_TYPES:
+ raise UnsupportedMimeType(
+ 'Please upload a logo with one of the following mime types:\n%s' %
+ ', '.join(RESIZABLE_MIME_TYPES))
+
+
+def StoreLogoInGCS(file_name, content, project_id):
+ mime_type = filecontent.GuessContentTypeFromFilename(file_name)
+ CheckMimeTypeResizable(mime_type)
+ if '\\' in file_name: # IE insists on giving us the whole path.
+ file_name = file_name[file_name.rindex('\\') + 1:]
+ return StoreObjectInGCS(
+ content, mime_type, project_id, thumb_width=LOGO_THUMB_WIDTH,
+ thumb_height=LOGO_THUMB_HEIGHT)
+
+
+@framework_helpers.retry(3, delay=0.25, backoff=1.25)
+def _FetchSignedURL(url):
+ """Request that devstorage API signs a GCS content URL."""
+ resp = urlfetch.fetch(url, follow_redirects=False)
+ redir = resp.headers["Location"]
+ return redir
+
+
+def SignUrl(bucket, object_id):
+ """Get a signed URL to download a GCS object.
+
+ Args:
+ bucket: string name of the GCS bucket.
+ object_id: string object ID of the file within that bucket.
+
+ Returns:
+ A signed URL, or '/mising-gcs-url' if signing failed.
+ """
+ try:
+ cache_key = 'gcs-object-url-%s' % object_id
+ cached = memcache.get(key=cache_key)
+ if cached is not None:
+ return cached
+
+ if IS_DEV_APPSERVER:
+ attachment_url = '/_ah/gcs/%s%s' % (bucket, object_id)
+ else:
+ result = ('https://www.googleapis.com/storage/v1/b/'
+ '{bucket}/o/{object_id}?access_token={token}&alt=media')
+ scopes = ['https://www.googleapis.com/auth/devstorage.read_only']
+ if object_id[0] == '/':
+ object_id = object_id[1:]
+ url = result.format(
+ bucket=bucket,
+ object_id=urllib.quote_plus(object_id),
+ token=app_identity.get_access_token(scopes)[0])
+ attachment_url = _FetchSignedURL(url)
+
+ if not memcache.set(key=cache_key, value=attachment_url, time=GCS_SIG_TTL):
+ logging.error('Could not cache gcs url %s for %s', attachment_url,
+ object_id)
+
+ return attachment_url
+
+ except Exception as e:
+ logging.exception(e)
+ return '/missing-gcs-url'
+
+
+def MaybeCreateDownload(bucket_name, object_id, filename):
+ """If the obj is not huge, and no download version exists, create it."""
+ src = '/%s%s' % (bucket_name, object_id)
+ dst = '/%s%s-download' % (bucket_name, object_id)
+ cloudstorage.validate_file_path(src)
+ cloudstorage.validate_file_path(dst)
+ logging.info('Maybe create %r from %r', dst, src)
+
+ if IS_DEV_APPSERVER:
+ logging.info('dev environment never makes download copies.')
+ return False
+
+ # If "Download" object already exists, we are done.
+ try:
+ cloudstorage.stat(dst)
+ logging.info('Download version of attachment already exists')
+ return True
+ except errors.NotFoundError:
+ pass
+
+ # If "View" object is huge, give up.
+ src_stat = cloudstorage.stat(src)
+ if src_stat.st_size > MAX_ATTACH_SIZE_TO_COPY:
+ logging.info('Download version of attachment would be too big')
+ return False
+
+ with cloudstorage.open(src, 'r') as infile:
+ content = infile.read()
+ logging.info('opened GCS object and read %r bytes', len(content))
+ content_type = src_stat.content_type
+ options = {
+ 'Content-Disposition': 'attachment; filename="%s"' % filename,
+ }
+ logging.info('Writing with options %r', options)
+ with cloudstorage.open(dst, 'w', content_type, options=options) as outfile:
+ outfile.write(content)
+ logging.info('done writing')
+
+ return True
diff --git a/framework/grid_view_helpers.py b/framework/grid_view_helpers.py
new file mode 100644
index 0000000..44af6b7
--- /dev/null
+++ b/framework/grid_view_helpers.py
@@ -0,0 +1,491 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Classes and functions for displaying grids of project artifacts.
+
+A grid is a two-dimensional display of items where the user can choose
+the X and Y axes.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import ezt
+
+import collections
+import logging
+import settings
+
+from features import features_constants
+from framework import framework_constants
+from framework import sorting
+from framework import table_view_helpers
+from framework import template_helpers
+from framework import urls
+from proto import tracker_pb2
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+from tracker import tracker_helpers
+
+
+# We shorten long attribute values to fit into the table cells.
+_MAX_CELL_DISPLAY_CHARS = 70
+
+
+def SortGridHeadings(col_name, heading_value_list, users_by_id, config,
+ asc_accessors):
+ """Sort the grid headings according to well-known status and label order.
+
+ Args:
+ col_name: String column name that is used on that grid axis.
+ heading_value_list: List of grid row or column heading values.
+ users_by_id: Dict mapping user_ids to UserViews.
+ config: ProjectIssueConfig PB for the current project.
+ asc_accessors: Dict (col_name -> function()) for special columns.
+
+ Returns:
+ The same heading values, but sorted in a logical order.
+ """
+ decorated_list = []
+ fd = tracker_bizobj.FindFieldDef(col_name, config)
+ if fd and fd.field_type != tracker_pb2.FieldTypes.ENUM_TYPE: # Handle fields.
+ for value in heading_value_list:
+ field_value = tracker_bizobj.GetFieldValueWithRawValue(
+ fd.field_type, None, users_by_id, value)
+ decorated_list.append([field_value, field_value])
+ elif col_name == 'status':
+ wk_statuses = [wks.status.lower()
+ for wks in config.well_known_statuses]
+ decorated_list = [(_WKSortingValue(value.lower(), wk_statuses), value)
+ for value in heading_value_list]
+
+ elif col_name in asc_accessors: # Special cols still sort alphabetically.
+ decorated_list = [(value, value)
+ for value in heading_value_list]
+
+ else: # Anything else is assumed to be a label prefix
+ col_name_dash = col_name + '-'
+ wk_labels = []
+ for wkl in config.well_known_labels:
+ lab_lower = wkl.label.lower()
+ if lab_lower.startswith(col_name_dash):
+ wk_labels.append(lab_lower.split('-', 1)[-1])
+ decorated_list = [(_WKSortingValue(value.lower(), wk_labels), value)
+ for value in heading_value_list]
+
+ decorated_list.sort()
+ result = [decorated_tuple[1] for decorated_tuple in decorated_list]
+ logging.info('Headers for %s are: %r', col_name, result)
+ return result
+
+
+def _WKSortingValue(value, well_known_list):
+ """Return a value used to sort headings so that well-known ones are first."""
+ if not value:
+ return sorting.MAX_STRING # Undefined values sort last.
+ try:
+ # well-known values sort by index
+ return well_known_list.index(value)
+ except ValueError:
+ return value # odd-ball values lexicographically after all well-known ones
+
+
+def MakeGridData(
+ artifacts, x_attr, x_headings, y_attr, y_headings, users_by_id,
+ artifact_view_factory, all_label_values, config, related_issues,
+ hotlist_context_dict=None):
+ """Return a list of grid row items for display by EZT.
+
+ Args:
+ artifacts: a list of issues to consider showing.
+ x_attr: lowercase name of the attribute that defines the x-axis.
+ x_headings: list of values for column headings.
+ y_attr: lowercase name of the attribute that defines the y-axis.
+ y_headings: list of values for row headings.
+ users_by_id: dict {user_id: user_view, ...} for referenced users.
+ artifact_view_factory: constructor for grid tiles.
+ all_label_values: pre-parsed dictionary of values from the key-value
+ labels on each issue: {issue_id: {key: [val,...], ...}, ...}
+ config: ProjectIssueConfig PB for the current project.
+ related_issues: dict {issue_id: issue} of pre-fetched related issues.
+ hotlist_context_dict: dict{issue_id: {hotlist_item_field: field_value, ..}}
+
+ Returns:
+ A list of EZTItems, each representing one grid row, and each having
+ a nested list of grid cells.
+
+ Each grid row has a row name, and a list of cells. Each cell has a
+ list of tiles. Each tile represents one artifact. Artifacts are
+ represented once in each cell that they match, so one artifact that
+ has multiple values for a certain attribute can occur in multiple cells.
+ """
+ x_attr = x_attr.lower()
+ y_attr = y_attr.lower()
+
+ # A flat dictionary {(x, y): [cell, ...], ...] for the whole grid.
+ x_y_data = collections.defaultdict(list)
+
+ # Put each issue into the grid cell(s) where it belongs.
+ for art in artifacts:
+ if hotlist_context_dict:
+ hotlist_issues_context = hotlist_context_dict[art.issue_id]
+ else:
+ hotlist_issues_context = None
+ label_value_dict = all_label_values[art.local_id]
+ x_vals = GetArtifactAttr(
+ art, x_attr, users_by_id, label_value_dict, config, related_issues,
+ hotlist_issue_context=hotlist_issues_context)
+ y_vals = GetArtifactAttr(
+ art, y_attr, users_by_id, label_value_dict, config, related_issues,
+ hotlist_issue_context=hotlist_issues_context)
+ tile = artifact_view_factory(art)
+
+ # Put the current issue into each cell where it belongs, which will usually
+ # be exactly 1 cell, but it could be a few.
+ if x_attr != '--' and y_attr != '--': # User specified both axes.
+ for x in x_vals:
+ for y in y_vals:
+ x_y_data[x, y].append(tile)
+ elif y_attr != '--': # User only specified Y axis.
+ for y in y_vals:
+ x_y_data['All', y].append(tile)
+ elif x_attr != '--': # User only specified X axis.
+ for x in x_vals:
+ x_y_data[x, 'All'].append(tile)
+ else: # User specified neither axis.
+ x_y_data['All', 'All'].append(tile)
+
+ # Convert the dictionary to a list-of-lists so that EZT can iterate over it.
+ grid_data = []
+ i = 0
+ for y in y_headings:
+ cells_in_row = []
+ for x in x_headings:
+ tiles = x_y_data[x, y]
+ for tile in tiles:
+ tile.data_idx = i
+ i += 1
+
+ drill_down = ''
+ if x_attr != '--':
+ drill_down = MakeDrillDownSearch(x_attr, x)
+ if y_attr != '--':
+ drill_down += MakeDrillDownSearch(y_attr, y)
+
+ cells_in_row.append(template_helpers.EZTItem(
+ tiles=tiles, count=len(tiles), drill_down=drill_down))
+ grid_data.append(template_helpers.EZTItem(
+ grid_y_heading=y, cells_in_row=cells_in_row))
+
+ return grid_data
+
+
+def MakeDrillDownSearch(attr, value):
+ """Constructs search term for drill-down.
+
+ Args:
+ attr: lowercase name of the attribute to narrow the search on.
+ value: value to narrow the search to.
+
+ Returns:
+ String with user-query term to narrow a search to the given attr value.
+ """
+ if value == framework_constants.NO_VALUES:
+ return '-has:%s ' % attr
+ else:
+ return '%s=%s ' % (attr, value)
+
+
+def MakeLabelValuesDict(art):
+ """Return a dict of label values and a list of one-word labels.
+
+ Args:
+ art: artifact object, e.g., an issue PB.
+
+ Returns:
+ A dict {prefix: [suffix,...], ...} for each key-value label.
+ """
+ label_values = collections.defaultdict(list)
+ for label_name in tracker_bizobj.GetLabels(art):
+ if '-' in label_name:
+ key, value = label_name.split('-', 1)
+ label_values[key.lower()].append(value)
+
+ return label_values
+
+
+def GetArtifactAttr(
+ art, attribute_name, users_by_id, label_attr_values_dict,
+ config, related_issues, hotlist_issue_context=None):
+ """Return the requested attribute values of the given artifact.
+
+ Args:
+ art: a tracked artifact with labels, local_id, summary, stars, and owner.
+ attribute_name: lowercase string name of attribute to get.
+ users_by_id: dictionary of UserViews already created.
+ label_attr_values_dict: dictionary {'key': [value, ...], }.
+ config: ProjectIssueConfig PB for the current project.
+ related_issues: dict {issue_id: issue} of pre-fetched related issues.
+ hotlist_issue_context: dict of {hotlist_issue_field: field_value,..}
+
+ Returns:
+ A list of string attribute values, or [framework_constants.NO_VALUES]
+ if the artifact has no value for that attribute.
+ """
+ if attribute_name == '--':
+ return []
+ if attribute_name == 'id':
+ return [art.local_id]
+ if attribute_name == 'summary':
+ return [art.summary]
+ if attribute_name == 'status':
+ return [tracker_bizobj.GetStatus(art)]
+ if attribute_name == 'stars':
+ return [art.star_count]
+ if attribute_name == 'attachments':
+ return [art.attachment_count]
+ # TODO(jrobbins): support blocking
+ if attribute_name == 'project':
+ return [art.project_name]
+ if attribute_name == 'mergedinto':
+ if art.merged_into and art.merged_into != 0:
+ return [tracker_bizobj.FormatIssueRef((
+ related_issues[art.merged_into].project_name,
+ related_issues[art.merged_into].local_id))]
+ else:
+ return [framework_constants.NO_VALUES]
+ if attribute_name == 'blocked':
+ return ['Yes' if art.blocked_on_iids else 'No']
+ if attribute_name == 'blockedon':
+ if not art.blocked_on_iids:
+ return [framework_constants.NO_VALUES]
+ else:
+ return [tracker_bizobj.FormatIssueRef((
+ related_issues[blocked_on_iid].project_name,
+ related_issues[blocked_on_iid].local_id)) for
+ blocked_on_iid in art.blocked_on_iids]
+ if attribute_name == 'blocking':
+ if not art.blocking_iids:
+ return [framework_constants.NO_VALUES]
+ return [tracker_bizobj.FormatIssueRef((
+ related_issues[blocking_iid].project_name,
+ related_issues[blocking_iid].local_id)) for
+ blocking_iid in art.blocking_iids]
+ if attribute_name == 'adder':
+ if hotlist_issue_context:
+ adder_id = hotlist_issue_context['adder_id']
+ return [users_by_id[adder_id].display_name]
+ else:
+ return [framework_constants.NO_VALUES]
+ if attribute_name == 'added':
+ if hotlist_issue_context:
+ return [hotlist_issue_context['date_added']]
+ else:
+ return [framework_constants.NO_VALUES]
+ if attribute_name == 'reporter':
+ return [users_by_id[art.reporter_id].display_name]
+ if attribute_name == 'owner':
+ owner_id = tracker_bizobj.GetOwnerId(art)
+ if not owner_id:
+ return [framework_constants.NO_VALUES]
+ else:
+ return [users_by_id[owner_id].display_name]
+ if attribute_name == 'cc':
+ cc_ids = tracker_bizobj.GetCcIds(art)
+ if not cc_ids:
+ return [framework_constants.NO_VALUES]
+ else:
+ return [users_by_id[cc_id].display_name for cc_id in cc_ids]
+ if attribute_name == 'component':
+ comp_ids = list(art.component_ids) + list(art.derived_component_ids)
+ if not comp_ids:
+ return [framework_constants.NO_VALUES]
+ else:
+ paths = []
+ for comp_id in comp_ids:
+ cd = tracker_bizobj.FindComponentDefByID(comp_id, config)
+ if cd:
+ paths.append(cd.path)
+ return paths
+
+ # Check to see if it is a field. Process as field only if it is not an enum
+ # type because enum types are stored as key-value labels.
+ fd = tracker_bizobj.FindFieldDef(attribute_name, config)
+ if fd and fd.field_type != tracker_pb2.FieldTypes.ENUM_TYPE:
+ values = []
+ for fv in art.field_values:
+ if fv.field_id == fd.field_id:
+ value = tracker_bizobj.GetFieldValueWithRawValue(
+ fd.field_type, fv, users_by_id, None)
+ values.append(value)
+ return values
+
+ # Since it is not a built-in attribute or a field, it must be a key-value
+ # label.
+ return label_attr_values_dict.get(
+ attribute_name, [framework_constants.NO_VALUES])
+
+
+def AnyArtifactHasNoAttr(
+ artifacts, attr_name, users_by_id, all_label_values, config,
+ related_issues, hotlist_context_dict=None):
+ """Return true if any artifact does not have a value for attr_name."""
+ # TODO(jrobbins): all_label_values needs to be keyed by issue_id to allow
+ # cross-project grid views.
+ for art in artifacts:
+ if hotlist_context_dict:
+ hotlist_issue_context = hotlist_context_dict[art.issue_id]
+ else:
+ hotlist_issue_context = None
+ vals = GetArtifactAttr(
+ art, attr_name.lower(), users_by_id, all_label_values[art.local_id],
+ config, related_issues, hotlist_issue_context=hotlist_issue_context)
+ if framework_constants.NO_VALUES in vals:
+ return True
+
+ return False
+
+
+def GetGridViewData(
+ mr, results, config, users_by_id, starred_iid_set,
+ grid_limited, related_issues, hotlist_context_dict=None):
+ """EZT template values to render a Grid View of issues.
+ Args:
+ mr: commonly used info parsed from the request.
+ results: The Issue PBs that are the search results to be displayed.
+ config: The ProjectConfig PB for the project this view is in.
+ users_by_id: A dictionary {user_id: user_view,...} for all the users
+ involved in results.
+ starred_iid_set: Set of issues that the user has starred.
+ grid_limited: True if the results were limited to fit within the grid.
+ related_issues: dict {issue_id: issue} of pre-fetched related issues.
+ hotlist_context_dict: dict for building a hotlist grid table
+
+ Returns:
+ Dictionary for EZT template rendering of the Grid View.
+ """
+ # We need ordered_columns because EZT loops have no loop-counter available.
+ # And, we use column number in the Javascript to hide/show columns.
+ columns = mr.col_spec.split()
+ ordered_columns = [template_helpers.EZTItem(col_index=i, name=col)
+ for i, col in enumerate(columns)]
+ other_built_in_cols = (features_constants.OTHER_BUILT_IN_COLS if
+ hotlist_context_dict else
+ tracker_constants.OTHER_BUILT_IN_COLS)
+ unshown_columns = table_view_helpers.ComputeUnshownColumns(
+ results, columns, config, other_built_in_cols)
+
+ grid_x_attr = (mr.x or config.default_x_attr or '--').lower()
+ grid_y_attr = (mr.y or config.default_y_attr or '--').lower()
+
+ # Prevent the user from using an axis that we don't support.
+ for bad_axis in tracker_constants.NOT_USED_IN_GRID_AXES:
+ lower_bad_axis = bad_axis.lower()
+ if grid_x_attr == lower_bad_axis:
+ grid_x_attr = '--'
+ if grid_y_attr == lower_bad_axis:
+ grid_y_attr = '--'
+ # Using the same attribute on both X and Y is not useful.
+ if grid_x_attr == grid_y_attr:
+ grid_x_attr = '--'
+
+ all_label_values = {}
+ for art in results:
+ all_label_values[art.local_id] = (
+ MakeLabelValuesDict(art))
+
+ if grid_x_attr == '--':
+ grid_x_headings = ['All']
+ else:
+ grid_x_items = table_view_helpers.ExtractUniqueValues(
+ [grid_x_attr], results, users_by_id, config, related_issues,
+ hotlist_context_dict=hotlist_context_dict)
+ grid_x_headings = grid_x_items[0].filter_values
+ if AnyArtifactHasNoAttr(
+ results, grid_x_attr, users_by_id, all_label_values,
+ config, related_issues, hotlist_context_dict= hotlist_context_dict):
+ grid_x_headings.append(framework_constants.NO_VALUES)
+ grid_x_headings = SortGridHeadings(
+ grid_x_attr, grid_x_headings, users_by_id, config,
+ tracker_helpers.SORTABLE_FIELDS)
+
+ if grid_y_attr == '--':
+ grid_y_headings = ['All']
+ else:
+ grid_y_items = table_view_helpers.ExtractUniqueValues(
+ [grid_y_attr], results, users_by_id, config, related_issues,
+ hotlist_context_dict=hotlist_context_dict)
+ grid_y_headings = grid_y_items[0].filter_values
+ if AnyArtifactHasNoAttr(
+ results, grid_y_attr, users_by_id, all_label_values,
+ config, related_issues, hotlist_context_dict= hotlist_context_dict):
+ grid_y_headings.append(framework_constants.NO_VALUES)
+ grid_y_headings = SortGridHeadings(
+ grid_y_attr, grid_y_headings, users_by_id, config,
+ tracker_helpers.SORTABLE_FIELDS)
+
+ logging.info('grid_x_headings = %s', grid_x_headings)
+ logging.info('grid_y_headings = %s', grid_y_headings)
+ grid_data = PrepareForMakeGridData(
+ results, starred_iid_set, grid_x_attr, grid_x_headings,
+ grid_y_attr, grid_y_headings, users_by_id, all_label_values,
+ config, related_issues, hotlist_context_dict=hotlist_context_dict)
+
+ grid_axis_choice_dict = {}
+ for oc in ordered_columns:
+ grid_axis_choice_dict[oc.name] = True
+ for uc in unshown_columns:
+ grid_axis_choice_dict[uc] = True
+ for bad_axis in tracker_constants.NOT_USED_IN_GRID_AXES:
+ if bad_axis in grid_axis_choice_dict:
+ del grid_axis_choice_dict[bad_axis]
+ grid_axis_choices = list(grid_axis_choice_dict.keys())
+ grid_axis_choices.sort()
+
+ grid_cell_mode = mr.cells
+ if len(results) > settings.max_tiles_in_grid and mr.cells == 'tiles':
+ grid_cell_mode = 'ids'
+
+ grid_view_data = {
+ 'grid_limited': ezt.boolean(grid_limited),
+ 'grid_shown': len(results),
+ 'grid_x_headings': grid_x_headings,
+ 'grid_y_headings': grid_y_headings,
+ 'grid_data': grid_data,
+ 'grid_axis_choices': grid_axis_choices,
+ 'grid_cell_mode': grid_cell_mode,
+ 'results': results, # Really only useful in if-any.
+ }
+ return grid_view_data
+
+
+def PrepareForMakeGridData(
+ allowed_results, starred_iid_set, x_attr,
+ grid_col_values, y_attr, grid_row_values, users_by_id, all_label_values,
+ config, related_issues, hotlist_context_dict=None):
+ """Return all data needed for EZT to render the body of the grid view."""
+
+ def IssueViewFactory(issue):
+ return template_helpers.EZTItem(
+ summary=issue.summary, local_id=issue.local_id, issue_id=issue.issue_id,
+ status=issue.status or issue.derived_status, starred=None, data_idx=0,
+ project_name=issue.project_name)
+
+ grid_data = MakeGridData(
+ allowed_results, x_attr, grid_col_values, y_attr, grid_row_values,
+ users_by_id, IssueViewFactory, all_label_values, config, related_issues,
+ hotlist_context_dict=hotlist_context_dict)
+ issue_dict = {issue.issue_id: issue for issue in allowed_results}
+ for grid_row in grid_data:
+ for grid_cell in grid_row.cells_in_row:
+ for tile in grid_cell.tiles:
+ if tile.issue_id in starred_iid_set:
+ tile.starred = ezt.boolean(True)
+ issue = issue_dict[tile.issue_id]
+ tile.issue_url = tracker_helpers.FormatRelativeIssueURL(
+ issue.project_name, urls.ISSUE_DETAIL, id=tile.local_id)
+ tile.issue_ref = issue.project_name + ':' + str(tile.local_id)
+
+ return grid_data
diff --git a/framework/jsonfeed.py b/framework/jsonfeed.py
new file mode 100644
index 0000000..44e9cea
--- /dev/null
+++ b/framework/jsonfeed.py
@@ -0,0 +1,134 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""This file defines a subclass of Servlet for JSON feeds.
+
+A "feed" is a servlet that is accessed by another part of our system and that
+responds with a JSON value rather than HTML to display in a browser.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import httplib
+import json
+import logging
+
+from google.appengine.api import app_identity
+
+import settings
+
+from framework import framework_constants
+from framework import permissions
+from framework import servlet
+from framework import xsrf
+from search import query2ast
+
+# This causes a JS error for a hacker trying to do a cross-site inclusion.
+XSSI_PREFIX = ")]}'\n"
+
+
+class JsonFeed(servlet.Servlet):
+ """A convenient base class for JSON feeds."""
+
+ # By default, JSON output is compact. Subclasses can set this to
+ # an integer, like 4, for pretty-printed output.
+ JSON_INDENT = None
+
+ # Some JSON handlers can only be accessed from our own app.
+ CHECK_SAME_APP = False
+
+ def HandleRequest(self, _mr):
+ """Override this method to implement handling of the request.
+
+ Args:
+ mr: common information parsed from the HTTP request.
+
+ Returns:
+ A dictionary of json data.
+ """
+ raise servlet.MethodNotSupportedError()
+
+ def _DoRequestHandling(self, request, mr):
+ """Do permission checking, page processing, and response formatting."""
+ try:
+ # TODO(jrobbins): check the XSRF token even for anon users
+ # after the next deployment.
+ if self.CHECK_SECURITY_TOKEN and mr.auth.user_id:
+ # Validate the XSRF token with the specific request path for this
+ # servlet. But, not every XHR request has a distinct token, so just
+ # use 'xhr' for ones that don't.
+ # TODO(jrobbins): make specific tokens for:
+ # user and project stars, issue options, check names.
+ try:
+ logging.info('request in jsonfeed is %r', request)
+ xsrf.ValidateToken(mr.token, mr.auth.user_id, request.path)
+ except xsrf.TokenIncorrect:
+ logging.info('using token path "xhr"')
+ xsrf.ValidateToken(mr.token, mr.auth.user_id, xsrf.XHR_SERVLET_PATH)
+
+ if self.CHECK_SAME_APP and not settings.local_mode:
+ calling_app_id = request.headers.get('X-Appengine-Inbound-Appid')
+ if calling_app_id != app_identity.get_application_id():
+ self.response.status = httplib.FORBIDDEN
+ return
+
+ self._CheckForMovedProject(mr, request)
+ self.AssertBasePermission(mr)
+
+ json_data = self.HandleRequest(mr)
+
+ self._RenderJsonResponse(json_data)
+
+ except query2ast.InvalidQueryError as e:
+ logging.warning('Trapped InvalidQueryError: %s', e)
+ logging.exception(e)
+ msg = e.message if e.message else 'invalid query'
+ self.abort(400, msg)
+ except permissions.PermissionException as e:
+ logging.info('Trapped PermissionException %s', e)
+ self.response.status = httplib.FORBIDDEN
+
+ # pylint: disable=unused-argument
+ # pylint: disable=arguments-differ
+ # Note: unused arguments necessary because they are specified in
+ # registerpages.py as an extra URL validation step even though we
+ # do our own URL parsing in monorailrequest.py
+ def get(self, project_name=None, viewed_username=None, hotlist_id=None):
+ """Collect page-specific and generic info, then render the page.
+
+ Args:
+ project_name: string project name parsed from the URL by webapp2,
+ but we also parse it out in our code.
+ viewed_username: string user email parsed from the URL by webapp2,
+ but we also parse it out in our code.
+ hotlist_id: string hotlist id parsed from the URL by webapp2,
+ but we also parse it out in our code.
+ """
+ self._DoRequestHandling(self.mr.request, self.mr)
+
+ # pylint: disable=unused-argument
+ # pylint: disable=arguments-differ
+ def post(self, project_name=None, viewed_username=None, hotlist_id=None):
+ """Parse the request, check base perms, and call form-specific code."""
+ self._DoRequestHandling(self.mr.request, self.mr)
+
+ def _RenderJsonResponse(self, json_data):
+ """Serialize the data as JSON so that it can be sent to the browser."""
+ json_str = json.dumps(json_data, indent=self.JSON_INDENT)
+ logging.debug(
+ 'Sending JSON response: %r length: %r',
+ json_str[:framework_constants.LOGGING_MAX_LENGTH], len(json_str))
+ self.response.content_type = framework_constants.CONTENT_TYPE_JSON
+ self.response.headers['X-Content-Type-Options'] = (
+ framework_constants.CONTENT_TYPE_JSON_OPTIONS)
+ self.response.write(XSSI_PREFIX)
+ self.response.write(json_str)
+
+
+class InternalTask(JsonFeed):
+ """Internal tasks are JSON feeds that can only be reached by our own code."""
+
+ CHECK_SECURITY_TOKEN = False
diff --git a/framework/monitoring.py b/framework/monitoring.py
new file mode 100644
index 0000000..6ddeeb9
--- /dev/null
+++ b/framework/monitoring.py
@@ -0,0 +1,109 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+"""Monitoring ts_mon custom to monorail."""
+
+from infra_libs import ts_mon
+from framework import framework_helpers
+
+
+def GetCommonFields(status, name, is_robot=False):
+ # type: (int, str, bool=False) -> Dict[str, Union[int, str, bool]]
+ return {
+ 'status': status,
+ 'name': name,
+ 'is_robot': is_robot,
+ }
+
+
+API_REQUESTS_COUNT = ts_mon.CounterMetric(
+ 'monorail/api_requests',
+ 'Number of requests to Monorail APIs',
+ [ts_mon.StringField('client_id'),
+ ts_mon.StringField('client_email'),
+ ts_mon.StringField('version')])
+
+def IncrementAPIRequestsCount(version, client_id, client_email=None):
+ # type: (str, str, Optional[str]) -> None
+ """Increment the request count in ts_mon."""
+ if not client_email:
+ client_email = 'anonymous'
+ elif not framework_helpers.IsServiceAccount(client_email):
+ # Avoid value explosion and protect PII info
+ client_email = 'user@email.com'
+
+ fields = {
+ 'client_id': client_id,
+ 'client_email': client_email,
+ 'version': version
+ }
+ API_REQUESTS_COUNT.increment_by(1, fields)
+
+
+# 90% of durations are in the range 11-1873ms. Growth factor 10^0.06 puts that
+# range into 37 buckets. Max finite bucket value is 12 minutes.
+DURATION_BUCKETER = ts_mon.GeometricBucketer(10**0.06)
+
+# 90% of sizes are in the range 0.17-217014 bytes. Growth factor 10^0.1 puts
+# that range into 54 buckets. Max finite bucket value is 6.3GB.
+SIZE_BUCKETER = ts_mon.GeometricBucketer(10**0.1)
+
+# TODO(https://crbug.com/monorail/9281): Differentiate internal/external calls.
+SERVER_DURATIONS = ts_mon.CumulativeDistributionMetric(
+ 'monorail/server_durations',
+ 'Time elapsed between receiving a request and sending a'
+ ' response (including parsing) in milliseconds.', [
+ ts_mon.IntegerField('status'),
+ ts_mon.StringField('name'),
+ ts_mon.BooleanField('is_robot'),
+ ],
+ bucketer=DURATION_BUCKETER)
+
+
+def AddServerDurations(elapsed_ms, fields):
+ # type: (int, Dict[str, Union[int, bool]]) -> None
+ SERVER_DURATIONS.add(elapsed_ms, fields=fields)
+
+
+SERVER_RESPONSE_STATUS = ts_mon.CounterMetric(
+ 'monorail/server_response_status',
+ 'Number of responses sent by HTTP status code.', [
+ ts_mon.IntegerField('status'),
+ ts_mon.StringField('name'),
+ ts_mon.BooleanField('is_robot'),
+ ])
+
+
+def IncrementServerResponseStatusCount(fields):
+ # type: (Dict[str, Union[int, bool]]) -> None
+ SERVER_RESPONSE_STATUS.increment(fields=fields)
+
+
+SERVER_REQUEST_BYTES = ts_mon.CumulativeDistributionMetric(
+ 'monorail/server_request_bytes',
+ 'Bytes received per http request (body only).', [
+ ts_mon.IntegerField('status'),
+ ts_mon.StringField('name'),
+ ts_mon.BooleanField('is_robot'),
+ ],
+ bucketer=SIZE_BUCKETER)
+
+
+def AddServerRequesteBytes(request_length, fields):
+ # type: (int, Dict[str, Union[int, bool]]) -> None
+ SERVER_REQUEST_BYTES.add(request_length, fields=fields)
+
+
+SERVER_RESPONSE_BYTES = ts_mon.CumulativeDistributionMetric(
+ 'monorail/server_response_bytes',
+ 'Bytes sent per http request (content only).', [
+ ts_mon.IntegerField('status'),
+ ts_mon.StringField('name'),
+ ts_mon.BooleanField('is_robot'),
+ ],
+ bucketer=SIZE_BUCKETER)
+
+
+def AddServerResponseBytes(response_length, fields):
+ SERVER_RESPONSE_BYTES.add(response_length, fields=fields)
diff --git a/framework/monorailcontext.py b/framework/monorailcontext.py
new file mode 100644
index 0000000..76ecff4
--- /dev/null
+++ b/framework/monorailcontext.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Context object to hold utility objects used during request processing.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+
+from framework import authdata
+from framework import permissions
+from framework import profiler
+from framework import sql
+from framework import template_helpers
+
+
+class MonorailContext(object):
+ """Context with objects used in request handling mechanics.
+
+ Attrributes:
+ cnxn: MonorailConnection to the SQL DB.
+ auth: AuthData object that identifies the account making the request.
+ perms: PermissionSet for requesting user, set by LookupLoggedInUserPerms().
+ profiler: Profiler object.
+ warnings: A list of warnings to present to the user.
+ errors: A list of errors to present to the user.
+
+ Unlike MonorailRequest, this object does not parse any part of the request,
+ retrieve any business objects (other than the User PB for the requesting
+ user), or check any permissions.
+ """
+
+ def __init__(
+ self, services, cnxn=None, requester=None, auth=None, perms=None,
+ autocreate=True):
+ """Construct a MonorailContext.
+
+ Args:
+ services: Connection to backends.
+ cnxn: Optional connection to SQL database.
+ requester: String email address of user making the request or None.
+ auth: AuthData object used during testing.
+ perms: PermissionSet used during testing.
+ autocreate: Set to False to require that a row in the User table already
+ exists for this user, otherwise raise NoSuchUserException.
+ """
+ self.cnxn = cnxn or sql.MonorailConnection()
+ self.auth = auth or authdata.AuthData.FromEmail(
+ self.cnxn, requester, services, autocreate=autocreate)
+ self.perms = perms # Usually None until LookupLoggedInUserPerms() called.
+ self.profiler = profiler.Profiler()
+
+ # TODO(jrobbins): make self.errors not be UI-centric.
+ self.warnings = []
+ self.errors = template_helpers.EZTError()
+
+ def LookupLoggedInUserPerms(self, project):
+ """Look up perms for user making a request in project (can be None)."""
+ with self.profiler.Phase('looking up signed in user permissions'):
+ self.perms = permissions.GetPermissions(
+ self.auth.user_pb, self.auth.effective_ids, project)
+
+ def CleanUp(self):
+ """Close the DB cnxn and any other clean up."""
+ if self.cnxn:
+ self.cnxn.Close()
+ self.cnxn = None
+
+ def __repr__(self):
+ """Return a string more useful for debugging."""
+ return '%s(cnxn=%r, auth=%r, perms=%r)' % (
+ self.__class__.__name__, self.cnxn, self.auth, self.perms)
diff --git a/framework/monorailrequest.py b/framework/monorailrequest.py
new file mode 100644
index 0000000..e51aa15
--- /dev/null
+++ b/framework/monorailrequest.py
@@ -0,0 +1,713 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Classes to hold information parsed from a request.
+
+To simplify our servlets and avoid duplication of code, we parse some
+info out of the request as soon as we get it and then pass a MonorailRequest
+object to the servlet-specific request handler methods.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import endpoints
+import logging
+import re
+import urllib
+
+import ezt
+import six
+
+from google.appengine.api import app_identity
+from google.appengine.api import oauth
+
+import webapp2
+
+import settings
+from businesslogic import work_env
+from features import features_constants
+from framework import authdata
+from framework import exceptions
+from framework import framework_bizobj
+from framework import framework_constants
+from framework import framework_views
+from framework import monorailcontext
+from framework import permissions
+from framework import profiler
+from framework import sql
+from framework import template_helpers
+from proto import api_pb2_v1
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+
+_HOSTPORT_RE = re.compile('^[-a-z0-9.]+(:\d+)?$', re.I)
+
+
+# TODO(jrobbins): Stop extending MonorailContext and change whole servlet
+# framework to pass around separate objects for mc and mr.
+class MonorailRequestBase(monorailcontext.MonorailContext):
+ """A base class with common attributes for internal and external requests."""
+
+ def __init__(self, services, requester=None, cnxn=None):
+ super(MonorailRequestBase, self).__init__(
+ services, cnxn=cnxn, requester=requester)
+
+ self.project_name = None
+ self.project = None
+ self.config = None
+
+ @property
+ def project_id(self):
+ return self.project.project_id if self.project else None
+
+
+class MonorailApiRequest(MonorailRequestBase):
+ """A class to hold information parsed from the Endpoints API request."""
+
+ # pylint: disable=attribute-defined-outside-init
+ def __init__(self, request, services, cnxn=None):
+ requester_object = (
+ endpoints.get_current_user() or
+ oauth.get_current_user(
+ framework_constants.OAUTH_SCOPE))
+ requester = requester_object.email().lower()
+ super(MonorailApiRequest, self).__init__(
+ services, requester=requester, cnxn=cnxn)
+ self.me_user_id = self.auth.user_id
+ self.viewed_username = None
+ self.viewed_user_auth = None
+ self.issue = None
+ self.granted_perms = set()
+
+ # query parameters
+ self.params = {
+ 'can': 1,
+ 'start': 0,
+ 'num': tracker_constants.DEFAULT_RESULTS_PER_PAGE,
+ 'q': '',
+ 'sort': '',
+ 'groupby': '',
+ 'projects': [],
+ 'hotlists': []
+ }
+ self.use_cached_searches = True
+ self.mode = None
+
+ if hasattr(request, 'projectId'):
+ self.project_name = request.projectId
+ with work_env.WorkEnv(self, services) as we:
+ self.project = we.GetProjectByName(self.project_name)
+ self.params['projects'].append(self.project_name)
+ self.config = we.GetProjectConfig(self.project_id)
+ if hasattr(request, 'additionalProject'):
+ self.params['projects'].extend(request.additionalProject)
+ self.params['projects'] = list(set(self.params['projects']))
+ self.LookupLoggedInUserPerms(self.project)
+ if hasattr(request, 'projectId'):
+ with work_env.WorkEnv(self, services) as we:
+ if hasattr(request, 'issueId'):
+ self.issue = we.GetIssueByLocalID(
+ self.project_id, request.issueId, use_cache=False)
+ self.granted_perms = tracker_bizobj.GetGrantedPerms(
+ self.issue, self.auth.effective_ids, self.config)
+ if hasattr(request, 'userId'):
+ self.viewed_username = request.userId.lower()
+ if self.viewed_username == 'me':
+ self.viewed_username = requester
+ self.viewed_user_auth = authdata.AuthData.FromEmail(
+ self.cnxn, self.viewed_username, services)
+ elif hasattr(request, 'groupName'):
+ self.viewed_username = request.groupName.lower()
+ try:
+ self.viewed_user_auth = authdata.AuthData.FromEmail(
+ self.cnxn, self.viewed_username, services)
+ except exceptions.NoSuchUserException:
+ self.viewed_user_auth = None
+
+ # Build q.
+ if hasattr(request, 'q') and request.q:
+ self.params['q'] = request.q
+ if hasattr(request, 'publishedMax') and request.publishedMax:
+ self.params['q'] += ' opened<=%d' % request.publishedMax
+ if hasattr(request, 'publishedMin') and request.publishedMin:
+ self.params['q'] += ' opened>=%d' % request.publishedMin
+ if hasattr(request, 'updatedMax') and request.updatedMax:
+ self.params['q'] += ' modified<=%d' % request.updatedMax
+ if hasattr(request, 'updatedMin') and request.updatedMin:
+ self.params['q'] += ' modified>=%d' % request.updatedMin
+ if hasattr(request, 'owner') and request.owner:
+ self.params['q'] += ' owner:%s' % request.owner
+ if hasattr(request, 'status') and request.status:
+ self.params['q'] += ' status:%s' % request.status
+ if hasattr(request, 'label') and request.label:
+ self.params['q'] += ' label:%s' % request.label
+
+ if hasattr(request, 'can') and request.can:
+ if request.can == api_pb2_v1.CannedQuery.all:
+ self.params['can'] = 1
+ elif request.can == api_pb2_v1.CannedQuery.new:
+ self.params['can'] = 6
+ elif request.can == api_pb2_v1.CannedQuery.open:
+ self.params['can'] = 2
+ elif request.can == api_pb2_v1.CannedQuery.owned:
+ self.params['can'] = 3
+ elif request.can == api_pb2_v1.CannedQuery.reported:
+ self.params['can'] = 4
+ elif request.can == api_pb2_v1.CannedQuery.starred:
+ self.params['can'] = 5
+ elif request.can == api_pb2_v1.CannedQuery.to_verify:
+ self.params['can'] = 7
+ else: # Endpoints should have caught this.
+ raise exceptions.InputException(
+ 'Canned query %s is not supported.', request.can)
+ if hasattr(request, 'startIndex') and request.startIndex:
+ self.params['start'] = request.startIndex
+ if hasattr(request, 'maxResults') and request.maxResults:
+ self.params['num'] = request.maxResults
+ if hasattr(request, 'sort') and request.sort:
+ self.params['sort'] = request.sort
+
+ self.query_project_names = self.GetParam('projects')
+ self.group_by_spec = self.GetParam('groupby')
+ self.group_by_spec = ' '.join(ParseColSpec(
+ self.group_by_spec, ignore=tracker_constants.NOT_USED_IN_GRID_AXES))
+ self.sort_spec = self.GetParam('sort')
+ self.sort_spec = ' '.join(ParseColSpec(self.sort_spec))
+ self.query = self.GetParam('q')
+ self.can = self.GetParam('can')
+ self.start = self.GetParam('start')
+ self.num = self.GetParam('num')
+
+ def GetParam(self, query_param_name, default_value=None,
+ _antitamper_re=None):
+ return self.params.get(query_param_name, default_value)
+
+ def GetPositiveIntParam(self, query_param_name, default_value=None):
+ """Returns 0 if the user-provided value is less than 0."""
+ return max(self.GetParam(query_param_name, default_value=default_value),
+ 0)
+
+
+class MonorailRequest(MonorailRequestBase):
+ """A class to hold information parsed from the HTTP request.
+
+ The goal of MonorailRequest is to do almost all URL path and query string
+ procesing in one place, which makes the servlet code simpler.
+
+ Attributes:
+ cnxn: connection to the SQL databases.
+ logged_in_user_id: int user ID of the signed-in user, or None.
+ effective_ids: set of signed-in user ID and all their user group IDs.
+ user_pb: User object for the signed in user.
+ project_name: string name of the current project.
+ project_id: int ID of the current projet.
+ viewed_username: string username of the user whose profile is being viewed.
+ can: int "canned query" number to scope the user's search.
+ num: int number of results to show per pagination page.
+ start: int position in result set to show on this pagination page.
+ etc: there are many more, all read-only.
+ """
+
+ # pylint: disable=attribute-defined-outside-init
+ def __init__(self, services, params=None):
+ """Initialize the MonorailRequest object."""
+ # Note: mr starts off assuming anon until ParseRequest() is called.
+ super(MonorailRequest, self).__init__(services)
+ self.form_overrides = {}
+ if params:
+ self.form_overrides.update(params)
+ self.debug_enabled = False
+ self.use_cached_searches = True
+
+ self.hotlist_id = None
+ self.hotlist = None
+ self.hotlist_name = None
+
+ self.viewed_username = None
+ self.viewed_user_auth = authdata.AuthData()
+
+ def ParseRequest(self, request, services, do_user_lookups=True):
+ """Parse tons of useful info from the given request object.
+
+ Args:
+ request: webapp2 Request object w/ path and query params.
+ services: connections to backend servers including DB.
+ do_user_lookups: Set to False to disable lookups during testing.
+ """
+ with self.profiler.Phase('basic parsing'):
+ self.request = request
+ self.current_page_url = request.url
+ self.current_page_url_encoded = urllib.quote_plus(self.current_page_url)
+
+ # Only accept a hostport from the request that looks valid.
+ if not _HOSTPORT_RE.match(request.host):
+ raise exceptions.InputException(
+ 'request.host looks funny: %r', request.host)
+
+ logging.info('Request: %s', self.current_page_url)
+
+ with self.profiler.Phase('path parsing'):
+ (viewed_user_val, self.project_name,
+ self.hotlist_id, self.hotlist_name) = _ParsePathIdentifiers(
+ self.request.path)
+ self.viewed_username = _GetViewedEmail(
+ viewed_user_val, self.cnxn, services)
+ with self.profiler.Phase('qs parsing'):
+ self._ParseQueryParameters()
+ with self.profiler.Phase('overrides parsing'):
+ self._ParseFormOverrides()
+
+ if not self.project: # It can be already set in unit tests.
+ self._LookupProject(services)
+ if self.project_id and services.config:
+ self.config = services.config.GetProjectConfig(self.cnxn, self.project_id)
+
+ if do_user_lookups:
+ if self.viewed_username:
+ self._LookupViewedUser(services)
+ self._LookupLoggedInUser(services)
+ # TODO(jrobbins): re-implement HandleLurkerViewingSelf()
+
+ if not self.hotlist:
+ self._LookupHotlist(services)
+
+ if self.query is None:
+ self.query = self._CalcDefaultQuery()
+
+ prod_debug_allowed = self.perms.HasPerm(
+ permissions.VIEW_DEBUG, self.auth.user_id, None)
+ self.debug_enabled = (request.params.get('debug') and
+ (settings.local_mode or prod_debug_allowed))
+ # temporary option for perf testing on staging instance.
+ if request.params.get('disable_cache'):
+ if settings.local_mode or 'staging' in request.host:
+ self.use_cached_searches = False
+
+ def _CalcDefaultQuery(self):
+ """When URL has no q= param, return the default for members or ''."""
+ if (self.can == 2 and self.project and self.auth.effective_ids and
+ framework_bizobj.UserIsInProject(self.project, self.auth.effective_ids)
+ and self.config):
+ return self.config.member_default_query
+ else:
+ return ''
+
+ def _ParseQueryParameters(self):
+ """Parse and convert all the query string params used in any servlet."""
+ self.start = self.GetPositiveIntParam('start', default_value=0)
+ self.num = self.GetPositiveIntParam(
+ 'num', default_value=tracker_constants.DEFAULT_RESULTS_PER_PAGE)
+ # Prevent DoS attacks that try to make us serve really huge result pages.
+ self.num = min(self.num, settings.max_artifact_search_results_per_page)
+
+ self.invalidation_timestep = self.GetIntParam(
+ 'invalidation_timestep', default_value=0)
+
+ self.continue_issue_id = self.GetIntParam(
+ 'continue_issue_id', default_value=0)
+ self.redir = self.GetParam('redir')
+
+ # Search scope, a.k.a., canned query ID
+ # TODO(jrobbins): make configurable
+ self.can = self.GetIntParam(
+ 'can', default_value=tracker_constants.OPEN_ISSUES_CAN)
+
+ # Search query
+ self.query = self.GetParam('q')
+
+ # Sorting of search results (needed for result list and flipper)
+ self.sort_spec = self.GetParam(
+ 'sort', default_value='',
+ antitamper_re=framework_constants.SORTSPEC_RE)
+ self.sort_spec = ' '.join(ParseColSpec(self.sort_spec))
+
+ # Note: This is set later in request handling by ComputeColSpec().
+ self.col_spec = None
+
+ # Grouping of search results (needed for result list and flipper)
+ self.group_by_spec = self.GetParam(
+ 'groupby', default_value='',
+ antitamper_re=framework_constants.SORTSPEC_RE)
+ self.group_by_spec = ' '.join(ParseColSpec(
+ self.group_by_spec, ignore=tracker_constants.NOT_USED_IN_GRID_AXES))
+
+ # For issue list and grid mode.
+ self.cursor = self.GetParam('cursor')
+ self.preview = self.GetParam('preview')
+ self.mode = self.GetParam('mode') or 'list'
+ self.x = self.GetParam('x', default_value='')
+ self.y = self.GetParam('y', default_value='')
+ self.cells = self.GetParam('cells', default_value='ids')
+
+ # For the dashboard and issue lists included in the dashboard.
+ self.ajah = self.GetParam('ajah') # AJAH = Asychronous Javascript And HTML
+ self.table_title = self.GetParam('table_title')
+ self.panel_id = self.GetIntParam('panel')
+
+ # For pagination of updates lists
+ self.before = self.GetPositiveIntParam('before')
+ self.after = self.GetPositiveIntParam('after')
+
+ # For cron tasks and backend calls
+ self.lower_bound = self.GetIntParam('lower_bound')
+ self.upper_bound = self.GetIntParam('upper_bound')
+ self.shard_id = self.GetIntParam('shard_id')
+
+ # For specifying which objects to operate on
+ self.local_id = self.GetIntParam('id')
+ self.local_id_list = self.GetIntListParam('ids')
+ self.seq = self.GetIntParam('seq')
+ self.aid = self.GetIntParam('aid')
+ self.signed_aid = self.GetParam('signed_aid')
+ self.specified_user_id = self.GetIntParam('u', default_value=0)
+ self.specified_logged_in_user_id = self.GetIntParam(
+ 'logged_in_user_id', default_value=0)
+ self.specified_me_user_ids = self.GetIntListParam('me_user_ids')
+
+ # TODO(jrobbins): Phase this out after next deployment. If an old
+ # version of the default GAE module sends a request with the old
+ # me_user_id= parameter, then accept it.
+ specified_me_user_id = self.GetIntParam(
+ 'me_user_id', default_value=0)
+ if specified_me_user_id:
+ self.specified_me_user_ids = [specified_me_user_id]
+
+ self.specified_project = self.GetParam('project')
+ self.specified_project_id = self.GetIntParam('project_id')
+ self.query_project_names = self.GetListParam('projects', default_value=[])
+ self.template_name = self.GetParam('template')
+ self.component_path = self.GetParam('component')
+ self.field_name = self.GetParam('field')
+
+ # For image attachments
+ self.inline = bool(self.GetParam('inline'))
+ self.thumb = bool(self.GetParam('thumb'))
+
+ # For JS callbacks
+ self.token = self.GetParam('token')
+ self.starred = bool(self.GetIntParam('starred'))
+
+ # For issue reindexing utility servlet
+ self.auto_submit = self.GetParam('auto_submit')
+
+ # For issue dependency reranking servlet
+ self.parent_id = self.GetIntParam('parent_id')
+ self.target_id = self.GetIntParam('target_id')
+ self.moved_ids = self.GetIntListParam('moved_ids')
+ self.split_above = self.GetBoolParam('split_above')
+
+ # For adding issues to hotlists servlet
+ self.hotlist_ids_remove = self.GetIntListParam('hotlist_ids_remove')
+ self.hotlist_ids_add = self.GetIntListParam('hotlist_ids_add')
+ self.issue_refs = self.GetListParam('issue_refs')
+
+ def _ParseFormOverrides(self):
+ """Support deep linking by allowing the user to set form fields via QS."""
+ allowed_overrides = {
+ 'template_name': self.GetParam('template_name'),
+ 'initial_summary': self.GetParam('summary'),
+ 'initial_description': (self.GetParam('description') or
+ self.GetParam('comment')),
+ 'initial_comment': self.GetParam('comment'),
+ 'initial_status': self.GetParam('status'),
+ 'initial_owner': self.GetParam('owner'),
+ 'initial_cc': self.GetParam('cc'),
+ 'initial_blocked_on': self.GetParam('blockedon'),
+ 'initial_blocking': self.GetParam('blocking'),
+ 'initial_merge_into': self.GetIntParam('mergeinto'),
+ 'initial_components': self.GetParam('components'),
+ 'initial_hotlists': self.GetParam('hotlists'),
+
+ # For the people pages
+ 'initial_add_members': self.GetParam('add_members'),
+ 'initially_expanded_form': ezt.boolean(self.GetParam('expand_form')),
+
+ # For user group admin pages
+ 'initial_name': (self.GetParam('group_name') or
+ self.GetParam('proposed_project_name')),
+ }
+
+ # Only keep the overrides that were actually provided in the query string.
+ self.form_overrides.update(
+ (k, v) for (k, v) in allowed_overrides.items()
+ if v is not None)
+
+ def _LookupViewedUser(self, services):
+ """Get information about the viewed user (if any) from the request."""
+ try:
+ with self.profiler.Phase('get viewed user, if any'):
+ self.viewed_user_auth = authdata.AuthData.FromEmail(
+ self.cnxn, self.viewed_username, services, autocreate=False)
+ except exceptions.NoSuchUserException:
+ logging.info('could not find user %r', self.viewed_username)
+ webapp2.abort(404, 'user not found')
+
+ if not self.viewed_user_auth.user_id:
+ webapp2.abort(404, 'user not found')
+
+ def _LookupProject(self, services):
+ """Get information about the current project (if any) from the request.
+
+ Raises:
+ NoSuchProjectException if there is no project with that name.
+ """
+ logging.info('project_name is %r', self.project_name)
+ if self.project_name:
+ self.project = services.project.GetProjectByName(
+ self.cnxn, self.project_name)
+ if not self.project:
+ raise exceptions.NoSuchProjectException()
+
+ def _LookupHotlist(self, services):
+ """Get information about the current hotlist (if any) from the request."""
+ with self.profiler.Phase('get current hotlist, if any'):
+ if self.hotlist_name:
+ hotlist_id_dict = services.features.LookupHotlistIDs(
+ self.cnxn, [self.hotlist_name], [self.viewed_user_auth.user_id])
+ try:
+ self.hotlist_id = hotlist_id_dict[(
+ self.hotlist_name, self.viewed_user_auth.user_id)]
+ except KeyError:
+ webapp2.abort(404, 'invalid hotlist')
+
+ if not self.hotlist_id:
+ logging.info('no hotlist_id or bad hotlist_name, so no hotlist')
+ else:
+ self.hotlist = services.features.GetHotlistByID(
+ self.cnxn, self.hotlist_id)
+ if not self.hotlist or (
+ self.viewed_user_auth.user_id and
+ self.viewed_user_auth.user_id not in self.hotlist.owner_ids):
+ webapp2.abort(404, 'invalid hotlist')
+
+ def _LookupLoggedInUser(self, services):
+ """Get information about the signed-in user (if any) from the request."""
+ self.auth = authdata.AuthData.FromRequest(self.cnxn, services)
+ self.me_user_id = (self.GetIntParam('me') or
+ self.viewed_user_auth.user_id or self.auth.user_id)
+
+ self.LookupLoggedInUserPerms(self.project)
+
+ def ComputeColSpec(self, config):
+ """Set col_spec based on param, default in the config, or site default."""
+ if self.col_spec is not None:
+ return # Already set.
+ default_col_spec = ''
+ if config:
+ default_col_spec = config.default_col_spec
+
+ col_spec = self.GetParam(
+ 'colspec', default_value=default_col_spec,
+ antitamper_re=framework_constants.COLSPEC_RE)
+ cols_lower = col_spec.lower().split()
+ if self.project and any(
+ hotlist_col in cols_lower for hotlist_col in [
+ 'rank', 'adder', 'added']):
+ # if the the list is a project list and the 'colspec' is a carry-over
+ # from hotlists, set col_spec to None so it will be set to default in
+ # in the next if statement
+ col_spec = None
+
+ if not col_spec:
+ # If col spec is still empty then default to the global col spec.
+ col_spec = tracker_constants.DEFAULT_COL_SPEC
+
+ self.col_spec = ' '.join(ParseColSpec(col_spec,
+ max_parts=framework_constants.MAX_COL_PARTS))
+
+ def PrepareForReentry(self, echo_data):
+ """Expose the results of form processing as if it was a new GET.
+
+ This method is called only when the user submits a form with invalid
+ information which they are being asked to correct it. Updating the MR
+ object allows the normal servlet get() method to populate the form with
+ the entered values and error messages.
+
+ Args:
+ echo_data: dict of {page_data_key: value_to_reoffer, ...} that will
+ override whatever HTML form values are nomally shown to the
+ user when they initially view the form. This allows them to
+ fix user input that was not valid.
+ """
+ self.form_overrides.update(echo_data)
+
+ def GetParam(self, query_param_name, default_value=None,
+ antitamper_re=None):
+ """Get a query parameter from the URL as a utf8 string."""
+ value = self.request.params.get(query_param_name)
+ assert value is None or isinstance(value, six.text_type)
+ using_default = value is None
+ if using_default:
+ value = default_value
+
+ if antitamper_re and not antitamper_re.match(value):
+ if using_default:
+ logging.error('Default value fails antitamper for %s field: %s',
+ query_param_name, value)
+ else:
+ logging.info('User seems to have tampered with %s field: %s',
+ query_param_name, value)
+ raise exceptions.InputException()
+
+ return value
+
+ def GetIntParam(self, query_param_name, default_value=None):
+ """Get an integer param from the URL or default."""
+ value = self.request.params.get(query_param_name)
+ if value is None or value == '':
+ return default_value
+
+ try:
+ return int(value)
+ except (TypeError, ValueError):
+ raise exceptions.InputException(
+ 'Invalid value for integer param: %r' % value)
+
+ def GetPositiveIntParam(self, query_param_name, default_value=None):
+ """Returns 0 if the user-provided value is less than 0."""
+ return max(self.GetIntParam(query_param_name, default_value=default_value),
+ 0)
+
+ def GetListParam(self, query_param_name, default_value=None):
+ """Get a list of strings from the URL or default."""
+ params = self.request.params.get(query_param_name)
+ if params is None:
+ return default_value
+ if not params:
+ return []
+ return params.split(',')
+
+ def GetIntListParam(self, query_param_name, default_value=None):
+ """Get a list of ints from the URL or default."""
+ param_list = self.GetListParam(query_param_name)
+ if param_list is None:
+ return default_value
+
+ try:
+ return [int(p) for p in param_list]
+ except (TypeError, ValueError):
+ raise exceptions.InputException('Invalid value for integer list param')
+
+ def GetBoolParam(self, query_param_name, default_value=None):
+ """Get a boolean param from the URL or default."""
+ value = self.request.params.get(query_param_name)
+ if value is None:
+ return default_value
+
+ if (not value) or (value.lower() == 'false'):
+ return False
+ return True
+
+
+def _ParsePathIdentifiers(path):
+ """Parse out the workspace being requested (if any).
+
+ Args:
+ path: A string beginning with the request's path info.
+
+ Returns:
+ (viewed_user_val, project_name).
+ """
+ viewed_user_val = None
+ project_name = None
+ hotlist_id = None
+ hotlist_name = None
+
+ # Strip off any query params
+ split_path = path.lstrip('/').split('?')[0].split('/')
+ if len(split_path) >= 2:
+ if split_path[0] == 'hotlists':
+ if split_path[1].isdigit():
+ hotlist_id = int(split_path[1])
+ if split_path[0] == 'p':
+ project_name = split_path[1]
+ if split_path[0] == 'u' or split_path[0] == 'users':
+ viewed_user_val = urllib.unquote(split_path[1])
+ if len(split_path) >= 4 and split_path[2] == 'hotlists':
+ try:
+ hotlist_id = int(
+ urllib.unquote(split_path[3].split('.')[0]))
+ except ValueError:
+ raw_last_path = (split_path[3][:-3] if
+ split_path[3].endswith('.do') else split_path[3])
+ last_path = urllib.unquote(raw_last_path)
+ match = framework_bizobj.RE_HOTLIST_NAME.match(
+ last_path)
+ if not match:
+ raise exceptions.InputException(
+ 'Could not parse hotlist id or name')
+ else:
+ hotlist_name = last_path.lower()
+
+ if split_path[0] == 'g':
+ viewed_user_val = urllib.unquote(split_path[1])
+
+ return viewed_user_val, project_name, hotlist_id, hotlist_name
+
+
+def _GetViewedEmail(viewed_user_val, cnxn, services):
+ """Returns the viewed user's email.
+
+ Args:
+ viewed_user_val: Could be either int (user_id) or str (email).
+ cnxn: connection to the SQL database.
+ services: Interface to all persistence storage backends.
+
+ Returns:
+ viewed_email
+ """
+ if not viewed_user_val:
+ return None
+
+ try:
+ viewed_userid = int(viewed_user_val)
+ viewed_email = services.user.LookupUserEmail(cnxn, viewed_userid)
+ if not viewed_email:
+ logging.info('userID %s not found', viewed_userid)
+ webapp2.abort(404, 'user not found')
+ except ValueError:
+ viewed_email = viewed_user_val
+
+ return viewed_email
+
+
+def ParseColSpec(
+ col_spec, max_parts=framework_constants.MAX_SORT_PARTS,
+ ignore=None):
+ """Split a string column spec into a list of column names.
+
+ We dedup col parts because an attacker could try to DoS us or guess
+ zero or one result by measuring the time to process a request that
+ has a very long column list.
+
+ Args:
+ col_spec: a unicode string containing a list of labels.
+ max_parts: optional int maximum number of parts to consider.
+ ignore: optional list of column name parts to ignore.
+
+ Returns:
+ A list of the extracted labels. Non-alphanumeric
+ characters other than the period will be stripped from the text.
+ """
+ cols = framework_constants.COLSPEC_COL_RE.findall(col_spec)
+ result = [] # List of column headers with no duplicates.
+ # Set of column parts that we have processed so far.
+ seen = set()
+ if ignore:
+ seen = set(ignore_col.lower() for ignore_col in ignore)
+ max_parts += len(ignore)
+
+ for col in cols:
+ parts = []
+ for part in col.split('/'):
+ if (part.lower() not in seen and len(seen) < max_parts
+ and len(part) < framework_constants.MAX_COL_LEN):
+ parts.append(part)
+ seen.add(part.lower())
+ if parts:
+ result.append('/'.join(parts))
+ return result
diff --git a/framework/paginate.py b/framework/paginate.py
new file mode 100644
index 0000000..bbe0998
--- /dev/null
+++ b/framework/paginate.py
@@ -0,0 +1,202 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Classes that help display pagination widgets for result sets."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import base64
+import logging
+import hmac
+
+import ezt
+from google.protobuf import message
+
+import settings
+from framework import exceptions
+from framework import framework_helpers
+from services import secrets_svc
+from proto import secrets_pb2
+
+
+def GeneratePageToken(request_contents, start):
+ # type: (secrets_pb2.ListRequestContents, int) -> str
+ """Encrypts a List requests's contents and generates a next page token.
+
+ Args:
+ request_contents: ListRequestContents object that holds data given by the
+ request.
+ start: int start index that should be used for the subsequent request.
+
+ Returns:
+ String next_page_token that is a serialized PageTokenContents object.
+ """
+ digester = hmac.new(secrets_svc.GetPaginationKey())
+ digester.update(request_contents.SerializeToString())
+ token_contents = secrets_pb2.PageTokenContents(
+ start=start,
+ encrypted_list_request_contents=digester.digest())
+ serialized_token = token_contents.SerializeToString()
+ # Page tokens must be URL-safe strings (see aip.dev/158)
+ # and proto string fields must be utf-8 strings while
+ # `SerializeToString()` returns binary bytes contained in a str type.
+ # So we must encode with web-safe base64 format.
+ return base64.b64encode(serialized_token)
+
+
+def ValidateAndParsePageToken(token, request_contents):
+ # type: (str, secrets_pb2.ListRequestContents) -> int
+ """Returns the start index of the page if the token is valid.
+
+ Args:
+ token: String token given in a ListFoo API request.
+ request_contents: ListRequestContents object that holds data given by the
+ request.
+
+ Returns:
+ The start index that should be used when getting the requested page.
+
+ Raises:
+ PageTokenException: if the token is invalid or incorrect for the given
+ request_contents.
+ """
+ token_contents = secrets_pb2.PageTokenContents()
+ try:
+ decoded_serialized_token = base64.b64decode(token)
+ token_contents.ParseFromString(decoded_serialized_token)
+ except (message.DecodeError, TypeError):
+ raise exceptions.PageTokenException('Invalid page token.')
+
+ start = token_contents.start
+ expected_token = GeneratePageToken(request_contents, start)
+ if hmac.compare_digest(token, expected_token):
+ return start
+ raise exceptions.PageTokenException(
+ 'Request parameters must match those from the previous request.')
+
+
+# If extracting items_per_page and start values from a MonorailRequest object,
+# keep in mind that mr.num and mr.GetPositiveIntParam may return different
+# values. mr.num is the result of calling mr.GetPositiveIntParam with a default
+# value.
+class VirtualPagination(object):
+ """Class to calc Prev and Next pagination links based on result counts."""
+
+ def __init__(self, total_count, items_per_page, start, list_page_url=None,
+ count_up=True, start_param_name='start', num_param_name='num',
+ max_num=None, url_params=None, project_name=None):
+ """Given 'num' and 'start' params, determine Prev and Next links.
+
+ Args:
+ total_count: total number of artifacts that satisfy the query.
+ items_per_page: number of items to display on each page, e.g., 25.
+ start: the start index of the pagination page.
+ list_page_url: URL of the web application page that is displaying
+ the list of artifacts. Used to build the Prev and Next URLs.
+ If None, no URLs will be built.
+ count_up: if False, count down from total_count.
+ start_param_name: query string parameter name for the start value
+ of the pagination page.
+ num_param: query string parameter name for the number of items
+ to show on a pagination page.
+ max_num: optional limit on the value of the num param. If not given,
+ settings.max_artifact_search_results_per_page is used.
+ url_params: list of (param_name, param_value) we want to keep
+ in any new urls.
+ project_name: the name of the project we are operating in.
+ """
+ self.total_count = total_count
+ self.prev_url = ''
+ self.reload_url = ''
+ self.next_url = ''
+
+ if max_num is None:
+ max_num = settings.max_artifact_search_results_per_page
+
+ self.num = items_per_page
+ self.num = min(self.num, max_num)
+
+ if count_up:
+ self.start = start or 0
+ self.last = min(self.total_count, self.start + self.num)
+ prev_start = max(0, self.start - self.num)
+ next_start = self.start + self.num
+ else:
+ self.start = start or self.total_count
+ self.last = max(0, self.start - self.num)
+ prev_start = min(self.total_count, self.start + self.num)
+ next_start = self.start - self.num
+
+ if list_page_url:
+ if project_name:
+ list_servlet_rel_url = '/p/%s%s' % (
+ project_name, list_page_url)
+ else:
+ list_servlet_rel_url = list_page_url
+
+ self.reload_url = framework_helpers.FormatURL(
+ url_params, list_servlet_rel_url,
+ **{start_param_name: self.start, num_param_name: self.num})
+
+ if prev_start != self.start:
+ self.prev_url = framework_helpers.FormatURL(
+ url_params, list_servlet_rel_url,
+ **{start_param_name: prev_start, num_param_name: self.num})
+ if ((count_up and next_start < self.total_count) or
+ (not count_up and next_start >= 1)):
+ self.next_url = framework_helpers.FormatURL(
+ url_params, list_servlet_rel_url,
+ **{start_param_name: next_start, num_param_name: self.num})
+
+ self.visible = ezt.boolean(self.last != self.start)
+
+ # Adjust indices to one-based values for display to users.
+ if count_up:
+ self.start += 1
+ else:
+ self.last += 1
+
+ def DebugString(self):
+ """Return a string that is useful in on-page debugging."""
+ return '%s - %s of %s; prev_url:%s; next_url:%s' % (
+ self.start, self.last, self.total_count, self.prev_url, self.next_url)
+
+
+class ArtifactPagination(VirtualPagination):
+ """Class to calc Prev and Next pagination links based on a results list."""
+
+ def __init__(
+ self, results, items_per_page, start, project_name, list_page_url,
+ total_count=None, limit_reached=False, skipped=0, url_params=None):
+ """Given 'num' and 'start' params, determine Prev and Next links.
+
+ Args:
+ results: a list of artifact ids that satisfy the query.
+ items_per_page: number of items to display on each page, e.g., 25.
+ start: the start index of the pagination page.
+ project_name: the name of the project we are operating in.
+ list_page_url: URL of the web application page that is displaying
+ the list of artifacts. Used to build the Prev and Next URLs.
+ total_count: specify total result count rather than the length of results
+ limit_reached: optional boolean that indicates that more results could
+ not be fetched because a limit was reached.
+ skipped: optional int number of items that were skipped and left off the
+ front of results.
+ url_params: list of (param_name, param_value) we want to keep
+ in any new urls.
+ """
+ if total_count is None:
+ total_count = skipped + len(results)
+ super(ArtifactPagination, self).__init__(
+ total_count, items_per_page, start, list_page_url=list_page_url,
+ project_name=project_name, url_params=url_params)
+
+ self.limit_reached = ezt.boolean(limit_reached)
+ # Determine which of those results should be visible on the current page.
+ range_start = self.start - 1 - skipped
+ range_end = range_start + self.num
+ assert 0 <= range_start <= range_end
+ self.visible_results = results[range_start:range_end]
diff --git a/framework/pbproxy_test_pb2.py b/framework/pbproxy_test_pb2.py
new file mode 100644
index 0000000..3c47ae1
--- /dev/null
+++ b/framework/pbproxy_test_pb2.py
@@ -0,0 +1,24 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Message classes for use by template_helpers_test."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+from protorpc import messages
+
+
+class PBProxyExample(messages.Message):
+ """A simple protocol buffer to test template_helpers.PBProxy."""
+ nickname = messages.StringField(1)
+ invited = messages.BooleanField(2, default=False)
+
+
+class PBProxyNested(messages.Message):
+ """A simple protocol buffer to test template_helpers.PBProxy."""
+ nested = messages.MessageField(PBProxyExample, 1)
+ multiple_strings = messages.StringField(2, repeated=True)
+ multiple_pbes = messages.MessageField(PBProxyExample, 3, repeated=True)
diff --git a/framework/permissions.py b/framework/permissions.py
new file mode 100644
index 0000000..eb40dc7
--- /dev/null
+++ b/framework/permissions.py
@@ -0,0 +1,1242 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Classes and functions to implement permission checking.
+
+The main data structure is a simple map from (user role, project status,
+project_access_level) to specific perms.
+
+A perm is simply a string that indicates that the user has a given
+permission. The servlets and templates can test whether the current
+user has permission to see a UI element or perform an action by
+testing for the presence of the corresponding perm in the user's
+permission set.
+
+The user role is one of admin, owner, member, outsider user, or anon.
+The project status is one of the project states defined in project_pb2,
+or a special constant defined below. Likewise for access level.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import bisect
+import collections
+import logging
+import time
+
+import ezt
+
+import settings
+from framework import framework_bizobj
+from framework import framework_constants
+from proto import project_pb2
+from proto import site_pb2
+from proto import tracker_pb2
+from proto import usergroup_pb2
+from tracker import tracker_bizobj
+
+# Constants that define permissions.
+# Note that perms with a leading "_" can never be granted
+# to users who are not site admins.
+VIEW = 'View'
+EDIT_PROJECT = 'EditProject'
+CREATE_PROJECT = 'CreateProject'
+PUBLISH_PROJECT = '_PublishProject' # for making "doomed" projects LIVE
+VIEW_DEBUG = '_ViewDebug' # on-page debugging info
+EDIT_OTHER_USERS = '_EditOtherUsers' # can edit other user's prefs, ban, etc.
+CUSTOMIZE_PROCESS = 'CustomizeProcess' # can use some enterprise features
+VIEW_EXPIRED_PROJECT = '_ViewExpiredProject' # view long-deleted projects
+# View the list of contributors even in hub-and-spoke projects.
+VIEW_CONTRIBUTOR_LIST = 'ViewContributorList'
+
+# Quota
+VIEW_QUOTA = 'ViewQuota'
+EDIT_QUOTA = 'EditQuota'
+
+# Permissions for editing user groups
+CREATE_GROUP = 'CreateGroup'
+EDIT_GROUP = 'EditGroup'
+DELETE_GROUP = 'DeleteGroup'
+VIEW_GROUP = 'ViewGroup'
+
+# Perms for Source tools
+# TODO(jrobbins): Monorail is just issue tracking with no version control, so
+# phase out use of the term "Commit", sometime after Monorail's initial launch.
+COMMIT = 'Commit'
+
+# Perms for issue tracking
+CREATE_ISSUE = 'CreateIssue'
+EDIT_ISSUE = 'EditIssue'
+EDIT_ISSUE_OWNER = 'EditIssueOwner'
+EDIT_ISSUE_SUMMARY = 'EditIssueSummary'
+EDIT_ISSUE_STATUS = 'EditIssueStatus'
+EDIT_ISSUE_CC = 'EditIssueCc'
+EDIT_ISSUE_APPROVAL = 'EditIssueApproval'
+DELETE_ISSUE = 'DeleteIssue'
+# This allows certain API clients to attribute comments to other users.
+# The permission is not offered in the UI, but it can be typed in as
+# a custom permission name. The ID of the API client is also recorded.
+IMPORT_COMMENT = 'ImportComment'
+ADD_ISSUE_COMMENT = 'AddIssueComment'
+VIEW_INBOUND_MESSAGES = 'ViewInboundMessages'
+CREATE_HOTLIST = 'CreateHotlist'
+# Note, there is no separate DELETE_ATTACHMENT perm. We
+# allow a user to delete an attachment iff they could soft-delete
+# the comment that holds the attachment.
+
+# Note: the "_" in the perm name makes it impossible for a
+# project owner to grant it to anyone as an extra perm.
+ADMINISTER_SITE = '_AdministerSite'
+
+# Permissions to soft-delete artifact comment
+DELETE_ANY = 'DeleteAny'
+DELETE_OWN = 'DeleteOwn'
+
+# Granting this allows owners to delegate some team management work.
+EDIT_ANY_MEMBER_NOTES = 'EditAnyMemberNotes'
+
+# Permission to star/unstar any artifact.
+SET_STAR = 'SetStar'
+
+# Permission to flag any artifact as spam.
+FLAG_SPAM = 'FlagSpam'
+VERDICT_SPAM = 'VerdictSpam'
+MODERATE_SPAM = 'ModerateSpam'
+
+# Permissions for custom fields.
+EDIT_FIELD_DEF = 'EditFieldDef'
+EDIT_FIELD_DEF_VALUE = 'EditFieldDefValue'
+
+# Permissions for user hotlists.
+ADMINISTER_HOTLIST = 'AdministerHotlist'
+EDIT_HOTLIST = 'EditHotlist'
+VIEW_HOTLIST = 'ViewHotlist'
+HOTLIST_OWNER_PERMISSIONS = [ADMINISTER_HOTLIST, EDIT_HOTLIST]
+HOTLIST_EDITOR_PERMISSIONS = [EDIT_HOTLIST]
+
+RESTRICTED_APPROVAL_STATUSES = [
+ tracker_pb2.ApprovalStatus.NA,
+ tracker_pb2.ApprovalStatus.APPROVED,
+ tracker_pb2.ApprovalStatus.NOT_APPROVED]
+
+STANDARD_ADMIN_PERMISSIONS = [
+ EDIT_PROJECT, CREATE_PROJECT, PUBLISH_PROJECT, VIEW_DEBUG,
+ EDIT_OTHER_USERS, CUSTOMIZE_PROCESS,
+ VIEW_QUOTA, EDIT_QUOTA, ADMINISTER_SITE,
+ EDIT_ANY_MEMBER_NOTES, VERDICT_SPAM, MODERATE_SPAM]
+
+STANDARD_ISSUE_PERMISSIONS = [
+ VIEW, EDIT_ISSUE, ADD_ISSUE_COMMENT, DELETE_ISSUE, FLAG_SPAM]
+
+# Monorail has no source control, but keep COMMIT for backward compatability.
+STANDARD_SOURCE_PERMISSIONS = [COMMIT]
+
+STANDARD_COMMENT_PERMISSIONS = [DELETE_OWN, DELETE_ANY]
+
+STANDARD_OTHER_PERMISSIONS = [CREATE_ISSUE, FLAG_SPAM, SET_STAR]
+
+STANDARD_PERMISSIONS = (STANDARD_ADMIN_PERMISSIONS +
+ STANDARD_ISSUE_PERMISSIONS +
+ STANDARD_SOURCE_PERMISSIONS +
+ STANDARD_COMMENT_PERMISSIONS +
+ STANDARD_OTHER_PERMISSIONS)
+
+# roles
+SITE_ADMIN_ROLE = 'admin'
+OWNER_ROLE = 'owner'
+COMMITTER_ROLE = 'committer'
+CONTRIBUTOR_ROLE = 'contributor'
+USER_ROLE = 'user'
+ANON_ROLE = 'anon'
+
+# Project state out-of-band values for keys
+UNDEFINED_STATUS = 'undefined_status'
+UNDEFINED_ACCESS = 'undefined_access'
+WILDCARD_ACCESS = 'wildcard_access'
+
+
+class PermissionSet(object):
+ """Class to represent the set of permissions available to the user."""
+
+ def __init__(self, perm_names, consider_restrictions=True):
+ """Create a PermissionSet with the given permissions.
+
+ Args:
+ perm_names: a list of permission name strings.
+ consider_restrictions: if true, the user's permissions can be blocked
+ by restriction labels on an artifact. Project owners and site
+ admins do not consider restrictions so that they cannot
+ "lock themselves out" of editing an issue.
+ """
+ self.perm_names = frozenset(p.lower() for p in perm_names)
+ self.consider_restrictions = consider_restrictions
+
+ def __getattr__(self, perm_name):
+ """Easy permission testing in EZT. E.g., [if-any perms.format_drive]."""
+ return ezt.boolean(self.HasPerm(perm_name, None, None))
+
+ def CanUsePerm(
+ self, perm_name, effective_ids, project, restriction_labels,
+ granted_perms=None):
+ """Return True if the user can use the given permission.
+
+ Args:
+ perm_name: string name of permission, e.g., 'EditIssue'.
+ effective_ids: set of int user IDs for the user (including any groups),
+ or an empty set if user is not signed in.
+ project: Project PB for the project being accessed, or None if not
+ in a project.
+ restriction_labels: list of strings that restrict permission usage.
+ granted_perms: optional list of lowercase strings of permissions that the
+ user is granted only within the scope of one issue, e.g., by being
+ named in a user-type custom field that grants permissions.
+
+ Restriction labels have 3 parts, e.g.:
+ 'Restrict-EditIssue-InnerCircle' blocks the use of just the
+ EditIssue permission, unless the user also has the InnerCircle
+ permission. This allows fine-grained restrictions on specific
+ actions, such as editing, commenting, or deleting.
+
+ Restriction labels and permissions are case-insensitive.
+
+ Returns:
+ True if the user can use the given permission, or False
+ if they cannot (either because they don't have that permission
+ or because it is blocked by a relevant restriction label).
+ """
+ # TODO(jrobbins): room for performance improvement: avoid set creation and
+ # repeated string operations.
+ granted_perms = granted_perms or set()
+ perm_lower = perm_name.lower()
+ if perm_lower in granted_perms:
+ return True
+
+ needed_perms = {perm_lower}
+ if self.consider_restrictions:
+ for label in restriction_labels:
+ label = label.lower()
+ # format: Restrict-Action-ToThisPerm
+ _kw, requested_perm, needed_perm = label.split('-', 2)
+ if requested_perm == perm_lower and needed_perm not in granted_perms:
+ needed_perms.add(needed_perm)
+
+ if not effective_ids:
+ effective_ids = {framework_constants.NO_USER_SPECIFIED}
+
+ # Get all extra perms for all effective ids.
+ # Id X might have perm A and Y might have B, if both A and B are needed
+ # True should be returned.
+ extra_perms = set()
+ for user_id in effective_ids:
+ extra_perms.update(p.lower() for p in GetExtraPerms(project, user_id))
+ return all(self.HasPerm(perm, None, None, extra_perms)
+ for perm in needed_perms)
+
+ def HasPerm(self, perm_name, user_id, project, extra_perms=None):
+ """Return True if the user has the given permission (ignoring user groups).
+
+ Args:
+ perm_name: string name of permission, e.g., 'EditIssue'.
+ user_id: int user id of the user, or None if user is not signed in.
+ project: Project PB for the project being accessed, or None if not
+ in a project.
+ extra_perms: list of extra perms. If not given, GetExtraPerms will be
+ called to get them.
+
+ Returns:
+ True if the user has the given perm.
+ """
+ perm_name = perm_name.lower()
+
+ # Return early if possible.
+ if perm_name in self.perm_names:
+ return True
+
+ if extra_perms is None:
+ # TODO(jrobbins): room for performance improvement: pre-compute
+ # extra perms (maybe merge them into the perms object), avoid
+ # redundant call to lower().
+ return any(
+ p.lower() == perm_name
+ for p in GetExtraPerms(project, user_id))
+
+ return perm_name in extra_perms
+
+ def DebugString(self):
+ """Return a useful string to show when debugging."""
+ return 'PermissionSet(%s)' % ', '.join(sorted(self.perm_names))
+
+ def __repr__(self):
+ return '%s(%r)' % (self.__class__.__name__, self.perm_names)
+
+
+EMPTY_PERMISSIONSET = PermissionSet([])
+
+READ_ONLY_PERMISSIONSET = PermissionSet([VIEW])
+
+USER_PERMISSIONSET = PermissionSet([
+ VIEW, FLAG_SPAM, SET_STAR,
+ CREATE_ISSUE, ADD_ISSUE_COMMENT,
+ DELETE_OWN])
+
+CONTRIBUTOR_ACTIVE_PERMISSIONSET = PermissionSet(
+ [VIEW,
+ FLAG_SPAM, VERDICT_SPAM, SET_STAR,
+ CREATE_ISSUE, ADD_ISSUE_COMMENT,
+ DELETE_OWN])
+
+CONTRIBUTOR_INACTIVE_PERMISSIONSET = PermissionSet(
+ [VIEW])
+
+COMMITTER_ACTIVE_PERMISSIONSET = PermissionSet(
+ [VIEW, COMMIT, VIEW_CONTRIBUTOR_LIST,
+ FLAG_SPAM, VERDICT_SPAM, SET_STAR, VIEW_QUOTA,
+ CREATE_ISSUE, ADD_ISSUE_COMMENT, EDIT_ISSUE, VIEW_INBOUND_MESSAGES,
+ DELETE_OWN])
+
+COMMITTER_INACTIVE_PERMISSIONSET = PermissionSet(
+ [VIEW, VIEW_CONTRIBUTOR_LIST,
+ VIEW_INBOUND_MESSAGES, VIEW_QUOTA])
+
+OWNER_ACTIVE_PERMISSIONSET = PermissionSet(
+ [VIEW, VIEW_CONTRIBUTOR_LIST, EDIT_PROJECT, COMMIT,
+ FLAG_SPAM, VERDICT_SPAM, SET_STAR, VIEW_QUOTA,
+ CREATE_ISSUE, ADD_ISSUE_COMMENT, EDIT_ISSUE, DELETE_ISSUE,
+ VIEW_INBOUND_MESSAGES,
+ DELETE_ANY, EDIT_ANY_MEMBER_NOTES],
+ consider_restrictions=False)
+
+OWNER_INACTIVE_PERMISSIONSET = PermissionSet(
+ [VIEW, VIEW_CONTRIBUTOR_LIST, EDIT_PROJECT,
+ VIEW_INBOUND_MESSAGES, VIEW_QUOTA],
+ consider_restrictions=False)
+
+ADMIN_PERMISSIONSET = PermissionSet(
+ [VIEW, VIEW_CONTRIBUTOR_LIST,
+ CREATE_PROJECT, EDIT_PROJECT, PUBLISH_PROJECT, VIEW_DEBUG,
+ COMMIT, CUSTOMIZE_PROCESS, FLAG_SPAM, VERDICT_SPAM, SET_STAR,
+ ADMINISTER_SITE, VIEW_EXPIRED_PROJECT, EDIT_OTHER_USERS,
+ VIEW_QUOTA, EDIT_QUOTA,
+ CREATE_ISSUE, ADD_ISSUE_COMMENT, EDIT_ISSUE, DELETE_ISSUE,
+ EDIT_ISSUE_APPROVAL,
+ VIEW_INBOUND_MESSAGES,
+ DELETE_ANY, EDIT_ANY_MEMBER_NOTES,
+ CREATE_GROUP, EDIT_GROUP, DELETE_GROUP, VIEW_GROUP,
+ MODERATE_SPAM, CREATE_HOTLIST],
+ consider_restrictions=False)
+
+GROUP_IMPORT_BORG_PERMISSIONSET = PermissionSet(
+ [CREATE_GROUP, VIEW_GROUP, EDIT_GROUP])
+
+# Permissions for project pages, e.g., the project summary page
+_PERMISSIONS_TABLE = {
+
+ # Project owners can view and edit artifacts in a LIVE project.
+ (OWNER_ROLE, project_pb2.ProjectState.LIVE, WILDCARD_ACCESS):
+ OWNER_ACTIVE_PERMISSIONSET,
+
+ # Project owners can view, but not edit artifacts in ARCHIVED.
+ # Note: EDIT_PROJECT is not enough permission to change an ARCHIVED project
+ # back to LIVE if a delete_time was set.
+ (OWNER_ROLE, project_pb2.ProjectState.ARCHIVED, WILDCARD_ACCESS):
+ OWNER_INACTIVE_PERMISSIONSET,
+
+ # Project members can view their own project, regardless of state.
+ (COMMITTER_ROLE, project_pb2.ProjectState.LIVE, WILDCARD_ACCESS):
+ COMMITTER_ACTIVE_PERMISSIONSET,
+ (COMMITTER_ROLE, project_pb2.ProjectState.ARCHIVED, WILDCARD_ACCESS):
+ COMMITTER_INACTIVE_PERMISSIONSET,
+
+ # Project contributors can view their own project, regardless of state.
+ (CONTRIBUTOR_ROLE, project_pb2.ProjectState.LIVE, WILDCARD_ACCESS):
+ CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ (CONTRIBUTOR_ROLE, project_pb2.ProjectState.ARCHIVED, WILDCARD_ACCESS):
+ CONTRIBUTOR_INACTIVE_PERMISSIONSET,
+
+ # Non-members users can read and comment in projects with access == ANYONE
+ (USER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE):
+ USER_PERMISSIONSET,
+
+ # Anonymous users can only read projects with access == ANYONE.
+ (ANON_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE):
+ READ_ONLY_PERMISSIONSET,
+
+ # Permissions for site pages, e.g., creating a new project
+ (USER_ROLE, UNDEFINED_STATUS, UNDEFINED_ACCESS):
+ PermissionSet([CREATE_PROJECT, CREATE_GROUP, CREATE_HOTLIST]),
+ }
+
+def GetPermissions(user, effective_ids, project):
+ """Return a permission set appropriate for the user and project.
+
+ Args:
+ user: The User PB for the signed-in user, or None for anon users.
+ effective_ids: set of int user IDs for the current user and all user
+ groups that they are a member of. This will be an empty set for
+ anonymous users.
+ project: either a Project protobuf, or None for a page whose scope is
+ wider than a single project.
+
+ Returns:
+ a PermissionSet object for the current user and project (or for
+ site-wide operations if project is None).
+
+ If an exact match for the user's role and project status is found, that is
+ returned. Otherwise, we look for permissions for the user's role that is
+ not specific to any project status, or not specific to any project access
+ level. If neither of those are defined, we give the user an empty
+ permission set.
+ """
+ # Site admins get ADMIN_PERMISSIONSET regardless of groups or projects.
+ if user and user.is_site_admin:
+ return ADMIN_PERMISSIONSET
+
+ # Grant the borg job permission to view/edit groups
+ if user and user.email == settings.borg_service_account:
+ return GROUP_IMPORT_BORG_PERMISSIONSET
+
+ # Anon users don't need to accumulate anything.
+ if not effective_ids:
+ role, status, access = _GetPermissionKey(None, project)
+ return _LookupPermset(role, status, access)
+
+ effective_perms = set()
+ consider_restrictions = True
+
+ # Check for signed-in user with no roles in the current project.
+ if not project or not framework_bizobj.UserIsInProject(
+ project, effective_ids):
+ role, status, access = _GetPermissionKey(None, project)
+ return _LookupPermset(USER_ROLE, status, access)
+
+ # Signed-in user gets the union of all their PermissionSets from the table.
+ for user_id in effective_ids:
+ role, status, access = _GetPermissionKey(user_id, project)
+ role_perms = _LookupPermset(role, status, access)
+ # Accumulate a union of all the user's permissions.
+ effective_perms.update(role_perms.perm_names)
+ # If any role allows the user to ignore restriction labels, then
+ # ignore them overall.
+ if not role_perms.consider_restrictions:
+ consider_restrictions = False
+
+ return PermissionSet(
+ effective_perms, consider_restrictions=consider_restrictions)
+
+
+def UpdateIssuePermissions(
+ perms, project, issue, effective_ids, granted_perms=None, config=None):
+ """Update the PermissionSet for an specific issue.
+
+ Take into account granted permissions and label restrictions to filter the
+ permissions, and updates the VIEW and EDIT_ISSUE permissions depending on the
+ role of the user in the issue (i.e. owner, reporter, cc or approver).
+
+ Args:
+ perms: The PermissionSet to update.
+ project: The Project PB for the issue project.
+ issue: The Issue PB.
+ effective_ids: Set of int user IDs for the current user and all user
+ groups that they are a member of. This will be an empty set for
+ anonymous users.
+ granted_perms: optional list of strings of permissions that the user is
+ granted only within the scope of one issue, e.g., by being named in
+ a user-type custom field that grants permissions.
+ config: optional ProjectIssueConfig PB where granted perms should be
+ extracted from, if granted_perms is not given.
+ """
+ if config:
+ granted_perms = tracker_bizobj.GetGrantedPerms(
+ issue, effective_ids, config)
+ elif granted_perms is None:
+ granted_perms = []
+
+ # If the user has no permission to view the project, it has no permissions on
+ # this issue.
+ if not perms.HasPerm(VIEW, None, None):
+ return EMPTY_PERMISSIONSET
+
+ # Compute the restrictions for the given issue and store them in a dictionary
+ # of {perm: set(needed_perms)}.
+ restrictions = collections.defaultdict(set)
+ if perms.consider_restrictions:
+ for label in GetRestrictions(issue):
+ label = label.lower()
+ # format: Restrict-Action-ToThisPerm
+ _, requested_perm, needed_perm = label.split('-', 2)
+ restrictions[requested_perm.lower()].add(needed_perm.lower())
+
+ # Store the user permissions, and the extra permissions of all effective IDs
+ # in the given project.
+ all_perms = set(perms.perm_names)
+ for effective_id in effective_ids:
+ all_perms.update(p.lower() for p in GetExtraPerms(project, effective_id))
+
+ # And filter them applying the restriction labels.
+ filtered_perms = set()
+ for perm_name in all_perms:
+ perm_name = perm_name.lower()
+ restricted = any(
+ restriction not in all_perms and restriction not in granted_perms
+ for restriction in restrictions.get(perm_name, []))
+ if not restricted:
+ filtered_perms.add(perm_name)
+
+ # Add any granted permissions.
+ filtered_perms.update(granted_perms)
+
+ # The VIEW perm might have been removed due to restrictions, but the issue
+ # owner, reporter, cc and approvers can always be an issue.
+ allowed_ids = set(
+ tracker_bizobj.GetCcIds(issue)
+ + tracker_bizobj.GetApproverIds(issue)
+ + [issue.reporter_id, tracker_bizobj.GetOwnerId(issue)])
+ if effective_ids and not allowed_ids.isdisjoint(effective_ids):
+ filtered_perms.add(VIEW.lower())
+
+ # If the issue is deleted, only the VIEW and DELETE_ISSUE permissions are
+ # relevant.
+ if issue.deleted:
+ if VIEW.lower() not in filtered_perms:
+ return EMPTY_PERMISSIONSET
+ if DELETE_ISSUE.lower() in filtered_perms:
+ return PermissionSet([VIEW, DELETE_ISSUE], perms.consider_restrictions)
+ return PermissionSet([VIEW], perms.consider_restrictions)
+
+ # The EDIT_ISSUE permission might have been removed due to restrictions, but
+ # the owner always has permission to edit it.
+ if effective_ids and tracker_bizobj.GetOwnerId(issue) in effective_ids:
+ filtered_perms.add(EDIT_ISSUE.lower())
+
+ return PermissionSet(filtered_perms, perms.consider_restrictions)
+
+
+def _LookupPermset(role, status, access):
+ """Lookup the appropriate PermissionSet in _PERMISSIONS_TABLE.
+
+ Args:
+ role: a string indicating the user's role in the project.
+ status: a Project PB status value, or UNDEFINED_STATUS.
+ access: a Project PB access value, or UNDEFINED_ACCESS.
+
+ Returns:
+ A PermissionSet that is appropriate for that kind of user in that
+ project context.
+ """
+ if (role, status, access) in _PERMISSIONS_TABLE:
+ return _PERMISSIONS_TABLE[(role, status, access)]
+ elif (role, status, WILDCARD_ACCESS) in _PERMISSIONS_TABLE:
+ return _PERMISSIONS_TABLE[(role, status, WILDCARD_ACCESS)]
+ else:
+ return EMPTY_PERMISSIONSET
+
+
+def _GetPermissionKey(user_id, project, expired_before=None):
+ """Return a permission lookup key appropriate for the user and project."""
+ if user_id is None:
+ role = ANON_ROLE
+ elif project and IsExpired(project, expired_before=expired_before):
+ role = USER_ROLE # Do not honor roles in expired projects.
+ elif project and user_id in project.owner_ids:
+ role = OWNER_ROLE
+ elif project and user_id in project.committer_ids:
+ role = COMMITTER_ROLE
+ elif project and user_id in project.contributor_ids:
+ role = CONTRIBUTOR_ROLE
+ else:
+ role = USER_ROLE
+
+ if project is None:
+ status = UNDEFINED_STATUS
+ else:
+ status = project.state
+
+ if project is None:
+ access = UNDEFINED_ACCESS
+ else:
+ access = project.access
+
+ return role, status, access
+
+
+def GetExtraPerms(project, member_id):
+ """Return a list of extra perms for the user in the project.
+
+ Args:
+ project: Project PB for the current project.
+ member_id: user id of a project owner, member, or contributor.
+
+ Returns:
+ A list of strings for the extra perms granted to the
+ specified user in this project. The list will often be empty.
+ """
+
+ _, extra_perms = FindExtraPerms(project, member_id)
+
+ if extra_perms:
+ return list(extra_perms.perms)
+ else:
+ return []
+
+
+def FindExtraPerms(project, member_id):
+ """Return a ExtraPerms PB for the given user in the project.
+
+ Args:
+ project: Project PB for the current project, or None if the user is
+ not currently in a project.
+ member_id: user ID of a project owner, member, or contributor.
+
+ Returns:
+ A pair (idx, extra_perms).
+ * If project is None or member_id is not part of the project, both are None.
+ * If member_id has no extra_perms, extra_perms is None, and idx points to
+ the position where it should go to keep the ExtraPerms sorted in project.
+ * Otherwise, idx is the position of member_id in the project's extra_perms,
+ and extra_perms is an ExtraPerms PB.
+ """
+ class ExtraPermsView(object):
+ def __len__(self):
+ return len(project.extra_perms)
+ def __getitem__(self, idx):
+ return project.extra_perms[idx].member_id
+
+ if not project:
+ # TODO(jrobbins): maybe define extra perms for site-wide operations.
+ return None, None
+
+ # Users who have no current role cannot have any extra perms. Don't
+ # consider effective_ids (which includes user groups) for this check.
+ if not framework_bizobj.UserIsInProject(project, {member_id}):
+ return None, None
+
+ extra_perms_view = ExtraPermsView()
+ # Find the index of the first extra_perms.member_id greater than or equal to
+ # member_id.
+ idx = bisect.bisect_left(extra_perms_view, member_id)
+ if idx >= len(project.extra_perms) or extra_perms_view[idx] > member_id:
+ return idx, None
+ return idx, project.extra_perms[idx]
+
+
+def GetCustomPermissions(project):
+ """Return a sorted iterable of custom perms granted in a project."""
+ custom_permissions = set()
+ for extra_perms in project.extra_perms:
+ for perm in extra_perms.perms:
+ if perm not in STANDARD_PERMISSIONS:
+ custom_permissions.add(perm)
+
+ return sorted(custom_permissions)
+
+
+def UserCanViewProject(user, effective_ids, project, expired_before=None):
+ """Return True if the user can view the given project.
+
+ Args:
+ user: User protobuf for the user trying to view the project.
+ effective_ids: set of int user IDs of the user trying to view the project
+ (including any groups), or an empty set for anonymous users.
+ project: the Project protobuf to check.
+ expired_before: option time value for testing.
+
+ Returns:
+ True if the user should be allowed to view the project.
+ """
+ perms = GetPermissions(user, effective_ids, project)
+
+ if IsExpired(project, expired_before=expired_before):
+ needed_perm = VIEW_EXPIRED_PROJECT
+ else:
+ needed_perm = VIEW
+
+ return perms.CanUsePerm(needed_perm, effective_ids, project, [])
+
+
+def IsExpired(project, expired_before=None):
+ """Return True if a project deletion has been pending long enough already.
+
+ Args:
+ project: The project being viewed.
+ expired_before: If supplied, this method will return True only if the
+ project expired before the given time.
+
+ Returns:
+ True if the project is eligible for reaping.
+ """
+ if project.state != project_pb2.ProjectState.ARCHIVED:
+ return False
+
+ if expired_before is None:
+ expired_before = int(time.time())
+
+ return project.delete_time and project.delete_time < expired_before
+
+
+def CanDeleteComment(comment, commenter, user_id, perms):
+ """Returns true if the user can (un)delete the given comment.
+
+ UpdateIssuePermissions must have been called first.
+
+ Args:
+ comment: An IssueComment PB object.
+ commenter: An User PB object with the user who created the comment.
+ user_id: The ID of the user whose permission we want to check.
+ perms: The PermissionSet with the issue permissions.
+
+ Returns:
+ True if the user can (un)delete the comment.
+ """
+ # User is not logged in or has no permissions.
+ if not user_id or not perms:
+ return False
+
+ # Nobody can (un)delete comments by banned users or spam comments, which
+ # should be un-flagged instead.
+ if commenter.banned or comment.is_spam:
+ return False
+
+ # Site admin or project owners can delete any comment.
+ permit_delete_any = perms.HasPerm(DELETE_ANY, None, None, [])
+ if permit_delete_any:
+ return True
+
+ # Users cannot undelete unless they deleted.
+ if comment.deleted_by and comment.deleted_by != user_id:
+ return False
+
+ # Users can delete their own items.
+ permit_delete_own = perms.HasPerm(DELETE_OWN, None, None, [])
+ if permit_delete_own and comment.user_id == user_id:
+ return True
+
+ return False
+
+
+def CanFlagComment(comment, commenter, comment_reporters, user_id, perms):
+ """Returns true if the user can flag the given comment.
+
+ UpdateIssuePermissions must have been called first.
+ Assumes that the user has permission to view the issue.
+
+ Args:
+ comment: An IssueComment PB object.
+ commenter: An User PB object with the user who created the comment.
+ comment_reporters: A collection of user IDs who flagged the comment as spam.
+ user_id: The ID of the user for whom we're checking permissions.
+ perms: The PermissionSet with the issue permissions.
+
+ Returns:
+ A tuple (can_flag, is_flagged).
+ can_flag is True if the user can flag the comment. and is_flagged is True
+ if the user sees the comment marked as spam.
+ """
+ # Nobody can flag comments by banned users.
+ if commenter.banned:
+ return False, comment.is_spam
+
+ # If a comment was deleted for a reason other than being spam, nobody can
+ # flag or un-flag it.
+ if comment.deleted_by and not comment.is_spam:
+ return False, comment.is_spam
+
+ # A user with the VerdictSpam permission sees whether the comment is flagged
+ # as spam or not, and can mark it as flagged or un-flagged.
+ # If the comment is flagged as spam, all users see it as flagged, but only
+ # those with the VerdictSpam can un-flag it.
+ permit_verdict_spam = perms.HasPerm(VERDICT_SPAM, None, None, [])
+ if permit_verdict_spam or comment.is_spam:
+ return permit_verdict_spam, comment.is_spam
+
+ # Otherwise, the comment is not marked as flagged and the user doesn't have
+ # the VerdictSpam permission.
+ # They are able to report a comment as spam if they have the FlagSpam
+ # permission, and they see the comment as flagged if the have previously
+ # reported it as spam.
+ permit_flag_spam = perms.HasPerm(FLAG_SPAM, None, None, [])
+ return permit_flag_spam, user_id in comment_reporters
+
+
+def CanViewComment(comment, commenter, user_id, perms):
+ """Returns true if the user can view the given comment.
+
+ UpdateIssuePermissions must have been called first.
+ Assumes that the user has permission to view the issue.
+
+ Args:
+ comment: An IssueComment PB object.
+ commenter: An User PB object with the user who created the comment.
+ user_id: The ID of the user whose permission we want to check.
+ perms: The PermissionSet with the issue permissions.
+
+ Returns:
+ True if the user can view the comment.
+ """
+ # Nobody can view comments by banned users.
+ if commenter.banned:
+ return False
+
+ # Only users with the permission to un-flag comments can view flagged
+ # comments.
+ if comment.is_spam:
+ # If the comment is marked as spam, whether the user can un-flag the comment
+ # or not doesn't depend on who reported it as spam.
+ can_flag, _ = CanFlagComment(comment, commenter, [], user_id, perms)
+ return can_flag
+
+ # Only users with the permission to un-delete comments can view deleted
+ # comments.
+ if comment.deleted_by:
+ return CanDeleteComment(comment, commenter, user_id, perms)
+
+ return True
+
+
+def CanViewInboundMessage(comment, user_id, perms):
+ """Returns true if the user can view the given comment's inbound message.
+
+ UpdateIssuePermissions must have been called first.
+ Assumes that the user has permission to view the comment.
+
+ Args:
+ comment: An IssueComment PB object.
+ commenter: An User PB object with the user who created the comment.
+ user_id: The ID of the user whose permission we want to check.
+ perms: The PermissionSet with the issue permissions.
+
+ Returns:
+ True if the user can view the comment's inbound message.
+ """
+ return (perms.HasPerm(VIEW_INBOUND_MESSAGES, None, None, [])
+ or comment.user_id == user_id)
+
+
+def CanView(effective_ids, perms, project, restrictions, granted_perms=None):
+ """Checks if user has permission to view an issue."""
+ return perms.CanUsePerm(
+ VIEW, effective_ids, project, restrictions, granted_perms=granted_perms)
+
+
+def CanCreateProject(perms):
+ """Return True if the given user may create a project.
+
+ Args:
+ perms: Permissionset for the current user.
+
+ Returns:
+ True if the user should be allowed to create a project.
+ """
+ # "ANYONE" means anyone who has the needed perm.
+ if (settings.project_creation_restriction ==
+ site_pb2.UserTypeRestriction.ANYONE):
+ return perms.HasPerm(CREATE_PROJECT, None, None)
+
+ if (settings.project_creation_restriction ==
+ site_pb2.UserTypeRestriction.ADMIN_ONLY):
+ return perms.HasPerm(ADMINISTER_SITE, None, None)
+
+ return False
+
+
+def CanCreateGroup(perms):
+ """Return True if the given user may create a user group.
+
+ Args:
+ perms: Permissionset for the current user.
+
+ Returns:
+ True if the user should be allowed to create a group.
+ """
+ # "ANYONE" means anyone who has the needed perm.
+ if (settings.group_creation_restriction ==
+ site_pb2.UserTypeRestriction.ANYONE):
+ return perms.HasPerm(CREATE_GROUP, None, None)
+
+ if (settings.group_creation_restriction ==
+ site_pb2.UserTypeRestriction.ADMIN_ONLY):
+ return perms.HasPerm(ADMINISTER_SITE, None, None)
+
+ return False
+
+
+def CanEditGroup(perms, effective_ids, group_owner_ids):
+ """Return True if the given user may edit a user group.
+
+ Args:
+ perms: Permissionset for the current user.
+ effective_ids: set of user IDs for the logged in user.
+ group_owner_ids: set of user IDs of the user group owners.
+
+ Returns:
+ True if the user should be allowed to edit the group.
+ """
+ return (perms.HasPerm(EDIT_GROUP, None, None) or
+ not effective_ids.isdisjoint(group_owner_ids))
+
+
+def CanViewGroupMembers(perms, effective_ids, group_settings, member_ids,
+ owner_ids, user_project_ids):
+ """Return True if the given user may view a user group's members.
+
+ Args:
+ perms: Permissionset for the current user.
+ effective_ids: set of user IDs for the logged in user.
+ group_settings: PB of UserGroupSettings.
+ member_ids: A list of member ids of this user group.
+ owner_ids: A list of owner ids of this user group.
+ user_project_ids: A list of project ids which the user has a role.
+
+ Returns:
+ True if the user should be allowed to view the group's members.
+ """
+ if perms.HasPerm(VIEW_GROUP, None, None):
+ return True
+ # The user could view this group with membership of some projects which are
+ # friends of the group.
+ if (group_settings.friend_projects and user_project_ids
+ and (set(group_settings.friend_projects) & set(user_project_ids))):
+ return True
+ visibility = group_settings.who_can_view_members
+ if visibility == usergroup_pb2.MemberVisibility.OWNERS:
+ return not effective_ids.isdisjoint(owner_ids)
+ elif visibility == usergroup_pb2.MemberVisibility.MEMBERS:
+ return (not effective_ids.isdisjoint(member_ids) or
+ not effective_ids.isdisjoint(owner_ids))
+ else:
+ return True
+
+
+def IsBanned(user, user_view):
+ """Return True if this user is banned from using our site."""
+ if user is None:
+ return False # Anyone is welcome to browse
+
+ if user.banned:
+ return True # We checked the "Banned" checkbox for this user.
+
+ if user_view:
+ if user_view.domain in settings.banned_user_domains:
+ return True # Some spammers create many accounts with the same domain.
+
+ if '+' in (user.email or ''):
+ # Spammers can make plus-addr Google accounts in unexpected domains.
+ return True
+
+ return False
+
+
+def CanBan(mr, services):
+ """Return True if the user is allowed to ban other users, site-wide."""
+ if mr.perms.HasPerm(ADMINISTER_SITE, None, None):
+ return True
+
+ owned, _, _ = services.project.GetUserRolesInAllProjects(mr.cnxn,
+ mr.auth.effective_ids)
+ return len(owned) > 0
+
+
+def CanExpungeUsers(mr):
+ """Return True is the user is allowed to delete user accounts."""
+ return mr.perms.HasPerm(ADMINISTER_SITE, None, None)
+
+
+def CanViewContributorList(mr, project):
+ """Return True if we should display the list project contributors.
+
+ This is used on the project summary page, when deciding to offer the
+ project People page link, and when generating autocomplete options
+ that include project members.
+
+ Args:
+ mr: commonly used info parsed from the request.
+ project: the Project we're interested in.
+
+ Returns:
+ True if we should display the project contributor list.
+ """
+ if not project:
+ return False # We are not even in a project context.
+
+ if not project.only_owners_see_contributors:
+ return True # Contributor list is not resticted.
+
+ # If it is hub-and-spoke, check for the perm that allows the user to
+ # view it anyway.
+ return mr.perms.HasPerm(
+ VIEW_CONTRIBUTOR_LIST, mr.auth.user_id, project)
+
+
+def ShouldCheckForAbandonment(mr):
+ """Return True if user should be warned before changing/deleting their role.
+
+ Args:
+ mr: common info parsed from the user's request.
+
+ Returns:
+ True if user should be warned before changing/deleting their role.
+ """
+ # Note: No need to warn admins because they won't lose access anyway.
+ if mr.perms.CanUsePerm(
+ ADMINISTER_SITE, mr.auth.effective_ids, mr.project, []):
+ return False
+
+ return mr.perms.CanUsePerm(
+ EDIT_PROJECT, mr.auth.effective_ids, mr.project, [])
+
+
+# For speed, we remember labels that we have already classified as being
+# restriction labels or not being restriction labels. These sets are for
+# restrictions in general, not for any particular perm.
+_KNOWN_RESTRICTION_LABELS = set()
+_KNOWN_NON_RESTRICTION_LABELS = set()
+
+
+def IsRestrictLabel(label, perm=''):
+ """Returns True if a given label is a restriction label.
+
+ Args:
+ label: string for the label to examine.
+ perm: a permission that can be restricted (e.g. 'View' or 'Edit').
+ Defaults to '' to mean 'any'.
+
+ Returns:
+ True if a given label is a restriction label (of the specified perm)
+ """
+ if label in _KNOWN_NON_RESTRICTION_LABELS:
+ return False
+ if not perm and label in _KNOWN_RESTRICTION_LABELS:
+ return True
+
+ prefix = ('restrict-%s-' % perm.lower()) if perm else 'restrict-'
+ is_restrict = label.lower().startswith(prefix) and label.count('-') >= 2
+
+ if is_restrict:
+ _KNOWN_RESTRICTION_LABELS.add(label)
+ elif not perm:
+ _KNOWN_NON_RESTRICTION_LABELS.add(label)
+
+ return is_restrict
+
+
+def HasRestrictions(issue, perm=''):
+ """Return True if the issue has any restrictions (on the specified perm)."""
+ return (
+ any(IsRestrictLabel(lab, perm=perm) for lab in issue.labels) or
+ any(IsRestrictLabel(lab, perm=perm) for lab in issue.derived_labels))
+
+
+def GetRestrictions(issue, perm=''):
+ """Return a list of restriction labels on the given issue."""
+ if not issue:
+ return []
+
+ return [lab.lower() for lab in tracker_bizobj.GetLabels(issue)
+ if IsRestrictLabel(lab, perm=perm)]
+
+
+def CanViewIssue(
+ effective_ids, perms, project, issue, allow_viewing_deleted=False,
+ granted_perms=None):
+ """Checks if user has permission to view an artifact.
+
+ Args:
+ effective_ids: set of user IDs for the logged in user and any user
+ group memberships. Should be an empty set for anon users.
+ perms: PermissionSet for the user.
+ project: Project PB for the project that contains this issue.
+ issue: Issue PB for the issue being viewed.
+ allow_viewing_deleted: True if the user should be allowed to view
+ deleted artifacts.
+ granted_perms: optional list of strings of permissions that the user is
+ granted only within the scope of one issue, e.g., by being named in
+ a user-type custom field that grants permissions.
+
+ Returns:
+ True iff the user can view the specified issue.
+ """
+ if issue.deleted and not allow_viewing_deleted:
+ return False
+
+ perms = UpdateIssuePermissions(
+ perms, project, issue, effective_ids, granted_perms=granted_perms)
+ return perms.HasPerm(VIEW, None, None)
+
+
+def CanEditIssue(effective_ids, perms, project, issue, granted_perms=None):
+ """Return True if a user can edit an issue.
+
+ Args:
+ effective_ids: set of user IDs for the logged in user and any user
+ group memberships. Should be an empty set for anon users.
+ perms: PermissionSet for the user.
+ project: Project PB for the project that contains this issue.
+ issue: Issue PB for the issue being viewed.
+ granted_perms: optional list of strings of permissions that the user is
+ granted only within the scope of one issue, e.g., by being named in
+ a user-type custom field that grants permissions.
+
+ Returns:
+ True iff the user can edit the specified issue.
+ """
+ perms = UpdateIssuePermissions(
+ perms, project, issue, effective_ids, granted_perms=granted_perms)
+ return perms.HasPerm(EDIT_ISSUE, None, None)
+
+
+def CanCommentIssue(effective_ids, perms, project, issue, granted_perms=None):
+ """Return True if a user can comment on an issue."""
+
+ return perms.CanUsePerm(
+ ADD_ISSUE_COMMENT, effective_ids, project,
+ GetRestrictions(issue), granted_perms=granted_perms)
+
+
+def CanUpdateApprovalStatus(
+ effective_ids, perms, project, approver_ids, new_status):
+ """Return True if a user can change the approval status to the new status."""
+ if not effective_ids.isdisjoint(approver_ids):
+ return True # Approval approvers can always change the approval status
+
+ if new_status not in RESTRICTED_APPROVAL_STATUSES:
+ return True
+
+ return perms.CanUsePerm(EDIT_ISSUE_APPROVAL, effective_ids, project, [])
+
+
+def CanUpdateApprovers(effective_ids, perms, project, current_approver_ids):
+ """Return True if a user can edit the list of approvers for an approval."""
+ if not effective_ids.isdisjoint(current_approver_ids):
+ return True
+
+ return perms.CanUsePerm(EDIT_ISSUE_APPROVAL, effective_ids, project, [])
+
+
+def CanViewComponentDef(effective_ids, perms, project, component_def):
+ """Return True if a user can view the given component definition."""
+ if not effective_ids.isdisjoint(component_def.admin_ids):
+ return True # Component admins can view that component.
+
+ # TODO(jrobbins): check restrictions on the component definition.
+ return perms.CanUsePerm(VIEW, effective_ids, project, [])
+
+
+def CanEditComponentDef(effective_ids, perms, project, component_def, config):
+ """Return True if a user can edit the given component definition."""
+ if not effective_ids.isdisjoint(component_def.admin_ids):
+ return True # Component admins can edit that component.
+
+ # Check to see if user is admin of any parent component.
+ parent_components = tracker_bizobj.FindAncestorComponents(
+ config, component_def)
+ for parent in parent_components:
+ if not effective_ids.isdisjoint(parent.admin_ids):
+ return True
+
+ return perms.CanUsePerm(EDIT_PROJECT, effective_ids, project, [])
+
+
+def CanViewFieldDef(effective_ids, perms, project, field_def):
+ """Return True if a user can view the given field definition."""
+ if not effective_ids.isdisjoint(field_def.admin_ids):
+ return True # Field admins can view that field.
+
+ # TODO(jrobbins): check restrictions on the field definition.
+ return perms.CanUsePerm(VIEW, effective_ids, project, [])
+
+
+def CanEditFieldDef(effective_ids, perms, project, field_def):
+ """Return True if a user can edit the given field definition."""
+ if not effective_ids.isdisjoint(field_def.admin_ids):
+ return True # Field admins can edit that field.
+
+ return perms.CanUsePerm(EDIT_PROJECT, effective_ids, project, [])
+
+
+def CanEditValueForFieldDef(effective_ids, perms, project, field_def):
+ """Return True if a user can edit the given field definition value.
+ This method does not check that a user can edit the project issues."""
+ if not effective_ids:
+ return False
+ if not field_def.is_restricted_field:
+ return True
+ if not effective_ids.isdisjoint(field_def.editor_ids):
+ return True
+ return CanEditFieldDef(effective_ids, perms, project, field_def)
+
+
+def CanViewTemplate(effective_ids, perms, project, template):
+ """Return True if a user can view the given issue template."""
+ if not effective_ids.isdisjoint(template.admin_ids):
+ return True # template admins can view that template.
+
+ # Members-only templates are only shown to members, other templates are
+ # shown to any user that is generally allowed to view project content.
+ if template.members_only:
+ return framework_bizobj.UserIsInProject(project, effective_ids)
+ else:
+ return perms.CanUsePerm(VIEW, effective_ids, project, [])
+
+
+def CanEditTemplate(effective_ids, perms, project, template):
+ """Return True if a user can edit the given field definition."""
+ if not effective_ids.isdisjoint(template.admin_ids):
+ return True # Template admins can edit that template.
+
+ return perms.CanUsePerm(EDIT_PROJECT, effective_ids, project, [])
+
+
+def CanViewHotlist(effective_ids, perms, hotlist):
+ """Return True if a user can view the given hotlist."""
+ if not hotlist.is_private or perms.HasPerm(ADMINISTER_SITE, None, None):
+ return True
+
+ return any([user_id in (hotlist.owner_ids + hotlist.editor_ids)
+ for user_id in effective_ids])
+
+
+def CanEditHotlist(effective_ids, perms, hotlist):
+ """Return True if a user is editor(add/remove issues and change rankings)."""
+ return perms.HasPerm(ADMINISTER_SITE, None, None) or any(
+ [user_id in (hotlist.owner_ids + hotlist.editor_ids)
+ for user_id in effective_ids])
+
+
+def CanAdministerHotlist(effective_ids, perms, hotlist):
+ """Return True if user is owner(add/remove members, edit/delete hotlist)."""
+ return perms.HasPerm(ADMINISTER_SITE, None, None) or any(
+ [user_id in hotlist.owner_ids for user_id in effective_ids])
+
+
+def CanCreateHotlist(perms):
+ """Return True if the given user may create a hotlist.
+
+ Args:
+ perms: Permissionset for the current user.
+
+ Returns:
+ True if the user should be allowed to create a hotlist.
+ """
+ if (settings.hotlist_creation_restriction ==
+ site_pb2.UserTypeRestriction.ANYONE):
+ return perms.HasPerm(CREATE_HOTLIST, None, None)
+
+ if (settings.hotlist_creation_restriction ==
+ site_pb2.UserTypeRestriction.ADMIN_ONLY):
+ return perms.HasPerm(ADMINISTER_SITE, None, None)
+
+
+class Error(Exception):
+ """Base class for errors from this module."""
+
+
+class PermissionException(Error):
+ """The user is not authorized to make the current request."""
+
+
+class BannedUserException(Error):
+ """The user has been banned from using our service."""
diff --git a/framework/profiler.py b/framework/profiler.py
new file mode 100644
index 0000000..362585f
--- /dev/null
+++ b/framework/profiler.py
@@ -0,0 +1,200 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A simple profiler object to track how time is spent on a request.
+
+The profiler is called from application code at the begining and
+end of each major phase and subphase of processing. The profiler
+object keeps track of how much time was spent on each phase or subphase.
+
+This class is useful when developers need to understand where
+server-side time is being spent. It includes durations in
+milliseconds, and a simple bar chart on the HTML page.
+
+On-page debugging and performance info is useful because it makes it easier
+to explore performance interactively.
+"""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import datetime
+import logging
+import random
+import re
+import threading
+import time
+
+from infra_libs import ts_mon
+
+from contextlib import contextmanager
+
+from google.appengine.api import app_identity
+
+PHASE_TIME = ts_mon.CumulativeDistributionMetric(
+ 'monorail/servlet/phase_time',
+ 'Time spent in profiler phases, in ms',
+ [ts_mon.StringField('phase')])
+
+# trace_service requires names less than 128 bytes
+# https://cloud.google.com/trace/docs/reference/v1/rest/v1/projects.traces#Trace
+MAX_PHASE_NAME_LENGTH = 128
+
+
+class Profiler(object):
+ """Object to record and help display request processing profiling info.
+
+ The Profiler class holds a list of phase objects, which can hold additional
+ phase objects (which are subphases). Each phase or subphase represents some
+ meaningful part of this application's HTTP request processing.
+ """
+
+ _COLORS = ['900', '090', '009', '360', '306', '036',
+ '630', '630', '063', '333']
+
+ def __init__(self, opt_trace_context=None, opt_trace_service=None):
+ """Each request processing profile begins with an empty list of phases."""
+ self.top_phase = _Phase('overall profile', -1, None)
+ self.current_phase = self.top_phase
+ self.next_color = 0
+ self.original_thread_id = threading.current_thread().ident
+ self.trace_context = opt_trace_context
+ self.trace_service = opt_trace_service
+ self.project_id = app_identity.get_application_id()
+
+ def StartPhase(self, name='unspecified phase'):
+ """Begin a (sub)phase by pushing a new phase onto a stack."""
+ if self.original_thread_id != threading.current_thread().ident:
+ return # We only profile the main thread.
+ color = self._COLORS[self.next_color % len(self._COLORS)]
+ self.next_color += 1
+ self.current_phase = _Phase(name, color, self.current_phase)
+
+ def EndPhase(self):
+ """End a (sub)phase by poping the phase stack."""
+ if self.original_thread_id != threading.current_thread().ident:
+ return # We only profile the main thread.
+ self.current_phase = self.current_phase.End()
+
+ @contextmanager
+ def Phase(self, name='unspecified phase'):
+ """Context manager to automatically begin and end (sub)phases."""
+ self.StartPhase(name)
+ try:
+ yield
+ finally:
+ self.EndPhase()
+
+ def LogStats(self):
+ """Log sufficiently-long phases and subphases, for debugging purposes."""
+ self.top_phase.End()
+ lines = ['Stats:']
+ self.top_phase.AccumulateStatLines(self.top_phase.elapsed_seconds, lines)
+ logging.info('\n'.join(lines))
+
+ def ReportTrace(self):
+ """Send a profile trace to Google Cloud Tracing."""
+ self.top_phase.End()
+ spans = self.top_phase.SpanJson()
+ if not self.trace_service or not self.trace_context:
+ logging.info('would have sent trace: %s', spans)
+ return
+
+ # Format of trace_context: 'TRACE_ID/SPAN_ID;o=TRACE_TRUE'
+ # (from https://cloud.google.com/trace/docs/troubleshooting#force-trace)
+ # TODO(crbug/monorail:7086): Respect the o=TRACE_TRUE part.
+ # Note: on Appngine it seems ';o=1' is omitted rather than set to 0.
+ trace_id, root_span_id = self.trace_context.split(';')[0].split('/')
+ for s in spans:
+ # TODO(crbug/monorail:7087): Consider setting `parentSpanId` to
+ # `root_span_id` for the children of `top_phase`.
+ if not 'parentSpanId' in s:
+ s['parentSpanId'] = root_span_id
+ traces_body = {
+ 'projectId': self.project_id,
+ 'traceId': trace_id,
+ 'spans': spans,
+ }
+ body = {
+ 'traces': [traces_body]
+ }
+ # TODO(crbug/monorail:7088): Do this async so it doesn't delay the response.
+ request = self.trace_service.projects().patchTraces(
+ projectId=self.project_id, body=body)
+ _res = request.execute()
+
+
+class _Phase(object):
+ """A _Phase instance represents a period of time during request processing."""
+
+ def __init__(self, name, color, parent):
+ """Initialize a (sub)phase with the given name and current system clock."""
+ self.start = time.time()
+ self.name = name[:MAX_PHASE_NAME_LENGTH]
+ self.color = color
+ self.subphases = []
+ self.elapsed_seconds = None
+ self.ms = 'in_progress' # shown if the phase never records a finish.
+ self.uncategorized_ms = None
+ self.parent = parent
+ if self.parent is not None:
+ self.parent._RegisterSubphase(self)
+
+ self.id = str(random.getrandbits(64))
+
+
+ def _RegisterSubphase(self, subphase):
+ """Add a subphase to this phase."""
+ self.subphases.append(subphase)
+
+ def End(self):
+ """Record the time between the start and end of this (sub)phase."""
+ self.elapsed_seconds = time.time() - self.start
+ self.ms = int(self.elapsed_seconds * 1000)
+ for sub in self.subphases:
+ if sub.elapsed_seconds is None:
+ logging.warn('issue3182: subphase is %r', sub and sub.name)
+ categorized = sum(sub.elapsed_seconds or 0.0 for sub in self.subphases)
+ self.uncategorized_ms = int((self.elapsed_seconds - categorized) * 1000)
+ return self.parent
+
+ def AccumulateStatLines(self, total_seconds, lines, indent=''):
+ # Only phases that took longer than 30ms are interesting.
+ if self.ms <= 30:
+ return
+
+ percent = self.elapsed_seconds // total_seconds * 100
+ lines.append('%s%5d ms (%2d%%): %s' % (indent, self.ms, percent, self.name))
+
+ # Remove IDs etc to reduce the phase name cardinality for ts_mon.
+ normalized_phase = re.sub('[0-9]+', '', self.name)
+ PHASE_TIME.add(self.ms, {'phase': normalized_phase})
+
+ for subphase in self.subphases:
+ subphase.AccumulateStatLines(total_seconds, lines, indent=indent + ' ')
+
+ def SpanJson(self):
+ """Return a json representation of this phase as a GCP Cloud Trace object.
+ """
+ endTime = self.start + self.elapsed_seconds
+
+ span = {
+ 'kind': 'RPC_SERVER',
+ 'name': self.name,
+ 'spanId': self.id,
+ 'startTime':
+ datetime.datetime.fromtimestamp(self.start).isoformat() + 'Z',
+ 'endTime': datetime.datetime.fromtimestamp(endTime).isoformat() + 'Z',
+ }
+
+ if self.parent is not None and self.parent.id is not None:
+ span['parentSpanId'] = self.parent.id
+
+ spans = [span]
+ for s in self.subphases:
+ spans.extend(s.SpanJson())
+
+ return spans
diff --git a/framework/ratelimiter.py b/framework/ratelimiter.py
new file mode 100644
index 0000000..b2bbb25
--- /dev/null
+++ b/framework/ratelimiter.py
@@ -0,0 +1,292 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Request rate limiting implementation.
+
+This is intented to be used for automatic DDoS protection.
+
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import datetime
+import logging
+import os
+import settings
+import time
+
+from infra_libs import ts_mon
+
+from google.appengine.api import memcache
+from google.appengine.api.modules import modules
+from google.appengine.api import users
+
+from services import client_config_svc
+
+
+N_MINUTES = 5
+EXPIRE_AFTER_SECS = 60 * 60
+DEFAULT_LIMIT = 60 * N_MINUTES # 300 page requests in 5 minutes is 1 QPS.
+DEFAULT_API_QPM = 1000 # For example, chromiumdash uses ~64 per page, 8s each.
+
+ANON_USER = 'anon'
+
+COUNTRY_HEADER = 'X-AppEngine-Country'
+
+COUNTRY_LIMITS = {
+ # Two-letter country code: max requests per N_MINUTES
+ # This limit will apply to all requests coming
+ # from this country.
+ # To add a country code, see GAE logs and use the
+ # appropriate code from https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2
+ # E.g., 'cn': 300, # Limit to 1 QPS.
+}
+
+# Modules not in this list will not have rate limiting applied by this
+# class.
+MODULE_ALLOWLIST = ['default', 'api']
+
+
+def _CacheKeys(request, now_sec):
+ """ Returns an array of arrays. Each array contains strings with
+ the same prefix and a timestamp suffix, starting with the most
+ recent and decrementing by 1 minute each time.
+ """
+ now = datetime.datetime.fromtimestamp(now_sec)
+ country = request.headers.get(COUNTRY_HEADER, 'ZZ')
+ ip = request.remote_addr
+ minute_buckets = [now - datetime.timedelta(minutes=m) for m in
+ range(N_MINUTES)]
+ user = users.get_current_user()
+ user_email = user.email() if user else ANON_USER
+
+ # <IP, country, user_email> to be rendered into each key prefix.
+ prefixes = []
+
+ # All logged-in users get a per-user rate limit, regardless of IP and country.
+ if user:
+ prefixes.append(['ALL', 'ALL', user.email()])
+ else:
+ # All anon requests get a per-IP ratelimit.
+ prefixes.append([ip, 'ALL', 'ALL'])
+
+ # All requests from a problematic country get a per-country rate limit,
+ # regardless of the user (even a non-logged-in one) or IP.
+ if country in COUNTRY_LIMITS:
+ prefixes.append(['ALL', country, 'ALL'])
+
+ keysets = []
+ for prefix in prefixes:
+ keysets.append(['ratelimit-%s-%s' % ('-'.join(prefix),
+ str(minute_bucket.replace(second=0, microsecond=0)))
+ for minute_bucket in minute_buckets])
+
+ return keysets, country, ip, user_email
+
+
+def _CreateApiCacheKeys(client_id, client_email, now_sec):
+ country = os.environ.get('HTTP_X_APPENGINE_COUNTRY')
+ ip = os.environ.get('REMOTE_ADDR')
+ now = datetime.datetime.fromtimestamp(now_sec)
+ minute_buckets = [now - datetime.timedelta(minutes=m) for m in
+ range(N_MINUTES)]
+ minute_strs = [str(minute_bucket.replace(second=0, microsecond=0))
+ for minute_bucket in minute_buckets]
+ keys = []
+
+ if client_id and client_id != 'anonymous':
+ keys.append(['apiratelimit-%s-%s' % (client_id, minute_str)
+ for minute_str in minute_strs])
+ if client_email:
+ keys.append(['apiratelimit-%s-%s' % (client_email, minute_str)
+ for minute_str in minute_strs])
+ else:
+ keys.append(['apiratelimit-%s-%s' % (ip, minute_str)
+ for minute_str in minute_strs])
+ if country in COUNTRY_LIMITS:
+ keys.append(['apiratelimit-%s-%s' % (country, minute_str)
+ for minute_str in minute_strs])
+
+ return keys
+
+
+class RateLimiter(object):
+
+ blocked_requests = ts_mon.CounterMetric(
+ 'monorail/ratelimiter/blocked_request',
+ 'Count of requests that exceeded the rate limit and were blocked.',
+ None)
+ limit_exceeded = ts_mon.CounterMetric(
+ 'monorail/ratelimiter/rate_exceeded',
+ 'Count of requests that exceeded the rate limit.',
+ None)
+ cost_thresh_exceeded = ts_mon.CounterMetric(
+ 'monorail/ratelimiter/cost_thresh_exceeded',
+ 'Count of requests that were expensive to process',
+ None)
+ checks = ts_mon.CounterMetric(
+ 'monorail/ratelimiter/check',
+ 'Count of checks done, by fail/success type.',
+ [ts_mon.StringField('type')])
+
+ def __init__(self, _cache=memcache, fail_open=True, **_kwargs):
+ self.fail_open = fail_open
+
+ def CheckStart(self, request, now=None):
+ if (modules.get_current_module_name() not in MODULE_ALLOWLIST or
+ users.is_current_user_admin()):
+ return
+ logging.info('X-AppEngine-Country: %s' %
+ request.headers.get(COUNTRY_HEADER, 'ZZ'))
+
+ if now is None:
+ now = time.time()
+
+ keysets, country, ip, user_email = _CacheKeys(request, now)
+ # There are either two or three sets of keys in keysets.
+ # Three if the user's country is in COUNTRY_LIMITS, otherwise two.
+ self._AuxCheckStart(
+ keysets, COUNTRY_LIMITS.get(country, DEFAULT_LIMIT),
+ settings.ratelimiting_enabled,
+ RateLimitExceeded(country=country, ip=ip, user_email=user_email))
+
+ def _AuxCheckStart(self, keysets, limit, ratelimiting_enabled,
+ exception_obj):
+ for keys in keysets:
+ count = 0
+ try:
+ counters = memcache.get_multi(keys)
+ count = sum(counters.values())
+ self.checks.increment({'type': 'success'})
+ except Exception as e:
+ logging.error(e)
+ if not self.fail_open:
+ self.checks.increment({'type': 'fail_closed'})
+ raise exception_obj
+ self.checks.increment({'type': 'fail_open'})
+
+ if count > limit:
+ # Since webapp2 won't let us return a 429 error code
+ # <http://tools.ietf.org/html/rfc6585#section-4>, we can't
+ # monitor rate limit exceeded events with our standard tools.
+ # We return a 400 with a custom error message to the client,
+ # and this logging is so we can monitor it internally.
+ logging.info('%s, %d' % (exception_obj.message, count))
+
+ self.limit_exceeded.increment()
+
+ if ratelimiting_enabled:
+ self.blocked_requests.increment()
+ raise exception_obj
+
+ k = keys[0]
+ # Only update the latest *time* bucket for each prefix (reverse chron).
+ memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
+ memcache.incr(k, initial_value=0)
+
+ def CheckEnd(self, request, now, start_time):
+ """If a request was expensive to process, charge some extra points
+ against this set of buckets.
+ We pass in both now and start_time so we can update the buckets
+ based on keys created from start_time instead of now.
+ now and start_time are float seconds.
+ """
+ if (modules.get_current_module_name() not in MODULE_ALLOWLIST):
+ return
+
+ elapsed_ms = int((now - start_time) * 1000)
+ # Would it kill the python lib maintainers to have timedelta.total_ms()?
+ penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1
+ if penalty >= 1:
+ # TODO: Look into caching the keys instead of generating them twice
+ # for every request. Say, return them from CheckStart so they can
+ # be passed back in here later.
+ keysets, country, ip, user_email = _CacheKeys(request, start_time)
+
+ self._AuxCheckEnd(
+ keysets,
+ 'Rate Limit Cost Threshold Exceeded: %s, %s, %s' % (
+ country, ip, user_email),
+ penalty)
+
+ def _AuxCheckEnd(self, keysets, log_str, penalty):
+ self.cost_thresh_exceeded.increment()
+ for keys in keysets:
+ logging.info(log_str)
+
+ # Only update the latest *time* bucket for each prefix (reverse chron).
+ k = keys[0]
+ memcache.add(k, 0, time=EXPIRE_AFTER_SECS)
+ memcache.incr(k, delta=penalty, initial_value=0)
+
+
+class ApiRateLimiter(RateLimiter):
+
+ blocked_requests = ts_mon.CounterMetric(
+ 'monorail/apiratelimiter/blocked_request',
+ 'Count of requests that exceeded the rate limit and were blocked.',
+ None)
+ limit_exceeded = ts_mon.CounterMetric(
+ 'monorail/apiratelimiter/rate_exceeded',
+ 'Count of requests that exceeded the rate limit.',
+ None)
+ cost_thresh_exceeded = ts_mon.CounterMetric(
+ 'monorail/apiratelimiter/cost_thresh_exceeded',
+ 'Count of requests that were expensive to process',
+ None)
+ checks = ts_mon.CounterMetric(
+ 'monorail/apiratelimiter/check',
+ 'Count of checks done, by fail/success type.',
+ [ts_mon.StringField('type')])
+
+ #pylint: disable=arguments-differ
+ def CheckStart(self, client_id, client_email, now=None):
+ if now is None:
+ now = time.time()
+
+ keysets = _CreateApiCacheKeys(client_id, client_email, now)
+ qpm_limit = client_config_svc.GetQPMDict().get(
+ client_email, DEFAULT_API_QPM)
+ if qpm_limit < DEFAULT_API_QPM:
+ qpm_limit = DEFAULT_API_QPM
+ window_limit = qpm_limit * N_MINUTES
+ self._AuxCheckStart(
+ keysets, window_limit,
+ settings.api_ratelimiting_enabled,
+ ApiRateLimitExceeded(client_id, client_email))
+
+ #pylint: disable=arguments-differ
+ def CheckEnd(self, client_id, client_email, now, start_time):
+
+ elapsed_ms = int((now - start_time) * 1000)
+ penalty = elapsed_ms // settings.ratelimiting_ms_per_count - 1
+
+ if penalty >= 1:
+ keysets = _CreateApiCacheKeys(client_id, client_email, start_time)
+ self._AuxCheckEnd(
+ keysets,
+ 'API Rate Limit Cost Threshold Exceeded: %s, %s' % (
+ client_id, client_email),
+ penalty)
+
+
+class RateLimitExceeded(Exception):
+ def __init__(self, country=None, ip=None, user_email=None, **_kwargs):
+ self.country = country
+ self.ip = ip
+ self.user_email = user_email
+ message = 'RateLimitExceeded: %s, %s, %s' % (
+ self.country, self.ip, self.user_email)
+ super(RateLimitExceeded, self).__init__(message)
+
+
+class ApiRateLimitExceeded(Exception):
+ def __init__(self, client_id, client_email):
+ self.client_id = client_id
+ self.client_email = client_email
+ message = 'RateLimitExceeded: %s, %s' % (
+ self.client_id, self.client_email)
+ super(ApiRateLimitExceeded, self).__init__(message)
diff --git a/framework/reap.py b/framework/reap.py
new file mode 100644
index 0000000..6bc5cf0
--- /dev/null
+++ b/framework/reap.py
@@ -0,0 +1,125 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A class to handle cron requests to expunge doomed and deletable projects."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import time
+
+from framework import jsonfeed
+
+RUN_DURATION_LIMIT = 50 * 60 # 50 minutes
+
+
+class Reap(jsonfeed.InternalTask):
+ """Look for doomed and deletable projects and delete them."""
+
+ def HandleRequest(self, mr):
+ """Update/Delete doomed and deletable projects as needed.
+
+ Args:
+ mr: common information parsed from the HTTP request.
+
+ Returns:
+ Results dictionary in JSON format. The JSON will look like this:
+ {
+ 'doomed_project_ids': <int>,
+ 'expunged_project_ids': <int>
+ }
+ doomed_project_ids are the projects which have been marked as deletable.
+ expunged_project_ids are the projects that have either been completely
+ expunged or are in the midst of being expunged.
+ """
+ doomed_project_ids = self._MarkDoomedProjects(mr.cnxn)
+ expunged_project_ids = self._ExpungeDeletableProjects(mr.cnxn)
+ return {
+ 'doomed_project_ids': doomed_project_ids,
+ 'expunged_project_ids': expunged_project_ids,
+ }
+
+ def _MarkDoomedProjects(self, cnxn):
+ """No longer needed projects get doomed, and this marks them deletable."""
+ now = int(time.time())
+ doomed_project_rows = self.services.project.project_tbl.Select(
+ cnxn, cols=['project_id'],
+ # We only match projects with real timestamps and not delete_time = 0.
+ where=[('delete_time < %s', [now]), ('delete_time != %s', [0])],
+ state='archived', limit=1000)
+ doomed_project_ids = [row[0] for row in doomed_project_rows]
+ for project_id in doomed_project_ids:
+ # Note: We go straight to services layer because this is an internal
+ # request, not a request from a user.
+ self.services.project.MarkProjectDeletable(
+ cnxn, project_id, self.services.config)
+
+ return doomed_project_ids
+
+ def _ExpungeDeletableProjects(self, cnxn):
+ """Chip away at deletable projects until they are gone."""
+ request_deadline = time.time() + RUN_DURATION_LIMIT
+
+ deletable_project_rows = self.services.project.project_tbl.Select(
+ cnxn, cols=['project_id'], state='deletable', limit=100)
+ deletable_project_ids = [row[0] for row in deletable_project_rows]
+ # expunged_project_ids will contain projects that have either been
+ # completely expunged or are in the midst of being expunged.
+ expunged_project_ids = set()
+ for project_id in deletable_project_ids:
+ for _part in self._ExpungeParts(cnxn, project_id):
+ expunged_project_ids.add(project_id)
+ if time.time() > request_deadline:
+ return list(expunged_project_ids)
+
+ return list(expunged_project_ids)
+
+ def _ExpungeParts(self, cnxn, project_id):
+ """Delete all data from the specified project, one part at a time.
+
+ This method purges all data associated with the specified project. The
+ following is purged:
+ * All issues of the project.
+ * Project config.
+ * Saved queries.
+ * Filter rules.
+ * Former locations.
+ * Local ID counters.
+ * Quick edit history.
+ * Item stars.
+ * Project from the DB.
+
+ Returns a generator whose return values can be either issue
+ ids or the specified project id. The returned values are intended to be
+ iterated over and not read.
+ """
+ # Purge all issues of the project.
+ while True:
+ issue_id_rows = self.services.issue.issue_tbl.Select(
+ cnxn, cols=['id'], project_id=project_id, limit=1000)
+ issue_ids = [row[0] for row in issue_id_rows]
+ for issue_id in issue_ids:
+ self.services.issue_star.ExpungeStars(cnxn, issue_id)
+ self.services.issue.ExpungeIssues(cnxn, issue_ids)
+ yield issue_ids
+ break
+
+ # All project purge functions are called with cnxn and project_id.
+ project_purge_functions = (
+ self.services.config.ExpungeConfig,
+ self.services.template.ExpungeProjectTemplates,
+ self.services.features.ExpungeSavedQueriesExecuteInProject,
+ self.services.features.ExpungeFilterRules,
+ self.services.issue.ExpungeFormerLocations,
+ self.services.issue.ExpungeLocalIDCounters,
+ self.services.features.ExpungeQuickEditHistory,
+ self.services.project_star.ExpungeStars,
+ self.services.project.ExpungeProject,
+ )
+
+ for f in project_purge_functions:
+ f(cnxn, project_id)
+ yield project_id
diff --git a/framework/redis_utils.py b/framework/redis_utils.py
new file mode 100644
index 0000000..440603b
--- /dev/null
+++ b/framework/redis_utils.py
@@ -0,0 +1,125 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""A utility module for interfacing with Redis conveniently. """
+import json
+import logging
+import threading
+
+import redis
+
+import settings
+from protorpc import protobuf
+
+connection_pool = None
+
+def CreateRedisClient():
+ # type: () -> redis.Redis
+ """Creates a Redis object which implements Redis protocol and connection.
+
+ Returns:
+ redis.Redis object initialized with a connection pool.
+ None on failure.
+ """
+ global connection_pool
+ if not connection_pool:
+ connection_pool = redis.BlockingConnectionPool(
+ host=settings.redis_host,
+ port=settings.redis_port,
+ max_connections=1,
+ # When Redis is not available, calls hang indefinitely without these.
+ socket_connect_timeout=2,
+ socket_timeout=2,
+ )
+ return redis.Redis(connection_pool=connection_pool)
+
+
+def AsyncVerifyRedisConnection():
+ # type: () -> None
+ """Verifies the redis connection in a separate thread.
+
+ Note that although an exception in the thread won't kill the main thread,
+ it is not risk free.
+
+ AppEngine joins with any running threads before finishing the request.
+ If this thread were to hang indefinitely, then it would cause the request
+ to hit DeadlineExceeded, thus still causing a user facing failure.
+
+ We mitigate this risk by setting socket timeouts on our connection pool.
+
+ # TODO(crbug/monorail/8221): Remove this code during this milestone.
+ """
+
+ def _AsyncVerifyRedisConnection():
+ logging.info('AsyncVerifyRedisConnection thread started.')
+ redis_client = CreateRedisClient()
+ VerifyRedisConnection(redis_client)
+
+ logging.info('Starting thread for AsyncVerifyRedisConnection.')
+ threading.Thread(target=_AsyncVerifyRedisConnection).start()
+
+
+def FormatRedisKey(key, prefix=None):
+ # type: (int, str) -> str
+ """Converts key to string and prepends the prefix.
+
+ Args:
+ key: Integer key.
+ prefix: String to prepend to the key.
+
+ Returns:
+ Formatted key with the format: "namespace:prefix:key".
+ """
+ formatted_key = ''
+ if prefix:
+ if prefix[-1] != ':':
+ prefix += ':'
+ formatted_key += prefix
+ return formatted_key + str(key)
+
+def VerifyRedisConnection(redis_client, msg=None):
+ # type: (redis.Redis, Optional[str]) -> bool
+ """Checks the connection to Redis to ensure a connection can be established.
+
+ Args:
+ redis_client: client to connect and ping redis server. This can be a redis
+ or fakeRedis object.
+ msg: string for used logging information.
+
+ Returns:
+ True when connection to server is valid.
+ False when an error occurs or redis_client is None.
+ """
+ if not redis_client:
+ logging.info('Redis client is set to None on connect in %s', msg)
+ return False
+ try:
+ redis_client.ping()
+ logging.info('Redis client successfully connected to Redis in %s', msg)
+ return True
+ except redis.RedisError as identifier:
+ # TODO(crbug/monorail/8224): We can downgrade this to warning once we are
+ # done with the switchover from memcache. Before that, log it to ensure we
+ # see it.
+ logging.exception(
+ 'Redis error occurred while connecting to server in %s: %s', msg,
+ identifier)
+ return False
+
+
+def SerializeValue(value, pb_class=None):
+ # type: (Any, Optional[type|classobj]) -> str
+ """Serialize object as for storage in Redis. """
+ if pb_class and pb_class is not int:
+ return protobuf.encode_message(value)
+ else:
+ return json.dumps(value)
+
+
+def DeserializeValue(value, pb_class=None):
+ # type: (str, Optional[type|classobj]) -> Any
+ """Deserialize a string to create a python object. """
+ if pb_class and pb_class is not int:
+ return protobuf.decode_message(pb_class, value)
+ else:
+ return json.loads(value)
diff --git a/framework/registerpages_helpers.py b/framework/registerpages_helpers.py
new file mode 100644
index 0000000..9982639
--- /dev/null
+++ b/framework/registerpages_helpers.py
@@ -0,0 +1,81 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""This file sets up all the urls for monorail pages."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+
+import httplib
+import logging
+
+import webapp2
+
+
+def MakeRedirect(redirect_to_this_uri, permanent=True):
+ """Return a new request handler class that redirects to the given URL."""
+
+ class Redirect(webapp2.RequestHandler):
+ """Redirect is a response handler that issues a redirect to another URI."""
+
+ def get(self, **_kw):
+ """Send the 301/302 response code and write the Location: redirect."""
+ self.response.location = redirect_to_this_uri
+ self.response.headers.add('Strict-Transport-Security',
+ 'max-age=31536000; includeSubDomains')
+ self.response.status = (
+ httplib.MOVED_PERMANENTLY if permanent else httplib.FOUND)
+
+ return Redirect
+
+
+def MakeRedirectInScope(uri_in_scope, scope, permanent=True, keep_qs=False):
+ """Redirect to a URI within a given scope, e.g., per project or user.
+
+ Args:
+ uri_in_scope: a uri within a project or user starting with a slash.
+ scope: a string indicating the uri-space scope:
+ p for project pages
+ u for user pages
+ g for group pages
+ permanent: True for a HTTP 301 permanently moved response code,
+ otherwise a HTTP 302 temporarily moved response will be used.
+ keep_qs: set to True to make the redirect retain the query string.
+ When true, permanent is ignored.
+
+ Example:
+ self._SetupProjectPage(
+ redirect.MakeRedirectInScope('/newpage', 'p'), '/oldpage')
+
+ Returns:
+ A class that can be used with webapp2.
+ """
+ assert uri_in_scope.startswith('/')
+
+ class RedirectInScope(webapp2.RequestHandler):
+ """A handler that redirects to another URI in the same scope."""
+
+ def get(self, **_kw):
+ """Send the 301/302 response code and write the Location: redirect."""
+ split_path = self.request.path.lstrip('/').split('/')
+ if len(split_path) > 1:
+ project_or_user = split_path[1]
+ url = '//%s/%s/%s%s' % (
+ self.request.host, scope, project_or_user, uri_in_scope)
+ else:
+ url = '/'
+ if keep_qs and self.request.query_string:
+ url += '?' + self.request.query_string
+ self.response.location = url
+
+ self.response.headers.add('Strict-Transport-Security',
+ 'max-age=31536000; includeSubDomains')
+ if permanent and not keep_qs:
+ self.response.status = httplib.MOVED_PERMANENTLY
+ else:
+ self.response.status = httplib.FOUND
+
+ return RedirectInScope
diff --git a/framework/servlet.py b/framework/servlet.py
new file mode 100644
index 0000000..1ed6935
--- /dev/null
+++ b/framework/servlet.py
@@ -0,0 +1,1047 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Base classes for Monorail servlets.
+
+This base class provides HTTP get() and post() methods that
+conveniently drive the process of parsing the request, checking base
+permissions, gathering common page information, gathering
+page-specific information, and adding on-page debugging information
+(when appropriate). Subclasses can simply implement the page-specific
+logic.
+
+Summary of page classes:
+ Servlet: abstract base class for all Monorail servlets.
+ _ContextDebugItem: displays page_data elements for on-page debugging.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import gc
+import httplib
+import json
+import logging
+import os
+import random
+import time
+import urllib
+
+import ezt
+from third_party import httpagentparser
+
+from google.appengine.api import app_identity
+from google.appengine.api import modules
+from google.appengine.api import users
+from oauth2client.client import GoogleCredentials
+
+import webapp2
+
+import settings
+from businesslogic import work_env
+from features import savedqueries_helpers
+from features import features_bizobj
+from features import hotlist_views
+from framework import alerts
+from framework import exceptions
+from framework import framework_constants
+from framework import framework_helpers
+from framework import framework_views
+from framework import monorailrequest
+from framework import permissions
+from framework import ratelimiter
+from framework import servlet_helpers
+from framework import template_helpers
+from framework import urls
+from framework import xsrf
+from project import project_constants
+from proto import project_pb2
+from search import query2ast
+from tracker import tracker_views
+
+from infra_libs import ts_mon
+
+NONCE_LENGTH = 32
+
+if not settings.unit_test_mode:
+ import MySQLdb
+
+GC_COUNT = ts_mon.NonCumulativeDistributionMetric(
+ 'monorail/servlet/gc_count',
+ 'Count of objects in each generation tracked by the GC',
+ [ts_mon.IntegerField('generation')])
+
+GC_EVENT_REQUEST = ts_mon.CounterMetric(
+ 'monorail/servlet/gc_event_request',
+ 'Counts of requests that triggered at least one GC event',
+ [])
+
+# TODO(crbug/monorail:7084): Find a better home for this code.
+trace_service = None
+# TOD0(crbug/monorail:7082): Re-enable this once we have a solution that doesn't
+# inur clatency, or when we're actively using Cloud Tracing data.
+# if app_identity.get_application_id() != 'testing-app':
+# logging.warning('app id: %s', app_identity.get_application_id())
+# try:
+# credentials = GoogleCredentials.get_application_default()
+# trace_service = discovery.build(
+# 'cloudtrace', 'v1', credentials=credentials)
+# except Exception as e:
+# logging.warning('could not get trace service: %s', e)
+
+
+class MethodNotSupportedError(NotImplementedError):
+ """An exception class for indicating that the method is not supported.
+
+ Used by GatherPageData and ProcessFormData to indicate that GET and POST,
+ respectively, are not supported methods on the given Servlet.
+ """
+ pass
+
+
+class Servlet(webapp2.RequestHandler):
+ """Base class for all Monorail servlets.
+
+ Defines a framework of methods that build up parts of the EZT page data.
+
+ Subclasses should override GatherPageData and/or ProcessFormData to
+ handle requests.
+ """
+
+ _MAIN_TAB_MODE = None # Normally overriden in subclasses to be one of these:
+
+ MAIN_TAB_NONE = 't0'
+ MAIN_TAB_DASHBOARD = 't1'
+ MAIN_TAB_ISSUES = 't2'
+ MAIN_TAB_PEOPLE = 't3'
+ MAIN_TAB_PROCESS = 't4'
+ MAIN_TAB_UPDATES = 't5'
+ MAIN_TAB_ADMIN = 't6'
+ MAIN_TAB_DETAILS = 't7'
+ PROCESS_TAB_SUMMARY = 'st1'
+ PROCESS_TAB_STATUSES = 'st3'
+ PROCESS_TAB_LABELS = 'st4'
+ PROCESS_TAB_RULES = 'st5'
+ PROCESS_TAB_TEMPLATES = 'st6'
+ PROCESS_TAB_COMPONENTS = 'st7'
+ PROCESS_TAB_VIEWS = 'st8'
+ ADMIN_TAB_META = 'st1'
+ ADMIN_TAB_ADVANCED = 'st9'
+ HOTLIST_TAB_ISSUES = 'ht2'
+ HOTLIST_TAB_PEOPLE = 'ht3'
+ HOTLIST_TAB_DETAILS = 'ht4'
+
+ # Most forms require a security token, however if a form is really
+ # just redirecting to a search GET request without writing any data,
+ # subclass can override this to allow anonymous use.
+ CHECK_SECURITY_TOKEN = True
+
+ # Some pages might be posted to by clients outside of Monorail.
+ # ie: The issue entry page, by the issue filing wizard. In these cases,
+ # we can allow an xhr-scoped XSRF token to be used to post to the page.
+ ALLOW_XHR = False
+
+ # Most forms just ignore fields that have value "". Subclasses can override
+ # if needed.
+ KEEP_BLANK_FORM_VALUES = False
+
+ # Most forms use regular forms, but subclasses that accept attached files can
+ # override this to be True.
+ MULTIPART_POST_BODY = False
+
+ # This value should not typically be overridden.
+ _TEMPLATE_PATH = framework_constants.TEMPLATE_PATH
+
+ _PAGE_TEMPLATE = None # Normally overriden in subclasses.
+ _ELIMINATE_BLANK_LINES = False
+
+ _MISSING_PERMISSIONS_TEMPLATE = 'sitewide/403-page.ezt'
+
+ def __init__(self, request, response, services=None,
+ content_type='text/html; charset=UTF-8'):
+ """Load and parse the template, saving it for later use."""
+ super(Servlet, self).__init__(request, response)
+ if self._PAGE_TEMPLATE: # specified in subclasses
+ template_path = self._TEMPLATE_PATH + self._PAGE_TEMPLATE
+ self.template = template_helpers.GetTemplate(
+ template_path, eliminate_blank_lines=self._ELIMINATE_BLANK_LINES)
+ else:
+ self.template = None
+
+ self._missing_permissions_template = template_helpers.MonorailTemplate(
+ self._TEMPLATE_PATH + self._MISSING_PERMISSIONS_TEMPLATE)
+ self.services = services or self.app.config.get('services')
+ self.content_type = content_type
+ self.mr = None
+ self.ratelimiter = ratelimiter.RateLimiter()
+
+ def dispatch(self):
+ """Do common stuff then dispatch the request to get() or put() methods."""
+ handler_start_time = time.time()
+
+ logging.info('\n\n\nRequest handler: %r', self)
+ count0, count1, count2 = gc.get_count()
+ logging.info('gc counts: %d %d %d', count0, count1, count2)
+ GC_COUNT.add(count0, {'generation': 0})
+ GC_COUNT.add(count1, {'generation': 1})
+ GC_COUNT.add(count2, {'generation': 2})
+
+ self.mr = monorailrequest.MonorailRequest(self.services)
+
+ self.response.headers.add('Strict-Transport-Security',
+ 'max-age=31536000; includeSubDomains')
+
+ if 'X-Cloud-Trace-Context' in self.request.headers:
+ self.mr.profiler.trace_context = (
+ self.request.headers.get('X-Cloud-Trace-Context'))
+ # TOD0(crbug/monorail:7082): Re-enable tracing.
+ # if trace_service is not None:
+ # self.mr.profiler.trace_service = trace_service
+
+ if self.services.cache_manager:
+ # TODO(jrobbins): don't do this step if invalidation_timestep was
+ # passed via the request and matches our last timestep
+ try:
+ with self.mr.profiler.Phase('distributed invalidation'):
+ self.services.cache_manager.DoDistributedInvalidation(self.mr.cnxn)
+
+ except MySQLdb.OperationalError as e:
+ logging.exception(e)
+ page_data = {
+ 'http_response_code': httplib.SERVICE_UNAVAILABLE,
+ 'requested_url': self.request.url,
+ }
+ self.template = template_helpers.GetTemplate(
+ 'templates/framework/database-maintenance.ezt',
+ eliminate_blank_lines=self._ELIMINATE_BLANK_LINES)
+ self.template.WriteResponse(
+ self.response, page_data, content_type='text/html')
+ return
+
+ try:
+ self.ratelimiter.CheckStart(self.request)
+
+ with self.mr.profiler.Phase('parsing request and doing lookups'):
+ self.mr.ParseRequest(self.request, self.services)
+
+ self.response.headers['X-Frame-Options'] = 'SAMEORIGIN'
+ webapp2.RequestHandler.dispatch(self)
+
+ except exceptions.NoSuchUserException as e:
+ logging.warning('Trapped NoSuchUserException %s', e)
+ self.abort(404, 'user not found')
+
+ except exceptions.NoSuchGroupException as e:
+ logging.warning('Trapped NoSuchGroupException %s', e)
+ self.abort(404, 'user group not found')
+
+ except exceptions.InputException as e:
+ logging.info('Rejecting invalid input: %r', e)
+ self.response.status = httplib.BAD_REQUEST
+
+ except exceptions.NoSuchProjectException as e:
+ logging.info('Rejecting invalid request: %r', e)
+ self.response.status = httplib.NOT_FOUND
+
+ except xsrf.TokenIncorrect as e:
+ logging.info('Bad XSRF token: %r', e.message)
+ self.response.status = httplib.BAD_REQUEST
+
+ except permissions.BannedUserException as e:
+ logging.warning('The user has been banned')
+ url = framework_helpers.FormatAbsoluteURL(
+ self.mr, urls.BANNED, include_project=False, copy_params=False)
+ self.redirect(url, abort=True)
+
+ except ratelimiter.RateLimitExceeded as e:
+ logging.info('RateLimitExceeded Exception %s', e)
+ self.response.status = httplib.BAD_REQUEST
+ self.response.body = 'Slow your roll.'
+
+ finally:
+ self.mr.CleanUp()
+ self.ratelimiter.CheckEnd(self.request, time.time(), handler_start_time)
+
+ total_processing_time = time.time() - handler_start_time
+ logging.info(
+ 'Processed request in %d ms', int(total_processing_time * 1000))
+
+ end_count0, end_count1, end_count2 = gc.get_count()
+ logging.info('gc counts: %d %d %d', end_count0, end_count1, end_count2)
+ if (end_count0 < count0) or (end_count1 < count1) or (end_count2 < count2):
+ GC_EVENT_REQUEST.increment()
+
+ if settings.enable_profiler_logging:
+ self.mr.profiler.LogStats()
+
+ # TOD0(crbug/monorail:7082, crbug/monorail:7088): Re-enable this when we
+ # have solved the latency, or when we really need the profiler data.
+ # if self.mr.profiler.trace_context is not None:
+ # try:
+ # self.mr.profiler.ReportTrace()
+ # except Exception as ex:
+ # # We never want Cloud Tracing to cause a user-facing error.
+ # logging.warning('Ignoring exception reporting Cloud Trace %s', ex)
+
+ def _AddHelpDebugPageData(self, page_data):
+ with self.mr.profiler.Phase('help and debug data'):
+ page_data.update(self.GatherHelpData(self.mr, page_data))
+ page_data.update(self.GatherDebugData(self.mr, page_data))
+
+ # pylint: disable=unused-argument
+ def get(self, **kwargs):
+ """Collect page-specific and generic info, then render the page.
+
+ Args:
+ Any path components parsed by webapp2 will be in kwargs, but we do
+ our own parsing later anyway, so igore them for now.
+ """
+ page_data = {}
+ nonce = framework_helpers.MakeRandomKey(length=NONCE_LENGTH)
+ try:
+ csp_header = 'Content-Security-Policy'
+ csp_scheme = 'https:'
+ if settings.local_mode:
+ csp_header = 'Content-Security-Policy-Report-Only'
+ csp_scheme = 'http:'
+ user_agent_str = self.mr.request.headers.get('User-Agent', '')
+ ua = httpagentparser.detect(user_agent_str)
+ browser, browser_major_version = 'Unknown browser', 0
+ if ua.has_key('browser'):
+ browser = ua['browser']['name']
+ try:
+ browser_major_version = int(ua['browser']['version'].split('.')[0])
+ except ValueError:
+ logging.warn('Could not parse version: %r', ua['browser']['version'])
+ csp_supports_report_sample = (
+ (browser == 'Chrome' and browser_major_version >= 59) or
+ (browser == 'Opera' and browser_major_version >= 46))
+ version_base = _VersionBaseURL(self.mr.request)
+ self.response.headers.add(csp_header,
+ ("default-src %(scheme)s ; "
+ "script-src"
+ " %(rep_samp)s" # Report 40 chars of any inline violation.
+ " 'unsafe-inline'" # Only counts in browsers that lack CSP2.
+ " 'strict-dynamic'" # Allows <script nonce> to load more.
+ " %(version_base)s/static/dist/"
+ " 'self' 'nonce-%(nonce)s'; "
+ "child-src 'none'; "
+ "frame-src accounts.google.com" # All used by gapi.js auth.
+ " content-issuetracker.corp.googleapis.com"
+ " login.corp.google.com up.corp.googleapis.com"
+ # Used by Google Feedback
+ " feedback.googleusercontent.com"
+ " www.google.com; "
+ "img-src %(scheme)s data: blob: ; "
+ "style-src %(scheme)s 'unsafe-inline'; "
+ "object-src 'none'; "
+ "base-uri 'self'; " # Used by Google Feedback
+ "report-uri /csp.do" % {
+ 'nonce': nonce,
+ 'scheme': csp_scheme,
+ 'rep_samp': "'report-sample'" if csp_supports_report_sample else '',
+ 'version_base': version_base,
+ }))
+
+ page_data.update(self._GatherFlagData(self.mr))
+
+ # Page-specific work happens in this call.
+ page_data.update(self._DoPageProcessing(self.mr, nonce))
+
+ self._AddHelpDebugPageData(page_data)
+
+ with self.mr.profiler.Phase('rendering template'):
+ self._RenderResponse(page_data)
+
+ except (MethodNotSupportedError, NotImplementedError) as e:
+ # Instead of these pages throwing 500s display the 404 message and log.
+ # The motivation of this is to minimize 500s on the site to keep alerts
+ # meaningful during fuzzing. For more context see
+ # https://bugs.chromium.org/p/monorail/issues/detail?id=659
+ logging.warning('Trapped NotImplementedError %s', e)
+ self.abort(404, 'invalid page')
+ except query2ast.InvalidQueryError as e:
+ logging.warning('Trapped InvalidQueryError: %s', e)
+ logging.exception(e)
+ msg = e.message if e.message else 'invalid query'
+ self.abort(400, msg)
+ except permissions.PermissionException as e:
+ logging.warning('Trapped PermissionException %s', e)
+ logging.warning('mr.auth.user_id is %s', self.mr.auth.user_id)
+ logging.warning('mr.auth.effective_ids is %s', self.mr.auth.effective_ids)
+ logging.warning('mr.perms is %s', self.mr.perms)
+ if not self.mr.auth.user_id:
+ # If not logged in, let them log in
+ url = _SafeCreateLoginURL(self.mr)
+ self.redirect(url, abort=True)
+ else:
+ # Display the missing permissions template.
+ page_data = {
+ 'reason': e.message,
+ 'http_response_code': httplib.FORBIDDEN,
+ }
+ with self.mr.profiler.Phase('gather base data'):
+ page_data.update(self.GatherBaseData(self.mr, nonce))
+ self._AddHelpDebugPageData(page_data)
+ self._missing_permissions_template.WriteResponse(
+ self.response, page_data, content_type=self.content_type)
+
+ def SetCacheHeaders(self, response):
+ """Set headers to allow the response to be cached."""
+ headers = framework_helpers.StaticCacheHeaders()
+ for name, value in headers:
+ response.headers[name] = value
+
+ def GetTemplate(self, _page_data):
+ """Get the template to use for writing the http response.
+
+ Defaults to self.template. This method can be overwritten in subclasses
+ to allow dynamic template selection based on page_data.
+
+ Args:
+ _page_data: A dict of data for ezt rendering, containing base ezt
+ data, page data, and debug data.
+
+ Returns:
+ The template to be used for writing the http response.
+ """
+ return self.template
+
+ def _GatherFlagData(self, mr):
+ page_data = {
+ 'project_stars_enabled': ezt.boolean(
+ settings.enable_project_stars),
+ 'user_stars_enabled': ezt.boolean(settings.enable_user_stars),
+ 'can_create_project': ezt.boolean(
+ permissions.CanCreateProject(mr.perms)),
+ 'can_create_group': ezt.boolean(
+ permissions.CanCreateGroup(mr.perms)),
+ }
+
+ return page_data
+
+ def _RenderResponse(self, page_data):
+ logging.info('rendering response len(page_data) is %r', len(page_data))
+ self.GetTemplate(page_data).WriteResponse(
+ self.response, page_data, content_type=self.content_type)
+
+ def ProcessFormData(self, mr, post_data):
+ """Handle form data and redirect appropriately.
+
+ Args:
+ mr: commonly used info parsed from the request.
+ post_data: HTML form data from the request.
+
+ Returns:
+ String URL to redirect the user to, or None if response was already sent.
+ """
+ raise MethodNotSupportedError()
+
+ def post(self, **kwargs):
+ """Parse the request, check base perms, and call form-specific code."""
+ try:
+ # Page-specific work happens in this call.
+ self._DoFormProcessing(self.request, self.mr)
+
+ except permissions.PermissionException as e:
+ logging.warning('Trapped permission-related exception "%s".', e)
+ # TODO(jrobbins): can we do better than an error page? not much.
+ self.response.status = httplib.BAD_REQUEST
+
+ def _DoCommonRequestProcessing(self, request, mr):
+ """Do common processing dependent on having the user and project pbs."""
+ with mr.profiler.Phase('basic processing'):
+ self._CheckForMovedProject(mr, request)
+ self.AssertBasePermission(mr)
+
+ def _DoPageProcessing(self, mr, nonce):
+ """Do user lookups and gather page-specific ezt data."""
+ with mr.profiler.Phase('common request data'):
+ self._DoCommonRequestProcessing(self.request, mr)
+ self._MaybeRedirectToBrandedDomain(self.request, mr.project_name)
+ page_data = self.GatherBaseData(mr, nonce)
+
+ with mr.profiler.Phase('page processing'):
+ page_data.update(self.GatherPageData(mr))
+ page_data.update(mr.form_overrides)
+ template_helpers.ExpandLabels(page_data)
+ self._RecordVisitTime(mr)
+
+ return page_data
+
+ def _DoFormProcessing(self, request, mr):
+ """Do user lookups and handle form data."""
+ self._DoCommonRequestProcessing(request, mr)
+
+ if self.CHECK_SECURITY_TOKEN:
+ try:
+ xsrf.ValidateToken(
+ request.POST.get('token'), mr.auth.user_id, request.path)
+ except xsrf.TokenIncorrect as err:
+ if self.ALLOW_XHR:
+ xsrf.ValidateToken(request.POST.get('token'), mr.auth.user_id, 'xhr')
+ else:
+ raise err
+
+ redirect_url = self.ProcessFormData(mr, request.POST)
+
+ # Most forms redirect the user to a new URL on success. If no
+ # redirect_url was returned, the form handler must have already
+ # sent a response. E.g., bounced the user back to the form with
+ # invalid form fields higlighted.
+ if redirect_url:
+ self.redirect(redirect_url, abort=True)
+ else:
+ assert self.response.body
+
+ def _CheckForMovedProject(self, mr, request):
+ """If the project moved, redirect there or to an informational page."""
+ if not mr.project:
+ return # We are on a site-wide or user page.
+ if not mr.project.moved_to:
+ return # This project has not moved.
+ admin_url = '/p/%s%s' % (mr.project_name, urls.ADMIN_META)
+ if request.path.startswith(admin_url):
+ return # It moved, but we are near the page that can un-move it.
+
+ logging.info('project %s has moved: %s', mr.project.project_name,
+ mr.project.moved_to)
+
+ moved_to = mr.project.moved_to
+ if project_constants.RE_PROJECT_NAME.match(moved_to):
+ # Use the redir query parameter to avoid redirect loops.
+ if mr.redir is None:
+ url = framework_helpers.FormatMovedProjectURL(mr, moved_to)
+ if '?' in url:
+ url += '&redir=1'
+ else:
+ url += '?redir=1'
+ logging.info('trusted move to a new project on our site')
+ self.redirect(url, abort=True)
+
+ logging.info('not a trusted move, will display link to user to click')
+ # Attach the project name as a url param instead of generating a /p/
+ # link to the destination project.
+ url = framework_helpers.FormatAbsoluteURL(
+ mr, urls.PROJECT_MOVED,
+ include_project=False, copy_params=False, project=mr.project_name)
+ self.redirect(url, abort=True)
+
+ def _MaybeRedirectToBrandedDomain(self, request, project_name):
+ """If we are live and the project should be branded, check request host."""
+ if request.params.get('redir'):
+ return # Avoid any chance of a redirect loop.
+ if not project_name:
+ return
+ needed_domain = framework_helpers.GetNeededDomain(
+ project_name, request.host)
+ if not needed_domain:
+ return
+
+ url = 'https://%s%s' % (needed_domain, request.path_qs)
+ if '?' in url:
+ url += '&redir=1'
+ else:
+ url += '?redir=1'
+ logging.info('branding redirect to url %r', url)
+ self.redirect(url, abort=True)
+
+ def CheckPerm(self, mr, perm, art=None, granted_perms=None):
+ """Return True if the user can use the requested permission."""
+ return servlet_helpers.CheckPerm(
+ mr, perm, art=art, granted_perms=granted_perms)
+
+ def MakePagePerms(self, mr, art, *perm_list, **kwargs):
+ """Make an EZTItem with a set of permissions needed in a given template.
+
+ Args:
+ mr: commonly used info parsed from the request.
+ art: a project artifact, such as an issue.
+ *perm_list: any number of permission names that are referenced
+ in the EZT template.
+ **kwargs: dictionary that may include 'granted_perms' list of permissions
+ granted to the current user specifically on the current page.
+
+ Returns:
+ An EZTItem with one attribute for each permission and the value
+ of each attribute being an ezt.boolean(). True if the user
+ is permitted to do that action on the given artifact, or
+ False if not.
+ """
+ granted_perms = kwargs.get('granted_perms')
+ page_perms = template_helpers.EZTItem()
+ for perm in perm_list:
+ setattr(
+ page_perms, perm,
+ ezt.boolean(self.CheckPerm(
+ mr, perm, art=art, granted_perms=granted_perms)))
+
+ return page_perms
+
+ def AssertBasePermission(self, mr):
+ """Make sure that the logged in user has permission to view this page.
+
+ Subclasses should call super, then check additional permissions
+ and raise a PermissionException if the user is not authorized to
+ do something.
+
+ Args:
+ mr: commonly used info parsed from the request.
+
+ Raises:
+ PermissionException: If the user does not have permisssion to view
+ the current page.
+ """
+ servlet_helpers.AssertBasePermission(mr)
+
+ def GatherBaseData(self, mr, nonce):
+ """Return a dict of info used on almost all pages."""
+ project = mr.project
+
+ project_summary = ''
+ project_alert = None
+ project_read_only = False
+ project_home_page = ''
+ project_thumbnail_url = ''
+ if project:
+ project_summary = project.summary
+ project_alert = _CalcProjectAlert(project)
+ project_read_only = project.read_only_reason
+ project_home_page = project.home_page
+ project_thumbnail_url = tracker_views.LogoView(project).thumbnail_url
+
+ with work_env.WorkEnv(mr, self.services) as we:
+ is_project_starred = False
+ project_view = None
+ if mr.project:
+ if permissions.UserCanViewProject(
+ mr.auth.user_pb, mr.auth.effective_ids, mr.project):
+ is_project_starred = we.IsProjectStarred(mr.project_id)
+ # TODO(jrobbins): should this be a ProjectView?
+ project_view = template_helpers.PBProxy(mr.project)
+
+ grid_x_attr = None
+ grid_y_attr = None
+ hotlist_view = None
+ if mr.hotlist:
+ users_by_id = framework_views.MakeAllUserViews(
+ mr.cnxn, self.services.user,
+ features_bizobj.UsersInvolvedInHotlists([mr.hotlist]))
+ hotlist_view = hotlist_views.HotlistView(
+ mr.hotlist, mr.perms, mr.auth, mr.viewed_user_auth.user_id,
+ users_by_id, self.services.hotlist_star.IsItemStarredBy(
+ mr.cnxn, mr.hotlist.hotlist_id, mr.auth.user_id))
+ grid_x_attr = mr.x.lower()
+ grid_y_attr = mr.y.lower()
+
+ app_version = os.environ.get('CURRENT_VERSION_ID')
+
+ viewed_username = None
+ if mr.viewed_user_auth.user_view:
+ viewed_username = mr.viewed_user_auth.user_view.username
+
+ config = None
+ if mr.project_id and self.services.config:
+ with mr.profiler.Phase('getting config'):
+ config = self.services.config.GetProjectConfig(mr.cnxn, mr.project_id)
+ grid_x_attr = (mr.x or config.default_x_attr).lower()
+ grid_y_attr = (mr.y or config.default_y_attr).lower()
+
+ viewing_self = mr.auth.user_id == mr.viewed_user_auth.user_id
+ offer_saved_queries_subtab = (
+ viewing_self or mr.auth.user_pb and mr.auth.user_pb.is_site_admin)
+
+ login_url = _SafeCreateLoginURL(mr)
+ logout_url = _SafeCreateLogoutURL(mr)
+ logout_url_goto_home = users.create_logout_url('/')
+ version_base = _VersionBaseURL(mr.request)
+
+ base_data = {
+ # EZT does not have constants for True and False, so we pass them in.
+ 'True':
+ ezt.boolean(True),
+ 'False':
+ ezt.boolean(False),
+ 'local_mode':
+ ezt.boolean(settings.local_mode),
+ 'site_name':
+ settings.site_name,
+ 'show_search_metadata':
+ ezt.boolean(False),
+ 'page_template':
+ self._PAGE_TEMPLATE,
+ 'main_tab_mode':
+ self._MAIN_TAB_MODE,
+ 'project_summary':
+ project_summary,
+ 'project_home_page':
+ project_home_page,
+ 'project_thumbnail_url':
+ project_thumbnail_url,
+ 'hotlist_id':
+ mr.hotlist_id,
+ 'hotlist':
+ hotlist_view,
+ 'hostport':
+ mr.request.host,
+ 'absolute_base_url':
+ '%s://%s' % (mr.request.scheme, mr.request.host),
+ 'project_home_url':
+ None,
+ 'link_rel_canonical':
+ None, # For specifying <link rel="canonical">
+ 'projectname':
+ mr.project_name,
+ 'project':
+ project_view,
+ 'project_is_restricted':
+ ezt.boolean(_ProjectIsRestricted(mr)),
+ 'offer_contributor_list':
+ ezt.boolean(permissions.CanViewContributorList(mr, mr.project)),
+ 'logged_in_user':
+ mr.auth.user_view,
+ 'form_token':
+ None, # Set to a value below iff the user is logged in.
+ 'form_token_path':
+ None,
+ 'token_expires_sec':
+ None,
+ 'xhr_token':
+ None, # Set to a value below iff the user is logged in.
+ 'flag_spam_token':
+ None,
+ 'nonce':
+ nonce,
+ 'perms':
+ mr.perms,
+ 'warnings':
+ mr.warnings,
+ 'errors':
+ mr.errors,
+ 'viewed_username':
+ viewed_username,
+ 'viewed_user':
+ mr.viewed_user_auth.user_view,
+ 'viewed_user_pb':
+ template_helpers.PBProxy(mr.viewed_user_auth.user_pb),
+ 'viewing_self':
+ ezt.boolean(viewing_self),
+ 'viewed_user_id':
+ mr.viewed_user_auth.user_id,
+ 'offer_saved_queries_subtab':
+ ezt.boolean(offer_saved_queries_subtab),
+ 'currentPageURL':
+ mr.current_page_url,
+ 'currentPageURLEncoded':
+ mr.current_page_url_encoded,
+ 'login_url':
+ login_url,
+ 'logout_url':
+ logout_url,
+ 'logout_url_goto_home':
+ logout_url_goto_home,
+ 'continue_issue_id':
+ mr.continue_issue_id,
+ 'feedback_email':
+ settings.feedback_email,
+ 'category_css':
+ None, # Used to specify a category of stylesheet
+ 'category2_css':
+ None, # specify a 2nd category of stylesheet if needed.
+ 'page_css':
+ None, # Used to add a stylesheet to a specific page.
+ 'can':
+ mr.can,
+ 'query':
+ mr.query,
+ 'colspec':
+ None,
+ 'sortspec':
+ mr.sort_spec,
+
+ # Options for issuelist display
+ 'grid_x_attr':
+ grid_x_attr,
+ 'grid_y_attr':
+ grid_y_attr,
+ 'grid_cell_mode':
+ mr.cells,
+ 'grid_mode':
+ None,
+ 'list_mode':
+ None,
+ 'chart_mode':
+ None,
+ 'is_cross_project':
+ ezt.boolean(False),
+
+ # for project search (some also used in issue search)
+ 'start':
+ mr.start,
+ 'num':
+ mr.num,
+ 'groupby':
+ mr.group_by_spec,
+ 'q_field_size': (min(
+ framework_constants.MAX_ARTIFACT_SEARCH_FIELD_SIZE,
+ max(framework_constants.MIN_ARTIFACT_SEARCH_FIELD_SIZE,
+ len(mr.query) + framework_constants.AUTOSIZE_STEP))),
+ 'mode':
+ None, # Display mode, e.g., grid mode.
+ 'ajah':
+ mr.ajah,
+ 'table_title':
+ mr.table_title,
+ 'alerts':
+ alerts.AlertsView(mr), # For alert.ezt
+ 'project_alert':
+ project_alert,
+ 'title':
+ None, # First part of page title
+ 'title_summary':
+ None, # Appended to title on artifact detail pages
+
+ # TODO(jrobbins): make sure that the templates use
+ # project_read_only for project-mutative actions and if any
+ # uses of read_only remain.
+ 'project_read_only':
+ ezt.boolean(project_read_only),
+ 'site_read_only':
+ ezt.boolean(settings.read_only),
+ 'banner_time':
+ servlet_helpers.GetBannerTime(settings.banner_time),
+ 'read_only':
+ ezt.boolean(settings.read_only or project_read_only),
+ 'site_banner_message':
+ settings.banner_message,
+ 'robots_no_index':
+ None,
+ 'analytics_id':
+ settings.analytics_id,
+ 'is_project_starred':
+ ezt.boolean(is_project_starred),
+ 'version_base':
+ version_base,
+ 'app_version':
+ app_version,
+ 'gapi_client_id':
+ settings.gapi_client_id,
+ 'viewing_user_page':
+ ezt.boolean(False),
+ 'old_ui_url':
+ None,
+ 'new_ui_url':
+ None,
+ 'is_member':
+ ezt.boolean(False),
+ }
+
+ if mr.project:
+ base_data['project_home_url'] = '/p/%s' % mr.project_name
+
+ # Always add xhr-xsrf token because even anon users need some
+ # pRPC methods, e.g., autocomplete, flipper, and charts.
+ base_data['token_expires_sec'] = xsrf.TokenExpiresSec()
+ base_data['xhr_token'] = xsrf.GenerateToken(
+ mr.auth.user_id, xsrf.XHR_SERVLET_PATH)
+ # Always add other anti-xsrf tokens when the user is logged in.
+ if mr.auth.user_id:
+ form_token_path = self._FormHandlerURL(mr.request.path)
+ base_data['form_token'] = xsrf.GenerateToken(
+ mr.auth.user_id, form_token_path)
+ base_data['form_token_path'] = form_token_path
+
+ return base_data
+
+ def _FormHandlerURL(self, path):
+ """Return the form handler for the main form on a page."""
+ if path.endswith('/'):
+ return path + 'edit.do'
+ elif path.endswith('.do'):
+ return path # This happens as part of PleaseCorrect().
+ else:
+ return path + '.do'
+
+ def GatherPageData(self, mr):
+ """Return a dict of page-specific ezt data."""
+ raise MethodNotSupportedError()
+
+ # pylint: disable=unused-argument
+ def GatherHelpData(self, mr, page_data):
+ """Return a dict of values to drive on-page user help.
+
+ Args:
+ mr: common information parsed from the HTTP request.
+ page_data: Dictionary of base and page template data.
+
+ Returns:
+ A dict of values to drive on-page user help, to be added to page_data.
+ """
+ help_data = {
+ 'cue': None, # for cues.ezt
+ 'account_cue': None, # for cues.ezt
+ }
+ dismissed = []
+ if mr.auth.user_pb:
+ with work_env.WorkEnv(mr, self.services) as we:
+ userprefs = we.GetUserPrefs(mr.auth.user_id)
+ dismissed = [
+ pv.name for pv in userprefs.prefs if pv.value == 'true']
+ if (mr.auth.user_pb.vacation_message and
+ 'you_are_on_vacation' not in dismissed):
+ help_data['cue'] = 'you_are_on_vacation'
+ if (mr.auth.user_pb.email_bounce_timestamp and
+ 'your_email_bounced' not in dismissed):
+ help_data['cue'] = 'your_email_bounced'
+ if mr.auth.user_pb.linked_parent_id:
+ # This one is not dismissable.
+ help_data['account_cue'] = 'switch_to_parent_account'
+ parent_email = self.services.user.LookupUserEmail(
+ mr.cnxn, mr.auth.user_pb.linked_parent_id)
+ help_data['parent_email'] = parent_email
+
+ return help_data
+
+ def GatherDebugData(self, mr, page_data):
+ """Return debugging info for display at the very bottom of the page."""
+ if mr.debug_enabled:
+ debug = [_ContextDebugCollection('Page data', page_data)]
+ return {
+ 'dbg': 'on',
+ 'debug': debug,
+ 'profiler': mr.profiler,
+ }
+ else:
+ if '?' in mr.current_page_url:
+ debug_url = mr.current_page_url + '&debug=1'
+ else:
+ debug_url = mr.current_page_url + '?debug=1'
+
+ return {
+ 'debug_uri': debug_url,
+ 'dbg': 'off',
+ 'debug': [('none', 'recorded')],
+ }
+
+ def PleaseCorrect(self, mr, **echo_data):
+ """Show the same form again so that the user can correct their input."""
+ mr.PrepareForReentry(echo_data)
+ self.get()
+
+ def _RecordVisitTime(self, mr, now=None):
+ """Record the signed in user's last visit time, if possible."""
+ now = now or int(time.time())
+ if not settings.read_only and mr.auth.user_id:
+ user_pb = mr.auth.user_pb
+ if (user_pb.last_visit_timestamp <
+ now - framework_constants.VISIT_RESOLUTION):
+ user_pb.last_visit_timestamp = now
+ self.services.user.UpdateUser(mr.cnxn, user_pb.user_id, user_pb)
+
+
+def _CalcProjectAlert(project):
+ """Return a string to be shown as red text explaning the project state."""
+
+ project_alert = None
+
+ if project.read_only_reason:
+ project_alert = 'READ-ONLY: %s.' % project.read_only_reason
+ if project.moved_to:
+ project_alert = 'This project has moved to: %s.' % project.moved_to
+ elif project.delete_time:
+ delay_seconds = project.delete_time - time.time()
+ delay_days = delay_seconds // framework_constants.SECS_PER_DAY
+ if delay_days <= 0:
+ project_alert = 'Scheduled for deletion today.'
+ else:
+ days_word = 'day' if delay_days == 1 else 'days'
+ project_alert = (
+ 'Scheduled for deletion in %d %s.' % (delay_days, days_word))
+ elif project.state == project_pb2.ProjectState.ARCHIVED:
+ project_alert = 'Project is archived: read-only by members only.'
+
+ return project_alert
+
+
+class _ContextDebugItem(object):
+ """Wrapper class to generate on-screen debugging output."""
+
+ def __init__(self, key, val):
+ """Store the key and generate a string for the value."""
+ self.key = key
+ if isinstance(val, list):
+ nested_debug_strs = [self.StringRep(v) for v in val]
+ self.val = '[%s]' % ', '.join(nested_debug_strs)
+ else:
+ self.val = self.StringRep(val)
+
+ def StringRep(self, val):
+ """Make a useful string representation of the given value."""
+ try:
+ return val.DebugString()
+ except Exception:
+ try:
+ return str(val.__dict__)
+ except Exception:
+ return repr(val)
+
+
+class _ContextDebugCollection(object):
+ """Attach a title to a dictionary for exporting as a table of debug info."""
+
+ def __init__(self, title, collection):
+ self.title = title
+ self.collection = [_ContextDebugItem(key, collection[key])
+ for key in sorted(collection.keys())]
+
+
+def _ProjectIsRestricted(mr):
+ """Return True if the mr has a 'private' project."""
+ return (mr.project and
+ mr.project.access != project_pb2.ProjectAccess.ANYONE)
+
+
+def _SafeCreateLoginURL(mr, continue_url=None):
+ """Make a login URL w/ a detailed continue URL, otherwise use a short one."""
+ continue_url = continue_url or mr.current_page_url
+ try:
+ url = users.create_login_url(continue_url)
+ except users.RedirectTooLongError:
+ if mr.project_name:
+ url = users.create_login_url('/p/%s' % mr.project_name)
+ else:
+ url = users.create_login_url('/')
+
+ # Give the user a choice of existing accounts in their session
+ # or the option to add an account, even if they are currently
+ # signed in to exactly one account.
+ if mr.auth.user_id:
+ # Notice: this makes assuptions about the output of users.create_login_url,
+ # which can change at any time. See https://crbug.com/monorail/3352.
+ url = url.replace('/ServiceLogin', '/AccountChooser', 1)
+ return url
+
+
+def _SafeCreateLogoutURL(mr):
+ """Make a logout URL w/ a detailed continue URL, otherwise use a short one."""
+ try:
+ return users.create_logout_url(mr.current_page_url)
+ except users.RedirectTooLongError:
+ if mr.project_name:
+ return users.create_logout_url('/p/%s' % mr.project_name)
+ else:
+ return users.create_logout_url('/')
+
+
+def _VersionBaseURL(request):
+ """Return a version-specific URL that we use to load static assets."""
+ if settings.local_mode:
+ version_base = '%s://%s' % (request.scheme, request.host)
+ else:
+ version_base = '%s://%s-dot-%s' % (
+ request.scheme, modules.get_current_version_name(),
+ app_identity.get_default_version_hostname())
+
+ return version_base
diff --git a/framework/servlet_helpers.py b/framework/servlet_helpers.py
new file mode 100644
index 0000000..68eb0c4
--- /dev/null
+++ b/framework/servlet_helpers.py
@@ -0,0 +1,160 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Helper functions used by the Monorail servlet base class."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import calendar
+import datetime
+import logging
+import urllib
+
+from framework import framework_bizobj
+from framework import framework_helpers
+from framework import permissions
+from framework import template_helpers
+from framework import urls
+from framework import xsrf
+
+_ZERO = datetime.timedelta(0)
+
+class _UTCTimeZone(datetime.tzinfo):
+ """UTC"""
+ def utcoffset(self, _dt):
+ return _ZERO
+ def tzname(self, _dt):
+ return "UTC"
+ def dst(self, _dt):
+ return _ZERO
+
+_UTC = _UTCTimeZone()
+
+
+def GetBannerTime(timestamp):
+ """Converts a timestamp into EZT-ready data so it can appear in the banner.
+
+ Args:
+ timestamp: timestamp expressed in the following format:
+ [year,month,day,hour,minute,second]
+ e.g. [2009,3,20,21,45,50] represents March 20 2009 9:45:50 PM
+
+ Returns:
+ EZT-ready data used to display the time inside the banner message.
+ """
+ if timestamp is None:
+ return None
+
+ ts = datetime.datetime(*timestamp, tzinfo=_UTC)
+ return calendar.timegm(ts.timetuple())
+
+
+def AssertBasePermissionForUser(user, user_view):
+ """Verify user permissions and state.
+
+ Args:
+ user: user_pb2.User protocol buffer for the user
+ user_view: framework.views.UserView for the user
+ """
+ if permissions.IsBanned(user, user_view):
+ raise permissions.BannedUserException(
+ 'You have been banned from using this site')
+
+
+def AssertBasePermission(mr):
+ """Make sure that the logged in user can view the requested page.
+
+ Args:
+ mr: common information parsed from the HTTP request.
+
+ Returns:
+ Nothing
+
+ Raises:
+ BannedUserException: If the user is banned.
+ PermissionException: If the user does not have permisssion to view.
+ """
+ AssertBasePermissionForUser(mr.auth.user_pb, mr.auth.user_view)
+
+ if mr.project_name and not CheckPerm(mr, permissions.VIEW):
+ logging.info('your perms are %r', mr.perms)
+ raise permissions.PermissionException(
+ 'User is not allowed to view this project')
+
+
+def CheckPerm(mr, perm, art=None, granted_perms=None):
+ """Convenience method that makes permission checks easier.
+
+ Args:
+ mr: common information parsed from the HTTP request.
+ perm: A permission constant, defined in module framework.permissions
+ art: Optional artifact pb
+ granted_perms: optional set of perms granted specifically in that artifact.
+
+ Returns:
+ A boolean, whether the request can be satisfied, given the permission.
+ """
+ return mr.perms.CanUsePerm(
+ perm, mr.auth.effective_ids, mr.project,
+ permissions.GetRestrictions(art), granted_perms=granted_perms)
+
+
+def CheckPermForProject(mr, perm, project, art=None):
+ """Convenience method that makes permission checks for projects easier.
+
+ Args:
+ mr: common information parsed from the HTTP request.
+ perm: A permission constant, defined in module framework.permissions
+ project: The project to enforce permissions for.
+ art: Optional artifact pb
+
+ Returns:
+ A boolean, whether the request can be satisfied, given the permission.
+ """
+ perms = permissions.GetPermissions(
+ mr.auth.user_pb, mr.auth.effective_ids, project)
+ return perms.CanUsePerm(
+ perm, mr.auth.effective_ids, project, permissions.GetRestrictions(art))
+
+
+def ComputeIssueEntryURL(mr, config):
+ """Compute the URL to use for the "New issue" subtab.
+
+ Args:
+ mr: commonly used info parsed from the request.
+ config: ProjectIssueConfig for the current project.
+
+ Returns:
+ A URL string to use. It will be simply "entry" in the non-customized
+ case. Otherewise it will be a fully qualified URL that includes some
+ query string parameters.
+ """
+ if not config.custom_issue_entry_url:
+ return '/p/%s/issues/entry' % (mr.project_name)
+
+ base_url = config.custom_issue_entry_url
+ sep = '&' if '?' in base_url else '?'
+ token = xsrf.GenerateToken(
+ mr.auth.user_id, '/p/%s%s%s' % (mr.project_name, urls.ISSUE_ENTRY, '.do'))
+ role_name = framework_helpers.GetRoleName(mr.auth.effective_ids, mr.project)
+
+ continue_url = urllib.quote(framework_helpers.FormatAbsoluteURL(
+ mr, urls.ISSUE_ENTRY + '.do'))
+
+ return '%s%stoken=%s&role=%s&continue=%s' % (
+ base_url, sep, urllib.quote(token),
+ urllib.quote(role_name or ''), continue_url)
+
+
+def IssueListURL(mr, config, query_string=None):
+ """Make an issue list URL for non-members or members."""
+ url = '/p/%s%s' % (mr.project_name, urls.ISSUE_LIST)
+ if query_string:
+ url += '?' + query_string
+ elif framework_bizobj.UserIsInProject(mr.project, mr.auth.effective_ids):
+ if config and config.member_default_query:
+ url += '?q=' + urllib.quote_plus(config.member_default_query)
+ return url
diff --git a/framework/sorting.py b/framework/sorting.py
new file mode 100644
index 0000000..558044c
--- /dev/null
+++ b/framework/sorting.py
@@ -0,0 +1,575 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Helper functions for sorting lists of project artifacts.
+
+This module exports the SortArtifacts function that does sorting of
+Monorail business objects (e.g., an issue). The sorting is done by
+extracting relevant values from the PB using a dictionary of
+accessor functions.
+
+The desired sorting directives are specified in part of the user's
+HTTP request. This sort spec consists of the names of the columns
+with optional minus signs to indicate descending sort order.
+
+The tool configuration object also affects sorting. When sorting by
+key-value labels, the well-known labels are considered to come
+before any non-well-known labels, and those well-known labels sort in
+the order in which they are defined in the tool config PB.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+from functools import total_ordering
+
+import settings
+from proto import tracker_pb2
+from services import caches
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+
+@total_ordering
+class DescendingValue(object):
+ """A wrapper which reverses the sort order of values."""
+
+ @classmethod
+ def MakeDescendingValue(cls, obj):
+ """Make a value that sorts in the reverse order as obj."""
+ if isinstance(obj, int):
+ return -obj
+ if obj == MAX_STRING:
+ return MIN_STRING
+ if obj == MIN_STRING:
+ return MAX_STRING
+ if isinstance(obj, list):
+ return [cls.MakeDescendingValue(item) for item in reversed(obj)]
+ return DescendingValue(obj)
+
+ def __init__(self, val):
+ self.val = val
+
+ def __eq__(self, other):
+ if isinstance(other, DescendingValue):
+ return self.val == other.val
+ return self.val == other
+
+ def __ne__(self, other):
+ if isinstance(other, DescendingValue):
+ return self.val != other.val
+ return self.val != other
+
+ def __lt__(self, other):
+ if isinstance(other, DescendingValue):
+ return other.val < self.val
+ return other < self.val
+
+ def __repr__(self):
+ return 'DescendingValue(%r)' % self.val
+
+
+# A string that sorts after every other string, and one that sorts before them.
+MAX_STRING = '~~~'
+MIN_STRING = DescendingValue(MAX_STRING)
+
+
+# RAMCache {issue_id: {column_name: sort_key, ...}, ...}
+art_values_cache = None
+
+
+def InitializeArtValues(services):
+ global art_values_cache
+ art_values_cache = caches.RamCache(
+ services.cache_manager, 'issue', max_size=settings.issue_cache_max_size)
+
+
+def InvalidateArtValuesKeys(cnxn, keys):
+ art_values_cache.InvalidateKeys(cnxn, keys)
+
+
+def SortArtifacts(
+ artifacts, config, accessors, postprocessors, group_by_spec, sort_spec,
+ users_by_id=None, tie_breakers=None):
+ """Return a list of artifacts sorted by the user's sort specification.
+
+ In the following, an "accessor" is a function(art) -> [field_value, ...].
+
+ Args:
+ artifacts: an unsorted list of project artifact PBs.
+ config: Project config PB instance that defines the sort order for
+ labels and statuses in this project.
+ accessors: dict {column_name: accessor} to get values from the artifacts.
+ postprocessors: dict {column_name: postprocessor} to get user emails
+ and timestamps.
+ group_by_spec: string that lists the grouping order
+ sort_spec: string that lists the sort order
+ users_by_id: optional dictionary {user_id: user_view,...} for all users
+ who participate in the list of artifacts.
+ tie_breakers: list of column names to add to the end of the sort
+ spec if they are not already somewhere in the sort spec.
+
+ Returns:
+ A sorted list of artifacts.
+
+ Note: if username_cols is supplied, then users_by_id should be too.
+
+ The approach to sorting is to construct a comprehensive sort key for
+ each artifact. To create the sort key, we (a) build lists with a
+ variable number of fields to sort on, and (b) allow individual
+ fields to be sorted in descending order. Even with the time taken
+ to build the sort keys, calling sorted() with the key seems to be
+ faster overall than doing multiple stable-sorts or doing one sort
+ using a multi-field comparison function.
+ """
+ sort_directives = ComputeSortDirectives(
+ config, group_by_spec, sort_spec, tie_breakers=tie_breakers)
+
+ # Build a list of accessors that will extract sort keys from the issues.
+ accessor_pairs = [
+ (sd, _MakeCombinedSortKeyAccessor(
+ sd, config, accessors, postprocessors, users_by_id))
+ for sd in sort_directives]
+
+ def SortKey(art):
+ """Make a sort_key for the given artifact, used by sorted() below."""
+ if art_values_cache.HasItem(art.issue_id):
+ art_values = art_values_cache.GetItem(art.issue_id)
+ else:
+ art_values = {}
+
+ sort_key = []
+ for sd, accessor in accessor_pairs:
+ if sd not in art_values:
+ art_values[sd] = accessor(art)
+ sort_key.append(art_values[sd])
+
+ art_values_cache.CacheItem(art.issue_id, art_values)
+ return sort_key
+
+ return sorted(artifacts, key=SortKey)
+
+
+def ComputeSortDirectives(config, group_by_spec, sort_spec, tie_breakers=None):
+ """Return a list with sort directives to be used in sorting.
+
+ Args:
+ config: Project config PB instance that defines the sort order for
+ labels and statuses in this project.
+ group_by_spec: string that lists the grouping order
+ sort_spec: string that lists the sort order
+ tie_breakers: list of column names to add to the end of the sort
+ spec if they are not already somewhere in the sort spec.
+
+ Returns:
+ A list of lower-case column names, each one may have a leading
+ minus-sign.
+ """
+ # Prepend the end-user's sort spec to any project default sort spec.
+ if tie_breakers is None:
+ tie_breakers = ['id']
+ sort_spec = '%s %s %s' % (
+ group_by_spec, sort_spec, config.default_sort_spec)
+ # Sort specs can have interfering sort orders, so remove any duplicates.
+ field_names = set()
+ sort_directives = []
+ for sort_directive in sort_spec.lower().split():
+ field_name = sort_directive.lstrip('-')
+ if field_name not in field_names:
+ sort_directives.append(sort_directive)
+ field_names.add(field_name)
+
+ # Add in the project name so that the overall ordering is completely
+ # defined in cross-project search. Otherwise, issues jump up and
+ # down on each reload of the same query, and prev/next links get
+ # messed up. It's a no-op in single projects.
+ if 'project' not in sort_directives:
+ sort_directives.append('project')
+
+ for tie_breaker in tie_breakers:
+ if tie_breaker not in sort_directives:
+ sort_directives.append(tie_breaker)
+
+ return sort_directives
+
+
+def _MakeCombinedSortKeyAccessor(
+ sort_directive, config, accessors, postprocessors, users_by_id):
+ """Return an accessor that extracts a sort key for a UI table column.
+
+ Args:
+ sort_directive: string with column name and optional leading minus sign,
+ for combined columns, it may have slashes, e.g., "-priority/pri".
+ config: ProjectIssueConfig instance that defines the sort order for
+ labels and statuses in this project.
+ accessors: dictionary of (column_name -> accessor) to get values
+ from the artifacts.
+ postprocessors: dict {column_name: postprocessor} to get user emails
+ and timestamps.
+ users_by_id: dictionary {user_id: user_view,...} for all users
+ who participate in the list of artifacts (e.g., owners, reporters, cc).
+
+ Returns:
+ A list of accessor functions that can be applied to an issue to extract
+ the relevant sort key value.
+
+ The strings for status and labels are converted to lower case in
+ this method so that they sort like case-insensitive enumerations.
+ Any component-specific field of the artifact is sorted according to the
+ value returned by the accessors defined in that component. Those
+ accessor functions should lower case string values for fields where
+ case-insensitive sorting is desired.
+ """
+ if sort_directive.startswith('-'):
+ combined_col_name = sort_directive[1:]
+ descending = True
+ else:
+ combined_col_name = sort_directive
+ descending = False
+
+ wk_labels = [wkl.label for wkl in config.well_known_labels]
+ accessors = [
+ _MakeSingleSortKeyAccessor(
+ col_name, config, accessors, postprocessors, users_by_id, wk_labels)
+ for col_name in combined_col_name.split('/')]
+
+ # The most common case is that we sort on a single column, like "priority".
+ if len(accessors) == 1:
+ return _MaybeMakeDescending(accessors[0], descending)
+
+ # Less commonly, we are sorting on a combined column like "priority/pri".
+ def CombinedAccessor(art):
+ """Flatten and sort the values for each column in a combined column."""
+ key_part = []
+ for single_accessor in accessors:
+ value = single_accessor(art)
+ if isinstance(value, list):
+ key_part.extend(value)
+ else:
+ key_part.append(value)
+ return sorted(key_part)
+
+ return _MaybeMakeDescending(CombinedAccessor, descending)
+
+
+def _MaybeMakeDescending(accessor, descending):
+ """If descending is True, return a new function that reverses accessor."""
+ if not descending:
+ return accessor
+
+ def DescendingAccessor(art):
+ asc_value = accessor(art)
+ return DescendingValue.MakeDescendingValue(asc_value)
+
+ return DescendingAccessor
+
+
+def _MakeSingleSortKeyAccessor(
+ col_name, config, accessors, postprocessors, users_by_id, wk_labels):
+ """Return an accessor function for a single simple UI column."""
+ # Case 1. Handle built-in fields: status, component.
+ if col_name == 'status':
+ wk_statuses = [wks.status for wks in config.well_known_statuses]
+ return _IndexOrLexical(wk_statuses, accessors[col_name])
+
+ if col_name == 'component':
+ comp_defs = sorted(config.component_defs, key=lambda cd: cd.path.lower())
+ comp_ids = [cd.component_id for cd in comp_defs]
+ return _IndexListAccessor(comp_ids, accessors[col_name])
+
+ # Case 2. Any other defined accessor functions.
+ if col_name in accessors:
+ if postprocessors and col_name in postprocessors:
+ # sort users by email address or timestamp rather than user ids.
+ return _MakeAccessorWithPostProcessor(
+ users_by_id, accessors[col_name], postprocessors[col_name])
+ else:
+ return accessors[col_name]
+
+ # Case 3. Anything else is assumed to be a label prefix or custom field.
+ return _IndexOrLexicalList(
+ wk_labels, config.field_defs, col_name, users_by_id)
+
+
+IGNORABLE_INDICATOR = -1
+
+
+def _PrecomputeSortIndexes(values, col_name):
+ """Precompute indexes of strings in the values list for fast lookup later."""
+ # Make a dictionary that immediately gives us the index of any value
+ # in the list, and also add the same values in all-lower letters. In
+ # the case where two values differ only by case, the later value wins,
+ # which is fine.
+ indexes = {}
+ if col_name:
+ prefix = col_name + '-'
+ else:
+ prefix = ''
+ for idx, val in enumerate(values):
+ if val.lower().startswith(prefix):
+ indexes[val] = idx
+ indexes[val.lower()] = idx
+ else:
+ indexes[val] = IGNORABLE_INDICATOR
+ indexes[val.lower()] = IGNORABLE_INDICATOR
+
+ return indexes
+
+
+def _MakeAccessorWithPostProcessor(users_by_id, base_accessor, postprocessor):
+ """Make an accessor that returns a list of user_view properties for sorting.
+
+ Args:
+ users_by_id: dictionary {user_id: user_view, ...} for all participants
+ in the entire list of artifacts.
+ base_accessor: an accessor function f(artifact) -> user_id.
+ postprocessor: function f(user_view) -> single sortable value.
+
+ Returns:
+ An accessor f(artifact) -> value that can be used in sorting
+ the decorated list.
+ """
+
+ def Accessor(art):
+ """Return a user edit name for the given artifact's base_accessor."""
+ id_or_id_list = base_accessor(art)
+ if isinstance(id_or_id_list, list):
+ values = [postprocessor(users_by_id[user_id])
+ for user_id in id_or_id_list]
+ else:
+ values = [postprocessor(users_by_id[id_or_id_list])]
+
+ return sorted(values) or MAX_STRING
+
+ return Accessor
+
+
+def _MakeColumnAccessor(col_name):
+ """Make an accessor for an issue's labels that have col_name as a prefix.
+
+ Args:
+ col_name: string column name.
+
+ Returns:
+ An accessor that can be applied to an artifact to return a list of
+ labels that have col_name as a prefix.
+
+ For example, _MakeColumnAccessor('priority')(issue) could result in
+ [], or ['priority-high'], or a longer list for multi-valued labels.
+ """
+ prefix = col_name + '-'
+
+ def Accessor(art):
+ """Return a list of label values on the given artifact."""
+ result = [label.lower() for label in tracker_bizobj.GetLabels(art)
+ if label.lower().startswith(prefix)]
+ return result
+
+ return Accessor
+
+
+def _IndexOrLexical(wk_values, base_accessor):
+ """Return an accessor to score an artifact based on a user-defined ordering.
+
+ Args:
+ wk_values: a list of well-known status values from the config.
+ base_accessor: function that gets a field from a given issue.
+
+ Returns:
+ An accessor that can be applied to an issue to return a suitable
+ sort key.
+
+ For example, when used to sort issue statuses, these accessors return an
+ integer for well-known statuses, a string for odd-ball statuses, and an
+ extreme value key for issues with no status. That causes issues to appear
+ in the expected order with odd-ball issues sorted lexicographically after
+ the ones with well-known status values, and issues with no defined status at
+ the very end.
+ """
+ well_known_value_indexes = _PrecomputeSortIndexes(wk_values, '')
+
+ def Accessor(art):
+ """Custom-made function to return a specific value of any issue."""
+ value = base_accessor(art)
+ if not value:
+ # Undefined values sort last.
+ return MAX_STRING
+
+ try:
+ # Well-known values sort by index. Ascending sorting has positive ints
+ # in well_known_value_indexes.
+ return well_known_value_indexes[value]
+ except KeyError:
+ # Odd-ball values after well-known and lexicographically.
+ return value.lower()
+
+ return Accessor
+
+
+def _IndexListAccessor(wk_values, base_accessor):
+ """Return an accessor to score an artifact based on a user-defined ordering.
+
+ Args:
+ wk_values: a list of well-known values from the config.
+ base_accessor: function that gets a field from a given issue.
+
+ Returns:
+ An accessor that can be applied to an issue to return a suitable
+ sort key.
+ """
+ well_known_value_indexes = {
+ val: idx for idx, val in enumerate(wk_values)}
+
+ def Accessor(art):
+ """Custom-made function to return a specific value of any issue."""
+ values = base_accessor(art)
+ if not values:
+ # Undefined values sort last.
+ return MAX_STRING
+
+ indexes = [well_known_value_indexes.get(val, MAX_STRING) for val in values]
+ return sorted(indexes)
+
+ return Accessor
+
+
+def _IndexOrLexicalList(wk_values, full_fd_list, col_name, users_by_id):
+ """Return an accessor to score an artifact based on a user-defined ordering.
+
+ Args:
+ wk_values: A list of well-known labels from the config.
+ full_fd_list: list of FieldDef PBs that belong to the config.
+ col_name: lowercase string name of the column that will be sorted on.
+ users_by_id: A dictionary {user_id: user_view}.
+
+ Returns:
+ An accessor that can be applied to an issue to return a suitable
+ sort key.
+ """
+ well_known_value_indexes = _PrecomputeSortIndexes(wk_values, col_name)
+
+ if col_name.endswith(tracker_constants.APPROVER_COL_SUFFIX):
+ # Custom field names cannot end with the APPROVER_COL_SUFFIX. So the only
+ # possible relevant values are approvers for an APPROVAL_TYPE named
+ # field_name and any values from labels with the key 'field_name-approvers'.
+ field_name = col_name[:-len(tracker_constants.APPROVER_COL_SUFFIX)]
+ approval_fds = [fd for fd in full_fd_list
+ if (fd.field_name.lower() == field_name and
+ fd.field_type == tracker_pb2.FieldTypes.APPROVAL_TYPE)]
+
+ def ApproverAccessor(art):
+ """Custom-made function to return a sort value or an issue's approvers."""
+ idx_or_lex_list = (
+ _SortableApprovalApproverValues(art, approval_fds, users_by_id) +
+ _SortableLabelValues(art, col_name, well_known_value_indexes))
+ if not idx_or_lex_list:
+ return MAX_STRING # issues with no value sort to the end of the list.
+ return sorted(idx_or_lex_list)
+
+ return ApproverAccessor
+
+ # Column name does not end with APPROVER_COL_SUFFIX, so relevant values
+ # are Approval statuses or Field Values for fields named col_name and
+ # values from labels with the key equal to col_name.
+ field_name = col_name
+ phase_name = None
+ if '.' in col_name:
+ phase_name, field_name = col_name.split('.', 1)
+
+ fd_list = [fd for fd in full_fd_list
+ if (fd.field_name.lower() == field_name and
+ fd.field_type != tracker_pb2.FieldTypes.ENUM_TYPE and
+ bool(phase_name) == fd.is_phase_field)]
+ approval_fds = []
+ if not phase_name:
+ approval_fds = [fd for fd in fd_list if
+ fd.field_type == tracker_pb2.FieldTypes.APPROVAL_TYPE]
+
+ def Accessor(art):
+ """Custom-made function to return a sort value for any issue."""
+ idx_or_lex_list = (
+ _SortableApprovalStatusValues(art, approval_fds) +
+ _SortableFieldValues(art, fd_list, users_by_id, phase_name) +
+ _SortableLabelValues(art, col_name, well_known_value_indexes))
+ if not idx_or_lex_list:
+ return MAX_STRING # issues with no value sort to the end of the list.
+ return sorted(idx_or_lex_list)
+
+ return Accessor
+
+
+def _SortableApprovalStatusValues(art, fd_list):
+ """Return a list of approval statuses relevant to one UI table column."""
+ sortable_value_list = []
+ for fd in fd_list:
+ for av in art.approval_values:
+ if av.approval_id == fd.field_id:
+ # Order approval statuses by life cycle.
+ # NOT_SET == 8 but should be before all other statuses.
+ sortable_value_list.append(
+ 0 if av.status.number == 8 else av.status.number)
+
+ return sortable_value_list
+
+
+def _SortableApprovalApproverValues(art, fd_list, users_by_id):
+ """Return a list of approval approvers relevant to one UI table column."""
+ sortable_value_list = []
+ for fd in fd_list:
+ for av in art.approval_values:
+ if av.approval_id == fd.field_id:
+ sortable_value_list.extend(
+ [users_by_id.get(approver_id).email
+ for approver_id in av.approver_ids
+ if users_by_id.get(approver_id)])
+
+ return sortable_value_list
+
+
+def _SortableFieldValues(art, fd_list, users_by_id, phase_name):
+ """Return a list of field values relevant to one UI table column."""
+ phase_id = None
+ if phase_name:
+ phase_id = next((
+ phase.phase_id for phase in art.phases
+ if phase.name.lower() == phase_name), None)
+ sortable_value_list = []
+ for fd in fd_list:
+ for fv in art.field_values:
+ if fv.field_id == fd.field_id and fv.phase_id == phase_id:
+ sortable_value_list.append(
+ tracker_bizobj.GetFieldValue(fv, users_by_id))
+
+ return sortable_value_list
+
+
+def _SortableLabelValues(art, col_name, well_known_value_indexes):
+ """Return a list of ints and strings for labels relevant to one UI column."""
+ col_name_dash = col_name + '-'
+ sortable_value_list = []
+ for label in tracker_bizobj.GetLabels(art):
+ idx_or_lex = well_known_value_indexes.get(label)
+ if idx_or_lex == IGNORABLE_INDICATOR:
+ continue # Label is known to not have the desired prefix.
+ if idx_or_lex is None:
+ if '-' not in label:
+ # Skip an irrelevant OneWord label and remember to ignore it later.
+ well_known_value_indexes[label] = IGNORABLE_INDICATOR
+ continue
+ label_lower = label.lower()
+ if label_lower.startswith(col_name_dash):
+ # Label is a key-value label with an odd-ball value, remember it
+ value = label_lower[len(col_name_dash):]
+ idx_or_lex = value
+ well_known_value_indexes[label] = value
+ else:
+ # Label was a key-value label that is not relevant to this column.
+ # Remember to ignore it later.
+ well_known_value_indexes[label] = IGNORABLE_INDICATOR
+ continue
+
+ sortable_value_list.append(idx_or_lex)
+
+ return sortable_value_list
diff --git a/framework/sql.py b/framework/sql.py
new file mode 100644
index 0000000..d99b045
--- /dev/null
+++ b/framework/sql.py
@@ -0,0 +1,1048 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of classes for interacting with tables in SQL."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import random
+import re
+import sys
+import time
+
+from six import string_types
+
+import settings
+
+if not settings.unit_test_mode:
+ import MySQLdb
+
+from framework import exceptions
+from framework import framework_helpers
+
+from infra_libs import ts_mon
+
+from Queue import Queue
+
+
+class ConnectionPool(object):
+ """Manage a set of database connections such that they may be re-used.
+ """
+
+ def __init__(self, poolsize=1):
+ self.poolsize = poolsize
+ self.queues = {}
+
+ def get(self, instance, database):
+ """Retun a database connection, or throw an exception if none can
+ be made.
+ """
+ key = instance + '/' + database
+
+ if not key in self.queues:
+ queue = Queue(self.poolsize)
+ self.queues[key] = queue
+
+ queue = self.queues[key]
+
+ if queue.empty():
+ cnxn = cnxn_ctor(instance, database)
+ else:
+ cnxn = queue.get()
+ # Make sure the connection is still good.
+ cnxn.ping()
+ cnxn.commit()
+
+ return cnxn
+
+ def release(self, cnxn):
+ if not cnxn.pool_key in self.queues:
+ raise BaseException('unknown pool key: %s' % cnxn.pool_key)
+
+ q = self.queues[cnxn.pool_key]
+ if q.full():
+ cnxn.close()
+ else:
+ q.put(cnxn)
+
+
+@framework_helpers.retry(1, delay=1, backoff=2)
+def cnxn_ctor(instance, database):
+ logging.info('About to connect to SQL instance %r db %r', instance, database)
+ if settings.unit_test_mode:
+ raise ValueError('unit tests should not need real database connections')
+ try:
+ if settings.local_mode:
+ start_time = time.time()
+ cnxn = MySQLdb.connect(
+ host='127.0.0.1', port=3306, db=database, user='root', charset='utf8')
+ else:
+ start_time = time.time()
+ cnxn = MySQLdb.connect(
+ unix_socket='/cloudsql/' + instance, db=database, user='root',
+ charset='utf8')
+ duration = int((time.time() - start_time) * 1000)
+ DB_CNXN_LATENCY.add(duration)
+ CONNECTION_COUNT.increment({'success': True})
+ except MySQLdb.OperationalError:
+ CONNECTION_COUNT.increment({'success': False})
+ raise
+ cnxn.pool_key = instance + '/' + database
+ cnxn.is_bad = False
+ return cnxn
+
+
+# One connection pool per database instance (primary, replicas are each an
+# instance). We'll have four connections per instance because we fetch
+# issue comments, stars, spam verdicts and spam verdict history in parallel
+# with promises.
+cnxn_pool = ConnectionPool(settings.db_cnxn_pool_size)
+
+# MonorailConnection maintains a dictionary of connections to SQL databases.
+# Each is identified by an int shard ID.
+# And there is one connection to the primary DB identified by key PRIMARY_CNXN.
+PRIMARY_CNXN = 'primary_cnxn'
+
+# When one replica is temporarily unresponseive, we can use a different one.
+BAD_SHARD_AVOIDANCE_SEC = 45
+
+
+CONNECTION_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/connection_count',
+ 'Count of connections made to the SQL database.',
+ [ts_mon.BooleanField('success')])
+
+DB_CNXN_LATENCY = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_cnxn_latency',
+ 'Time needed to establish a DB connection.',
+ None)
+
+DB_QUERY_LATENCY = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_query_latency',
+ 'Time needed to make a DB query.',
+ [ts_mon.StringField('type')])
+
+DB_COMMIT_LATENCY = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_commit_latency',
+ 'Time needed to make a DB commit.',
+ None)
+
+DB_ROLLBACK_LATENCY = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_rollback_latency',
+ 'Time needed to make a DB rollback.',
+ None)
+
+DB_RETRY_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/db_retry_count',
+ 'Count of queries retried.',
+ None)
+
+DB_QUERY_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/db_query_count',
+ 'Count of queries sent to the DB.',
+ [ts_mon.StringField('type')])
+
+DB_COMMIT_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/db_commit_count',
+ 'Count of commits sent to the DB.',
+ None)
+
+DB_ROLLBACK_COUNT = ts_mon.CounterMetric(
+ 'monorail/sql/db_rollback_count',
+ 'Count of rollbacks sent to the DB.',
+ None)
+
+DB_RESULT_ROWS = ts_mon.CumulativeDistributionMetric(
+ 'monorail/sql/db_result_rows',
+ 'Number of results returned by a DB query.',
+ None)
+
+
+def RandomShardID():
+ """Return a random shard ID to load balance across replicas."""
+ return random.randint(0, settings.num_logical_shards - 1)
+
+
+class MonorailConnection(object):
+ """Create and manage connections to the SQL servers.
+
+ We only store connections in the context of a single user request, not
+ across user requests. The main purpose of this class is to make using
+ sharded tables easier.
+ """
+ unavailable_shards = {} # {shard_id: timestamp of failed attempt}
+
+ def __init__(self):
+ self.sql_cnxns = {} # {PRIMARY_CNXN: cnxn, shard_id: cnxn, ...}
+
+ @framework_helpers.retry(1, delay=0.1, backoff=2)
+ def GetPrimaryConnection(self):
+ """Return a connection to the primary SQL DB."""
+ if PRIMARY_CNXN not in self.sql_cnxns:
+ self.sql_cnxns[PRIMARY_CNXN] = cnxn_pool.get(
+ settings.db_instance, settings.db_database_name)
+ logging.info(
+ 'created a primary connection %r', self.sql_cnxns[PRIMARY_CNXN])
+
+ return self.sql_cnxns[PRIMARY_CNXN]
+
+ @framework_helpers.retry(1, delay=0.1, backoff=2)
+ def GetConnectionForShard(self, shard_id):
+ """Return a connection to the DB replica that will be used for shard_id."""
+ if shard_id not in self.sql_cnxns:
+ physical_shard_id = shard_id % settings.num_logical_shards
+
+ replica_name = settings.db_replica_names[
+ physical_shard_id % len(settings.db_replica_names)]
+ shard_instance_name = (
+ settings.physical_db_name_format % replica_name)
+ self.unavailable_shards[shard_id] = int(time.time())
+ self.sql_cnxns[shard_id] = cnxn_pool.get(
+ shard_instance_name, settings.db_database_name)
+ del self.unavailable_shards[shard_id]
+ logging.info('created a replica connection for shard %d', shard_id)
+
+ return self.sql_cnxns[shard_id]
+
+ def Execute(self, stmt_str, stmt_args, shard_id=None, commit=True, retries=2):
+ """Execute the given SQL statement on one of the relevant databases."""
+ if shard_id is None:
+ # No shard was specified, so hit the primary.
+ sql_cnxn = self.GetPrimaryConnection()
+ else:
+ if shard_id in self.unavailable_shards:
+ bad_age_sec = int(time.time()) - self.unavailable_shards[shard_id]
+ if bad_age_sec < BAD_SHARD_AVOIDANCE_SEC:
+ logging.info('Avoiding bad replica %r, age %r', shard_id, bad_age_sec)
+ shard_id = (shard_id + 1) % settings.num_logical_shards
+ sql_cnxn = self.GetConnectionForShard(shard_id)
+
+ try:
+ return self._ExecuteWithSQLConnection(
+ sql_cnxn, stmt_str, stmt_args, commit=commit)
+ except MySQLdb.OperationalError as e:
+ logging.exception(e)
+ logging.info('retries: %r', retries)
+ if retries > 0:
+ DB_RETRY_COUNT.increment()
+ self.sql_cnxns = {} # Drop all old mysql connections and make new.
+ return self.Execute(
+ stmt_str, stmt_args, shard_id=shard_id, commit=commit,
+ retries=retries - 1)
+ else:
+ raise e
+
+ def _ExecuteWithSQLConnection(
+ self, sql_cnxn, stmt_str, stmt_args, commit=True):
+ """Execute a statement on the given database and return a cursor."""
+
+ start_time = time.time()
+ cursor = sql_cnxn.cursor()
+ cursor.execute('SET NAMES utf8mb4')
+ if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
+ cursor.executemany(stmt_str, stmt_args)
+ duration = (time.time() - start_time) * 1000
+ DB_QUERY_LATENCY.add(duration, {'type': 'write'})
+ DB_QUERY_COUNT.increment({'type': 'write'})
+ else:
+ cursor.execute(stmt_str, args=stmt_args)
+ duration = (time.time() - start_time) * 1000
+ DB_QUERY_LATENCY.add(duration, {'type': 'read'})
+ DB_QUERY_COUNT.increment({'type': 'read'})
+ DB_RESULT_ROWS.add(cursor.rowcount)
+
+ if stmt_str.startswith('INSERT') or stmt_str.startswith('REPLACE'):
+ formatted_statement = '%s %s' % (stmt_str, stmt_args)
+ else:
+ formatted_statement = stmt_str % tuple(stmt_args)
+ logging.info(
+ '%d rows in %d ms: %s', cursor.rowcount, int(duration),
+ formatted_statement.replace('\n', ' '))
+
+ if commit and not stmt_str.startswith('SELECT'):
+ try:
+ sql_cnxn.commit()
+ duration = (time.time() - start_time) * 1000
+ DB_COMMIT_LATENCY.add(duration)
+ DB_COMMIT_COUNT.increment()
+ except MySQLdb.DatabaseError:
+ sql_cnxn.rollback()
+ duration = (time.time() - start_time) * 1000
+ DB_ROLLBACK_LATENCY.add(duration)
+ DB_ROLLBACK_COUNT.increment()
+
+ return cursor
+
+ def Commit(self):
+ """Explicitly commit any pending txns. Normally done automatically."""
+ sql_cnxn = self.GetPrimaryConnection()
+ try:
+ sql_cnxn.commit()
+ except MySQLdb.DatabaseError:
+ logging.exception('Commit failed for cnxn, rolling back')
+ sql_cnxn.rollback()
+
+ def Close(self):
+ """Safely close any connections that are still open."""
+ for sql_cnxn in self.sql_cnxns.values():
+ try:
+ sql_cnxn.rollback() # Abandon any uncommitted changes.
+ cnxn_pool.release(sql_cnxn)
+ except MySQLdb.DatabaseError:
+ # This might happen if the cnxn is somehow already closed.
+ logging.exception('ProgrammingError when trying to close cnxn')
+
+
+class SQLTableManager(object):
+ """Helper class to make it easier to deal with an SQL table."""
+
+ def __init__(self, table_name):
+ self.table_name = table_name
+
+ def Select(
+ self, cnxn, distinct=False, cols=None, left_joins=None,
+ joins=None, where=None, or_where_conds=False, group_by=None,
+ order_by=None, limit=None, offset=None, shard_id=None, use_clause=None,
+ having=None, **kwargs):
+ """Compose and execute an SQL SELECT statement on this table.
+
+ Args:
+ cnxn: MonorailConnection to the databases.
+ distinct: If True, add DISTINCT keyword.
+ cols: List of columns to retrieve, defaults to '*'.
+ left_joins: List of LEFT JOIN (str, args) pairs.
+ joins: List of regular JOIN (str, args) pairs.
+ where: List of (str, args) for WHERE clause.
+ or_where_conds: Set to True to use OR in the WHERE conds.
+ group_by: List of strings for GROUP BY clause.
+ order_by: List of (str, args) for ORDER BY clause.
+ limit: Optional LIMIT on the number of rows returned.
+ offset: Optional OFFSET when using LIMIT.
+ shard_id: Int ID of the shard to query.
+ use_clause: Optional string USE clause to tell the DB which index to use.
+ having: List of (str, args) for Optional HAVING clause
+ **kwargs: WHERE-clause equality and set-membership conditions.
+
+ Keyword args are used to build up more WHERE conditions that compare
+ column values to constants. Key word Argument foo='bar' translates to 'foo
+ = "bar"', and foo=[3, 4, 5] translates to 'foo IN (3, 4, 5)'.
+
+ Returns:
+ A list of rows, each row is a tuple of values for the requested cols.
+ """
+ cols = cols or ['*'] # If columns not specified, retrieve all columns.
+ stmt = Statement.MakeSelect(
+ self.table_name, cols, distinct=distinct,
+ or_where_conds=or_where_conds)
+ if use_clause:
+ stmt.AddUseClause(use_clause)
+ if having:
+ stmt.AddHavingTerms(having)
+ stmt.AddJoinClauses(left_joins or [], left=True)
+ stmt.AddJoinClauses(joins or [])
+ stmt.AddWhereTerms(where or [], **kwargs)
+ stmt.AddGroupByTerms(group_by or [])
+ stmt.AddOrderByTerms(order_by or [])
+ stmt.SetLimitAndOffset(limit, offset)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args, shard_id=shard_id)
+ rows = cursor.fetchall()
+ cursor.close()
+ return rows
+
+ def SelectRow(
+ self, cnxn, cols=None, default=None, where=None, **kwargs):
+ """Run a query that is expected to return just one row."""
+ rows = self.Select(cnxn, distinct=True, cols=cols, where=where, **kwargs)
+ if len(rows) == 1:
+ return rows[0]
+ elif not rows:
+ logging.info('SelectRow got 0 results, so using default %r', default)
+ return default
+ else:
+ raise ValueError('SelectRow got %d results, expected only 1', len(rows))
+
+ def SelectValue(self, cnxn, col, default=None, where=None, **kwargs):
+ """Run a query that is expected to return just one row w/ one value."""
+ row = self.SelectRow(
+ cnxn, cols=[col], default=[default], where=where, **kwargs)
+ return row[0]
+
+ def InsertRows(
+ self, cnxn, cols, row_values, replace=False, ignore=False,
+ commit=True, return_generated_ids=False):
+ """Insert all the given rows.
+
+ Args:
+ cnxn: MonorailConnection object.
+ cols: List of column names to set.
+ row_values: List of lists with values to store. The length of each
+ nested list should be equal to len(cols).
+ replace: Set to True if inserted values should replace existing DB rows
+ that have the same DB keys.
+ ignore: Set to True to ignore rows that would duplicate existing DB keys.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ return_generated_ids: Set to True to return a list of generated
+ autoincrement IDs for inserted rows. This requires us to insert rows
+ one at a time.
+
+ Returns:
+ If return_generated_ids is set to True, this method returns a list of the
+ auto-increment IDs generated by the DB. Otherwise, [] is returned.
+ """
+ if not row_values:
+ return None # Nothing to insert
+
+ generated_ids = []
+ if return_generated_ids:
+ # We must insert the rows one-at-a-time to know the generated IDs.
+ for row_value in row_values:
+ stmt = Statement.MakeInsert(
+ self.table_name, cols, [row_value], replace=replace, ignore=ignore)
+ stmt_str, stmt_args = stmt.Generate()
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ if cursor.lastrowid:
+ generated_ids.append(cursor.lastrowid)
+ cursor.close()
+ return generated_ids
+
+ stmt = Statement.MakeInsert(
+ self.table_name, cols, row_values, replace=replace, ignore=ignore)
+ stmt_str, stmt_args = stmt.Generate()
+ cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ return []
+
+
+ def InsertRow(
+ self, cnxn, replace=False, ignore=False, commit=True, **kwargs):
+ """Insert a single row into the table.
+
+ Args:
+ cnxn: MonorailConnection object.
+ replace: Set to True if inserted values should replace existing DB rows
+ that have the same DB keys.
+ ignore: Set to True to ignore rows that would duplicate existing DB keys.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ **kwargs: column=value assignments to specify what to store in the DB.
+
+ Returns:
+ The generated autoincrement ID of the key column if one was generated.
+ Otherwise, return None.
+ """
+ cols = sorted(kwargs.keys())
+ row = tuple(kwargs[col] for col in cols)
+ generated_ids = self.InsertRows(
+ cnxn, cols, [row], replace=replace, ignore=ignore,
+ commit=commit, return_generated_ids=True)
+ if generated_ids:
+ return generated_ids[0]
+ else:
+ return None
+
+ def Update(self, cnxn, delta, where=None, commit=True, limit=None, **kwargs):
+ """Update one or more rows.
+
+ Args:
+ cnxn: MonorailConnection object.
+ delta: Dictionary of {column: new_value} assignments.
+ where: Optional list of WHERE conditions saying which rows to update.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ limit: Optional LIMIT on the number of rows updated.
+ **kwargs: WHERE-clause equality and set-membership conditions.
+
+ Returns:
+ Int number of rows updated.
+ """
+ if not delta:
+ return 0 # Nothing is being changed
+
+ stmt = Statement.MakeUpdate(self.table_name, delta)
+ stmt.AddWhereTerms(where, **kwargs)
+ stmt.SetLimitAndOffset(limit, None)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ result = cursor.rowcount
+ cursor.close()
+ return result
+
+ def IncrementCounterValue(self, cnxn, col_name, where=None, **kwargs):
+ """Atomically increment a counter stored in MySQL, return new value.
+
+ Args:
+ cnxn: MonorailConnection object.
+ col_name: int column to increment.
+ where: Optional list of WHERE conditions saying which rows to update.
+ **kwargs: WHERE-clause equality and set-membership conditions. The
+ where and kwargs together should narrow the update down to exactly
+ one row.
+
+ Returns:
+ The new, post-increment value of the counter.
+ """
+ stmt = Statement.MakeIncrement(self.table_name, col_name)
+ stmt.AddWhereTerms(where, **kwargs)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args)
+ assert cursor.rowcount == 1, (
+ 'missing or ambiguous counter: %r' % cursor.rowcount)
+ result = cursor.lastrowid
+ cursor.close()
+ return result
+
+ def Delete(self, cnxn, where=None, or_where_conds=False, commit=True,
+ limit=None, **kwargs):
+ """Delete the specified table rows.
+
+ Args:
+ cnxn: MonorailConnection object.
+ where: Optional list of WHERE conditions saying which rows to update.
+ or_where_conds: Set to True to use OR in the WHERE conds.
+ commit: Set to False if this operation is part of a series of operations
+ that should not be committed until the final one is done.
+ limit: Optional LIMIT on the number of rows deleted.
+ **kwargs: WHERE-clause equality and set-membership conditions.
+
+ Returns:
+ Int number of rows updated.
+ """
+ # Deleting the whole table is never intended in Monorail.
+ assert where or kwargs
+
+ stmt = Statement.MakeDelete(self.table_name, or_where_conds=or_where_conds)
+ stmt.AddWhereTerms(where, **kwargs)
+ stmt.SetLimitAndOffset(limit, None)
+ stmt_str, stmt_args = stmt.Generate()
+
+ cursor = cnxn.Execute(stmt_str, stmt_args, commit=commit)
+ result = cursor.rowcount
+ cursor.close()
+ return result
+
+
+class Statement(object):
+ """A class to help build complex SQL statements w/ full escaping.
+
+ Start with a Make*() method, then fill in additional clauses as needed,
+ then call Generate() to return the SQL string and argument list. We pass
+ the string and args to MySQLdb separately so that it can do escaping on
+ the arg values as appropriate to prevent SQL-injection attacks.
+
+ The only values that are not escaped by MySQLdb are the table names
+ and column names, and bits of SQL syntax, all of which is hard-coded
+ in our application.
+ """
+
+ @classmethod
+ def MakeSelect(cls, table_name, cols, distinct=False, or_where_conds=False):
+ """Construct a SELECT statement."""
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in cols)
+ main_clause = 'SELECT%s %s FROM %s' % (
+ (' DISTINCT' if distinct else ''), ', '.join(cols), table_name)
+ return cls(main_clause, or_where_conds=or_where_conds)
+
+ @classmethod
+ def MakeInsert(
+ cls, table_name, cols, new_values, replace=False, ignore=False):
+ """Construct an INSERT statement."""
+ if replace == True:
+ return cls.MakeReplace(table_name, cols, new_values, ignore)
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in cols)
+ ignore_word = ' IGNORE' if ignore else ''
+ main_clause = 'INSERT%s INTO %s (%s)' % (
+ ignore_word, table_name, ', '.join(cols))
+ return cls(main_clause, insert_args=new_values)
+
+ @classmethod
+ def MakeReplace(
+ cls, table_name, cols, new_values, ignore=False):
+ """Construct an INSERT...ON DUPLICATE KEY UPDATE... statement.
+
+ Uses the INSERT/UPDATE syntax because REPLACE is literally a DELETE
+ followed by an INSERT, which doesn't play well with foreign keys.
+ INSERT/UPDATE is an atomic check of whether the primary key exists,
+ followed by an INSERT if it doesn't or an UPDATE if it does.
+ """
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in cols)
+ ignore_word = ' IGNORE' if ignore else ''
+ main_clause = 'INSERT%s INTO %s (%s)' % (
+ ignore_word, table_name, ', '.join(cols))
+ return cls(main_clause, insert_args=new_values, duplicate_update_cols=cols)
+
+ @classmethod
+ def MakeUpdate(cls, table_name, delta):
+ """Construct an UPDATE statement."""
+ assert _IsValidTableName(table_name)
+ assert all(_IsValidColumnName(col) for col in delta.keys())
+ update_strs = []
+ update_args = []
+ for col, val in delta.items():
+ update_strs.append(col + '=%s')
+ update_args.append(val)
+
+ main_clause = 'UPDATE %s SET %s' % (
+ table_name, ', '.join(update_strs))
+ return cls(main_clause, update_args=update_args)
+
+ @classmethod
+ def MakeIncrement(cls, table_name, col_name, step=1):
+ """Construct an UPDATE statement that increments and returns a counter."""
+ assert _IsValidTableName(table_name)
+ assert _IsValidColumnName(col_name)
+
+ main_clause = (
+ 'UPDATE %s SET %s = LAST_INSERT_ID(%s + %%s)' % (
+ table_name, col_name, col_name))
+ update_args = [step]
+ return cls(main_clause, update_args=update_args)
+
+ @classmethod
+ def MakeDelete(cls, table_name, or_where_conds=False):
+ """Construct a DELETE statement."""
+ assert _IsValidTableName(table_name)
+ main_clause = 'DELETE FROM %s' % table_name
+ return cls(main_clause, or_where_conds=or_where_conds)
+
+ def __init__(
+ self, main_clause, insert_args=None, update_args=None,
+ duplicate_update_cols=None, or_where_conds=False):
+ self.main_clause = main_clause # E.g., SELECT or DELETE
+ self.or_where_conds = or_where_conds
+ self.insert_args = insert_args or [] # For INSERT statements
+ for row_value in self.insert_args:
+ if not all(_IsValidDBValue(val) for val in row_value):
+ raise exceptions.InputException('Invalid DB value %r' % (row_value,))
+ self.update_args = update_args or [] # For UPDATEs
+ for val in self.update_args:
+ if not _IsValidDBValue(val):
+ raise exceptions.InputException('Invalid DB value %r' % val)
+ self.duplicate_update_cols = duplicate_update_cols or [] # For REPLACE-ish
+
+ self.use_clauses = []
+ self.join_clauses, self.join_args = [], []
+ self.where_conds, self.where_args = [], []
+ self.having_conds, self.having_args = [], []
+ self.group_by_terms, self.group_by_args = [], []
+ self.order_by_terms, self.order_by_args = [], []
+ self.limit, self.offset = None, None
+
+ def Generate(self):
+ """Return an SQL string having %s placeholders and args to fill them in."""
+ clauses = [self.main_clause] + self.use_clauses + self.join_clauses
+ if self.where_conds:
+ if self.or_where_conds:
+ clauses.append('WHERE ' + '\n OR '.join(self.where_conds))
+ else:
+ clauses.append('WHERE ' + '\n AND '.join(self.where_conds))
+ if self.group_by_terms:
+ clauses.append('GROUP BY ' + ', '.join(self.group_by_terms))
+ if self.having_conds:
+ assert self.group_by_terms
+ clauses.append('HAVING %s' % ','.join(self.having_conds))
+ if self.order_by_terms:
+ clauses.append('ORDER BY ' + ', '.join(self.order_by_terms))
+
+ if self.limit and self.offset:
+ clauses.append('LIMIT %d OFFSET %d' % (self.limit, self.offset))
+ elif self.limit:
+ clauses.append('LIMIT %d' % self.limit)
+ elif self.offset:
+ clauses.append('LIMIT %d OFFSET %d' % (sys.maxint, self.offset))
+
+ if self.insert_args:
+ clauses.append('VALUES (' + PlaceHolders(self.insert_args[0]) + ')')
+ args = self.insert_args
+ if self.duplicate_update_cols:
+ clauses.append('ON DUPLICATE KEY UPDATE %s' % (
+ ', '.join(['%s=VALUES(%s)' % (col, col)
+ for col in self.duplicate_update_cols])))
+ assert not (self.join_args + self.update_args + self.where_args +
+ self.group_by_args + self.order_by_args + self.having_args)
+ else:
+ args = (self.join_args + self.update_args + self.where_args +
+ self.group_by_args + self.having_args + self.order_by_args)
+ assert not (self.insert_args + self.duplicate_update_cols)
+
+ args = _BoolsToInts(args)
+ stmt_str = '\n'.join(clause for clause in clauses if clause)
+
+ assert _IsValidStatement(stmt_str), stmt_str
+ return stmt_str, args
+
+ def AddUseClause(self, use_clause):
+ """Add a USE clause (giving the DB a hint about which indexes to use)."""
+ assert _IsValidUseClause(use_clause), use_clause
+ self.use_clauses.append(use_clause)
+
+ def AddJoinClauses(self, join_pairs, left=False):
+ """Save JOIN clauses based on the given list of join conditions."""
+ for join, args in join_pairs:
+ assert _IsValidJoin(join), join
+ assert join.count('%s') == len(args), join
+ self.join_clauses.append(
+ ' %sJOIN %s' % (('LEFT ' if left else ''), join))
+ self.join_args.extend(args)
+
+ def AddGroupByTerms(self, group_by_term_list):
+ """Save info needed to generate the GROUP BY clause."""
+ assert all(_IsValidGroupByTerm(term) for term in group_by_term_list)
+ self.group_by_terms.extend(group_by_term_list)
+
+ def AddOrderByTerms(self, order_by_pairs):
+ """Save info needed to generate the ORDER BY clause."""
+ for term, args in order_by_pairs:
+ assert _IsValidOrderByTerm(term), term
+ assert term.count('%s') == len(args), term
+ self.order_by_terms.append(term)
+ self.order_by_args.extend(args)
+
+ def SetLimitAndOffset(self, limit, offset):
+ """Save info needed to generate the LIMIT OFFSET clause."""
+ self.limit = limit
+ self.offset = offset
+
+ def AddWhereTerms(self, where_cond_pairs, **kwargs):
+ """Generate a WHERE clause."""
+ where_cond_pairs = where_cond_pairs or []
+
+ for cond, args in where_cond_pairs:
+ assert _IsValidWhereCond(cond), cond
+ assert cond.count('%s') == len(args), cond
+ self.where_conds.append(cond)
+ self.where_args.extend(args)
+
+ for col, val in sorted(kwargs.items()):
+ assert _IsValidColumnName(col), col
+ eq = True
+ if col.endswith('_not'):
+ col = col[:-4]
+ eq = False
+
+ if isinstance(val, set):
+ val = list(val) # MySQL inteface cannot handle sets.
+
+ if val is None or val == []:
+ if val == [] and self.main_clause and self.main_clause.startswith(
+ 'UPDATE'):
+ # https://crbug.com/monorail/6735: Avoid empty arrays for UPDATE.
+ raise exceptions.InputException('Invalid update DB value %r' % col)
+ op = 'IS' if eq else 'IS NOT'
+ self.where_conds.append(col + ' ' + op + ' NULL')
+ elif isinstance(val, list):
+ op = 'IN' if eq else 'NOT IN'
+ # Sadly, MySQLdb cannot escape lists, so we flatten to multiple "%s"s
+ self.where_conds.append(
+ col + ' ' + op + ' (' + PlaceHolders(val) + ')')
+ self.where_args.extend(val)
+ else:
+ op = '=' if eq else '!='
+ self.where_conds.append(col + ' ' + op + ' %s')
+ self.where_args.append(val)
+
+ def AddHavingTerms(self, having_cond_pairs):
+ """Generate a HAVING clause."""
+ for cond, args in having_cond_pairs:
+ assert _IsValidHavingCond(cond), cond
+ assert cond.count('%s') == len(args), cond
+ self.having_conds.append(cond)
+ self.having_args.extend(args)
+
+
+def PlaceHolders(sql_args):
+ """Return a comma-separated list of %s placeholders for the given args."""
+ return ','.join('%s' for _ in sql_args)
+
+
+TABLE_PAT = '[A-Z][_a-zA-Z0-9]+'
+COLUMN_PAT = '[a-z][_a-z]+'
+COMPARE_OP_PAT = '(<|>|=|!=|>=|<=|LIKE|NOT LIKE)'
+SHORTHAND = {
+ 'table': TABLE_PAT,
+ 'column': COLUMN_PAT,
+ 'tab_col': r'(%s\.)?%s' % (TABLE_PAT, COLUMN_PAT),
+ 'placeholder': '%s', # That's a literal %s that gets passed to MySQLdb
+ 'multi_placeholder': '%s(, ?%s)*',
+ 'compare_op': COMPARE_OP_PAT,
+ 'opt_asc_desc': '( ASC| DESC)?',
+ 'opt_alias': '( AS %s)?' % TABLE_PAT,
+ 'email_cond': (r'\(?'
+ r'('
+ r'(LOWER\(Spare\d+\.email\) IS NULL OR )?'
+ r'LOWER\(Spare\d+\.email\) '
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))'
+ r'( (AND|OR) )?'
+ r')+'
+ r'\)?' % COMPARE_OP_PAT),
+ 'hotlist_cond': (r'\(?'
+ r'('
+ r'(LOWER\(Cond\d+\.name\) IS NULL OR )?'
+ r'LOWER\(Cond\d+\.name\) '
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))'
+ r'( (AND|OR) )?'
+ r')+'
+ r'\)?' % COMPARE_OP_PAT),
+ 'phase_cond': (r'\(?'
+ r'('
+ r'(LOWER\(Phase\d+\.name\) IS NULL OR )?'
+ r'LOWER\(Phase\d+\.name\) '
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))?'
+ r'( (AND|OR) )?'
+ r')+'
+ r'\)?' % COMPARE_OP_PAT),
+ 'approval_cond': (r'\(?'
+ r'('
+ r'(LOWER\(Cond\d+\.status\) IS NULL OR )?'
+ r'LOWER\(Cond\d+\.status\) '
+ r'(%s %%s|IN \(%%s(, ?%%s)*\))'
+ r'( (AND|OR) )?'
+ r')+'
+ r'\)?' % COMPARE_OP_PAT),
+ }
+
+
+def _MakeRE(regex_str):
+ """Return a regular expression object, expanding our shorthand as needed."""
+ return re.compile(regex_str.format(**SHORTHAND))
+
+
+TABLE_RE = _MakeRE('^{table}$')
+TAB_COL_RE = _MakeRE('^{tab_col}$')
+USE_CLAUSE_RE = _MakeRE(
+ r'^USE INDEX \({column}\) USE INDEX FOR ORDER BY \({column}\)$')
+HAVING_RE_LIST = [
+ _MakeRE(r'^COUNT\(\*\) {compare_op} {placeholder}$')]
+COLUMN_RE_LIST = [
+ TAB_COL_RE,
+ _MakeRE(r'\*'),
+ _MakeRE(r'COUNT\(\*\)'),
+ _MakeRE(r'COUNT\({tab_col}\)'),
+ _MakeRE(r'COUNT\(DISTINCT\({tab_col}\)\)'),
+ _MakeRE(r'MAX\({tab_col}\)'),
+ _MakeRE(r'MIN\({tab_col}\)'),
+ _MakeRE(r'GROUP_CONCAT\((DISTINCT )?{tab_col}( ORDER BY {tab_col})?' \
+ r'( SEPARATOR \'.*\')?\)'),
+ ]
+JOIN_RE_LIST = [
+ TABLE_RE,
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} IN \({multi_placeholder}\))?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r'( AND {tab_col} IN \({multi_placeholder}\))?'
+ r'( AND {tab_col} = {tab_col})?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r'( AND {tab_col} IN \({multi_placeholder}\))?'
+ r'( AND {tab_col} IS NULL)?'
+ r'( AND \({tab_col} IS NULL'
+ r' OR {tab_col} NOT IN \({multi_placeholder}\)\))?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r' AND \(?{tab_col} {compare_op} {placeholder}\)?'
+ r'( AND {tab_col} = {tab_col})?$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r' AND {tab_col} = {tab_col}$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r'( AND {tab_col} = {tab_col})?'
+ r'( AND {tab_col} = {placeholder})?'
+ r' AND \({tab_col} IS NULL OR'
+ r' {tab_col} {compare_op} {placeholder}\)$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r' AND \({tab_col} IS NOT NULL AND {tab_col} != {placeholder}\)'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col}'
+ r' AND LOWER\({tab_col}\) = LOWER\({placeholder}\)'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {tab_col} = {tab_col} AND {email_cond}$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON {email_cond}$'),
+ _MakeRE(
+ r'^{table}{opt_alias} ON '
+ r'\({tab_col} = {tab_col} OR {tab_col} = {tab_col}\)$'),
+ _MakeRE(
+ r'^\({table} AS {table} JOIN User AS {table} '
+ r'ON {tab_col} = {tab_col} AND {email_cond}\) '
+ r'ON Issue(Snapshot)?.id = {tab_col}'
+ r'( AND {tab_col} IS NULL)?'),
+ _MakeRE(
+ r'^\({table} JOIN Hotlist AS {table} '
+ r'ON {tab_col} = {tab_col} AND {hotlist_cond}\) '
+ r'ON Issue.id = {tab_col}?'),
+ _MakeRE(
+ r'^\({table} AS {table} JOIN IssuePhaseDef AS {table} '
+ r'ON {tab_col} = {tab_col} AND {phase_cond}\) '
+ r'ON Issue.id = {tab_col}?'),
+ _MakeRE(
+ r'^IssuePhaseDef AS {table} ON {phase_cond}'),
+ _MakeRE(
+ r'^Issue2ApprovalValue AS {table} ON {tab_col} = {tab_col} '
+ r'AND {tab_col} = {placeholder} AND {approval_cond}'),
+ _MakeRE(
+ r'^{table} AS {table} ON {tab_col} = {tab_col} '
+ r'LEFT JOIN {table} AS {table} ON {tab_col} = {tab_col}'),
+ ]
+ORDER_BY_RE_LIST = [
+ _MakeRE(r'^{tab_col}{opt_asc_desc}$'),
+ _MakeRE(r'^LOWER\({tab_col}\){opt_asc_desc}$'),
+ _MakeRE(r'^ISNULL\({tab_col}\){opt_asc_desc}$'),
+ _MakeRE(r'^\(ISNULL\({tab_col}\) AND ISNULL\({tab_col}\)\){opt_asc_desc}$'),
+ _MakeRE(r'^FIELD\({tab_col}, {multi_placeholder}\){opt_asc_desc}$'),
+ _MakeRE(r'^FIELD\(IF\(ISNULL\({tab_col}\), {tab_col}, {tab_col}\), '
+ r'{multi_placeholder}\){opt_asc_desc}$'),
+ _MakeRE(r'^CONCAT\({tab_col}, {tab_col}\){opt_asc_desc}$'),
+ ]
+GROUP_BY_RE_LIST = [
+ TAB_COL_RE,
+ ]
+WHERE_COND_RE_LIST = [
+ _MakeRE(r'^TRUE$'),
+ _MakeRE(r'^FALSE$'),
+ _MakeRE(r'^{tab_col} IS NULL$'),
+ _MakeRE(r'^{tab_col} IS NOT NULL$'),
+ _MakeRE(r'^{tab_col} {compare_op} {tab_col}$'),
+ _MakeRE(r'^{tab_col} {compare_op} {placeholder}$'),
+ _MakeRE(r'^{tab_col} %% {placeholder} = {placeholder}$'),
+ _MakeRE(r'^{tab_col} IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^{tab_col} NOT IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^LOWER\({tab_col}\) IS NULL$'),
+ _MakeRE(r'^LOWER\({tab_col}\) IS NOT NULL$'),
+ _MakeRE(r'^LOWER\({tab_col}\) {compare_op} {placeholder}$'),
+ _MakeRE(r'^LOWER\({tab_col}\) IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^LOWER\({tab_col}\) NOT IN \({multi_placeholder}\)$'),
+ _MakeRE(r'^LOWER\({tab_col}\) LIKE {placeholder}$'),
+ _MakeRE(r'^LOWER\({tab_col}\) NOT LIKE {placeholder}$'),
+ _MakeRE(r'^timestep < \(SELECT MAX\(j.timestep\) FROM Invalidate AS j '
+ r'WHERE j.kind = %s '
+ r'AND j.cache_key = Invalidate.cache_key\)$'),
+ _MakeRE(r'^\({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}\) '
+ 'AND \({tab_col} IS NULL OR {tab_col} {compare_op} {placeholder}'
+ '\)$'),
+ _MakeRE(r'^\({tab_col} IS NOT NULL AND {tab_col} {compare_op} '
+ '{placeholder}\) OR \({tab_col} IS NOT NULL AND {tab_col} '
+ '{compare_op} {placeholder}\)$'),
+ ]
+
+# Note: We never use ';' for multiple statements, '@' for SQL variables, or
+# any quoted strings in stmt_str (quotes are put in my MySQLdb for args).
+STMT_STR_RE = re.compile(
+ r'\A(SELECT|UPDATE|DELETE|INSERT|REPLACE) [\'-+=!<>%*.,()\w\s]+\Z',
+ re.MULTILINE)
+
+
+def _IsValidDBValue(val):
+ if isinstance(val, string_types):
+ return '\x00' not in val
+ return True
+
+
+def _IsValidTableName(table_name):
+ return TABLE_RE.match(table_name)
+
+
+def _IsValidColumnName(column_expr):
+ return any(regex.match(column_expr) for regex in COLUMN_RE_LIST)
+
+
+def _IsValidUseClause(use_clause):
+ return USE_CLAUSE_RE.match(use_clause)
+
+def _IsValidHavingCond(cond):
+ if cond.startswith('(') and cond.endswith(')'):
+ cond = cond[1:-1]
+
+ if ' OR ' in cond:
+ return all(_IsValidHavingCond(c) for c in cond.split(' OR '))
+
+ if ' AND ' in cond:
+ return all(_IsValidHavingCond(c) for c in cond.split(' AND '))
+
+ return any(regex.match(cond) for regex in HAVING_RE_LIST)
+
+
+def _IsValidJoin(join):
+ return any(regex.match(join) for regex in JOIN_RE_LIST)
+
+
+def _IsValidOrderByTerm(term):
+ return any(regex.match(term) for regex in ORDER_BY_RE_LIST)
+
+
+def _IsValidGroupByTerm(term):
+ return any(regex.match(term) for regex in GROUP_BY_RE_LIST)
+
+
+def _IsValidWhereCond(cond):
+ if cond.startswith('NOT '):
+ cond = cond[4:]
+ if cond.startswith('(') and cond.endswith(')'):
+ cond = cond[1:-1]
+
+ if any(regex.match(cond) for regex in WHERE_COND_RE_LIST):
+ return True
+
+ if ' OR ' in cond:
+ return all(_IsValidWhereCond(c) for c in cond.split(' OR '))
+
+ if ' AND ' in cond:
+ return all(_IsValidWhereCond(c) for c in cond.split(' AND '))
+
+ return False
+
+
+def _IsValidStatement(stmt_str):
+ """Final check to make sure there is no funny junk sneaking in somehow."""
+ return (STMT_STR_RE.match(stmt_str) and
+ '--' not in stmt_str)
+
+
+def _BoolsToInts(arg_list):
+ """Convert any True values to 1s and Falses to 0s.
+
+ Google's copy of MySQLdb has bool-to-int conversion disabled,
+ and yet it seems to be needed otherwise they are converted
+ to strings and always interpreted as 0 (which is FALSE).
+
+ Args:
+ arg_list: (nested) list of SQL statment argument values, which may
+ include some boolean values.
+
+ Returns:
+ The same list, but with True replaced by 1 and False replaced by 0.
+ """
+ result = []
+ for arg in arg_list:
+ if isinstance(arg, (list, tuple)):
+ result.append(_BoolsToInts(arg))
+ elif arg is True:
+ result.append(1)
+ elif arg is False:
+ result.append(0)
+ else:
+ result.append(arg)
+
+ return result
diff --git a/framework/table_view_helpers.py b/framework/table_view_helpers.py
new file mode 100644
index 0000000..3fa07c2
--- /dev/null
+++ b/framework/table_view_helpers.py
@@ -0,0 +1,793 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Classes and functions for displaying lists of project artifacts.
+
+This file exports classes TableRow and TableCell that help
+represent HTML table rows and cells. These classes make rendering
+HTML tables that list project artifacts much easier to do with EZT.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import itertools
+import logging
+
+from functools import total_ordering
+
+import ezt
+
+from framework import framework_constants
+from framework import template_helpers
+from framework import timestr
+from proto import tracker_pb2
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+
+def ComputeUnshownColumns(results, shown_columns, config, built_in_cols):
+ """Return a list of unshown columns that the user could add.
+
+ Args:
+ results: list of search result PBs. Each must have labels.
+ shown_columns: list of column names to be used in results table.
+ config: harmonized config for the issue search, including all
+ well known labels and custom fields.
+ built_in_cols: list of other column names that are built into the tool.
+ E.g., star count, or creation date.
+
+ Returns:
+ List of column names to append to the "..." menu.
+ """
+ unshown_set = set() # lowercases column names
+ unshown_list = [] # original-case column names
+ shown_set = {col.lower() for col in shown_columns}
+ labels_already_seen = set() # whole labels, original case
+
+ def _MaybeAddLabel(label_name):
+ """Add the key part of the given label if needed."""
+ if label_name.lower() in labels_already_seen:
+ return
+ labels_already_seen.add(label_name.lower())
+ if '-' in label_name:
+ col, _value = label_name.split('-', 1)
+ _MaybeAddCol(col)
+
+ def _MaybeAddCol(col):
+ if col.lower() not in shown_set and col.lower() not in unshown_set:
+ unshown_list.append(col)
+ unshown_set.add(col.lower())
+
+ # The user can always add any of the default columns.
+ for col in config.default_col_spec.split():
+ _MaybeAddCol(col)
+
+ # The user can always add any of the built-in columns.
+ for col in built_in_cols:
+ _MaybeAddCol(col)
+
+ # The user can add a column for any well-known labels
+ for wkl in config.well_known_labels:
+ _MaybeAddLabel(wkl.label)
+
+ phase_names = set(itertools.chain.from_iterable(
+ (phase.name.lower() for phase in result.phases) for result in results))
+ # The user can add a column for any custom field
+ field_ids_alread_seen = set()
+ for fd in config.field_defs:
+ field_lower = fd.field_name.lower()
+ field_ids_alread_seen.add(fd.field_id)
+ if fd.is_phase_field:
+ for name in phase_names:
+ phase_field_col = name + '.' + field_lower
+ if (phase_field_col not in shown_set and
+ phase_field_col not in unshown_set):
+ unshown_list.append(phase_field_col)
+ unshown_set.add(phase_field_col)
+ elif field_lower not in shown_set and field_lower not in unshown_set:
+ unshown_list.append(fd.field_name)
+ unshown_set.add(field_lower)
+
+ if fd.field_type == tracker_pb2.FieldTypes.APPROVAL_TYPE:
+ approval_lower_approver = (
+ field_lower + tracker_constants.APPROVER_COL_SUFFIX)
+ if (approval_lower_approver not in shown_set and
+ approval_lower_approver not in unshown_set):
+ unshown_list.append(
+ fd.field_name + tracker_constants.APPROVER_COL_SUFFIX)
+ unshown_set.add(approval_lower_approver)
+
+ # The user can add a column for any key-value label or field in the results.
+ for r in results:
+ for label_name in tracker_bizobj.GetLabels(r):
+ _MaybeAddLabel(label_name)
+ for field_value in r.field_values:
+ if field_value.field_id not in field_ids_alread_seen:
+ field_ids_alread_seen.add(field_value.field_id)
+ fd = tracker_bizobj.FindFieldDefByID(field_value.field_id, config)
+ if fd: # could be None for a foreign field, which we don't display.
+ field_lower = fd.field_name.lower()
+ if field_lower not in shown_set and field_lower not in unshown_set:
+ unshown_list.append(fd.field_name)
+ unshown_set.add(field_lower)
+
+ return sorted(unshown_list)
+
+
+def ExtractUniqueValues(columns, artifact_list, users_by_id,
+ config, related_issues, hotlist_context_dict=None):
+ """Build a nested list of unique values so the user can auto-filter.
+
+ Args:
+ columns: a list of lowercase column name strings, which may contain
+ combined columns like "priority/pri".
+ artifact_list: a list of artifacts in the complete set of search results.
+ users_by_id: dict mapping user_ids to UserViews.
+ config: ProjectIssueConfig PB for the current project.
+ related_issues: dict {issue_id: issue} of pre-fetched related issues.
+ hotlist_context_dict: dict for building a hotlist grid table
+
+ Returns:
+ [EZTItem(col1, colname1, [val11, val12,...]), ...]
+ A list of EZTItems, each of which has a col_index, column_name,
+ and a list of unique values that appear in that column.
+ """
+ column_values = {col_name: {} for col_name in columns}
+
+ # For each combined column "a/b/c", add entries that point from "a" back
+ # to "a/b/c", from "b" back to "a/b/c", and from "c" back to "a/b/c".
+ combined_column_parts = collections.defaultdict(list)
+ for col in columns:
+ if '/' in col:
+ for col_part in col.split('/'):
+ combined_column_parts[col_part].append(col)
+
+ unique_labels = set()
+ for art in artifact_list:
+ unique_labels.update(tracker_bizobj.GetLabels(art))
+
+ for label in unique_labels:
+ if '-' in label:
+ col, val = label.split('-', 1)
+ col = col.lower()
+ if col in column_values:
+ column_values[col][val.lower()] = val
+ if col in combined_column_parts:
+ for combined_column in combined_column_parts[col]:
+ column_values[combined_column][val.lower()] = val
+ else:
+ if 'summary' in column_values:
+ column_values['summary'][label.lower()] = label
+
+ # TODO(jrobbins): Consider refacting some of this to tracker_bizobj
+ # or a new builtins.py to reduce duplication.
+ if 'reporter' in column_values:
+ for art in artifact_list:
+ reporter_id = art.reporter_id
+ if reporter_id and reporter_id in users_by_id:
+ reporter_username = users_by_id[reporter_id].display_name
+ column_values['reporter'][reporter_username] = reporter_username
+
+ if 'owner' in column_values:
+ for art in artifact_list:
+ owner_id = tracker_bizobj.GetOwnerId(art)
+ if owner_id and owner_id in users_by_id:
+ owner_username = users_by_id[owner_id].display_name
+ column_values['owner'][owner_username] = owner_username
+
+ if 'cc' in column_values:
+ for art in artifact_list:
+ cc_ids = tracker_bizobj.GetCcIds(art)
+ for cc_id in cc_ids:
+ if cc_id and cc_id in users_by_id:
+ cc_username = users_by_id[cc_id].display_name
+ column_values['cc'][cc_username] = cc_username
+
+ if 'component' in column_values:
+ for art in artifact_list:
+ all_comp_ids = list(art.component_ids) + list(art.derived_component_ids)
+ for component_id in all_comp_ids:
+ cd = tracker_bizobj.FindComponentDefByID(component_id, config)
+ if cd:
+ column_values['component'][cd.path] = cd.path
+
+ if 'stars' in column_values:
+ for art in artifact_list:
+ star_count = art.star_count
+ column_values['stars'][star_count] = star_count
+
+ if 'status' in column_values:
+ for art in artifact_list:
+ status = tracker_bizobj.GetStatus(art)
+ if status:
+ column_values['status'][status.lower()] = status
+
+ if 'project' in column_values:
+ for art in artifact_list:
+ project_name = art.project_name
+ column_values['project'][project_name] = project_name
+
+ if 'mergedinto' in column_values:
+ for art in artifact_list:
+ if art.merged_into and art.merged_into != 0:
+ merged_issue = related_issues[art.merged_into]
+ merged_issue_ref = tracker_bizobj.FormatIssueRef((
+ merged_issue.project_name, merged_issue.local_id))
+ column_values['mergedinto'][merged_issue_ref] = merged_issue_ref
+
+ if 'blocked' in column_values:
+ for art in artifact_list:
+ if art.blocked_on_iids:
+ column_values['blocked']['is_blocked'] = 'Yes'
+ else:
+ column_values['blocked']['is_not_blocked'] = 'No'
+
+ if 'blockedon' in column_values:
+ for art in artifact_list:
+ if art.blocked_on_iids:
+ for blocked_on_iid in art.blocked_on_iids:
+ blocked_on_issue = related_issues[blocked_on_iid]
+ blocked_on_ref = tracker_bizobj.FormatIssueRef((
+ blocked_on_issue.project_name, blocked_on_issue.local_id))
+ column_values['blockedon'][blocked_on_ref] = blocked_on_ref
+
+ if 'blocking' in column_values:
+ for art in artifact_list:
+ if art.blocking_iids:
+ for blocking_iid in art.blocking_iids:
+ blocking_issue = related_issues[blocking_iid]
+ blocking_ref = tracker_bizobj.FormatIssueRef((
+ blocking_issue.project_name, blocking_issue.local_id))
+ column_values['blocking'][blocking_ref] = blocking_ref
+
+ if 'added' in column_values:
+ for art in artifact_list:
+ if hotlist_context_dict and hotlist_context_dict[art.issue_id]:
+ issue_dict = hotlist_context_dict[art.issue_id]
+ date_added = issue_dict['date_added']
+ column_values['added'][date_added] = date_added
+
+ if 'adder' in column_values:
+ for art in artifact_list:
+ if hotlist_context_dict and hotlist_context_dict[art.issue_id]:
+ issue_dict = hotlist_context_dict[art.issue_id]
+ adder_id = issue_dict['adder_id']
+ adder = users_by_id[adder_id].display_name
+ column_values['adder'][adder] = adder
+
+ if 'note' in column_values:
+ for art in artifact_list:
+ if hotlist_context_dict and hotlist_context_dict[art.issue_id]:
+ issue_dict = hotlist_context_dict[art.issue_id]
+ note = issue_dict['note']
+ if issue_dict['note']:
+ column_values['note'][note] = note
+
+ if 'attachments' in column_values:
+ for art in artifact_list:
+ attachment_count = art.attachment_count
+ column_values['attachments'][attachment_count] = attachment_count
+
+ # Add all custom field values if the custom field name is a shown column.
+ field_id_to_col = {}
+ for art in artifact_list:
+ for fv in art.field_values:
+ field_col, field_type = field_id_to_col.get(fv.field_id, (None, None))
+ if field_col == 'NOT_SHOWN':
+ continue
+ if field_col is None:
+ fd = tracker_bizobj.FindFieldDefByID(fv.field_id, config)
+ if not fd:
+ field_id_to_col[fv.field_id] = 'NOT_SHOWN', None
+ continue
+ field_col = fd.field_name.lower()
+ field_type = fd.field_type
+ if field_col not in column_values:
+ field_id_to_col[fv.field_id] = 'NOT_SHOWN', None
+ continue
+ field_id_to_col[fv.field_id] = field_col, field_type
+
+ if field_type == tracker_pb2.FieldTypes.ENUM_TYPE:
+ continue # Already handled by label parsing
+ elif field_type == tracker_pb2.FieldTypes.INT_TYPE:
+ val = fv.int_value
+ elif field_type == tracker_pb2.FieldTypes.STR_TYPE:
+ val = fv.str_value
+ elif field_type == tracker_pb2.FieldTypes.USER_TYPE:
+ user = users_by_id.get(fv.user_id)
+ val = user.email if user else framework_constants.NO_USER_NAME
+ elif field_type == tracker_pb2.FieldTypes.DATE_TYPE:
+ val = fv.int_value # TODO(jrobbins): convert to date
+ elif field_type == tracker_pb2.FieldTypes.BOOL_TYPE:
+ val = 'Yes' if fv.int_value else 'No'
+
+ column_values[field_col][val] = val
+
+ # TODO(jrobbins): make the capitalization of well-known unique label and
+ # status values match the way it is written in the issue config.
+
+ # Return EZTItems for each column in left-to-right display order.
+ result = []
+ for i, col_name in enumerate(columns):
+ # TODO(jrobbins): sort each set of column values top-to-bottom, by the
+ # order specified in the project artifact config. For now, just sort
+ # lexicographically to make expected output defined.
+ sorted_col_values = sorted(column_values[col_name].values())
+ result.append(template_helpers.EZTItem(
+ col_index=i, column_name=col_name, filter_values=sorted_col_values))
+
+ return result
+
+
+def MakeTableData(
+ visible_results, starred_items, lower_columns, lower_group_by,
+ users_by_id, cell_factories, id_accessor, related_issues,
+ viewable_iids_set, config, context_for_all_issues=None):
+ """Return a list of list row objects for display by EZT.
+
+ Args:
+ visible_results: list of artifacts to display on one pagination page.
+ starred_items: list of IDs/names of items in the current project
+ that the signed in user has starred.
+ lower_columns: list of column names to display, all lowercase. These can
+ be combined column names, e.g., 'priority/pri'.
+ lower_group_by: list of column names that define row groups, all lowercase.
+ users_by_id: dict mapping user IDs to UserViews.
+ cell_factories: dict of functions that each create TableCell objects.
+ id_accessor: function that maps from an artifact to the ID/name that might
+ be in the starred items list.
+ related_issues: dict {issue_id: issue} of pre-fetched related issues.
+ viewable_iids_set: set of issue ids that can be viewed by the user.
+ config: ProjectIssueConfig PB for the current project.
+ context_for_all_issues: A dictionary of dictionaries containing values
+ passed in to cell factory functions to create TableCells. Dictionary
+ form: {issue_id: {'rank': issue_rank, 'issue_info': info_value, ..},
+ issue_id: {'rank': issue_rank}, ..}
+
+ Returns:
+ A list of TableRow objects, one for each visible result.
+ """
+ table_data = []
+
+ group_cell_factories = [
+ ChooseCellFactory(group.strip('-'), cell_factories, config)
+ for group in lower_group_by]
+
+ # Make a list of cell factories, one for each column.
+ factories_to_use = [
+ ChooseCellFactory(col, cell_factories, config) for col in lower_columns]
+
+ current_group = None
+ for idx, art in enumerate(visible_results):
+ row = MakeRowData(
+ art, lower_columns, users_by_id, factories_to_use, related_issues,
+ viewable_iids_set, config, context_for_all_issues)
+ row.starred = ezt.boolean(id_accessor(art) in starred_items)
+ row.idx = idx # EZT does not have loop counters, so add idx.
+ table_data.append(row)
+ row.group = None
+
+ # Also include group information for the first row in each group.
+ # TODO(jrobbins): This seems like more overhead than we need for the
+ # common case where no new group heading row is to be inserted.
+ group = MakeRowData(
+ art, [group_name.strip('-') for group_name in lower_group_by],
+ users_by_id, group_cell_factories, related_issues, viewable_iids_set,
+ config, context_for_all_issues)
+ for cell, group_name in zip(group.cells, lower_group_by):
+ cell.group_name = group_name
+ if group == current_group:
+ current_group.rows_in_group += 1
+ else:
+ row.group = group
+ current_group = group
+ current_group.rows_in_group = 1
+
+ return table_data
+
+
+def MakeRowData(
+ art, columns, users_by_id, cell_factory_list, related_issues,
+ viewable_iids_set, config, context_for_all_issues):
+ """Make a TableRow for use by EZT when rendering HTML table of results.
+
+ Args:
+ art: a project artifact PB
+ columns: list of lower-case column names
+ users_by_id: dictionary {user_id: UserView} with each UserView having
+ a "display_name" member.
+ cell_factory_list: list of functions that each create TableCell
+ objects for a given column.
+ related_issues: dict {issue_id: issue} of pre-fetched related issues.
+ viewable_iids_set: set of issue ids that can be viewed by the user.
+ config: ProjectIssueConfig PB for the current project.
+ context_for_all_issues: A dictionary of dictionaries containing values
+ passed in to cell factory functions to create TableCells. Dictionary
+ form: {issue_id: {'rank': issue_rank, 'issue_info': info_value, ..},
+ issue_id: {'rank': issue_rank}, ..}
+
+ Returns:
+ A TableRow object for use by EZT to render a table of results.
+ """
+ if context_for_all_issues is None:
+ context_for_all_issues = {}
+ ordered_row_data = []
+ non_col_labels = []
+ label_values = collections.defaultdict(list)
+
+ flattened_columns = set()
+ for col in columns:
+ if '/' in col:
+ flattened_columns.update(col.split('/'))
+ else:
+ flattened_columns.add(col)
+
+ # Group all "Key-Value" labels by key, and separate the "OneWord" labels.
+ _AccumulateLabelValues(
+ art.labels, flattened_columns, label_values, non_col_labels)
+
+ _AccumulateLabelValues(
+ art.derived_labels, flattened_columns, label_values,
+ non_col_labels, is_derived=True)
+
+ # Build up a list of TableCell objects for this row.
+ for i, col in enumerate(columns):
+ factory = cell_factory_list[i]
+ kw = {
+ 'col': col,
+ 'users_by_id': users_by_id,
+ 'non_col_labels': non_col_labels,
+ 'label_values': label_values,
+ 'related_issues': related_issues,
+ 'viewable_iids_set': viewable_iids_set,
+ 'config': config,
+ }
+ kw.update(context_for_all_issues.get(art.issue_id, {}))
+ new_cell = factory(art, **kw)
+ new_cell.col_index = i
+ ordered_row_data.append(new_cell)
+
+ return TableRow(ordered_row_data)
+
+
+def _AccumulateLabelValues(
+ labels, columns, label_values, non_col_labels, is_derived=False):
+ """Parse OneWord and Key-Value labels for display in a list page.
+
+ Args:
+ labels: a list of label strings.
+ columns: a list of column names.
+ label_values: mutable dictionary {key: [value, ...]} of label values
+ seen so far.
+ non_col_labels: mutable list of OneWord labels seen so far.
+ is_derived: true if these labels were derived via rules.
+
+ Returns:
+ Nothing. But, the given label_values dictionary will grow to hold
+ the values of the key-value labels passed in, and the non_col_labels
+ list will grow to hold the OneWord labels passed in. These are shown
+ in label columns, and in the summary column, respectively
+ """
+ for label_name in labels:
+ if '-' in label_name:
+ parts = label_name.split('-')
+ for pivot in range(1, len(parts)):
+ column_name = '-'.join(parts[:pivot])
+ value = '-'.join(parts[pivot:])
+ column_name = column_name.lower()
+ if column_name in columns:
+ label_values[column_name].append((value, is_derived))
+ else:
+ non_col_labels.append((label_name, is_derived))
+
+
+@total_ordering
+class TableRow(object):
+ """A tiny auxiliary class to represent a row in an HTML table."""
+
+ def __init__(self, cells):
+ """Initialize the table row with the given data."""
+ self.cells = cells
+ # Used by MakeTableData for layout.
+ self.idx = None
+ self.group = None
+ self.rows_in_group = None
+ self.starred = None
+
+ def __eq__(self, other):
+ """A row is == if each cell is == to the cells in the other row."""
+ return other and self.cells == other.cells
+
+ def __ne__(self, other):
+ return not other and self.cells != other.cells
+
+ def __lt__(self, other):
+ return other and self.cells < other.cells
+
+ def DebugString(self):
+ """Return a string that is useful for on-page debugging."""
+ return 'TR(%s)' % self.cells
+
+
+# TODO(jrobbins): also add unsortable... or change this to a list of operations
+# that can be done.
+CELL_TYPE_ID = 'ID'
+CELL_TYPE_SUMMARY = 'summary'
+CELL_TYPE_ATTR = 'attr'
+CELL_TYPE_UNFILTERABLE = 'unfilterable'
+CELL_TYPE_NOTE = 'note'
+CELL_TYPE_PROJECT = 'project'
+CELL_TYPE_URL = 'url'
+CELL_TYPE_ISSUES = 'issues'
+
+
+@total_ordering
+class TableCell(object):
+ """Helper class to represent a table cell when rendering using EZT."""
+
+ # Should instances of this class be rendered with whitespace:nowrap?
+ # Subclasses can override this constant.
+ NOWRAP = ezt.boolean(True)
+
+ def __init__(self, cell_type, explicit_values,
+ derived_values=None, non_column_labels=None, align='',
+ sort_values=True):
+ """Store all the given data for later access by EZT."""
+ self.type = cell_type
+ self.align = align
+ self.col_index = 0 # Is set afterward
+ self.values = []
+ if non_column_labels:
+ self.non_column_labels = [
+ template_helpers.EZTItem(value=v, is_derived=ezt.boolean(d))
+ for v, d in non_column_labels]
+ else:
+ self.non_column_labels = []
+
+ for v in (sorted(explicit_values) if sort_values else explicit_values):
+ self.values.append(CellItem(v))
+
+ if derived_values:
+ for v in (sorted(derived_values) if sort_values else derived_values):
+ self.values.append(CellItem(v, is_derived=True))
+
+ def __eq__(self, other):
+ """A row is == if each cell is == to the cells in the other row."""
+ return other and self.values == other.values
+
+ def __ne__(self, other):
+ return not other and self.values != other.values
+
+ def __lt__(self, other):
+ return other and self.values < other.values
+
+ def DebugString(self):
+ return 'TC(%r, %r, %r)' % (
+ self.type,
+ [v.DebugString() for v in self.values],
+ self.non_column_labels)
+
+
+def CompositeFactoryTableCell(factory_col_list_arg):
+ """Cell factory that combines multiple cells in a combined column."""
+
+ class FactoryClass(TableCell):
+ factory_col_list = factory_col_list_arg
+
+ def __init__(self, art, **kw):
+ TableCell.__init__(self, CELL_TYPE_UNFILTERABLE, [])
+
+ for sub_factory, sub_col in self.factory_col_list:
+ kw['col'] = sub_col
+ sub_cell = sub_factory(art, **kw)
+ self.non_column_labels.extend(sub_cell.non_column_labels)
+ self.values.extend(sub_cell.values)
+ return FactoryClass
+
+
+def CompositeColTableCell(columns_to_combine, cell_factories, config):
+ """Cell factory that combines multiple cells in a combined column."""
+ factory_col_list = []
+ for sub_col in columns_to_combine:
+ sub_factory = ChooseCellFactory(sub_col, cell_factories, config)
+ factory_col_list.append((sub_factory, sub_col))
+ return CompositeFactoryTableCell(factory_col_list)
+
+
+@total_ordering
+class CellItem(object):
+ """Simple class to display one part of a table cell's value, with style."""
+
+ def __init__(self, item, is_derived=False):
+ self.item = item
+ self.is_derived = ezt.boolean(is_derived)
+
+ def __eq__(self, other):
+ """A row is == if each cell is == to the item in the other row."""
+ return other and self.item == other.item
+
+ def __ne__(self, other):
+ return not other and self.item != other.item
+
+ def __lt__(self, other):
+ return other and self.item < other.item
+
+ def DebugString(self):
+ if self.is_derived:
+ return 'CI(derived: %r)' % self.item
+ else:
+ return 'CI(%r)' % self.item
+
+
+class TableCellKeyLabels(TableCell):
+ """TableCell subclass specifically for showing user-defined label values."""
+
+ def __init__(self, _art, col=None, label_values=None, **_kw):
+ label_value_pairs = label_values.get(col, [])
+ explicit_values = [value for value, is_derived in label_value_pairs
+ if not is_derived]
+ derived_values = [value for value, is_derived in label_value_pairs
+ if is_derived]
+ TableCell.__init__(self, CELL_TYPE_ATTR, explicit_values,
+ derived_values=derived_values)
+
+
+class TableCellProject(TableCell):
+ """TableCell subclass for showing an artifact's project name."""
+
+ def __init__(self, art, **_kw):
+ TableCell.__init__(
+ self, CELL_TYPE_PROJECT, [art.project_name])
+
+
+class TableCellStars(TableCell):
+ """TableCell subclass for showing an artifact's star count."""
+
+ def __init__(self, art, **_kw):
+ TableCell.__init__(
+ self, CELL_TYPE_ATTR, [art.star_count], align='right')
+
+
+class TableCellSummary(TableCell):
+ """TableCell subclass for showing an artifact's summary."""
+
+ def __init__(self, art, non_col_labels=None, **_kw):
+ TableCell.__init__(
+ self, CELL_TYPE_SUMMARY, [art.summary],
+ non_column_labels=non_col_labels)
+
+
+class TableCellDate(TableCell):
+ """TableCell subclass for showing any kind of date timestamp."""
+
+ # Make instances of this class render with whitespace:nowrap.
+ NOWRAP = ezt.boolean(True)
+
+ def __init__(self, timestamp, days_only=False):
+ values = []
+ if timestamp:
+ date_str = timestr.FormatRelativeDate(timestamp, days_only=days_only)
+ if not date_str:
+ date_str = timestr.FormatAbsoluteDate(timestamp)
+ values = [date_str]
+
+ TableCell.__init__(self, CELL_TYPE_UNFILTERABLE, values)
+
+
+class TableCellCustom(TableCell):
+ """Abstract TableCell subclass specifically for showing custom fields."""
+
+ def __init__(self, art, col=None, users_by_id=None, config=None, **_kw):
+ explicit_values = []
+ derived_values = []
+ cell_type = CELL_TYPE_ATTR
+ phase_names_by_id = {
+ phase.phase_id: phase.name.lower() for phase in art.phases}
+ phase_name = None
+ # Check if col represents a phase field value in the form <phase>.<field>
+ if '.' in col:
+ phase_name, col = col.split('.', 1)
+ for fv in art.field_values:
+ # TODO(jrobbins): for cross-project search this could be a list.
+ fd = tracker_bizobj.FindFieldDefByID(fv.field_id, config)
+ if not fd:
+ # TODO(jrobbins): This can happen if an issue with a custom
+ # field value is moved to a different project.
+ logging.warn('Issue ID %r has undefined field value %r',
+ art.issue_id, fv)
+ elif fd.field_name.lower() == col and (
+ phase_names_by_id.get(fv.phase_id) == phase_name):
+ if fd.field_type == tracker_pb2.FieldTypes.URL_TYPE:
+ cell_type = CELL_TYPE_URL
+ if fd.field_type == tracker_pb2.FieldTypes.STR_TYPE:
+ self.NOWRAP = ezt.boolean(False)
+ val = tracker_bizobj.GetFieldValue(fv, users_by_id)
+ if fv.derived:
+ derived_values.append(val)
+ else:
+ explicit_values.append(val)
+
+ TableCell.__init__(self, cell_type, explicit_values,
+ derived_values=derived_values)
+
+ def ExtractValue(self, fv, _users_by_id):
+ return 'field-id-%d-not-implemented-yet' % fv.field_id
+
+class TableCellApprovalStatus(TableCell):
+ """Abstract TableCell subclass specifically for showing approval fields."""
+
+ def __init__(self, art, col=None, config=None, **_kw):
+ explicit_values = []
+ for av in art.approval_values:
+ fd = tracker_bizobj.FindFieldDef(col, config)
+ ad = tracker_bizobj.FindApprovalDef(col, config)
+ if not (ad and fd):
+ logging.warn('Issue ID %r has undefined field value %r',
+ art.issue_id, av)
+ elif av.approval_id == fd.field_id:
+ explicit_values.append(av.status.name)
+ break
+
+ TableCell.__init__(self, CELL_TYPE_ATTR, explicit_values)
+
+
+class TableCellApprovalApprover(TableCell):
+ """TableCell subclass specifically for showing approval approvers."""
+
+ def __init__(self, art, col=None, config=None, users_by_id=None, **_kw):
+ explicit_values = []
+ approval_name = col[:-len(tracker_constants.APPROVER_COL_SUFFIX)]
+ for av in art.approval_values:
+ fd = tracker_bizobj.FindFieldDef(approval_name, config)
+ ad = tracker_bizobj.FindApprovalDef(approval_name, config)
+ if not (ad and fd):
+ logging.warn('Issue ID %r has undefined field value %r',
+ art.issue_id, av)
+ elif av.approval_id == fd.field_id:
+ explicit_values = [users_by_id.get(approver_id).display_name
+ for approver_id in av.approver_ids
+ if users_by_id.get(approver_id)]
+ break
+
+ TableCell.__init__(self, CELL_TYPE_ATTR, explicit_values)
+
+def ChooseCellFactory(col, cell_factories, config):
+ """Return the CellFactory to use for the given column."""
+ if col in cell_factories:
+ return cell_factories[col]
+
+ if '/' in col:
+ return CompositeColTableCell(col.split('/'), cell_factories, config)
+
+ is_approver_col = False
+ possible_field_name = col
+ if col.endswith(tracker_constants.APPROVER_COL_SUFFIX):
+ possible_field_name = col[:-len(tracker_constants.APPROVER_COL_SUFFIX)]
+ is_approver_col = True
+ # Check if col represents a phase field value in the form <phase>.<field>
+ elif '.' in possible_field_name:
+ possible_field_name = possible_field_name.split('.')[-1]
+
+ fd = tracker_bizobj.FindFieldDef(possible_field_name, config)
+ if fd:
+ # We cannot assume that non-enum_type field defs do not share their
+ # names with label prefixes. So we need to group them with
+ # TableCellKeyLabels to make sure we catch appropriate labels values.
+ if fd.field_type == tracker_pb2.FieldTypes.APPROVAL_TYPE:
+ if is_approver_col:
+ # Combined cell for 'FieldName-approver' to hold approvers
+ # belonging to FieldName and values belonging to labels with
+ # 'FieldName-approver' as the key.
+ return CompositeFactoryTableCell(
+ [(TableCellApprovalApprover, col), (TableCellKeyLabels, col)])
+ return CompositeFactoryTableCell(
+ [(TableCellApprovalStatus, col), (TableCellKeyLabels, col)])
+ elif fd.field_type != tracker_pb2.FieldTypes.ENUM_TYPE:
+ return CompositeFactoryTableCell(
+ [(TableCellCustom, col), (TableCellKeyLabels, col)])
+
+ return TableCellKeyLabels
diff --git a/framework/template_helpers.py b/framework/template_helpers.py
new file mode 100644
index 0000000..5f383c3
--- /dev/null
+++ b/framework/template_helpers.py
@@ -0,0 +1,326 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Some utility classes for interacting with templates."""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import cgi
+import cStringIO
+import httplib
+import logging
+import time
+import types
+
+import ezt
+import six
+
+from protorpc import messages
+
+import settings
+from framework import framework_constants
+
+
+_DISPLAY_VALUE_TRAILING_CHARS = 8
+_DISPLAY_VALUE_TIP_CHARS = 120
+
+
+class PBProxy(object):
+ """Wraps a Protocol Buffer so it is easy to acceess from a template."""
+
+ def __init__(self, pb):
+ self.__pb = pb
+
+ def __getattr__(self, name):
+ """Make the getters template friendly.
+
+ Psudo-hack alert: When attributes end with _bool, they are converted in
+ to EZT style bools. I.e., if false return None, if true return True.
+
+ Args:
+ name: the name of the attribute to get.
+
+ Returns:
+ The value of that attribute (as an EZT bool if the name ends with _bool).
+ """
+ if name.endswith('_bool'):
+ bool_name = name
+ name = name[0:-5]
+ else:
+ bool_name = None
+
+ # Make it possible for a PBProxy-local attribute to override the protocol
+ # buffer field, or even to allow attributes to be added to the PBProxy that
+ # the protocol buffer does not even have.
+ if name in self.__dict__:
+ if callable(self.__dict__[name]):
+ val = self.__dict__[name]()
+ else:
+ val = self.__dict__[name]
+
+ if bool_name:
+ return ezt.boolean(val)
+ return val
+
+ if bool_name:
+ # return an ezt.boolean for the named field.
+ return ezt.boolean(getattr(self.__pb, name))
+
+ val = getattr(self.__pb, name)
+
+ if isinstance(val, messages.Enum):
+ return int(val) # TODO(jrobbins): use str() instead
+
+ if isinstance(val, messages.Message):
+ return PBProxy(val)
+
+ # Return a list of values whose Message entries
+ # have been wrapped in PBProxies.
+ if isinstance(val, (list, messages.FieldList)):
+ list_to_return = []
+ for v in val:
+ if isinstance(v, messages.Message):
+ list_to_return.append(PBProxy(v))
+ else:
+ list_to_return.append(v)
+ return list_to_return
+
+ return val
+
+ def DebugString(self):
+ """Return a string representation that is useful in debugging."""
+ return 'PBProxy(%s)' % self.__pb
+
+ def __eq__(self, other):
+ # Disable warning about accessing other.__pb.
+ # pylint: disable=protected-access
+ return isinstance(other, PBProxy) and self.__pb == other.__pb
+
+
+_templates = {}
+
+
+def GetTemplate(
+ template_path, compress_whitespace=True, eliminate_blank_lines=False,
+ base_format=ezt.FORMAT_HTML):
+ """Make a MonorailTemplate if needed, or reuse one if possible."""
+ key = template_path, compress_whitespace, base_format
+ if key in _templates:
+ return _templates[key]
+
+ template = MonorailTemplate(
+ template_path, compress_whitespace=compress_whitespace,
+ eliminate_blank_lines=eliminate_blank_lines, base_format=base_format)
+ _templates[key] = template
+ return template
+
+
+class cStringIOUnicodeWrapper(object):
+ """Wrapper on cStringIO.StringIO that encodes unicode as UTF-8 as it goes."""
+
+ def __init__(self):
+ self.buffer = cStringIO.StringIO()
+
+ def write(self, s):
+ if isinstance(s, six.text_type):
+ utf8_s = s.encode('utf-8')
+ else:
+ utf8_s = s
+ self.buffer.write(utf8_s)
+
+ def getvalue(self):
+ return self.buffer.getvalue()
+
+
+SNIFFABLE_PATTERNS = {
+ '%PDF-': '%NoNoNo-',
+}
+
+
+class MonorailTemplate(object):
+ """A template with additional functionality."""
+
+ def __init__(self, template_path, compress_whitespace=True,
+ eliminate_blank_lines=False, base_format=ezt.FORMAT_HTML):
+ self.template_path = template_path
+ self.template = None
+ self.compress_whitespace = compress_whitespace
+ self.base_format = base_format
+ self.eliminate_blank_lines = eliminate_blank_lines
+
+ def WriteResponse(self, response, data, content_type=None):
+ """Write the parsed and filled in template to http server."""
+ if content_type:
+ response.content_type = content_type
+
+ response.status = data.get('http_response_code', httplib.OK)
+ whole_page = self.GetResponse(data)
+ if data.get('prevent_sniffing'):
+ for sniff_pattern, sniff_replacement in SNIFFABLE_PATTERNS.items():
+ whole_page = whole_page.replace(sniff_pattern, sniff_replacement)
+ start = time.time()
+ response.write(whole_page)
+ logging.info('wrote response in %dms', int((time.time() - start) * 1000))
+
+ def GetResponse(self, data):
+ """Generate the text from the template and return it as a string."""
+ template = self.GetTemplate()
+ start = time.time()
+ buf = cStringIOUnicodeWrapper()
+ template.generate(buf, data)
+ whole_page = buf.getvalue()
+ logging.info('rendering took %dms', int((time.time() - start) * 1000))
+ logging.info('whole_page len is %r', len(whole_page))
+ if self.eliminate_blank_lines:
+ lines = whole_page.split('\n')
+ whole_page = '\n'.join(line for line in lines if line.strip())
+ logging.info('smaller whole_page len is %r', len(whole_page))
+ logging.info('smaller rendering took %dms',
+ int((time.time() - start) * 1000))
+ return whole_page
+
+ def GetTemplate(self):
+ """Parse the EZT template, or return an already parsed one."""
+ # We don't operate directly on self.template to avoid races.
+ template = self.template
+
+ if template is None or settings.local_mode:
+ start = time.time()
+ template = ezt.Template(
+ fname=self.template_path,
+ compress_whitespace=self.compress_whitespace,
+ base_format=self.base_format)
+ logging.info('parsed in %dms', int((time.time() - start) * 1000))
+ self.template = template
+
+ return template
+
+ def GetTemplatePath(self):
+ """Accessor for the template path specified in the constructor.
+
+ Returns:
+ The string path for the template file provided to the constructor.
+ """
+ return self.template_path
+
+
+class EZTError(object):
+ """This class is a helper class to pass errors to EZT.
+
+ This class is used to hold information that will be passed to EZT but might
+ be unset. All unset values return None (ie EZT False)
+ Example: page errors
+ """
+
+ def __getattr__(self, _name):
+ """This is the EZT retrieval function."""
+ return None
+
+ def AnyErrors(self):
+ return len(self.__dict__) != 0
+
+ def DebugString(self):
+ return 'EZTError(%s)' % self.__dict__
+
+ def SetError(self, name, value):
+ self.__setattr__(name, value)
+
+ def SetCustomFieldError(self, field_id, value):
+ # This access works because of the custom __getattr__.
+ # pylint: disable=access-member-before-definition
+ # pylint: disable=attribute-defined-outside-init
+ if self.custom_fields is None:
+ self.custom_fields = []
+ self.custom_fields.append(EZTItem(field_id=field_id, message=value))
+
+ any_errors = property(AnyErrors, None)
+
+def FitUnsafeText(text, length):
+ """Trim some unsafe (unescaped) text to a specific length.
+
+ Three periods are appended if trimming occurs. Note that we cannot use
+ the ellipsis character (&hellip) because this is unescaped text.
+
+ Args:
+ text: the string to fit (ASCII or unicode).
+ length: the length to trim to.
+
+ Returns:
+ An ASCII or unicode string fitted to the given length.
+ """
+ if not text:
+ return ""
+
+ if len(text) <= length:
+ return text
+
+ return text[:length] + '...'
+
+
+def BytesKbOrMb(num_bytes):
+ """Return a human-readable string representation of a number of bytes."""
+ if num_bytes < 1024:
+ return '%d bytes' % num_bytes # e.g., 128 bytes
+ if num_bytes < 99 * 1024:
+ return '%.1f KB' % (num_bytes / 1024.0) # e.g. 23.4 KB
+ if num_bytes < 1024 * 1024:
+ return '%d KB' % (num_bytes / 1024) # e.g., 219 KB
+ if num_bytes < 99 * 1024 * 1024:
+ return '%.1f MB' % (num_bytes / 1024.0 / 1024.0) # e.g., 21.9 MB
+ return '%d MB' % (num_bytes / 1024 / 1024) # e.g., 100 MB
+
+
+class EZTItem(object):
+ """A class that makes a collection of fields easily accessible in EZT."""
+
+ def __init__(self, **kwargs):
+ """Store all the given key-value pairs as fields of this object."""
+ vars(self).update(kwargs)
+
+ def __repr__(self):
+ fields = ', '.join('%r: %r' % (k, v) for k, v in
+ sorted(vars(self).items()))
+ return '%s({%s})' % (self.__class__.__name__, fields)
+
+ def __eq__(self, other):
+ return self.__dict__ == other.__dict__
+
+
+def ExpandLabels(page_data):
+ """If page_data has a 'labels' list, expand it into 'label1', etc.
+
+ Args:
+ page_data: Template data which may include a 'labels' field.
+ """
+ label_list = page_data.get('labels', [])
+ if isinstance(label_list, types.StringTypes):
+ label_list = [label.strip() for label in page_data['labels'].split(',')]
+
+ for i in range(len(label_list)):
+ page_data['label%d' % i] = label_list[i]
+ for i in range(len(label_list), framework_constants.MAX_LABELS):
+ page_data['label%d' % i] = ''
+
+
+class TextRun(object):
+ """A fragment of user-entered text that needs to be safely displyed."""
+
+ def __init__(self, content, tag=None, href=None):
+ self.content = content
+ self.tag = tag
+ self.href = href
+ self.title = None
+ self.css_class = None
+
+ def FormatForHTMLEmail(self):
+ """Return a string that can be used in an HTML email body."""
+ if self.tag == 'a' and self.href:
+ return '<a href="%s">%s</a>' % (
+ cgi.escape(self.href, quote=True),
+ cgi.escape(self.content, quote=True))
+
+ return cgi.escape(self.content, quote=True)
diff --git a/framework/test/__init__.py b/framework/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/framework/test/__init__.py
diff --git a/framework/test/alerts_test.py b/framework/test/alerts_test.py
new file mode 100644
index 0000000..0c398c1
--- /dev/null
+++ b/framework/test/alerts_test.py
@@ -0,0 +1,43 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for alert display helpers."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import unittest
+
+import ezt
+
+from framework import alerts
+from testing import fake
+from testing import testing_helpers
+
+
+class AlertsViewTest(unittest.TestCase):
+
+ def testTimestamp(self):
+ """Tests that alerts are only shown when the timestamp is valid."""
+ project = fake.Project(project_name='testproj')
+
+ now = int(time.time())
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/testproj/?updated=10&ts=%s' % now, project=project)
+ alerts_view = alerts.AlertsView(mr)
+ self.assertEqual(10, alerts_view.updated)
+ self.assertEqual(ezt.boolean(True), alerts_view.show)
+
+ now -= 10
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/testproj/?updated=10&ts=%s' % now, project=project)
+ alerts_view = alerts.AlertsView(mr)
+ self.assertEqual(ezt.boolean(False), alerts_view.show)
+
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/testproj/?updated=10', project=project)
+ alerts_view = alerts.AlertsView(mr)
+ self.assertEqual(ezt.boolean(False), alerts_view.show)
diff --git a/framework/test/authdata_test.py b/framework/test/authdata_test.py
new file mode 100644
index 0000000..a0e7313
--- /dev/null
+++ b/framework/test/authdata_test.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the authdata module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+
+from google.appengine.api import users
+
+from framework import authdata
+from services import service_manager
+from testing import fake
+
+
+class AuthDataTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.user_1 = self.services.user.TestAddUser('test@example.com', 111)
+
+ def testFromRequest(self):
+
+ class FakeUser(object):
+ email = lambda _: self.user_1.email
+
+ with mock.patch.object(users, 'get_current_user',
+ autospec=True) as mock_get_current_user:
+ mock_get_current_user.return_value = FakeUser()
+ auth = authdata.AuthData.FromRequest(self.cnxn, self.services)
+ self.assertEqual(auth.user_id, 111)
+
+ def testFromEmail(self):
+ auth = authdata.AuthData.FromEmail(
+ self.cnxn, self.user_1.email, self.services)
+ self.assertEqual(auth.user_id, 111)
+ self.assertEqual(auth.user_pb.email, self.user_1.email)
+
+ def testFromuserId(self):
+ auth = authdata.AuthData.FromUserID(self.cnxn, 111, self.services)
+ self.assertEqual(auth.user_id, 111)
+ self.assertEqual(auth.user_pb.email, self.user_1.email)
+
+ def testFromUser(self):
+ auth = authdata.AuthData.FromUser(self.cnxn, self.user_1, self.services)
+ self.assertEqual(auth.user_id, 111)
+ self.assertEqual(auth.user_pb.email, self.user_1.email)
diff --git a/framework/test/banned_test.py b/framework/test/banned_test.py
new file mode 100644
index 0000000..73b9f03
--- /dev/null
+++ b/framework/test/banned_test.py
@@ -0,0 +1,58 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unittests for monorail.framework.banned."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import webapp2
+
+from framework import banned
+from framework import monorailrequest
+from services import service_manager
+from testing import testing_helpers
+
+
+class BannedTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services()
+
+ def testAssertBasePermission(self):
+ servlet = banned.Banned('request', 'response', services=self.services)
+
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.auth.user_id = 0 # Anon user cannot see banned page.
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ servlet.AssertBasePermission(mr)
+ self.assertEqual(404, cm.exception.code)
+
+ mr.auth.user_id = 111 # User who is not banned cannot view banned page.
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ servlet.AssertBasePermission(mr)
+ self.assertEqual(404, cm.exception.code)
+
+ # This should not throw exception.
+ mr.auth.user_pb.banned = 'spammer'
+ servlet.AssertBasePermission(mr)
+
+ def testGatherPageData(self):
+ servlet = banned.Banned('request', 'response', services=self.services)
+ self.assertNotEqual(servlet.template, None)
+
+ _request, mr = testing_helpers.GetRequestObjects()
+ page_data = servlet.GatherPageData(mr)
+
+ self.assertFalse(page_data['is_plus_address'])
+ self.assertEqual(None, page_data['currentPageURLEncoded'])
+
+ mr.auth.user_pb.email = 'user+shadystuff@example.com'
+ page_data = servlet.GatherPageData(mr)
+
+ self.assertTrue(page_data['is_plus_address'])
+ self.assertEqual(None, page_data['currentPageURLEncoded'])
diff --git a/framework/test/cloud_tasks_helpers_test.py b/framework/test/cloud_tasks_helpers_test.py
new file mode 100644
index 0000000..09ad2cd
--- /dev/null
+++ b/framework/test/cloud_tasks_helpers_test.py
@@ -0,0 +1,88 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""Tests for the cloud tasks helper module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from google.api_core import exceptions
+
+import mock
+import unittest
+
+from framework import cloud_tasks_helpers
+import settings
+
+
+class CloudTasksHelpersTest(unittest.TestCase):
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def test_create_task(self, get_client_mock):
+
+ queue = 'somequeue'
+ task = {
+ 'app_engine_http_request':
+ {
+ 'http_method': 'GET',
+ 'relative_uri': '/some_url'
+ }
+ }
+ cloud_tasks_helpers.create_task(task, queue=queue)
+
+ get_client_mock().queue_path.assert_called_with(
+ settings.app_id, settings.CLOUD_TASKS_REGION, queue)
+ get_client_mock().create_task.assert_called_once()
+ ((_parent, called_task), _kwargs) = get_client_mock().create_task.call_args
+ self.assertEqual(called_task, task)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def test_create_task_raises(self, get_client_mock):
+ task = {'app_engine_http_request': {}}
+
+ get_client_mock().create_task.side_effect = exceptions.GoogleAPICallError(
+ 'oh no!')
+
+ with self.assertRaises(exceptions.GoogleAPICallError):
+ cloud_tasks_helpers.create_task(task)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def test_create_task_retries(self, get_client_mock):
+ task = {'app_engine_http_request': {}}
+
+ cloud_tasks_helpers.create_task(task)
+
+ (_args, kwargs) = get_client_mock().create_task.call_args
+ self.assertEqual(kwargs.get('retry'), cloud_tasks_helpers._DEFAULT_RETRY)
+
+ def test_generate_simple_task(self):
+ actual = cloud_tasks_helpers.generate_simple_task(
+ '/alphabet/letters', {
+ 'a': 'a',
+ 'b': 'b'
+ })
+ expected = {
+ 'app_engine_http_request':
+ {
+ 'relative_uri': '/alphabet/letters',
+ 'body': 'a=a&b=b',
+ 'headers': {
+ 'Content-type': 'application/x-www-form-urlencoded'
+ }
+ }
+ }
+ self.assertEqual(actual, expected)
+
+ actual = cloud_tasks_helpers.generate_simple_task('/alphabet/letters', {})
+ expected = {
+ 'app_engine_http_request':
+ {
+ 'relative_uri': '/alphabet/letters',
+ 'body': '',
+ 'headers': {
+ 'Content-type': 'application/x-www-form-urlencoded'
+ }
+ }
+ }
+ self.assertEqual(actual, expected)
diff --git a/framework/test/csv_helpers_test.py b/framework/test/csv_helpers_test.py
new file mode 100644
index 0000000..19c89c5
--- /dev/null
+++ b/framework/test/csv_helpers_test.py
@@ -0,0 +1,61 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for csv_helpers functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import csv_helpers
+
+
+class IssueListCSVFunctionsTest(unittest.TestCase):
+
+ def testRewriteColspec(self):
+ self.assertEqual('', csv_helpers.RewriteColspec(''))
+
+ self.assertEqual('a B c', csv_helpers.RewriteColspec('a B c'))
+
+ self.assertEqual('a Summary AllLabels B Opened OpenedTimestamp c',
+ csv_helpers.RewriteColspec('a summary B opened c'))
+
+ self.assertEqual('Closed ClosedTimestamp Modified ModifiedTimestamp',
+ csv_helpers.RewriteColspec('Closed Modified'))
+
+ self.assertEqual('OwnerModified OwnerModifiedTimestamp',
+ csv_helpers.RewriteColspec('OwnerModified'))
+
+ def testReformatRowsForCSV(self):
+ # TODO(jojwang): write this test
+ pass
+
+ def testEscapeCSV(self):
+ self.assertEqual('', csv_helpers.EscapeCSV(None))
+ self.assertEqual(0, csv_helpers.EscapeCSV(0))
+ self.assertEqual('', csv_helpers.EscapeCSV(''))
+ self.assertEqual('hello', csv_helpers.EscapeCSV('hello'))
+ self.assertEqual('hello', csv_helpers.EscapeCSV(' hello '))
+
+ # Double quotes are escaped as two double quotes.
+ self.assertEqual("say 'hello'", csv_helpers.EscapeCSV("say 'hello'"))
+ self.assertEqual('say ""hello""', csv_helpers.EscapeCSV('say "hello"'))
+
+ # Things that look like formulas are prefixed with a single quote because
+ # some formula functions can have side-effects. See:
+ # https://www.contextis.com/resources/blog/comma-separated-vulnerabilities/
+ self.assertEqual("'=2+2", csv_helpers.EscapeCSV('=2+2'))
+ self.assertEqual("'=CMD| del *.*", csv_helpers.EscapeCSV('=CMD| del *.*'))
+
+ # Some spreadsheets apparently allow formula cells that start with
+ # plus, minus, and at-signs.
+ self.assertEqual("'+2+2", csv_helpers.EscapeCSV('+2+2'))
+ self.assertEqual("'-2+2", csv_helpers.EscapeCSV('-2+2'))
+ self.assertEqual("'@2+2", csv_helpers.EscapeCSV('@2+2'))
+
+ self.assertEqual(
+ u'division\xc3\xb7sign',
+ csv_helpers.EscapeCSV(u'division\xc3\xb7sign'))
diff --git a/framework/test/deleteusers_test.py b/framework/test/deleteusers_test.py
new file mode 100644
index 0000000..4cadbbd
--- /dev/null
+++ b/framework/test/deleteusers_test.py
@@ -0,0 +1,214 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for deleteusers classes."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import mock
+import unittest
+import urllib
+
+from framework import cloud_tasks_helpers
+from framework import deleteusers
+from framework import framework_constants
+from framework import urls
+from services import service_manager
+from testing import fake
+from testing import testing_helpers
+
+class TestWipeoutSyncCron(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services(user=fake.UserService())
+ self.task = deleteusers.WipeoutSyncCron(
+ request=None, response=None, services=self.services)
+ self.user_1 = self.services.user.TestAddUser('user1@example.com', 111)
+ self.user_2 = self.services.user.TestAddUser('user2@example.com', 222)
+ self.user_3 = self.services.user.TestAddUser('user3@example.com', 333)
+
+ def generate_simple_task(self, url, body):
+ return {
+ 'app_engine_http_request':
+ {
+ 'relative_uri': url,
+ 'body': body,
+ 'headers': {
+ 'Content-type': 'application/x-www-form-urlencoded'
+ }
+ }
+ }
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest(
+ path='url/url?batchsize=2',
+ services=self.services)
+ self.task.HandleRequest(mr)
+
+ self.assertEqual(get_client_mock().create_task.call_count, 3)
+
+ expected_task = self.generate_simple_task(
+ urls.SEND_WIPEOUT_USER_LISTS_TASK + '.do', 'limit=2&offset=0')
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ expected_task = self.generate_simple_task(
+ urls.SEND_WIPEOUT_USER_LISTS_TASK + '.do', 'limit=2&offset=2')
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ expected_task = self.generate_simple_task(
+ urls.DELETE_WIPEOUT_USERS_TASK + '.do', '')
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest_NoBatchSizeParam(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest(services=self.services)
+ self.task.HandleRequest(mr)
+
+ expected_task = self.generate_simple_task(
+ urls.SEND_WIPEOUT_USER_LISTS_TASK + '.do',
+ 'limit={}&offset=0'.format(deleteusers.MAX_BATCH_SIZE))
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest_NoUsers(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest()
+ self.services.user.users_by_id = {}
+ self.task.HandleRequest(mr)
+
+ calls = get_client_mock().create_task.call_args_list
+ self.assertEqual(len(calls), 0)
+
+
+class SendWipeoutUserListsTaskTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services(user=fake.UserService())
+ self.task = deleteusers.SendWipeoutUserListsTask(
+ request=None, response=None, services=self.services)
+ self.task.sendUserLists = mock.Mock()
+ deleteusers.authorize = mock.Mock(return_value='service')
+ self.user_1 = self.services.user.TestAddUser('user1@example.com', 111)
+ self.user_2 = self.services.user.TestAddUser('user2@example.com', 222)
+ self.user_3 = self.services.user.TestAddUser('user3@example.com', 333)
+
+ def testHandleRequest_NoBatchSizeParam(self):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url?limit=2&offset=1')
+ self.task.HandleRequest(mr)
+ deleteusers.authorize.assert_called_once_with()
+ self.task.sendUserLists.assert_called_once_with(
+ 'service', [
+ {'id': self.user_2.email},
+ {'id': self.user_3.email}])
+
+ def testHandleRequest_NoLimit(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ self.services.user.users_by_id = {}
+ with self.assertRaisesRegexp(AssertionError, 'Missing param limit'):
+ self.task.HandleRequest(mr)
+
+ def testHandleRequest_NoOffset(self):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url?limit=3')
+ self.services.user.users_by_id = {}
+ with self.assertRaisesRegexp(AssertionError, 'Missing param offset'):
+ self.task.HandleRequest(mr)
+
+ def testHandleRequest_ZeroOffset(self):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url?limit=2&offset=0')
+ self.task.HandleRequest(mr)
+ self.task.sendUserLists.assert_called_once_with(
+ 'service', [
+ {'id': self.user_1.email},
+ {'id': self.user_2.email}])
+
+
+class DeleteWipeoutUsersTaskTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services()
+ deleteusers.authorize = mock.Mock(return_value='service')
+ self.task = deleteusers.DeleteWipeoutUsersTask(
+ request=None, response=None, services=self.services)
+ deleted_users = [
+ {'id': 'user1@gmail.com'}, {'id': 'user2@gmail.com'},
+ {'id': 'user3@gmail.com'}, {'id': 'user4@gmail.com'}]
+ self.task.fetchDeletedUsers = mock.Mock(return_value=deleted_users)
+
+ def generate_simple_task(self, url, body):
+ return {
+ 'app_engine_http_request':
+ {
+ 'relative_uri': url,
+ 'body': body,
+ 'headers': {
+ 'Content-type': 'application/x-www-form-urlencoded'
+ }
+ }
+ }
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url?limit=3')
+ self.task.HandleRequest(mr)
+
+ deleteusers.authorize.assert_called_once_with()
+ self.task.fetchDeletedUsers.assert_called_once_with('service')
+ ((_app_id, _region, queue),
+ _kwargs) = get_client_mock().queue_path.call_args
+ self.assertEqual(queue, framework_constants.QUEUE_DELETE_USERS)
+
+ self.assertEqual(get_client_mock().create_task.call_count, 2)
+
+ query = urllib.urlencode(
+ {'emails': 'user1@gmail.com,user2@gmail.com,user3@gmail.com'})
+ expected_task = self.generate_simple_task(
+ urls.DELETE_USERS_TASK + '.do', query)
+
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ query = urllib.urlencode({'emails': 'user4@gmail.com'})
+ expected_task = self.generate_simple_task(
+ urls.DELETE_USERS_TASK + '.do', query)
+
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
+
+ @mock.patch('framework.cloud_tasks_helpers._get_client')
+ def testHandleRequest_DefaultMax(self, get_client_mock):
+ mr = testing_helpers.MakeMonorailRequest(path='url/url')
+ self.task.HandleRequest(mr)
+
+ deleteusers.authorize.assert_called_once_with()
+ self.task.fetchDeletedUsers.assert_called_once_with('service')
+ self.assertEqual(get_client_mock().create_task.call_count, 1)
+
+ emails = 'user1@gmail.com,user2@gmail.com,user3@gmail.com,user4@gmail.com'
+ query = urllib.urlencode({'emails': emails})
+ expected_task = self.generate_simple_task(
+ urls.DELETE_USERS_TASK + '.do', query)
+
+ get_client_mock().create_task.assert_any_call(
+ get_client_mock().queue_path(),
+ expected_task,
+ retry=cloud_tasks_helpers._DEFAULT_RETRY)
diff --git a/framework/test/emailfmt_test.py b/framework/test/emailfmt_test.py
new file mode 100644
index 0000000..dd7cca3
--- /dev/null
+++ b/framework/test/emailfmt_test.py
@@ -0,0 +1,821 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for monorail.framework.emailfmt."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+
+from google.appengine.ext import testbed
+
+import settings
+from framework import emailfmt
+from framework import framework_views
+from proto import project_pb2
+from testing import testing_helpers
+
+from google.appengine.api import apiproxy_stub_map
+
+
+class EmailFmtTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testValidateReferencesHeader(self):
+ project = project_pb2.Project()
+ project.project_name = 'open-open'
+ subject = 'slipped disk'
+ expected = emailfmt.MakeMessageID(
+ 'jrobbins@gmail.com', subject,
+ '%s@%s' % (project.project_name, emailfmt.MailDomain()))
+ self.assertTrue(
+ emailfmt.ValidateReferencesHeader(
+ expected, project, 'jrobbins@gmail.com', subject))
+
+ self.assertFalse(
+ emailfmt.ValidateReferencesHeader(
+ expected, project, 'jrobbins@gmail.com', 'something else'))
+
+ self.assertFalse(
+ emailfmt.ValidateReferencesHeader(
+ expected, project, 'someoneelse@gmail.com', subject))
+
+ project.project_name = 'other-project'
+ self.assertFalse(
+ emailfmt.ValidateReferencesHeader(
+ expected, project, 'jrobbins@gmail.com', subject))
+
+ def testParseEmailMessage(self):
+ msg = testing_helpers.MakeMessage(testing_helpers.HEADER_LINES, 'awesome!')
+
+ (from_addr, to_addrs, cc_addrs, references, incident_id,
+ subject, body) = emailfmt.ParseEmailMessage(msg)
+
+ self.assertEqual('user@example.com', from_addr)
+ self.assertEqual(['proj@monorail.example.com'], to_addrs)
+ self.assertEqual(['ningerso@chromium.org'], cc_addrs)
+ # Expected msg-id was generated from a previous known-good test run.
+ self.assertEqual(['<0=969704940193871313=13442892928193434663='
+ 'proj@monorail.example.com>'],
+ references)
+ self.assertEqual('', incident_id)
+ self.assertEqual('Issue 123 in proj: broken link', subject)
+ self.assertEqual('awesome!', body)
+
+ references_header = ('References', '<1234@foo.com> <5678@bar.com>')
+ msg = testing_helpers.MakeMessage(
+ testing_helpers.HEADER_LINES + [references_header], 'awesome!')
+ (from_addr, to_addrs, cc_addrs, references, incident_id, subject,
+ body) = emailfmt.ParseEmailMessage(msg)
+ self.assertItemsEqual(
+ ['<5678@bar.com>',
+ '<0=969704940193871313=13442892928193434663='
+ 'proj@monorail.example.com>',
+ '<1234@foo.com>'],
+ references)
+
+ def testParseEmailMessage_Bulk(self):
+ for precedence in ['Bulk', 'Junk']:
+ msg = testing_helpers.MakeMessage(
+ testing_helpers.HEADER_LINES + [('Precedence', precedence)],
+ 'I am on vacation!')
+
+ (from_addr, to_addrs, cc_addrs, references, incident_id, subject,
+ body) = emailfmt.ParseEmailMessage(msg)
+
+ self.assertEqual('', from_addr)
+ self.assertEqual([], to_addrs)
+ self.assertEqual([], cc_addrs)
+ self.assertEqual('', references)
+ self.assertEqual('', incident_id)
+ self.assertEqual('', subject)
+ self.assertEqual('', body)
+
+ def testExtractAddrs(self):
+ header_val = ''
+ self.assertEqual(
+ [], emailfmt._ExtractAddrs(header_val))
+
+ header_val = 'J. Robbins <a@b.com>, c@d.com,\n Nick "Name" Dude <e@f.com>'
+ self.assertEqual(
+ ['a@b.com', 'c@d.com', 'e@f.com'],
+ emailfmt._ExtractAddrs(header_val))
+
+ header_val = ('hot: J. O\'Robbins <a@b.com>; '
+ 'cool: "friendly" <e.g-h@i-j.k-L.com>')
+ self.assertEqual(
+ ['a@b.com', 'e.g-h@i-j.k-L.com'],
+ emailfmt._ExtractAddrs(header_val))
+
+ def CheckIdentifiedValues(
+ self, project_addr, subject, expected_project_name, expected_local_id,
+ expected_verb=None, expected_label=None):
+ """Testing helper function to check 3 results against expected values."""
+ project_name, verb, label = emailfmt.IdentifyProjectVerbAndLabel(
+ project_addr)
+ local_id = emailfmt.IdentifyIssue(project_name, subject)
+ self.assertEqual(expected_project_name, project_name)
+ self.assertEqual(expected_local_id, local_id)
+ self.assertEqual(expected_verb, verb)
+ self.assertEqual(expected_label, label)
+
+ def testIdentifyProjectAndIssues_Normal(self):
+ """Parse normal issue notification subject lines."""
+ self.CheckIdentifiedValues(
+ 'proj@monorail.example.com',
+ 'Issue 123 in proj: the dogs wont eat the dogfood',
+ 'proj', 123)
+
+ self.CheckIdentifiedValues(
+ 'Proj@MonoRail.Example.Com',
+ 'Issue 123 in proj: the dogs wont eat the dogfood',
+ 'proj', 123)
+
+ self.CheckIdentifiedValues(
+ 'proj-4-u@test-example3.com',
+ 'Issue 123 in proj-4-u: this one goes to: 11',
+ 'proj-4-u', 123)
+
+ self.CheckIdentifiedValues(
+ 'night@monorail.example.com',
+ 'Issue 451 in day: something is fishy',
+ 'night', None)
+
+ def testIdentifyProjectAndIssues_Compact(self):
+ """Parse compact subject lines."""
+ self.CheckIdentifiedValues(
+ 'proj@monorail.example.com',
+ 'proj:123: the dogs wont eat the dogfood',
+ 'proj', 123)
+
+ self.CheckIdentifiedValues(
+ 'Proj@MonoRail.Example.Com',
+ 'proj:123: the dogs wont eat the dogfood',
+ 'proj', 123)
+
+ self.CheckIdentifiedValues(
+ 'proj-4-u@test-example3.com',
+ 'proj-4-u:123: this one goes to: 11',
+ 'proj-4-u', 123)
+
+ self.CheckIdentifiedValues(
+ 'night@monorail.example.com',
+ 'day:451: something is fishy',
+ 'night', None)
+
+ def testIdentifyProjectAndIssues_NotAMatch(self):
+ """These subject lines do not match the ones we send."""
+ self.CheckIdentifiedValues(
+ 'no_reply@chromium.org',
+ 'Issue 234 in project foo: ignore this one',
+ None, None)
+
+ self.CheckIdentifiedValues(
+ 'no_reply@chromium.org',
+ 'foo-234: ignore this one',
+ None, None)
+
+ def testStripSubjectPrefixes(self):
+ self.assertEqual(
+ '',
+ emailfmt._StripSubjectPrefixes(''))
+
+ self.assertEqual(
+ 'this is it',
+ emailfmt._StripSubjectPrefixes('this is it'))
+
+ self.assertEqual(
+ 'this is it',
+ emailfmt._StripSubjectPrefixes('re: this is it'))
+
+ self.assertEqual(
+ 'this is it',
+ emailfmt._StripSubjectPrefixes('Re: Fwd: aw:this is it'))
+
+ self.assertEqual(
+ 'This - . IS it',
+ emailfmt._StripSubjectPrefixes('This - . IS it'))
+
+
+class MailDomainTest(unittest.TestCase):
+
+ def testTrivialCases(self):
+ self.assertEqual(
+ 'testbed-test.appspotmail.com',
+ emailfmt.MailDomain())
+
+
+class NoReplyAddressTest(unittest.TestCase):
+
+ def testNoCommenter(self):
+ self.assertEqual(
+ 'no_reply@testbed-test.appspotmail.com',
+ emailfmt.NoReplyAddress())
+
+ def testWithCommenter(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ 'user via monorail '
+ '<no_reply+v2.111@testbed-test.appspotmail.com>',
+ emailfmt.NoReplyAddress(
+ commenter_view=commenter_view, reveal_addr=True))
+
+ def testObscuredCommenter(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ u'u\u2026 via monorail '
+ '<no_reply+v2.111@testbed-test.appspotmail.com>',
+ emailfmt.NoReplyAddress(
+ commenter_view=commenter_view, reveal_addr=False))
+
+
+class FormatFromAddrTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project = project_pb2.Project(project_name='monorail')
+ self.old_send_email_as_format = settings.send_email_as_format
+ settings.send_email_as_format = 'monorail@%(domain)s'
+ self.old_send_noreply_email_as_format = (
+ settings.send_noreply_email_as_format)
+ settings.send_noreply_email_as_format = 'monorail+noreply@%(domain)s'
+
+ def tearDown(self):
+ self.old_send_email_as_format = settings.send_email_as_format
+ self.old_send_noreply_email_as_format = (
+ settings.send_noreply_email_as_format)
+
+ def testNoCommenter(self):
+ self.assertEqual('monorail@chromium.org',
+ emailfmt.FormatFromAddr(self.project))
+
+ @mock.patch('settings.branded_domains',
+ {'monorail': 'bugs.branded.com', '*': 'bugs.chromium.org'})
+ def testNoCommenter_Branded(self):
+ self.assertEqual('monorail@branded.com',
+ emailfmt.FormatFromAddr(self.project))
+
+ def testNoCommenterWithNoReply(self):
+ self.assertEqual('monorail+noreply@chromium.org',
+ emailfmt.FormatFromAddr(self.project, can_reply_to=False))
+
+ @mock.patch('settings.branded_domains',
+ {'monorail': 'bugs.branded.com', '*': 'bugs.chromium.org'})
+ def testNoCommenterWithNoReply_Branded(self):
+ self.assertEqual('monorail+noreply@branded.com',
+ emailfmt.FormatFromAddr(self.project, can_reply_to=False))
+
+ def testWithCommenter(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ u'user via monorail <monorail+v2.111@chromium.org>',
+ emailfmt.FormatFromAddr(
+ self.project, commenter_view=commenter_view, reveal_addr=True))
+
+ @mock.patch('settings.branded_domains',
+ {'monorail': 'bugs.branded.com', '*': 'bugs.chromium.org'})
+ def testWithCommenter_Branded(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ u'user via monorail <monorail+v2.111@branded.com>',
+ emailfmt.FormatFromAddr(
+ self.project, commenter_view=commenter_view, reveal_addr=True))
+
+ def testObscuredCommenter(self):
+ commenter_view = framework_views.StuffUserView(
+ 111, 'user@example.com', True)
+ self.assertEqual(
+ u'u\u2026 via monorail <monorail+v2.111@chromium.org>',
+ emailfmt.FormatFromAddr(
+ self.project, commenter_view=commenter_view, reveal_addr=False))
+
+ def testServiceAccountCommenter(self):
+ johndoe_bot = '123456789@developer.gserviceaccount.com'
+ commenter_view = framework_views.StuffUserView(
+ 111, johndoe_bot, True)
+ self.assertEqual(
+ ('johndoe via monorail <monorail+v2.111@chromium.org>'),
+ emailfmt.FormatFromAddr(
+ self.project, commenter_view=commenter_view, reveal_addr=False))
+
+
+class NormalizeHeaderWhitespaceTest(unittest.TestCase):
+
+ def testTrivialCases(self):
+ self.assertEqual(
+ '',
+ emailfmt.NormalizeHeader(''))
+
+ self.assertEqual(
+ '',
+ emailfmt.NormalizeHeader(' \t\n'))
+
+ self.assertEqual(
+ 'a',
+ emailfmt.NormalizeHeader('a'))
+
+ self.assertEqual(
+ 'a b',
+ emailfmt.NormalizeHeader(' a b '))
+
+ def testLongSummary(self):
+ big_string = 'x' * 500
+ self.assertEqual(
+ big_string[:emailfmt.MAX_HEADER_CHARS_CONSIDERED],
+ emailfmt.NormalizeHeader(big_string))
+
+ big_string = 'x y ' * 500
+ self.assertEqual(
+ big_string[:emailfmt.MAX_HEADER_CHARS_CONSIDERED],
+ emailfmt.NormalizeHeader(big_string))
+
+ big_string = 'x ' * 100
+ self.assertEqual(
+ 'x ' * 99 + 'x',
+ emailfmt.NormalizeHeader(big_string))
+
+ def testNormalCase(self):
+ self.assertEqual(
+ '[a] b: c d',
+ emailfmt.NormalizeHeader('[a] b:\tc\n\td'))
+
+
+class MakeMessageIDTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testMakeMessageIDTest(self):
+ message_id = emailfmt.MakeMessageID(
+ 'to@to.com', 'subject', 'from@from.com')
+ self.assertTrue(message_id.startswith('<0='))
+ self.assertEqual('testbed-test.appspotmail.com>',
+ message_id.split('@')[-1])
+
+ settings.mail_domain = None
+ message_id = emailfmt.MakeMessageID(
+ 'to@to.com', 'subject', 'from@from.com')
+ self.assertTrue(message_id.startswith('<0='))
+ self.assertEqual('testbed-test.appspotmail.com>',
+ message_id.split('@')[-1])
+
+ message_id = emailfmt.MakeMessageID(
+ 'to@to.com', 'subject', 'from@from.com')
+ self.assertTrue(message_id.startswith('<0='))
+ self.assertEqual('testbed-test.appspotmail.com>',
+ message_id.split('@')[-1])
+
+ message_id_ws_1 = emailfmt.MakeMessageID(
+ 'to@to.com',
+ 'this is a very long subject that is sure to be wordwrapped by gmail',
+ 'from@from.com')
+ message_id_ws_2 = emailfmt.MakeMessageID(
+ 'to@to.com',
+ 'this is a very long subject that \n\tis sure to be '
+ 'wordwrapped \t\tby gmail',
+ 'from@from.com')
+ self.assertEqual(message_id_ws_1, message_id_ws_2)
+
+
+class GetReferencesTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_datastore_v3_stub()
+ self.testbed.init_memcache_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testNotPartOfThread(self):
+ refs = emailfmt.GetReferences(
+ 'a@a.com', 'hi', None, emailfmt.NoReplyAddress())
+ self.assertEqual(0, len(refs))
+
+ def testAnywhereInThread(self):
+ refs = emailfmt.GetReferences(
+ 'a@a.com', 'hi', 0, emailfmt.NoReplyAddress())
+ self.assertTrue(len(refs))
+ self.assertTrue(refs.startswith('<0='))
+
+
+class StripQuotedTextTest(unittest.TestCase):
+
+ def CheckExpected(self, expected_output, test_input):
+ actual_output = emailfmt.StripQuotedText(test_input)
+ self.assertEqual(expected_output, actual_output)
+
+ def testAllNewText(self):
+ self.CheckExpected('', '')
+ self.CheckExpected('', '\n')
+ self.CheckExpected('', '\n\n')
+ self.CheckExpected('new', 'new')
+ self.CheckExpected('new', '\nnew\n')
+ self.CheckExpected('new\ntext', '\nnew\ntext\n')
+ self.CheckExpected('new\n\ntext', '\nnew\n\ntext\n')
+
+ def testQuotedLines(self):
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ '> something you said\n'
+ '> that took two lines'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ '> something you said\n'
+ '> that took two lines'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('> something you said\n'
+ '> that took two lines\n'
+ 'new\n'
+ 'text\n'
+ '\n'))
+
+ self.CheckExpected(
+ ('newtext'),
+ ('> something you said\n'
+ '> that took two lines\n'
+ 'newtext'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, So-and-so <so@and-so.com> Wrote:\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, So-and-so <so@and-so.com> Wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, user@example.com via Monorail\n'
+ '<monorail@chromium.com> Wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Jan 14, 2016 6:19 AM, "user@example.com via Monorail" <\n'
+ 'monorail@chromium.com> Wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Jan 14, 2016 6:19 AM, "user@example.com via Monorail" <\n'
+ 'monorail@monorail-prod.appspotmail.com> wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, So-and-so so@and-so.com wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Wed, Sep 8, 2010 at 6:56 PM, So =AND= <so@gmail.com>wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'On Mon, Jan 1, 2023, So-and-so <so@and-so.com> Wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'project-name@testbed-test.appspotmail.com wrote:\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'project-name@testbed-test.appspotmail.com a \xc3\xa9crit :\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ 'project.domain.com@testbed-test.appspotmail.com a \xc3\xa9crit :\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ '2023/01/4 <so@and-so.com>\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ '\n'
+ 'text'),
+ ('new\n'
+ '2023/01/4 <so-and@so.com>\n'
+ '\n'
+ '> something you said\n'
+ '> > in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ def testBoundaryLines(self):
+
+ self.CheckExpected(
+ ('new'),
+ ('new\n'
+ '---- forwarded message ======\n'
+ '\n'
+ 'something you said\n'
+ '> in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new'),
+ ('new\n'
+ '-----Original Message-----\n'
+ '\n'
+ 'something you said\n'
+ '> in response to some other junk\n'
+ '\n'
+ 'text\n'))
+
+ self.CheckExpected(
+ ('new'),
+ ('new\n'
+ '\n'
+ 'Updates:\n'
+ '\tStatus: Fixed\n'
+ '\n'
+ 'notification text\n'))
+
+ self.CheckExpected(
+ ('new'),
+ ('new\n'
+ '\n'
+ 'Comment #1 on issue 9 by username: Is there ...'
+ 'notification text\n'))
+
+ def testSignatures(self):
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '-- \n'
+ 'Name\n'
+ 'phone\n'
+ 'funny quote, or legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '--\n'
+ 'Name\n'
+ 'phone\n'
+ 'funny quote, or legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '--\n'
+ 'Name\n'
+ 'ginormous signature\n'
+ 'phone\n'
+ 'address\n'
+ 'address\n'
+ 'address\n'
+ 'homepage\n'
+ 'social network A\n'
+ 'social network B\n'
+ 'social network C\n'
+ 'funny quote\n'
+ '4 lines about why email should be short\n'
+ 'legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '_______________\n'
+ 'Name\n'
+ 'phone\n'
+ 'funny quote, or legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Thanks,\n'
+ 'Name\n'
+ '\n'
+ '_______________\n'
+ 'Name\n'
+ 'phone\n'
+ 'funny quote, or legal disclaimers\n'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Thanks,\n'
+ 'Name'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Cheers,\n'
+ 'Name'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Regards\n'
+ 'Name'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'best regards'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'THX'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Thank you,\n'
+ 'Name'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Sent from my iPhone'))
+
+ self.CheckExpected(
+ ('new\n'
+ 'text'),
+ ('new\n'
+ 'text\n'
+ '\n'
+ 'Sent from my iPod'))
diff --git a/framework/test/exceptions_test.py b/framework/test/exceptions_test.py
new file mode 100644
index 0000000..8fe2295
--- /dev/null
+++ b/framework/test/exceptions_test.py
@@ -0,0 +1,64 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+"""Unittest for the exceptions module."""
+
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import exceptions
+from framework import permissions
+
+
+class ErrorsManagerTest(unittest.TestCase):
+
+ def testRaiseIfErrors_Errors(self):
+ """We raise the given exception if there are errors."""
+ err_aggregator = exceptions.ErrorAggregator(exceptions.InputException)
+
+ err_aggregator.AddErrorMessage('The chickens are missing.')
+ err_aggregator.AddErrorMessage('The foxes are free.')
+ with self.assertRaisesRegexp(
+ exceptions.InputException,
+ 'The chickens are missing.\nThe foxes are free.'):
+ err_aggregator.RaiseIfErrors()
+
+ def testErrorsManager_NoErrors(self):
+ """ We don't raise exceptions if there are not errors. """
+ err_aggregator = exceptions.ErrorAggregator(exceptions.InputException)
+ err_aggregator.RaiseIfErrors()
+
+ def testWithinContext_ExceptionPassedIn(self):
+ """We do not suppress exceptions raised within wrapped code."""
+
+ with self.assertRaisesRegexp(exceptions.InputException,
+ 'We should raise this'):
+ with exceptions.ErrorAggregator(exceptions.InputException) as errors:
+ errors.AddErrorMessage('We should ignore this error.')
+ raise exceptions.InputException('We should raise this')
+
+ def testWithinContext_NoExceptionPassedIn(self):
+ """We raise an exception for any errors if no exceptions are passed in."""
+ with self.assertRaisesRegexp(exceptions.InputException,
+ 'We can raise this now.'):
+ with exceptions.ErrorAggregator(exceptions.InputException) as errors:
+ errors.AddErrorMessage('We can raise this now.')
+ return True
+
+ def testAddErrorMessage(self):
+ """We properly handle string formatting when needed."""
+ err_aggregator = exceptions.ErrorAggregator(exceptions.InputException)
+ err_aggregator.AddErrorMessage('No args')
+ err_aggregator.AddErrorMessage('No args2', 'unused', unused2=1)
+ err_aggregator.AddErrorMessage('{}', 'One arg')
+ err_aggregator.AddErrorMessage('{}, {two}', '1', two='2')
+
+ # Verify exceptions formatting a message don't clear the earlier messages.
+ with self.assertRaises(IndexError):
+ err_aggregator.AddErrorMessage('{}')
+
+ expected = ['No args', 'No args2', 'One arg', '1, 2']
+ self.assertEqual(err_aggregator.error_messages, expected)
diff --git a/framework/test/filecontent_test.py b/framework/test/filecontent_test.py
new file mode 100644
index 0000000..4843b47
--- /dev/null
+++ b/framework/test/filecontent_test.py
@@ -0,0 +1,188 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the filecontent module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import filecontent
+
+
+class MimeTest(unittest.TestCase):
+ """Test methods for the mime module."""
+
+ _TEST_EXTENSIONS_TO_CTYPES = {
+ 'html': 'text/plain',
+ 'htm': 'text/plain',
+ 'jpg': 'image/jpeg',
+ 'jpeg': 'image/jpeg',
+ 'pdf': 'application/pdf',
+ }
+
+ _CODE_EXTENSIONS = [
+ 'py', 'java', 'mf', 'bat', 'sh', 'php', 'vb', 'pl', 'sql',
+ 'patch', 'diff',
+ ]
+
+ def testCommonExtensions(self):
+ """Tests some common extensions for their expected content types."""
+ for ext, ctype in self._TEST_EXTENSIONS_TO_CTYPES.items():
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('file.%s' % ext),
+ ctype)
+
+ def testCaseDoesNotMatter(self):
+ """Ensure that case (upper/lower) of extension does not matter."""
+ for ext, ctype in self._TEST_EXTENSIONS_TO_CTYPES.items():
+ ext = ext.upper()
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('file.%s' % ext),
+ ctype)
+
+ for ext in self._CODE_EXTENSIONS:
+ ext = ext.upper()
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('code.%s' % ext),
+ 'text/plain')
+
+ def testCodeIsText(self):
+ """Ensure that code extensions are text/plain."""
+ for ext in self._CODE_EXTENSIONS:
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('code.%s' % ext),
+ 'text/plain')
+
+ def testNoExtensionIsText(self):
+ """Ensure that no extension indicates text/plain."""
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('noextension'),
+ 'text/plain')
+
+ def testUnknownExtension(self):
+ """Ensure that an obviously unknown extension returns is binary."""
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('f.madeupextension'),
+ 'application/octet-stream')
+
+ def testNoShockwaveFlash(self):
+ """Ensure that Shockwave files will NOT be served w/ that content type."""
+ self.assertEqual(
+ filecontent.GuessContentTypeFromFilename('bad.swf'),
+ 'application/octet-stream')
+
+
+class DecodeFileContentsTest(unittest.TestCase):
+
+ def IsBinary(self, contents):
+ _contents, is_binary, _is_long = (
+ filecontent.DecodeFileContents(contents))
+ return is_binary
+
+ def testFileIsBinaryEmpty(self):
+ self.assertFalse(self.IsBinary(''))
+
+ def testFileIsBinaryShortText(self):
+ self.assertFalse(self.IsBinary('This is some plain text.'))
+
+ def testLineLengthDetection(self):
+ unicode_str = (
+ u'Some non-ascii chars - '
+ u'\xa2\xfa\xb6\xe7\xfc\xea\xd0\xf4\xe6\xf0\xce\xf6\xbe')
+ short_line = unicode_str.encode('iso-8859-1')
+ long_line = (unicode_str * 100)[:filecontent._MAX_SOURCE_LINE_LEN_LOWER+1]
+ long_line = long_line.encode('iso-8859-1')
+
+ lines = [short_line] * 100
+ lines.append(long_line)
+
+ # High lower ratio - text
+ self.assertFalse(self.IsBinary('\n'.join(lines)))
+
+ lines.extend([long_line] * 99)
+
+ # 50/50 lower/upper ratio - binary
+ self.assertTrue(self.IsBinary('\n'.join(lines)))
+
+ # Single line too long - binary
+ lines = [short_line] * 100
+ lines.append(short_line * 100) # Very long line
+ self.assertTrue(self.IsBinary('\n'.join(lines)))
+
+ def testFileIsBinaryLongText(self):
+ self.assertFalse(self.IsBinary('This is plain text. \n' * 100))
+ # long utf-8 lines are OK
+ self.assertFalse(self.IsBinary('This one long line. ' * 100))
+
+ def testFileIsBinaryLongBinary(self):
+ bin_string = ''.join([chr(c) for c in range(122, 252)])
+ self.assertTrue(self.IsBinary(bin_string * 100))
+
+ def testFileIsTextByPath(self):
+ bin_string = ''.join([chr(c) for c in range(122, 252)] * 100)
+ unicode_str = (
+ u'Some non-ascii chars - '
+ u'\xa2\xfa\xb6\xe7\xfc\xea\xd0\xf4\xe6\xf0\xce\xf6\xbe')
+ long_line = (unicode_str * 100)[:filecontent._MAX_SOURCE_LINE_LEN_LOWER+1]
+ long_line = long_line.encode('iso-8859-1')
+
+ for contents in [bin_string, long_line]:
+ self.assertTrue(filecontent.DecodeFileContents(contents, path=None)[1])
+ self.assertTrue(filecontent.DecodeFileContents(contents, path='')[1])
+ self.assertTrue(filecontent.DecodeFileContents(contents, path='foo')[1])
+ self.assertTrue(
+ filecontent.DecodeFileContents(contents, path='foo.bin')[1])
+ self.assertTrue(
+ filecontent.DecodeFileContents(contents, path='foo.zzz')[1])
+ for path in ['a/b/Makefile.in', 'README', 'a/file.js', 'b.txt']:
+ self.assertFalse(
+ filecontent.DecodeFileContents(contents, path=path)[1])
+
+ def testFileIsBinaryByCommonExtensions(self):
+ contents = 'this is not examined'
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='junk.zip')[1])
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='JUNK.ZIP')[1])
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='/build/HelloWorld.o')[1])
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='/build/Hello.class')[1])
+ self.assertTrue(filecontent.DecodeFileContents(
+ contents, path='/trunk/libs.old/swing.jar')[1])
+
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='HelloWorld.cc')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='Hello.java')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='README')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='READ.ME')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='README.txt')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='README.TXT')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='/trunk/src/com/monorail/Hello.java')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='/branches/1.2/resource.el')[1])
+ self.assertFalse(filecontent.DecodeFileContents(
+ contents, path='/wiki/PageName.wiki')[1])
+
+ def testUnreasonablyLongFile(self):
+ contents = '\n' * (filecontent.SOURCE_FILE_MAX_LINES + 2)
+ _contents, is_binary, is_long = filecontent.DecodeFileContents(
+ contents)
+ self.assertFalse(is_binary)
+ self.assertTrue(is_long)
+
+ contents = '\n' * 100
+ _contents, is_binary, is_long = filecontent.DecodeFileContents(
+ contents)
+ self.assertFalse(is_binary)
+ self.assertFalse(is_long)
diff --git a/framework/test/framework_bizobj_test.py b/framework/test/framework_bizobj_test.py
new file mode 100644
index 0000000..131ebb5
--- /dev/null
+++ b/framework/test/framework_bizobj_test.py
@@ -0,0 +1,696 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for monorail.framework.framework_bizobj."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+import mock
+
+import settings
+from framework import authdata
+from framework import framework_bizobj
+from framework import framework_constants
+from proto import project_pb2
+from proto import tracker_pb2
+from proto import user_pb2
+from services import service_manager
+from services import client_config_svc
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+
+
+class CreateUserDisplayNamesAndEmailsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+
+ self.user_1 = self.services.user.TestAddUser(
+ 'user_1@test.com', 111, obscure_email=True)
+ self.user_2 = self.services.user.TestAddUser(
+ 'user_2@test.com', 222, obscure_email=False)
+ self.user_3 = self.services.user.TestAddUser(
+ 'user_3@test.com', 333, obscure_email=True)
+ self.user_4 = self.services.user.TestAddUser(
+ 'user_4@test.com', 444, obscure_email=False)
+ self.service_account = self.services.user.TestAddUser(
+ 'service@account.com', 999, obscure_email=True)
+ self.user_deleted = self.services.user.TestAddUser(
+ '', framework_constants.DELETED_USER_ID)
+ self.requester = self.services.user.TestAddUser('user_5@test.com', 555)
+ self.user_auth = authdata.AuthData(
+ user_id=self.requester.user_id, email=self.requester.email)
+ self.project = self.services.project.TestAddProject(
+ 'proj',
+ project_id=789,
+ owner_ids=[self.user_1.user_id],
+ committer_ids=[self.user_2.user_id, self.service_account.user_id])
+
+ @mock.patch('services.client_config_svc.GetServiceAccountMap')
+ def testUserCreateDisplayNamesAndEmails_NonProjectMembers(
+ self, fake_account_map):
+ fake_account_map.return_value = {'service@account.com': 'Service'}
+ users = [self.user_1, self.user_2, self.user_3, self.user_4,
+ self.service_account, self.user_deleted]
+ (display_names_by_id,
+ display_emails_by_id) = framework_bizobj.CreateUserDisplayNamesAndEmails(
+ self.cnxn, self.services, self.user_auth, users)
+ expected_display_names = {
+ self.user_1.user_id: testing_helpers.ObscuredEmail(self.user_1.email),
+ self.user_2.user_id: self.user_2.email,
+ self.user_3.user_id: testing_helpers.ObscuredEmail(self.user_3.email),
+ self.user_4.user_id: self.user_4.email,
+ self.service_account.user_id: 'Service',
+ self.user_deleted.user_id: framework_constants.DELETED_USER_NAME}
+ expected_display_emails = {
+ self.user_1.user_id:
+ testing_helpers.ObscuredEmail(self.user_1.email),
+ self.user_2.user_id:
+ self.user_2.email,
+ self.user_3.user_id:
+ testing_helpers.ObscuredEmail(self.user_3.email),
+ self.user_4.user_id:
+ self.user_4.email,
+ self.service_account.user_id:
+ testing_helpers.ObscuredEmail(self.service_account.email),
+ self.user_deleted.user_id: '',
+ }
+ self.assertEqual(display_names_by_id, expected_display_names)
+ self.assertEqual(display_emails_by_id, expected_display_emails)
+
+ @mock.patch('services.client_config_svc.GetServiceAccountMap')
+ def testUserCreateDisplayNamesAndEmails_ProjectMember(self, fake_account_map):
+ fake_account_map.return_value = {'service@account.com': 'Service'}
+ users = [self.user_1, self.user_2, self.user_3, self.user_4,
+ self.service_account, self.user_deleted]
+ self.project.committer_ids.append(self.requester.user_id)
+ (display_names_by_id,
+ display_emails_by_id) = framework_bizobj.CreateUserDisplayNamesAndEmails(
+ self.cnxn, self.services, self.user_auth, users)
+ expected_display_names = {
+ self.user_1.user_id: self.user_1.email, # Project member
+ self.user_2.user_id: self.user_2.email, # Project member and unobscured
+ self.user_3.user_id: testing_helpers.ObscuredEmail(self.user_3.email),
+ self.user_4.user_id: self.user_4.email, # Unobscured email
+ self.service_account.user_id: 'Service',
+ self.user_deleted.user_id: framework_constants.DELETED_USER_NAME
+ }
+ expected_display_emails = {
+ self.user_1.user_id: self.user_1.email, # Project member
+ self.user_2.user_id: self.user_2.email, # Project member and unobscured
+ self.user_3.user_id: testing_helpers.ObscuredEmail(self.user_3.email),
+ self.user_4.user_id: self.user_4.email, # Unobscured email
+ self.service_account.user_id: self.service_account.email,
+ self.user_deleted.user_id: ''
+ }
+ self.assertEqual(display_names_by_id, expected_display_names)
+ self.assertEqual(display_emails_by_id, expected_display_emails)
+
+ @mock.patch('services.client_config_svc.GetServiceAccountMap')
+ def testUserCreateDisplayNamesAndEmails_Admin(self, fake_account_map):
+ fake_account_map.return_value = {'service@account.com': 'Service'}
+ users = [self.user_1, self.user_2, self.user_3, self.user_4,
+ self.service_account, self.user_deleted]
+ self.user_auth.user_pb.is_site_admin = True
+ (display_names_by_id,
+ display_emails_by_id) = framework_bizobj.CreateUserDisplayNamesAndEmails(
+ self.cnxn, self.services, self.user_auth, users)
+ expected_display_names = {
+ self.user_1.user_id: self.user_1.email,
+ self.user_2.user_id: self.user_2.email,
+ self.user_3.user_id: self.user_3.email,
+ self.user_4.user_id: self.user_4.email,
+ self.service_account.user_id: 'Service',
+ self.user_deleted.user_id: framework_constants.DELETED_USER_NAME}
+ expected_display_emails = {
+ self.user_1.user_id: self.user_1.email,
+ self.user_2.user_id: self.user_2.email,
+ self.user_3.user_id: self.user_3.email,
+ self.user_4.user_id: self.user_4.email,
+ self.service_account.user_id: self.service_account.email,
+ self.user_deleted.user_id: ''
+ }
+
+ self.assertEqual(display_names_by_id, expected_display_names)
+ self.assertEqual(display_emails_by_id, expected_display_emails)
+
+
+class ParseAndObscureAddressTest(unittest.TestCase):
+
+ def testParseAndObscureAddress(self):
+ email = 'sir.chicken@farm.test'
+ (username, user_domain, obscured_username,
+ obscured_email) = framework_bizobj.ParseAndObscureAddress(email)
+
+ self.assertEqual(username, 'sir.chicken')
+ self.assertEqual(user_domain, 'farm.test')
+ self.assertEqual(obscured_username, 'sir.c')
+ self.assertEqual(obscured_email, 'sir.c...@farm.test')
+
+
+class FilterViewableEmailsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.user_1 = self.services.user.TestAddUser(
+ 'user_1@test.com', 111, obscure_email=True)
+ self.user_2 = self.services.user.TestAddUser(
+ 'user_2@test.com', 222, obscure_email=False)
+ self.requester = self.services.user.TestAddUser(
+ 'user_5@test.com', 555, obscure_email=True)
+ self.user_auth = authdata.AuthData(
+ user_id=self.requester.user_id, email=self.requester.email)
+ self.user_auth.user_pb.email = self.user_auth.email
+ self.project = self.services.project.TestAddProject(
+ 'proj', project_id=789, owner_ids=[111], committer_ids=[222])
+
+ def testFilterViewableEmail_Anon(self):
+ anon = authdata.AuthData()
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, anon, other_users)
+ self.assertEqual(filtered_users, [])
+
+ def testFilterViewableEmail_Self(self):
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, [self.user_auth.user_pb])
+ self.assertEqual(filtered_users, [self.user_auth.user_pb])
+
+ def testFilterViewableEmail_SiteAdmin(self):
+ self.user_auth.user_pb.is_site_admin = True
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, other_users)
+ self.assertEqual(filtered_users, other_users)
+
+ def testFilterViewableEmail_InDisplayNameGroup(self):
+ display_name_group_id = 666
+ self.services.usergroup.TestAddGroupSettings(
+ display_name_group_id, 'display-perm-perm@email.com')
+ settings.full_emails_perm_groups = ['display-perm-perm@email.com']
+ self.user_auth.effective_ids.add(display_name_group_id)
+
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, other_users)
+ self.assertEqual(filtered_users, other_users)
+
+ def testFilterViewableEmail_NonMember(self):
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, other_users)
+ self.assertEqual(filtered_users, [])
+
+ def testFilterViewableEmail_ProjectMember(self):
+ self.project.committer_ids.append(self.requester.user_id)
+ other_users = [self.user_1, self.user_2]
+ filtered_users = framework_bizobj.FilterViewableEmails(
+ self.cnxn, self.services, self.user_auth, other_users)
+ self.assertEqual(filtered_users, other_users)
+
+
+# TODO(https://crbug.com/monorail/8192): Remove deprecated tests.
+class DeprecatedShouldRevealEmailTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.user_1 = self.services.user.TestAddUser(
+ 'user_1@test.com', 111, obscure_email=True)
+ self.user_2 = self.services.user.TestAddUser(
+ 'user_2@test.com', 222, obscure_email=False)
+ self.requester = self.services.user.TestAddUser(
+ 'user_5@test.com', 555, obscure_email=True)
+ self.user_auth = authdata.AuthData(
+ user_id=self.requester.user_id, email=self.requester.email)
+ self.user_auth.user_pb.email = self.user_auth.email
+ self.project = self.services.project.TestAddProject(
+ 'proj', project_id=789, owner_ids=[111], committer_ids=[222])
+
+ def testDeprecatedShouldRevealEmail_Anon(self):
+ anon = authdata.AuthData()
+ self.assertFalse(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ anon, self.project, self.user_1.email))
+ self.assertFalse(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ anon, self.project, self.user_2.email))
+
+ def testDeprecatedShouldRevealEmail_Self(self):
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_auth.user_pb.email))
+
+ def testDeprecatedShouldRevealEmail_SiteAdmin(self):
+ self.user_auth.user_pb.is_site_admin = True
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_1.email))
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_2.email))
+
+ def testDeprecatedShouldRevealEmail_ProjectMember(self):
+ self.project.committer_ids.append(self.requester.user_id)
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_1.email))
+ self.assertTrue(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_2.email))
+
+ def testDeprecatedShouldRevealEmail_NonMember(self):
+ self.assertFalse(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_1.email))
+ self.assertFalse(
+ framework_bizobj.DeprecatedShouldRevealEmail(
+ self.user_auth, self.project, self.user_2.email))
+
+
+class ArtifactTest(unittest.TestCase):
+
+ def setUp(self):
+ # No custom fields. Exclusive prefixes: Type, Priority, Milestone.
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ def testMergeLabels_Labels(self):
+ # Empty case.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ [], [], [], self.config)
+ self.assertEqual(merged_labels, [])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, [])
+
+ # No-op case.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['a', 'b'], [], [], self.config)
+ self.assertEqual(merged_labels, ['a', 'b'])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, [])
+
+ # Adding and removing at the same time.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['a', 'b', 'd'], ['c'], ['d'], self.config)
+ self.assertEqual(merged_labels, ['a', 'b', 'c'])
+ self.assertEqual(update_add, ['c'])
+ self.assertEqual(update_remove, ['d'])
+
+ # Removing a non-matching label has no effect.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['a', 'b', 'd'], ['d'], ['e'], self.config)
+ self.assertEqual(merged_labels, ['a', 'b', 'd'])
+ self.assertEqual(update_add, []) # d was already there.
+ self.assertEqual(update_remove, []) # there was no e.
+
+ # We can add and remove at the same time.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], ['Hot'], ['OpSys-OSX'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'Hot'])
+ self.assertEqual(update_add, ['Hot'])
+ self.assertEqual(update_remove, ['OpSys-OSX'])
+
+ # Adding Priority-High replaces Priority-Medium.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], ['Priority-High', 'OpSys-Win'], [],
+ self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX', 'Priority-High', 'OpSys-Win'])
+ self.assertEqual(update_add, ['Priority-High', 'OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # Adding Priority-High and Priority-Low replaces with High only.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'],
+ ['Priority-High', 'Priority-Low'], [], self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX', 'Priority-High'])
+ self.assertEqual(update_add, ['Priority-High'])
+ self.assertEqual(update_remove, [])
+
+ # Removing a mix of matching and non-matching labels only does matching.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], [], ['Priority-Medium', 'OpSys-Win'],
+ self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX'])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, ['Priority-Medium'])
+
+ # Multi-part labels work as expected.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX-11'],
+ ['Priority-Medium-Rare', 'OpSys-OSX-13'], [], self.config)
+ self.assertEqual(
+ merged_labels, ['OpSys-OSX-11', 'Priority-Medium-Rare', 'OpSys-OSX-13'])
+ self.assertEqual(update_add, ['Priority-Medium-Rare', 'OpSys-OSX-13'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part exclusive prefixes only filter labels that match whole prefix.
+ self.config.exclusive_label_prefixes.append('Branch-Name')
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Branch-Name-xyz'],
+ ['Branch-Prediction', 'Branch-Name-Beta'], [], self.config)
+ self.assertEqual(merged_labels, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_add, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_remove, [])
+
+ def testMergeLabels_SingleValuedEnums(self):
+ self.config.field_defs.append(tracker_pb2.FieldDef(
+ field_id=1, field_name='Size',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE,
+ is_multivalued=False))
+ self.config.field_defs.append(tracker_pb2.FieldDef(
+ field_id=1, field_name='Branch-Name',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE,
+ is_multivalued=False))
+
+ # We can add a label for a single-valued enum.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], ['Size-L'], [], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'OpSys-OSX', 'Size-L'])
+ self.assertEqual(update_add, ['Size-L'])
+ self.assertEqual(update_remove, [])
+
+ # Adding and removing the same label adds it.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium'], ['Size-M'], ['Size-M'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'Size-M'])
+ self.assertEqual(update_add, ['Size-M'])
+ self.assertEqual(update_remove, [])
+
+ # Adding Size-L replaces Size-M.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'Size-M'], ['Size-L', 'OpSys-Win'], [],
+ self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'Size-L', 'OpSys-Win'])
+ self.assertEqual(update_add, ['Size-L', 'OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # Adding Size-L and Size-XL replaces with L only.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Size-M', 'OpSys-OSX'], ['Size-L', 'Size-XL'], [], self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX', 'Size-L'])
+ self.assertEqual(update_add, ['Size-L'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part labels work as expected.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Size-M', 'OpSys-OSX'], ['Size-M-USA'], [], self.config)
+ self.assertEqual(merged_labels, ['OpSys-OSX', 'Size-M-USA'])
+ self.assertEqual(update_add, ['Size-M-USA'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part enum names only filter labels that match whole name.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Branch-Name-xyz'],
+ ['Branch-Prediction', 'Branch-Name-Beta'], [], self.config)
+ self.assertEqual(merged_labels, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_add, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_remove, [])
+
+ def testMergeLabels_MultiValuedEnums(self):
+ self.config.field_defs.append(tracker_pb2.FieldDef(
+ field_id=1, field_name='OpSys',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE,
+ is_multivalued=True))
+ self.config.field_defs.append(tracker_pb2.FieldDef(
+ field_id=1, field_name='Branch-Name',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE,
+ is_multivalued=True))
+
+ # We can add a label for a multi-valued enum.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium'], ['OpSys-Win'], [], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'OpSys-Win'])
+ self.assertEqual(update_add, ['OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # We can remove a matching label for a multi-valued enum.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-Win'], [], ['OpSys-Win'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium'])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, ['OpSys-Win'])
+
+ # We can remove a non-matching label and it is a no-op.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], [], ['OpSys-Win'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'OpSys-OSX'])
+ self.assertEqual(update_add, [])
+ self.assertEqual(update_remove, [])
+
+ # Adding and removing the same label adds it.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium'], ['OpSys-Win'], ['OpSys-Win'], self.config)
+ self.assertEqual(merged_labels, ['Priority-Medium', 'OpSys-Win'])
+ self.assertEqual(update_add, ['OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # We can add a label for a multi-valued enum, even if matching exists.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Priority-Medium', 'OpSys-OSX'], ['OpSys-Win'], [], self.config)
+ self.assertEqual(
+ merged_labels, ['Priority-Medium', 'OpSys-OSX', 'OpSys-Win'])
+ self.assertEqual(update_add, ['OpSys-Win'])
+ self.assertEqual(update_remove, [])
+
+ # Adding two at the same time is fine.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Size-M', 'OpSys-OSX'], ['OpSys-Win', 'OpSys-Vax'], [], self.config)
+ self.assertEqual(
+ merged_labels, ['Size-M', 'OpSys-OSX', 'OpSys-Win', 'OpSys-Vax'])
+ self.assertEqual(update_add, ['OpSys-Win', 'OpSys-Vax'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part labels work as expected.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Size-M', 'OpSys-OSX'], ['OpSys-Win-10'], [], self.config)
+ self.assertEqual(merged_labels, ['Size-M', 'OpSys-OSX', 'OpSys-Win-10'])
+ self.assertEqual(update_add, ['OpSys-Win-10'])
+ self.assertEqual(update_remove, [])
+
+ # Multi-part enum names don't mess up anything.
+ (merged_labels, update_add, update_remove) = framework_bizobj.MergeLabels(
+ ['Branch-Name-xyz'],
+ ['Branch-Prediction', 'Branch-Name-Beta'], [], self.config)
+ self.assertEqual(
+ merged_labels,
+ ['Branch-Name-xyz', 'Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_add, ['Branch-Prediction', 'Branch-Name-Beta'])
+ self.assertEqual(update_remove, [])
+
+
+class CanonicalizeLabelTest(unittest.TestCase):
+
+ def testCanonicalizeLabel(self):
+ self.assertEqual(None, framework_bizobj.CanonicalizeLabel(None))
+ self.assertEqual('FooBar', framework_bizobj.CanonicalizeLabel('Foo Bar '))
+ self.assertEqual('Foo.Bar',
+ framework_bizobj.CanonicalizeLabel('Foo . Bar '))
+ self.assertEqual('Foo-Bar',
+ framework_bizobj.CanonicalizeLabel('Foo - Bar '))
+
+
+class UserIsInProjectTest(unittest.TestCase):
+
+ def testUserIsInProject(self):
+ p = project_pb2.Project()
+ self.assertFalse(framework_bizobj.UserIsInProject(p, {10}))
+ self.assertFalse(framework_bizobj.UserIsInProject(p, set()))
+
+ p.owner_ids.extend([1, 2, 3])
+ p.committer_ids.extend([4, 5, 6])
+ p.contributor_ids.extend([7, 8, 9])
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {1}))
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {4}))
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {7}))
+ self.assertFalse(framework_bizobj.UserIsInProject(p, {10}))
+
+ # Membership via group membership
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {10, 4}))
+
+ # Membership via several group memberships
+ self.assertTrue(framework_bizobj.UserIsInProject(p, {1, 4}))
+
+ # Several irrelevant group memberships
+ self.assertFalse(framework_bizobj.UserIsInProject(p, {10, 11, 12}))
+
+
+class IsValidColumnSpecTest(unittest.TestCase):
+
+ def testIsValidColumnSpec(self):
+ self.assertTrue(
+ framework_bizobj.IsValidColumnSpec('some columns hey-honk hay.honk'))
+
+ self.assertTrue(framework_bizobj.IsValidColumnSpec('some'))
+
+ self.assertTrue(framework_bizobj.IsValidColumnSpec(''))
+
+ def testIsValidColumnSpec_NotValid(self):
+ self.assertFalse(
+ framework_bizobj.IsValidColumnSpec('some columns hey-honk hay.'))
+
+ self.assertFalse(framework_bizobj.IsValidColumnSpec('some columns hey-'))
+
+ self.assertFalse(framework_bizobj.IsValidColumnSpec('-some columns hey'))
+
+ self.assertFalse(framework_bizobj.IsValidColumnSpec('some .columns hey'))
+
+
+class ValidatePrefTest(unittest.TestCase):
+
+ def testUnknown(self):
+ msg = framework_bizobj.ValidatePref('shoe_size', 'true')
+ self.assertIn('shoe_size', msg)
+ self.assertIn('Unknown', msg)
+
+ msg = framework_bizobj.ValidatePref('', 'true')
+ self.assertIn('Unknown', msg)
+
+ def testTooLong(self):
+ msg = framework_bizobj.ValidatePref('code_font', 'x' * 100)
+ self.assertIn('code_font', msg)
+ self.assertIn('too long', msg)
+
+ def testKnownValid(self):
+ self.assertIsNone(framework_bizobj.ValidatePref('code_font', 'true'))
+ self.assertIsNone(framework_bizobj.ValidatePref('code_font', 'false'))
+
+ def testKnownInvalid(self):
+ msg = framework_bizobj.ValidatePref('code_font', '')
+ self.assertIn('Invalid', msg)
+
+ msg = framework_bizobj.ValidatePref('code_font', 'sometimes')
+ self.assertIn('Invalid', msg)
+
+
+class IsRestrictNewIssuesUserTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.services.user.TestAddUser('corp_user@example.com', 111)
+ self.services.user.TestAddUser('corp_group@example.com', 888)
+ self.services.usergroup.TestAddGroupSettings(888, 'corp_group@example.com')
+
+ @mock.patch(
+ 'settings.restrict_new_issues_user_groups', ['corp_group@example.com'])
+ def testNonRestrictNewIssuesUser(self):
+ """We detect when a user is not part of a corp user group."""
+ self.assertFalse(
+ framework_bizobj.IsRestrictNewIssuesUser(self.cnxn, self.services, 111))
+
+ @mock.patch(
+ 'settings.restrict_new_issues_user_groups', ['corp_group@example.com'])
+ def testRestrictNewIssuesUser(self):
+ """We detect when a user is a member of such a group."""
+ self.services.usergroup.TestAddMembers(888, [111, 222])
+ self.assertTrue(
+ framework_bizobj.IsRestrictNewIssuesUser(self.cnxn, self.services, 111))
+
+
+class IsPublicIssueNoticeUserTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(), usergroup=fake.UserGroupService())
+ self.services.user.TestAddUser('corp_user@example.com', 111)
+ self.services.user.TestAddUser('corp_group@example.com', 888)
+ self.services.usergroup.TestAddGroupSettings(888, 'corp_group@example.com')
+
+ @mock.patch(
+ 'settings.public_issue_notice_user_groups', ['corp_group@example.com'])
+ def testNonPublicIssueNoticeUser(self):
+ """We detect when a user is not part of a corp user group."""
+ self.assertFalse(
+ framework_bizobj.IsPublicIssueNoticeUser(self.cnxn, self.services, 111))
+
+ @mock.patch(
+ 'settings.public_issue_notice_user_groups', ['corp_group@example.com'])
+ def testPublicIssueNoticeUser(self):
+ """We detect when a user is a member of such a group."""
+ self.services.usergroup.TestAddMembers(888, [111, 222])
+ self.assertTrue(
+ framework_bizobj.IsPublicIssueNoticeUser(self.cnxn, self.services, 111))
+
+
+class GetEffectiveIdsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(), usergroup=fake.UserGroupService())
+ self.services.user.TestAddUser('test@example.com', 111)
+
+ def testNoMemberships(self):
+ """No user groups means effective_ids == {user_id}."""
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111}})
+
+ def testNormalMemberships(self):
+ """effective_ids should be {user_id, group_id...}."""
+ self.services.usergroup.TestAddMembers(888, [111])
+ self.services.usergroup.TestAddMembers(999, [111])
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111, 888, 999}})
+
+ def testComputedUserGroup(self):
+ """effective_ids should be {user_id, group_id...}."""
+ self.services.usergroup.TestAddGroupSettings(888, 'everyone@example.com')
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111, 888}})
+
+ def testAccountHasParent(self):
+ """The parent's effective_ids are added to child's."""
+ child = self.services.user.TestAddUser('child@example.com', 111)
+ child.linked_parent_id = 222
+ parent = self.services.user.TestAddUser('parent@example.com', 222)
+ parent.linked_child_ids = [111]
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111, 222}})
+
+ self.services.usergroup.TestAddMembers(888, [111])
+ self.services.usergroup.TestAddMembers(999, [222])
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [111])
+ self.assertEqual(effective_ids, {111: {111, 222, 888, 999}})
+
+ def testAccountHasChildren(self):
+ """All linked child effective_ids are added to parent's."""
+ child1 = self.services.user.TestAddUser('child1@example.com', 111)
+ child1.linked_parent_id = 333
+ child2 = self.services.user.TestAddUser('child3@example.com', 222)
+ child2.linked_parent_id = 333
+ parent = self.services.user.TestAddUser('parent@example.com', 333)
+ parent.linked_child_ids = [111, 222]
+
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [333])
+ self.assertEqual(effective_ids, {333: {111, 222, 333}})
+
+ self.services.usergroup.TestAddMembers(888, [111])
+ self.services.usergroup.TestAddMembers(999, [222])
+ effective_ids = framework_bizobj.GetEffectiveIds(
+ self.cnxn, self.services, [333])
+ self.assertEqual(effective_ids, {333: {111, 222, 333, 888, 999}})
diff --git a/framework/test/framework_helpers_test.py b/framework/test/framework_helpers_test.py
new file mode 100644
index 0000000..1d0146c
--- /dev/null
+++ b/framework/test/framework_helpers_test.py
@@ -0,0 +1,563 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the framework_helpers module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+
+import mox
+import time
+
+from businesslogic import work_env
+from framework import framework_helpers
+from framework import framework_views
+from proto import features_pb2
+from proto import project_pb2
+from proto import user_pb2
+from services import service_manager
+from testing import fake
+from testing import testing_helpers
+
+
+class HelperFunctionsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.time = self.mox.CreateMock(framework_helpers.time)
+ framework_helpers.time = self.time # Point to a mocked out time module.
+
+ def tearDown(self):
+ framework_helpers.time = time # Point back to the time module.
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testRetryDecorator_ExceedFailures(self):
+ class Tracker(object):
+ func_called = 0
+ tracker = Tracker()
+
+ # Use a function that always fails.
+ @framework_helpers.retry(2, delay=1, backoff=2)
+ def testFunc(tracker):
+ tracker.func_called += 1
+ raise Exception('Failed')
+
+ self.time.sleep(1).AndReturn(None)
+ self.time.sleep(2).AndReturn(None)
+ self.mox.ReplayAll()
+ with self.assertRaises(Exception):
+ testFunc(tracker)
+ self.mox.VerifyAll()
+ self.assertEqual(3, tracker.func_called)
+
+ def testRetryDecorator_EventuallySucceed(self):
+ class Tracker(object):
+ func_called = 0
+ tracker = Tracker()
+
+ # Use a function that succeeds on the 2nd attempt.
+ @framework_helpers.retry(2, delay=1, backoff=2)
+ def testFunc(tracker):
+ tracker.func_called += 1
+ if tracker.func_called < 2:
+ raise Exception('Failed')
+
+ self.time.sleep(1).AndReturn(None)
+ self.mox.ReplayAll()
+ testFunc(tracker)
+ self.mox.VerifyAll()
+ self.assertEqual(2, tracker.func_called)
+
+ def testGetRoleName(self):
+ proj = project_pb2.Project()
+ proj.owner_ids.append(111)
+ proj.committer_ids.append(222)
+ proj.contributor_ids.append(333)
+
+ self.assertEqual(None, framework_helpers.GetRoleName(set(), proj))
+
+ self.assertEqual('Owner', framework_helpers.GetRoleName({111}, proj))
+ self.assertEqual('Committer', framework_helpers.GetRoleName({222}, proj))
+ self.assertEqual('Contributor', framework_helpers.GetRoleName({333}, proj))
+
+ self.assertEqual(
+ 'Owner', framework_helpers.GetRoleName({111, 222, 999}, proj))
+ self.assertEqual(
+ 'Committer', framework_helpers.GetRoleName({222, 333, 999}, proj))
+ self.assertEqual(
+ 'Contributor', framework_helpers.GetRoleName({333, 999}, proj))
+
+ def testGetHotlistRoleName(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+ hotlist.follower_ids.append(333)
+
+ self.assertEqual(None, framework_helpers.GetHotlistRoleName(set(), hotlist))
+
+ self.assertEqual(
+ 'Owner', framework_helpers.GetHotlistRoleName({111}, hotlist))
+ self.assertEqual(
+ 'Editor', framework_helpers.GetHotlistRoleName({222}, hotlist))
+ self.assertEqual(
+ 'Follower', framework_helpers.GetHotlistRoleName({333}, hotlist))
+
+ self.assertEqual(
+ 'Owner', framework_helpers.GetHotlistRoleName({111, 222, 999}, hotlist))
+ self.assertEqual(
+ 'Editor', framework_helpers.GetHotlistRoleName(
+ {222, 333, 999}, hotlist))
+ self.assertEqual(
+ 'Follower', framework_helpers.GetHotlistRoleName({333, 999}, hotlist))
+
+
+class UrlFormattingTest(unittest.TestCase):
+ """Tests for URL formatting."""
+
+ def setUp(self):
+ self.services = service_manager.Services(user=fake.UserService())
+
+ def testFormatMovedProjectURL(self):
+ """Project foo has been moved to bar. User is visiting /p/foo/..."""
+ mr = testing_helpers.MakeMonorailRequest()
+ mr.current_page_url = '/p/foo/'
+ self.assertEqual(
+ '/p/bar/',
+ framework_helpers.FormatMovedProjectURL(mr, 'bar'))
+
+ mr.current_page_url = '/p/foo/issues/list'
+ self.assertEqual(
+ '/p/bar/issues/list',
+ framework_helpers.FormatMovedProjectURL(mr, 'bar'))
+
+ mr.current_page_url = '/p/foo/issues/detail?id=123'
+ self.assertEqual(
+ '/p/bar/issues/detail?id=123',
+ framework_helpers.FormatMovedProjectURL(mr, 'bar'))
+
+ mr.current_page_url = '/p/foo/issues/detail?id=123#c7'
+ self.assertEqual(
+ '/p/bar/issues/detail?id=123#c7',
+ framework_helpers.FormatMovedProjectURL(mr, 'bar'))
+
+ def testFormatURL(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ path = '/dude/wheres/my/car'
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ url = framework_helpers.FormatURL(recognized_params, path)
+ self.assertEqual(path, url)
+
+ def testFormatURLWithRecognizedParams(self):
+ params = {}
+ query = []
+ for name in framework_helpers.RECOGNIZED_PARAMS:
+ params[name] = name
+ query.append('%s=%s' % (name, 123))
+ path = '/dude/wheres/my/car'
+ expected = '%s?%s' % (path, '&'.join(query))
+ mr = testing_helpers.MakeMonorailRequest(path=expected)
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ # No added params.
+ url = framework_helpers.FormatURL(recognized_params, path)
+ self.assertEqual(expected, url)
+
+ def testFormatURLWithKeywordArgs(self):
+ params = {}
+ query_pairs = []
+ for name in framework_helpers.RECOGNIZED_PARAMS:
+ params[name] = name
+ if name != 'can' and name != 'start':
+ query_pairs.append('%s=%s' % (name, 123))
+ path = '/dude/wheres/my/car'
+ mr = testing_helpers.MakeMonorailRequest(
+ path='%s?%s' % (path, '&'.join(query_pairs)))
+ query_pairs.append('can=yep')
+ query_pairs.append('start=486')
+ query_string = '&'.join(query_pairs)
+ expected = '%s?%s' % (path, query_string)
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ url = framework_helpers.FormatURL(
+ recognized_params, path, can='yep', start=486)
+ self.assertEqual(expected, url)
+
+ def testFormatURLWithKeywordArgsAndID(self):
+ params = {}
+ query_pairs = []
+ query_pairs.append('id=200') # id should be the first parameter.
+ for name in framework_helpers.RECOGNIZED_PARAMS:
+ params[name] = name
+ if name != 'can' and name != 'start':
+ query_pairs.append('%s=%s' % (name, 123))
+ path = '/dude/wheres/my/car'
+ mr = testing_helpers.MakeMonorailRequest(
+ path='%s?%s' % (path, '&'.join(query_pairs)))
+ query_pairs.append('can=yep')
+ query_pairs.append('start=486')
+ query_string = '&'.join(query_pairs)
+ expected = '%s?%s' % (path, query_string)
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ url = framework_helpers.FormatURL(
+ recognized_params, path, can='yep', start=486, id=200)
+ self.assertEqual(expected, url)
+
+ def testFormatURLWithStrangeParams(self):
+ mr = testing_helpers.MakeMonorailRequest(path='/foo?start=0')
+ recognized_params = [(name, mr.GetParam(name)) for name in
+ framework_helpers.RECOGNIZED_PARAMS]
+ url = framework_helpers.FormatURL(
+ recognized_params, '/foo',
+ r=0, path='/foo/bar', sketchy='/foo/ bar baz ')
+ self.assertEqual(
+ '/foo?start=0&path=/foo/bar&r=0&sketchy=/foo/%20bar%20baz%20',
+ url)
+
+ def testFormatAbsoluteURL(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/some-path',
+ headers={'Host': 'www.test.com'})
+ self.assertEqual(
+ 'http://www.test.com/p/proj/some/path',
+ framework_helpers.FormatAbsoluteURL(mr, '/some/path'))
+
+ def testFormatAbsoluteURL_CommonRequestParams(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/some-path?foo=bar&can=1',
+ headers={'Host': 'www.test.com'})
+ self.assertEqual(
+ 'http://www.test.com/p/proj/some/path?can=1',
+ framework_helpers.FormatAbsoluteURL(mr, '/some/path'))
+ self.assertEqual(
+ 'http://www.test.com/p/proj/some/path',
+ framework_helpers.FormatAbsoluteURL(
+ mr, '/some/path', copy_params=False))
+
+ def testFormatAbsoluteURL_NoProject(self):
+ path = '/some/path'
+ _request, mr = testing_helpers.GetRequestObjects(
+ headers={'Host': 'www.test.com'}, path=path)
+ url = framework_helpers.FormatAbsoluteURL(mr, path, include_project=False)
+ self.assertEqual(url, 'http://www.test.com/some/path')
+
+ def testGetHostPort_Local(self):
+ """We use testing-app.appspot.com when running locally."""
+ self.assertEqual('testing-app.appspot.com',
+ framework_helpers.GetHostPort())
+ self.assertEqual('testing-app.appspot.com',
+ framework_helpers.GetHostPort(project_name='proj'))
+
+ @mock.patch('settings.preferred_domains',
+ {'testing-app.appspot.com': 'example.com'})
+ def testGetHostPort_PreferredDomain(self):
+ """A prod server can have a preferred domain."""
+ self.assertEqual('example.com',
+ framework_helpers.GetHostPort())
+ self.assertEqual('example.com',
+ framework_helpers.GetHostPort(project_name='proj'))
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.com', '*': 'unbranded.com'})
+ @mock.patch('settings.preferred_domains',
+ {'testing-app.appspot.com': 'example.com'})
+ def testGetHostPort_BrandedDomain(self):
+ """A prod server can have a preferred domain."""
+ self.assertEqual('example.com',
+ framework_helpers.GetHostPort())
+ self.assertEqual('branded.com',
+ framework_helpers.GetHostPort(project_name='proj'))
+ self.assertEqual('unbranded.com',
+ framework_helpers.GetHostPort(project_name='other-proj'))
+
+ def testIssueCommentURL(self):
+ hostport = 'port.someplex.com'
+ proj = project_pb2.Project()
+ proj.project_name = 'proj'
+
+ url = 'https://port.someplex.com/p/proj/issues/detail?id=2'
+ actual_url = framework_helpers.IssueCommentURL(
+ hostport, proj, 2)
+ self.assertEqual(actual_url, url)
+
+ url = 'https://port.someplex.com/p/proj/issues/detail?id=2#c2'
+ actual_url = framework_helpers.IssueCommentURL(
+ hostport, proj, 2, seq_num=2)
+ self.assertEqual(actual_url, url)
+
+
+class WordWrapSuperLongLinesTest(unittest.TestCase):
+
+ def testEmptyLogMessage(self):
+ msg = ''
+ wrapped_msg = framework_helpers.WordWrapSuperLongLines(msg)
+ self.assertEqual(wrapped_msg, '')
+
+ def testShortLines(self):
+ msg = 'one\ntwo\nthree\n'
+ wrapped_msg = framework_helpers.WordWrapSuperLongLines(msg)
+ expected = 'one\ntwo\nthree\n'
+ self.assertEqual(wrapped_msg, expected)
+
+ def testOneLongLine(self):
+ msg = ('This is a super long line that just goes on and on '
+ 'and it seems like it will never stop because it is '
+ 'super long and it was entered by a user who had no '
+ 'familiarity with the return key.')
+ wrapped_msg = framework_helpers.WordWrapSuperLongLines(msg)
+ expected = ('This is a super long line that just goes on and on and it '
+ 'seems like it will never stop because it\n'
+ 'is super long and it was entered by a user who had no '
+ 'familiarity with the return key.')
+ self.assertEqual(wrapped_msg, expected)
+
+ msg2 = ('This is a super long line that just goes on and on '
+ 'and it seems like it will never stop because it is '
+ 'super long and it was entered by a user who had no '
+ 'familiarity with the return key. '
+ 'This is a super long line that just goes on and on '
+ 'and it seems like it will never stop because it is '
+ 'super long and it was entered by a user who had no '
+ 'familiarity with the return key.')
+ wrapped_msg2 = framework_helpers.WordWrapSuperLongLines(msg2)
+ expected2 = ('This is a super long line that just goes on and on and it '
+ 'seems like it will never stop because it\n'
+ 'is super long and it was entered by a user who had no '
+ 'familiarity with the return key. This is a\n'
+ 'super long line that just goes on and on and it seems like '
+ 'it will never stop because it is super\n'
+ 'long and it was entered by a user who had no familiarity '
+ 'with the return key.')
+ self.assertEqual(wrapped_msg2, expected2)
+
+ def testMixOfShortAndLong(self):
+ msg = ('[Author: mpcomplete]\n'
+ '\n'
+ # Description on one long line
+ 'Fix a memory leak in JsArray and JsObject for the IE and NPAPI '
+ 'ports. Each time you call GetElement* or GetProperty* to '
+ 'retrieve string or object token, the token would be leaked. '
+ 'I added a JsScopedToken to ensure that the right thing is '
+ 'done when the object leaves scope, depending on the platform.\n'
+ '\n'
+ 'R=zork\n'
+ 'CC=google-gears-eng@googlegroups.com\n'
+ 'DELTA=108 (52 added, 36 deleted, 20 changed)\n'
+ 'OCL=5932446\n'
+ 'SCL=5933728\n')
+ wrapped_msg = framework_helpers.WordWrapSuperLongLines(msg)
+ expected = (
+ '[Author: mpcomplete]\n'
+ '\n'
+ 'Fix a memory leak in JsArray and JsObject for the IE and NPAPI '
+ 'ports. Each time you call\n'
+ 'GetElement* or GetProperty* to retrieve string or object token, the '
+ 'token would be leaked. I added\n'
+ 'a JsScopedToken to ensure that the right thing is done when the '
+ 'object leaves scope, depending on\n'
+ 'the platform.\n'
+ '\n'
+ 'R=zork\n'
+ 'CC=google-gears-eng@googlegroups.com\n'
+ 'DELTA=108 (52 added, 36 deleted, 20 changed)\n'
+ 'OCL=5932446\n'
+ 'SCL=5933728\n')
+ self.assertEqual(wrapped_msg, expected)
+
+
+class ComputeListDeltasTest(unittest.TestCase):
+
+ def DoOne(self, old=None, new=None, added=None, removed=None):
+ """Run one call to the target method and check expected results."""
+ actual_added, actual_removed = framework_helpers.ComputeListDeltas(
+ old, new)
+ self.assertItemsEqual(added, actual_added)
+ self.assertItemsEqual(removed, actual_removed)
+
+ def testEmptyLists(self):
+ self.DoOne(old=[], new=[], added=[], removed=[])
+ self.DoOne(old=[1, 2], new=[], added=[], removed=[1, 2])
+ self.DoOne(old=[], new=[1, 2], added=[1, 2], removed=[])
+
+ def testUnchanged(self):
+ self.DoOne(old=[1], new=[1], added=[], removed=[])
+ self.DoOne(old=[1, 2], new=[1, 2], added=[], removed=[])
+ self.DoOne(old=[1, 2], new=[2, 1], added=[], removed=[])
+
+ def testCompleteChange(self):
+ self.DoOne(old=[1, 2], new=[3, 4], added=[3, 4], removed=[1, 2])
+
+ def testGeneralChange(self):
+ self.DoOne(old=[1, 2], new=[2], added=[], removed=[1])
+ self.DoOne(old=[1], new=[1, 2], added=[2], removed=[])
+ self.DoOne(old=[1, 2], new=[2, 3], added=[3], removed=[1])
+
+
+class UserSettingsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mr = testing_helpers.MakeMonorailRequest()
+ self.cnxn = 'cnxn'
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+
+ def testGatherUnifiedSettingsPageData(self):
+ mr = self.mr
+ mr.auth.user_view = framework_views.StuffUserView(100, 'user@invalid', True)
+ mr.auth.user_view.profile_url = '/u/profile/url'
+ userprefs = user_pb2.UserPrefs(
+ prefs=[user_pb2.UserPrefValue(name='public_issue_notice', value='true')])
+ page_data = framework_helpers.UserSettings.GatherUnifiedSettingsPageData(
+ mr.auth.user_id, mr.auth.user_view, mr.auth.user_pb, userprefs)
+
+ expected_keys = [
+ 'settings_user',
+ 'settings_user_pb',
+ 'settings_user_is_banned',
+ 'self',
+ 'profile_url_fragment',
+ 'preview_on_hover',
+ 'settings_user_prefs',
+ ]
+ self.assertItemsEqual(expected_keys, list(page_data.keys()))
+
+ self.assertEqual('profile/url', page_data['profile_url_fragment'])
+ self.assertTrue(page_data['settings_user_prefs'].public_issue_notice)
+ self.assertFalse(page_data['settings_user_prefs'].restrict_new_issues)
+
+ def testGatherUnifiedSettingsPageData_NoUserPrefs(self):
+ """If UserPrefs were not loaded, consider them all false."""
+ mr = self.mr
+ mr.auth.user_view = framework_views.StuffUserView(100, 'user@invalid', True)
+ userprefs = None
+
+ page_data = framework_helpers.UserSettings.GatherUnifiedSettingsPageData(
+ mr.auth.user_id, mr.auth.user_view, mr.auth.user_pb, userprefs)
+
+ self.assertFalse(page_data['settings_user_prefs'].public_issue_notice)
+ self.assertFalse(page_data['settings_user_prefs'].restrict_new_issues)
+
+ def testProcessBanForm(self):
+ """We can ban and unban users."""
+ user = self.services.user.TestAddUser('one@example.com', 111)
+ post_data = {'banned': 1, 'banned_reason': 'rude'}
+ framework_helpers.UserSettings.ProcessBanForm(
+ self.cnxn, self.services.user, post_data, 111, user)
+ self.assertEqual('rude', user.banned)
+
+ post_data = {} # not banned
+ framework_helpers.UserSettings.ProcessBanForm(
+ self.cnxn, self.services.user, post_data, 111, user)
+ self.assertEqual('', user.banned)
+
+ def testProcessSettingsForm_OldStylePrefs(self):
+ """We can set prefs that are stored in the User PB."""
+ user = self.services.user.TestAddUser('one@example.com', 111)
+ post_data = {'obscure_email': 1, 'notify': 1}
+ with work_env.WorkEnv(self.mr, self.services) as we:
+ framework_helpers.UserSettings.ProcessSettingsForm(
+ we, post_data, user)
+
+ self.assertTrue(user.obscure_email)
+ self.assertTrue(user.notify_issue_change)
+ self.assertFalse(user.notify_starred_ping)
+
+ def testProcessSettingsForm_NewStylePrefs(self):
+ """We can set prefs that are stored in the UserPrefs PB."""
+ user = self.services.user.TestAddUser('one@example.com', 111)
+ post_data = {'restrict_new_issues': 1}
+ with work_env.WorkEnv(self.mr, self.services) as we:
+ framework_helpers.UserSettings.ProcessSettingsForm(
+ we, post_data, user)
+ userprefs = we.GetUserPrefs(111)
+
+ actual = {upv.name: upv.value
+ for upv in userprefs.prefs}
+ expected = {
+ 'restrict_new_issues': 'true',
+ 'public_issue_notice': 'false',
+ }
+ self.assertEqual(expected, actual)
+
+
+class MurmurHash3Test(unittest.TestCase):
+
+ def testMurmurHash(self):
+ test_data = [
+ ('', 0),
+ ('agable@chromium.org', 4092810879),
+ (u'jrobbins@chromium.org', 904770043),
+ ('seanmccullough%google.com@gtempaccount.com', 1301269279),
+ ('rmistry+monorail@chromium.org', 4186878788),
+ ('jparent+foo@', 2923900874),
+ ('@example.com', 3043483168),
+ ]
+ hashes = [framework_helpers.MurmurHash3_x86_32(x)
+ for (x, _) in test_data]
+ self.assertListEqual(hashes, [e for (_, e) in test_data])
+
+ def testMurmurHashWithSeed(self):
+ test_data = [
+ ('', 1113155926, 2270882445),
+ ('agable@chromium.org', 772936925, 3995066671),
+ (u'jrobbins@chromium.org', 1519359761, 1273489513),
+ ('seanmccullough%google.com@gtempaccount.com', 49913829, 1202521153),
+ ('rmistry+monorail@chromium.org', 314860298, 3636123309),
+ ('jparent+foo@', 195791379, 332453977),
+ ('@example.com', 521490555, 257496459),
+ ]
+ hashes = [framework_helpers.MurmurHash3_x86_32(x, s)
+ for (x, s, _) in test_data]
+ self.assertListEqual(hashes, [e for (_, _, e) in test_data])
+
+
+class MakeRandomKeyTest(unittest.TestCase):
+
+ def testMakeRandomKey_Normal(self):
+ key1 = framework_helpers.MakeRandomKey()
+ key2 = framework_helpers.MakeRandomKey()
+ self.assertEqual(128, len(key1))
+ self.assertEqual(128, len(key2))
+ self.assertNotEqual(key1, key2)
+
+ def testMakeRandomKey_Length(self):
+ key = framework_helpers.MakeRandomKey()
+ self.assertEqual(128, len(key))
+ key16 = framework_helpers.MakeRandomKey(length=16)
+ self.assertEqual(16, len(key16))
+
+ def testMakeRandomKey_Chars(self):
+ key = framework_helpers.MakeRandomKey(chars='a', length=4)
+ self.assertEqual('aaaa', key)
+
+
+class IsServiceAccountTest(unittest.TestCase):
+
+ def testIsServiceAccount(self):
+ appspot = 'abc@appspot.gserviceaccount.com'
+ developer = '@developer.gserviceaccount.com'
+ bugdroid = 'bugdroid1@chromium.org'
+ user = 'test@example.com'
+
+ self.assertTrue(framework_helpers.IsServiceAccount(appspot))
+ self.assertTrue(framework_helpers.IsServiceAccount(developer))
+ self.assertTrue(framework_helpers.IsServiceAccount(bugdroid))
+ self.assertFalse(framework_helpers.IsServiceAccount(user))
+
+ client_emails = set([appspot, developer, bugdroid])
+ self.assertTrue(framework_helpers.IsServiceAccount(
+ appspot, client_emails=client_emails))
+ self.assertTrue(framework_helpers.IsServiceAccount(
+ developer, client_emails=client_emails))
+ self.assertTrue(framework_helpers.IsServiceAccount(
+ bugdroid, client_emails=client_emails))
+ self.assertFalse(framework_helpers.IsServiceAccount(
+ user, client_emails=client_emails))
diff --git a/framework/test/framework_views_test.py b/framework/test/framework_views_test.py
new file mode 100644
index 0000000..57f9fd1
--- /dev/null
+++ b/framework/test/framework_views_test.py
@@ -0,0 +1,326 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for framework_views classes."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import unittest
+
+from framework import framework_constants
+from framework import framework_views
+from framework import monorailrequest
+from proto import project_pb2
+from proto import tracker_pb2
+from proto import user_pb2
+import settings
+from services import service_manager
+from testing import fake
+
+
+LONG_STR = 'VeryLongStringThatCertainlyWillNotFit'
+LONG_PART_STR = 'OnePartThatWillNotFit-OneShort'
+
+
+class LabelViewTest(unittest.TestCase):
+
+ def testLabelView(self):
+ view = framework_views.LabelView('', None)
+ self.assertEqual('', view.name)
+
+ view = framework_views.LabelView('Priority-High', None)
+ self.assertEqual('Priority-High', view.name)
+ self.assertIsNone(view.is_restrict)
+ self.assertEqual('', view.docstring)
+ self.assertEqual('Priority', view.prefix)
+ self.assertEqual('High', view.value)
+
+ view = framework_views.LabelView('%s-%s' % (LONG_STR, LONG_STR), None)
+ self.assertEqual('%s-%s' % (LONG_STR, LONG_STR), view.name)
+ self.assertEqual('', view.docstring)
+ self.assertEqual(LONG_STR, view.prefix)
+ self.assertEqual(LONG_STR, view.value)
+
+ view = framework_views.LabelView(LONG_PART_STR, None)
+ self.assertEqual(LONG_PART_STR, view.name)
+ self.assertEqual('', view.docstring)
+ self.assertEqual('OnePartThatWillNotFit', view.prefix)
+ self.assertEqual('OneShort', view.value)
+
+ config = tracker_pb2.ProjectIssueConfig()
+ config.well_known_labels.append(tracker_pb2.LabelDef(
+ label='Priority-High', label_docstring='Must ship in this milestone'))
+
+ view = framework_views.LabelView('Priority-High', config)
+ self.assertEqual('Must ship in this milestone', view.docstring)
+
+ view = framework_views.LabelView('Priority-Foo', config)
+ self.assertEqual('', view.docstring)
+
+ view = framework_views.LabelView('Restrict-View-Commit', None)
+ self.assertTrue(view.is_restrict)
+
+
+class StatusViewTest(unittest.TestCase):
+
+ def testStatusView(self):
+ view = framework_views.StatusView('', None)
+ self.assertEqual('', view.name)
+
+ view = framework_views.StatusView('Accepted', None)
+ self.assertEqual('Accepted', view.name)
+ self.assertEqual('', view.docstring)
+ self.assertEqual('yes', view.means_open)
+
+ view = framework_views.StatusView(LONG_STR, None)
+ self.assertEqual(LONG_STR, view.name)
+ self.assertEqual('', view.docstring)
+ self.assertEqual('yes', view.means_open)
+
+ config = tracker_pb2.ProjectIssueConfig()
+ config.well_known_statuses.append(tracker_pb2.StatusDef(
+ status='SlamDunk', status_docstring='Code fixed and taught a lesson',
+ means_open=False))
+
+ view = framework_views.StatusView('SlamDunk', config)
+ self.assertEqual('Code fixed and taught a lesson', view.docstring)
+ self.assertFalse(view.means_open)
+
+ view = framework_views.StatusView('SlammedBack', config)
+ self.assertEqual('', view.docstring)
+
+
+class UserViewTest(unittest.TestCase):
+
+ def setUp(self):
+ self.user = user_pb2.User(user_id=111)
+
+ def testGetAvailablity_Anon(self):
+ self.user.user_id = 0
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual(None, user_view.avail_message)
+ self.assertEqual(None, user_view.avail_state)
+
+ def testGetAvailablity_Banned(self):
+ self.user.banned = 'spamming'
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('Banned', user_view.avail_message)
+ self.assertEqual('banned', user_view.avail_state)
+
+ def testGetAvailablity_Vacation(self):
+ self.user.vacation_message = 'gone fishing'
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('gone fishing', user_view.avail_message)
+ self.assertEqual('none', user_view.avail_state)
+
+ self.user.vacation_message = (
+ 'Gone fishing as really long time with lots of friends and reading '
+ 'a long novel by a famous author. I wont have internet access but '
+ 'If you urgently need anything you can call Alice or Bob for most '
+ 'things otherwise call Charlie. Wish me luck! ')
+ user_view = framework_views.UserView(self.user)
+ self.assertTrue(len(user_view.avail_message) >= 50)
+ self.assertTrue(len(user_view.avail_message_short) < 50)
+ self.assertEqual('none', user_view.avail_state)
+
+ def testGetAvailablity_Bouncing(self):
+ self.user.email_bounce_timestamp = 1234567890
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('Email to this user bounced', user_view.avail_message)
+ self.assertEqual(user_view.avail_message_short, user_view.avail_message)
+ self.assertEqual('none', user_view.avail_state)
+
+ def testGetAvailablity_Groups(self):
+ user_view = framework_views.UserView(self.user, is_group=True)
+ self.assertEqual(None, user_view.avail_message)
+ self.assertEqual(None, user_view.avail_state)
+
+ self.user.email = 'likely-user-group@example.com'
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual(None, user_view.avail_message)
+ self.assertEqual(None, user_view.avail_state)
+
+ def testGetAvailablity_NeverVisitied(self):
+ self.user.last_visit_timestamp = 0
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('User never visited', user_view.avail_message)
+ self.assertEqual('never', user_view.avail_state)
+
+ def testGetAvailablity_NotRecent(self):
+ now = int(time.time())
+ self.user.last_visit_timestamp = now - 20 * framework_constants.SECS_PER_DAY
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('Last visit 20 days ago', user_view.avail_message)
+ self.assertEqual('unsure', user_view.avail_state)
+
+ def testGetAvailablity_ReallyLongTime(self):
+ now = int(time.time())
+ self.user.last_visit_timestamp = now - 99 * framework_constants.SECS_PER_DAY
+ user_view = framework_views.UserView(self.user)
+ self.assertEqual('Last visit > 30 days ago', user_view.avail_message)
+ self.assertEqual('none', user_view.avail_state)
+
+ def testDeletedUser(self):
+ deleted_user = user_pb2.User(user_id=1)
+ user_view = framework_views.UserView(deleted_user)
+ self.assertEqual(
+ user_view.display_name, framework_constants.DELETED_USER_NAME)
+ self.assertEqual(user_view.email, '')
+ self.assertEqual(user_view.obscure_email, '')
+ self.assertEqual(user_view.profile_url, '')
+
+class RevealEmailsToMembersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.mr = monorailrequest.MonorailRequest(None)
+ self.mr.project = self.services.project.TestAddProject(
+ 'proj',
+ project_id=789,
+ owner_ids=[111],
+ committer_ids=[222],
+ contrib_ids=[333, 888])
+ user = self.services.user.TestAddUser('test@example.com', 1000)
+ self.mr.auth.user_pb = user
+
+ def CheckRevealAllToMember(
+ self, logged_in_user_id, expected, viewed_user_id=333, group_id=None):
+ user_view = framework_views.StuffUserView(
+ viewed_user_id, 'user@example.com', True)
+
+ if group_id:
+ pass # xxx re-implement groups
+
+ users_by_id = {333: user_view}
+ self.mr.auth.user_id = logged_in_user_id
+ self.mr.auth.effective_ids = {logged_in_user_id}
+ # Assert display name is obscured before the reveal.
+ self.assertEqual('u...@example.com', user_view.display_name)
+ # Assert profile url contains user ID before the reveal.
+ self.assertEqual('/u/%s/' % viewed_user_id, user_view.profile_url)
+ framework_views.RevealAllEmailsToMembers(
+ self.cnxn, self.services, self.mr.auth, users_by_id)
+ self.assertEqual(expected, not user_view.obscure_email)
+ if expected:
+ # Assert display name is now revealed.
+ self.assertEqual('user@example.com', user_view.display_name)
+ # Assert profile url contains the email.
+ self.assertEqual('/u/user@example.com/', user_view.profile_url)
+ else:
+ # Assert display name is still hidden.
+ self.assertEqual('u...@example.com', user_view.display_name)
+ # Assert profile url still contains user ID.
+ self.assertEqual('/u/%s/' % viewed_user_id, user_view.profile_url)
+
+ # TODO(https://crbug.com/monorail/8192): Remove this method and related test.
+ def DeprecatedCheckRevealAllToMember(
+ self, logged_in_user_id, expected, viewed_user_id=333, group_id=None):
+ user_view = framework_views.StuffUserView(
+ viewed_user_id, 'user@example.com', True)
+
+ if group_id:
+ pass # xxx re-implement groups
+
+ users_by_id = {333: user_view}
+ self.mr.auth.user_id = logged_in_user_id
+ self.mr.auth.effective_ids = {logged_in_user_id}
+ # Assert display name is obscured before the reveal.
+ self.assertEqual('u...@example.com', user_view.display_name)
+ # Assert profile url contains user ID before the reveal.
+ self.assertEqual('/u/%s/' % viewed_user_id, user_view.profile_url)
+ framework_views.RevealAllEmailsToMembers(
+ self.cnxn, self.services, self.mr.auth, users_by_id, self.mr.project)
+ self.assertEqual(expected, not user_view.obscure_email)
+ if expected:
+ # Assert display name is now revealed.
+ self.assertEqual('user@example.com', user_view.display_name)
+ # Assert profile url contains the email.
+ self.assertEqual('/u/user@example.com/', user_view.profile_url)
+ else:
+ # Assert display name is still hidden.
+ self.assertEqual('u...@example.com', user_view.display_name)
+ # Assert profile url still contains user ID.
+ self.assertEqual('/u/%s/' % viewed_user_id, user_view.profile_url)
+
+ def testDontRevealEmailsToPriviledgedDomain(self):
+ """We no longer give this advantage based on email address domain."""
+ for priviledged_user_domain in settings.priviledged_user_domains:
+ self.mr.auth.user_pb.email = 'test@' + priviledged_user_domain
+ self.CheckRevealAllToMember(100001, False)
+
+ def testRevealEmailToSelf(self):
+ logged_in_user = self.services.user.TestAddUser('user@example.com', 333)
+ self.mr.auth.user_pb = logged_in_user
+ self.CheckRevealAllToMember(333, True)
+
+ def testRevealAllEmailsToMembers_Collaborators(self):
+ self.CheckRevealAllToMember(0, False)
+ self.CheckRevealAllToMember(111, True)
+ self.CheckRevealAllToMember(222, True)
+ self.CheckRevealAllToMember(333, True)
+ self.CheckRevealAllToMember(444, False)
+
+ # Viewed user has indirect role in the project via a group.
+ self.CheckRevealAllToMember(0, False, group_id=888)
+ self.CheckRevealAllToMember(111, True, group_id=888)
+ # xxx re-implement
+ # self.CheckRevealAllToMember(
+ # 111, True, viewed_user_id=444, group_id=888)
+
+ # Logged in user has indirect role in the project via a group.
+ self.CheckRevealAllToMember(888, True)
+
+ def testDeprecatedRevealAllEmailsToMembers_Collaborators(self):
+ self.DeprecatedCheckRevealAllToMember(0, False)
+ self.DeprecatedCheckRevealAllToMember(111, True)
+ self.DeprecatedCheckRevealAllToMember(222, True)
+ self.DeprecatedCheckRevealAllToMember(333, True)
+ self.DeprecatedCheckRevealAllToMember(444, False)
+
+ # Viewed user has indirect role in the project via a group.
+ self.DeprecatedCheckRevealAllToMember(0, False, group_id=888)
+ self.DeprecatedCheckRevealAllToMember(111, True, group_id=888)
+
+ # Logged in user has indirect role in the project via a group.
+ self.DeprecatedCheckRevealAllToMember(888, True)
+
+ def testRevealAllEmailsToMembers_Admins(self):
+ self.CheckRevealAllToMember(555, False)
+ self.mr.auth.user_pb.is_site_admin = True
+ self.CheckRevealAllToMember(555, True)
+
+
+class RevealAllEmailsTest(unittest.TestCase):
+
+ def testRevealAllEmail(self):
+ users_by_id = {
+ 111: framework_views.StuffUserView(111, 'a@a.com', True),
+ 222: framework_views.StuffUserView(222, 'b@b.com', True),
+ 333: framework_views.StuffUserView(333, 'c@c.com', True),
+ 999: framework_views.StuffUserView(999, 'z@z.com', True),
+ }
+ # Assert display names are obscured before the reveal.
+ self.assertEqual('a...@a.com', users_by_id[111].display_name)
+ self.assertEqual('b...@b.com', users_by_id[222].display_name)
+ self.assertEqual('c...@c.com', users_by_id[333].display_name)
+ self.assertEqual('z...@z.com', users_by_id[999].display_name)
+
+ framework_views.RevealAllEmails(users_by_id)
+
+ self.assertFalse(users_by_id[111].obscure_email)
+ self.assertFalse(users_by_id[222].obscure_email)
+ self.assertFalse(users_by_id[333].obscure_email)
+ self.assertFalse(users_by_id[999].obscure_email)
+ # Assert display names are now revealed.
+ self.assertEqual('a@a.com', users_by_id[111].display_name)
+ self.assertEqual('b@b.com', users_by_id[222].display_name)
+ self.assertEqual('c@c.com', users_by_id[333].display_name)
+ self.assertEqual('z@z.com', users_by_id[999].display_name)
diff --git a/framework/test/gcs_helpers_test.py b/framework/test/gcs_helpers_test.py
new file mode 100644
index 0000000..3500e40
--- /dev/null
+++ b/framework/test/gcs_helpers_test.py
@@ -0,0 +1,185 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the framework_helpers module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+import uuid
+
+import mox
+
+from google.appengine.api import app_identity
+from google.appengine.api import images
+from google.appengine.api import urlfetch
+from google.appengine.ext import testbed
+from third_party import cloudstorage
+
+from framework import filecontent
+from framework import gcs_helpers
+from testing import fake
+from testing import testing_helpers
+
+
+class GcsHelpersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+ self.testbed.deactivate()
+
+ def testDeleteObjectFromGCS(self):
+ object_id = 'aaaaa'
+ bucket_name = 'test_bucket'
+ object_path = '/' + bucket_name + object_id
+
+ self.mox.StubOutWithMock(app_identity, 'get_default_gcs_bucket_name')
+ app_identity.get_default_gcs_bucket_name().AndReturn(bucket_name)
+
+ self.mox.StubOutWithMock(cloudstorage, 'delete')
+ cloudstorage.delete(object_path)
+
+ self.mox.ReplayAll()
+
+ gcs_helpers.DeleteObjectFromGCS(object_id)
+ self.mox.VerifyAll()
+
+ def testStoreObjectInGCS_ResizableMimeType(self):
+ guid = 'aaaaa'
+ project_id = 100
+ object_id = '/%s/attachments/%s' % (project_id, guid)
+ bucket_name = 'test_bucket'
+ object_path = '/' + bucket_name + object_id
+ mime_type = 'image/png'
+ content = 'content'
+ thumb_content = 'thumb_content'
+
+ self.mox.StubOutWithMock(app_identity, 'get_default_gcs_bucket_name')
+ app_identity.get_default_gcs_bucket_name().AndReturn(bucket_name)
+
+ self.mox.StubOutWithMock(uuid, 'uuid4')
+ uuid.uuid4().AndReturn(guid)
+
+ self.mox.StubOutWithMock(cloudstorage, 'open')
+ cloudstorage.open(
+ object_path, 'w', mime_type, options={}
+ ).AndReturn(fake.FakeFile())
+ cloudstorage.open(object_path + '-thumbnail', 'w', mime_type).AndReturn(
+ fake.FakeFile())
+
+ self.mox.StubOutWithMock(images, 'resize')
+ images.resize(content, gcs_helpers.DEFAULT_THUMB_WIDTH,
+ gcs_helpers.DEFAULT_THUMB_HEIGHT).AndReturn(thumb_content)
+
+ self.mox.ReplayAll()
+
+ ret_id = gcs_helpers.StoreObjectInGCS(
+ content, mime_type, project_id, gcs_helpers.DEFAULT_THUMB_WIDTH,
+ gcs_helpers.DEFAULT_THUMB_HEIGHT)
+ self.mox.VerifyAll()
+ self.assertEqual(object_id, ret_id)
+
+ def testStoreObjectInGCS_NotResizableMimeType(self):
+ guid = 'aaaaa'
+ project_id = 100
+ object_id = '/%s/attachments/%s' % (project_id, guid)
+ bucket_name = 'test_bucket'
+ object_path = '/' + bucket_name + object_id
+ mime_type = 'not_resizable_mime_type'
+ content = 'content'
+
+ self.mox.StubOutWithMock(app_identity, 'get_default_gcs_bucket_name')
+ app_identity.get_default_gcs_bucket_name().AndReturn(bucket_name)
+
+ self.mox.StubOutWithMock(uuid, 'uuid4')
+ uuid.uuid4().AndReturn(guid)
+
+ self.mox.StubOutWithMock(cloudstorage, 'open')
+ options = {'Content-Disposition': 'inline; filename="file.ext"'}
+ cloudstorage.open(
+ object_path, 'w', mime_type, options=options
+ ).AndReturn(fake.FakeFile())
+
+ self.mox.ReplayAll()
+
+ ret_id = gcs_helpers.StoreObjectInGCS(
+ content, mime_type, project_id, gcs_helpers.DEFAULT_THUMB_WIDTH,
+ gcs_helpers.DEFAULT_THUMB_HEIGHT, filename='file.ext')
+ self.mox.VerifyAll()
+ self.assertEqual(object_id, ret_id)
+
+ def testCheckMemeTypeResizable(self):
+ for resizable_mime_type in gcs_helpers.RESIZABLE_MIME_TYPES:
+ gcs_helpers.CheckMimeTypeResizable(resizable_mime_type)
+
+ with self.assertRaises(gcs_helpers.UnsupportedMimeType):
+ gcs_helpers.CheckMimeTypeResizable('not_resizable_mime_type')
+
+ def testStoreLogoInGCS(self):
+ file_name = 'test_file.png'
+ mime_type = 'image/png'
+ content = 'test content'
+ project_id = 100
+ object_id = 123
+
+ self.mox.StubOutWithMock(filecontent, 'GuessContentTypeFromFilename')
+ filecontent.GuessContentTypeFromFilename(file_name).AndReturn(mime_type)
+
+ self.mox.StubOutWithMock(gcs_helpers, 'StoreObjectInGCS')
+ gcs_helpers.StoreObjectInGCS(
+ content, mime_type, project_id,
+ thumb_width=gcs_helpers.LOGO_THUMB_WIDTH,
+ thumb_height=gcs_helpers.LOGO_THUMB_HEIGHT).AndReturn(object_id)
+
+ self.mox.ReplayAll()
+
+ ret_id = gcs_helpers.StoreLogoInGCS(file_name, content, project_id)
+ self.mox.VerifyAll()
+ self.assertEqual(object_id, ret_id)
+
+ @mock.patch('google.appengine.api.urlfetch.fetch')
+ def testFetchSignedURL_Success(self, mock_fetch):
+ mock_fetch.return_value = testing_helpers.Blank(
+ headers={'Location': 'signed url'})
+ actual = gcs_helpers._FetchSignedURL('signing req url')
+ mock_fetch.assert_called_with('signing req url', follow_redirects=False)
+ self.assertEqual('signed url', actual)
+
+ @mock.patch('google.appengine.api.urlfetch.fetch')
+ def testFetchSignedURL_UnderpopulatedResult(self, mock_fetch):
+ mock_fetch.return_value = testing_helpers.Blank(headers={})
+ self.assertRaises(
+ KeyError, gcs_helpers._FetchSignedURL, 'signing req url')
+
+ @mock.patch('google.appengine.api.urlfetch.fetch')
+ def testFetchSignedURL_DownloadError(self, mock_fetch):
+ mock_fetch.side_effect = urlfetch.DownloadError
+ self.assertRaises(
+ urlfetch.DownloadError,
+ gcs_helpers._FetchSignedURL, 'signing req url')
+
+ @mock.patch('framework.gcs_helpers._FetchSignedURL')
+ def testSignUrl_Success(self, mock_FetchSignedURL):
+ with mock.patch(
+ 'google.appengine.api.app_identity.get_access_token') as gat:
+ gat.return_value = ['token']
+ mock_FetchSignedURL.return_value = 'signed url'
+ signed_url = gcs_helpers.SignUrl('bucket', '/object')
+ self.assertEqual('signed url', signed_url)
+
+ @mock.patch('framework.gcs_helpers._FetchSignedURL')
+ def testSignUrl_DownloadError(self, mock_FetchSignedURL):
+ mock_FetchSignedURL.side_effect = urlfetch.DownloadError
+ self.assertEqual(
+ '/missing-gcs-url', gcs_helpers.SignUrl('bucket', '/object'))
diff --git a/framework/test/grid_view_helpers_test.py b/framework/test/grid_view_helpers_test.py
new file mode 100644
index 0000000..df3ecc6
--- /dev/null
+++ b/framework/test/grid_view_helpers_test.py
@@ -0,0 +1,201 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for grid_view_helpers classes and functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import framework_constants
+from framework import framework_views
+from framework import grid_view_helpers
+from proto import tracker_pb2
+from testing import fake
+from tracker import tracker_bizobj
+
+
+class GridViewHelpersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.default_cols = 'a b c'
+ self.builtin_cols = 'a b x y z'
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ self.art1 = fake.MakeTestIssue(
+ 789, 1, 'a summary', '', 0, derived_owner_id=111, star_count=12,
+ derived_labels='Priority-Medium Hot Mstone-1 Mstone-2',
+ derived_status='Overdue')
+ self.art2 = fake.MakeTestIssue(
+ 789, 1, 'a summary', 'New', 111, star_count=12, merged_into=200001,
+ labels='Priority-Medium Type-DEFECT Hot Mstone-1 Mstone-2')
+ self.users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', True),
+ }
+
+ def testSortGridHeadings(self):
+ config = fake.MakeTestConfig(
+ 789, labels=('Priority-High Priority-Medium Priority-Low Hot Cold '
+ 'Milestone-Near Milestone-Far '
+ 'Day-Sun Day-Mon Day-Tue Day-Wed Day-Thu Day-Fri Day-Sat'),
+ statuses=('New Accepted Started Fixed WontFix Invalid Duplicate'))
+ config.field_defs = [
+ tracker_pb2.FieldDef(field_id=1, project_id=789, field_name='Day',
+ field_type=tracker_pb2.FieldTypes.ENUM_TYPE)]
+ asc_accessors = {
+ 'id': 'some function that is not called',
+ 'reporter': 'some function that is not called',
+ 'opened': 'some function that is not called',
+ 'modified': 'some function that is not called',
+ }
+
+ # Verify that status headings are sorted according to the status
+ # values defined in the config.
+ col_name = 'status'
+ headings = ['Duplicate', 'Limbo', 'New', 'OnHold', 'Accepted', 'Fixed']
+ sorted_headings = grid_view_helpers.SortGridHeadings(
+ col_name, headings, self.users_by_id, config, asc_accessors)
+ self.assertEqual(
+ sorted_headings,
+ ['New', 'Accepted', 'Fixed', 'Duplicate', 'Limbo', 'OnHold'])
+
+ # Verify that special columns are sorted alphabetically or numerically.
+ col_name = 'id'
+ headings = [1, 2, 5, 3, 4]
+ sorted_headings = grid_view_helpers.SortGridHeadings(
+ col_name, headings, self.users_by_id, config, asc_accessors)
+ self.assertEqual(sorted_headings,
+ [1, 2, 3, 4, 5])
+
+ # Verify that label value headings are sorted according to the labels
+ # values defined in the config.
+ col_name = 'priority'
+ headings = ['Medium', 'High', 'Low', 'dont-care']
+ sorted_headings = grid_view_helpers.SortGridHeadings(
+ col_name, headings, self.users_by_id, config, asc_accessors)
+ self.assertEqual(sorted_headings,
+ ['High', 'Medium', 'Low', 'dont-care'])
+
+ # Verify that enum headings are sorted according to the labels
+ # values defined in the config.
+ col_name = 'day'
+ headings = ['Tue', 'Fri', 'Sun', 'Dogday', 'Wed', 'Caturday', 'Low']
+ sorted_headings = grid_view_helpers.SortGridHeadings(
+ col_name, headings, self.users_by_id, config, asc_accessors)
+ self.assertEqual(sorted_headings,
+ ['Sun', 'Tue', 'Wed', 'Fri',
+ 'Caturday', 'Dogday', 'Low'])
+
+ def testGetArtifactAttr_Explicit(self):
+ label_values = grid_view_helpers.MakeLabelValuesDict(self.art2)
+
+ id_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'id', self.users_by_id, label_values, self.config, {})
+ self.assertEqual([1], id_vals)
+ summary_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'summary', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['a summary'], summary_vals)
+ status_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'status', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['New'], status_vals)
+ stars_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'stars', self.users_by_id, label_values, self.config, {})
+ self.assertEqual([12], stars_vals)
+ owner_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'owner', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['f...@example.com'], owner_vals)
+ priority_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'priority', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['Medium'], priority_vals)
+ mstone_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'mstone', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['1', '2'], mstone_vals)
+ foo_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'foo', self.users_by_id, label_values, self.config, {})
+ self.assertEqual([framework_constants.NO_VALUES], foo_vals)
+ art3 = fake.MakeTestIssue(
+ 987, 5, 'unecessary summary', 'New', 111, star_count=12,
+ issue_id=200001, project_name='other-project')
+ related_issues = {200001: art3}
+ merged_into_vals = grid_view_helpers.GetArtifactAttr(
+ self.art2, 'mergedinto', self.users_by_id, label_values,
+ self.config, related_issues)
+ self.assertEqual(['other-project:5'], merged_into_vals)
+
+ def testGetArtifactAttr_Derived(self):
+ label_values = grid_view_helpers.MakeLabelValuesDict(self.art1)
+ status_vals = grid_view_helpers.GetArtifactAttr(
+ self.art1, 'status', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['Overdue'], status_vals)
+ owner_vals = grid_view_helpers.GetArtifactAttr(
+ self.art1, 'owner', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['f...@example.com'], owner_vals)
+ priority_vals = grid_view_helpers.GetArtifactAttr(
+ self.art1, 'priority', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['Medium'], priority_vals)
+ mstone_vals = grid_view_helpers.GetArtifactAttr(
+ self.art1, 'mstone', self.users_by_id, label_values, self.config, {})
+ self.assertEqual(['1', '2'], mstone_vals)
+
+ def testMakeLabelValuesDict_Empty(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'a summary', '', 0, derived_owner_id=111, star_count=12)
+ label_values = grid_view_helpers.MakeLabelValuesDict(art)
+ self.assertEqual({}, label_values)
+
+ def testMakeLabelValuesDict(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'a summary', '', 0, derived_owner_id=111, star_count=12,
+ labels=['Priority-Medium', 'Hot', 'Mstone-1', 'Mstone-2'])
+ label_values = grid_view_helpers.MakeLabelValuesDict(art)
+ self.assertEqual(
+ {'priority': ['Medium'], 'mstone': ['1', '2']},
+ label_values)
+
+ art = fake.MakeTestIssue(
+ 789, 1, 'a summary', '', 0, derived_owner_id=111, star_count=12,
+ labels='Priority-Medium Hot Mstone-1'.split(),
+ derived_labels=['Mstone-2'])
+ label_values = grid_view_helpers.MakeLabelValuesDict(art)
+ self.assertEqual(
+ {'priority': ['Medium'], 'mstone': ['1', '2']},
+ label_values)
+
+ def testMakeDrillDownSearch(self):
+ self.assertEqual('-has:milestone ',
+ grid_view_helpers.MakeDrillDownSearch('milestone', '----'))
+ self.assertEqual('milestone=22 ',
+ grid_view_helpers.MakeDrillDownSearch('milestone', '22'))
+ self.assertEqual(
+ 'owner=a@example.com ',
+ grid_view_helpers.MakeDrillDownSearch('owner', 'a@example.com'))
+
+ def testAnyArtifactHasNoAttr_Empty(self):
+ artifacts = []
+ all_label_values = {}
+ self.assertFalse(grid_view_helpers.AnyArtifactHasNoAttr(
+ artifacts, 'milestone', self.users_by_id, all_label_values,
+ self.config, {}))
+
+ def testAnyArtifactHasNoAttr(self):
+ artifacts = [self.art1]
+ all_label_values = {
+ self.art1.local_id: grid_view_helpers.MakeLabelValuesDict(self.art1),
+ }
+ self.assertFalse(grid_view_helpers.AnyArtifactHasNoAttr(
+ artifacts, 'mstone', self.users_by_id, all_label_values,
+ self.config, {}))
+ self.assertTrue(grid_view_helpers.AnyArtifactHasNoAttr(
+ artifacts, 'milestone', self.users_by_id, all_label_values,
+ self.config, {}))
+
+ def testGetGridViewData(self):
+ # TODO(jojwang): write this test
+ pass
+
+ def testPrepareForMakeGridData(self):
+ # TODO(jojwang): write this test
+ pass
diff --git a/framework/test/jsonfeed_test.py b/framework/test/jsonfeed_test.py
new file mode 100644
index 0000000..0a569e2
--- /dev/null
+++ b/framework/test/jsonfeed_test.py
@@ -0,0 +1,141 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for jsonfeed module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import httplib
+import logging
+import unittest
+
+from google.appengine.api import app_identity
+
+from framework import jsonfeed
+from framework import servlet
+from framework import xsrf
+from services import service_manager
+from testing import testing_helpers
+
+
+class JsonFeedTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = 'fake cnxn'
+
+ def testGet(self):
+ """Tests handling of GET requests."""
+ feed = TestableJsonFeed()
+
+ # all expected args are present + a bonus arg that should be ignored
+ feed.mr = testing_helpers.MakeMonorailRequest(
+ path='/foo/bar/wee?sna=foo', method='POST',
+ params={'a': '123', 'z': 'zebra'})
+ feed.get()
+
+ self.assertEqual(True, feed.handle_request_called)
+ self.assertEqual(1, len(feed.json_data))
+
+ def testPost(self):
+ """Tests handling of POST requests."""
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest(
+ path='/foo/bar/wee?sna=foo', method='POST',
+ params={'a': '123', 'z': 'zebra'})
+
+ feed.post()
+
+ self.assertEqual(True, feed.handle_request_called)
+ self.assertEqual(1, len(feed.json_data))
+
+ def testSecurityTokenChecked_BadToken(self):
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest(
+ user_info={'user_id': 555})
+ # Note that feed.mr has no token set.
+ self.assertRaises(xsrf.TokenIncorrect, feed.get)
+ self.assertRaises(xsrf.TokenIncorrect, feed.post)
+
+ feed.mr.token = 'bad token'
+ self.assertRaises(xsrf.TokenIncorrect, feed.get)
+ self.assertRaises(xsrf.TokenIncorrect, feed.post)
+
+ def testSecurityTokenChecked_HandlerDoesNotNeedToken(self):
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest(
+ user_info={'user_id': 555})
+ # Note that feed.mr has no token set.
+ feed.CHECK_SECURITY_TOKEN = False
+ feed.get()
+ feed.post()
+
+ def testSecurityTokenChecked_AnonUserDoesNotNeedToken(self):
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ # Note that feed.mr has no token set, but also no auth.user_id.
+ feed.get()
+ feed.post()
+
+ def testSameAppOnly_ExternallyAccessible(self):
+ feed = TestableJsonFeed()
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ # Note that request has no X-Appengine-Inbound-Appid set.
+ feed.get()
+ feed.post()
+
+ def testSameAppOnly_InternalOnlyCalledFromSameApp(self):
+ feed = TestableJsonFeed()
+ feed.CHECK_SAME_APP = True
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ app_id = app_identity.get_application_id()
+ feed.mr.request.headers['X-Appengine-Inbound-Appid'] = app_id
+ feed.get()
+ feed.post()
+
+ def testSameAppOnly_InternalOnlyCalledExternally(self):
+ feed = TestableJsonFeed()
+ feed.CHECK_SAME_APP = True
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ # Note that request has no X-Appengine-Inbound-Appid set.
+ self.assertIsNone(feed.get())
+ self.assertFalse(feed.handle_request_called)
+ self.assertEqual(httplib.FORBIDDEN, feed.response.status)
+ self.assertIsNone(feed.post())
+ self.assertFalse(feed.handle_request_called)
+ self.assertEqual(httplib.FORBIDDEN, feed.response.status)
+
+ def testSameAppOnly_InternalOnlyCalledFromWrongApp(self):
+ feed = TestableJsonFeed()
+ feed.CHECK_SAME_APP = True
+ feed.mr = testing_helpers.MakeMonorailRequest()
+ feed.mr.request.headers['X-Appengine-Inbound-Appid'] = 'wrong'
+ self.assertIsNone(feed.get())
+ self.assertFalse(feed.handle_request_called)
+ self.assertEqual(httplib.FORBIDDEN, feed.response.status)
+ self.assertIsNone(feed.post())
+ self.assertFalse(feed.handle_request_called)
+ self.assertEqual(httplib.FORBIDDEN, feed.response.status)
+
+
+class TestableJsonFeed(jsonfeed.JsonFeed):
+
+ def __init__(self, request=None):
+ response = testing_helpers.Blank()
+ super(TestableJsonFeed, self).__init__(
+ request or 'req', response, services=service_manager.Services())
+
+ self.response_data = None
+ self.handle_request_called = False
+ self.json_data = None
+
+ def HandleRequest(self, mr):
+ self.handle_request_called = True
+ return {'a': mr.GetParam('a')}
+
+ # The output chain is hard to double so we pass on that phase,
+ # but save the response data for inspection
+ def _RenderJsonResponse(self, json_data):
+ self.json_data = json_data
diff --git a/framework/test/monitoring_test.py b/framework/test/monitoring_test.py
new file mode 100644
index 0000000..edbd15d
--- /dev/null
+++ b/framework/test/monitoring_test.py
@@ -0,0 +1,86 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+"""Unit tests for the monitoring module."""
+
+import unittest
+from framework import monitoring
+
+COMMON_TEST_FIELDS = monitoring.GetCommonFields(200, 'monorail.v3.MethodName')
+
+
+class MonitoringTest(unittest.TestCase):
+
+ def testIncrementAPIRequestsCount(self):
+ # Non-service account email gets hidden.
+ monitoring.IncrementAPIRequestsCount(
+ 'v3', 'monorail-prod', client_email='client-email@chicken.com')
+ self.assertEqual(
+ 1,
+ monitoring.API_REQUESTS_COUNT.get(
+ fields={
+ 'client_id': 'monorail-prod',
+ 'client_email': 'user@email.com',
+ 'version': 'v3'
+ }))
+
+ # None email address gets replaced by 'anonymous'.
+ monitoring.IncrementAPIRequestsCount('v3', 'monorail-prod')
+ self.assertEqual(
+ 1,
+ monitoring.API_REQUESTS_COUNT.get(
+ fields={
+ 'client_id': 'monorail-prod',
+ 'client_email': 'anonymous',
+ 'version': 'v3'
+ }))
+
+ # Service account email is not hidden
+ monitoring.IncrementAPIRequestsCount(
+ 'endpoints',
+ 'monorail-prod',
+ client_email='123456789@developer.gserviceaccount.com')
+ self.assertEqual(
+ 1,
+ monitoring.API_REQUESTS_COUNT.get(
+ fields={
+ 'client_id': 'monorail-prod',
+ 'client_email': '123456789@developer.gserviceaccount.com',
+ 'version': 'endpoints'
+ }))
+
+ def testGetCommonFields(self):
+ fields = monitoring.GetCommonFields(200, 'monorail.v3.TestName')
+ self.assertEqual(
+ {
+ 'status': 200,
+ 'name': 'monorail.v3.TestName',
+ 'is_robot': False
+ }, fields)
+
+ def testAddServerDurations(self):
+ self.assertIsNone(
+ monitoring.SERVER_DURATIONS.get(fields=COMMON_TEST_FIELDS))
+ monitoring.AddServerDurations(500, COMMON_TEST_FIELDS)
+ self.assertIsNotNone(
+ monitoring.SERVER_DURATIONS.get(fields=COMMON_TEST_FIELDS))
+
+ def testIncrementServerResponseStatusCount(self):
+ monitoring.IncrementServerResponseStatusCount(COMMON_TEST_FIELDS)
+ self.assertEqual(
+ 1, monitoring.SERVER_RESPONSE_STATUS.get(fields=COMMON_TEST_FIELDS))
+
+ def testAddServerRequesteBytes(self):
+ self.assertIsNone(
+ monitoring.SERVER_REQUEST_BYTES.get(fields=COMMON_TEST_FIELDS))
+ monitoring.AddServerRequesteBytes(1234, COMMON_TEST_FIELDS)
+ self.assertIsNotNone(
+ monitoring.SERVER_REQUEST_BYTES.get(fields=COMMON_TEST_FIELDS))
+
+ def testAddServerResponseBytes(self):
+ self.assertIsNone(
+ monitoring.SERVER_RESPONSE_BYTES.get(fields=COMMON_TEST_FIELDS))
+ monitoring.AddServerResponseBytes(9876, COMMON_TEST_FIELDS)
+ self.assertIsNotNone(
+ monitoring.SERVER_RESPONSE_BYTES.get(fields=COMMON_TEST_FIELDS))
diff --git a/framework/test/monorailcontext_test.py b/framework/test/monorailcontext_test.py
new file mode 100644
index 0000000..ed93920
--- /dev/null
+++ b/framework/test/monorailcontext_test.py
@@ -0,0 +1,89 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for MonorailContext."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import mox
+
+from framework import authdata
+from framework import monorailcontext
+from framework import permissions
+from framework import profiler
+from framework import template_helpers
+from framework import sql
+from services import service_manager
+from testing import fake
+
+
+class MonorailContextTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cnxn = fake.MonorailConnection()
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService(),
+ project=fake.ProjectService())
+ self.project = self.services.project.TestAddProject(
+ 'proj', project_id=789, owner_ids=[111])
+ self.user = self.services.user.TestAddUser('owner@example.com', 111)
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testConstructor_PassingAuthAndPerms(self):
+ """We can easily make an mc for testing."""
+ auth = authdata.AuthData(user_id=111, email='owner@example.com')
+ mc = monorailcontext.MonorailContext(
+ None, cnxn=self.cnxn, auth=auth, perms=permissions.USER_PERMISSIONSET)
+ self.assertEqual(self.cnxn, mc.cnxn)
+ self.assertEqual(auth, mc.auth)
+ self.assertEqual(permissions.USER_PERMISSIONSET, mc.perms)
+ self.assertTrue(isinstance(mc.profiler, profiler.Profiler))
+ self.assertEqual([], mc.warnings)
+ self.assertTrue(isinstance(mc.errors, template_helpers.EZTError))
+
+ mc.CleanUp()
+ self.assertIsNone(mc.cnxn)
+
+ def testConstructor_AsUsedInApp(self):
+ """We can make an mc like it is done in the app or a test."""
+ self.mox.StubOutClassWithMocks(sql, 'MonorailConnection')
+ mock_cnxn = sql.MonorailConnection()
+ mock_cnxn.Close()
+ requester = 'new-user@example.com'
+ self.mox.ReplayAll()
+
+ mc = monorailcontext.MonorailContext(self.services, requester=requester)
+ mc.LookupLoggedInUserPerms(self.project)
+ self.assertEqual(mock_cnxn, mc.cnxn)
+ self.assertEqual(requester, mc.auth.email)
+ self.assertEqual(permissions.USER_PERMISSIONSET, mc.perms)
+ self.assertTrue(isinstance(mc.profiler, profiler.Profiler))
+ self.assertEqual([], mc.warnings)
+ self.assertTrue(isinstance(mc.errors, template_helpers.EZTError))
+
+ mc.CleanUp()
+ self.assertIsNone(mc.cnxn)
+
+ # Double Cleanup or Cleanup with no cnxn is not a crash.
+ mc.CleanUp()
+ self.assertIsNone(mc.cnxn)
+
+ def testRepr(self):
+ """We get nice debugging strings."""
+ auth = authdata.AuthData(user_id=111, email='owner@example.com')
+ mc = monorailcontext.MonorailContext(
+ None, cnxn=self.cnxn, auth=auth, perms=permissions.USER_PERMISSIONSET)
+ repr_str = '%r' % mc
+ self.assertTrue(repr_str.startswith('MonorailContext('))
+ self.assertIn('owner@example.com', repr_str)
+ self.assertIn('view', repr_str)
diff --git a/framework/test/monorailrequest_test.py b/framework/test/monorailrequest_test.py
new file mode 100644
index 0000000..fcd30c3
--- /dev/null
+++ b/framework/test/monorailrequest_test.py
@@ -0,0 +1,613 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the monorailrequest module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import endpoints
+import mock
+import re
+import unittest
+
+import mox
+import six
+
+from google.appengine.api import oauth
+from google.appengine.api import users
+
+import webapp2
+
+from framework import exceptions
+from framework import monorailrequest
+from framework import permissions
+from proto import project_pb2
+from proto import tracker_pb2
+from services import service_manager
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_constants
+
+
+class HostportReTest(unittest.TestCase):
+
+ def testGood(self):
+ test_data = [
+ 'localhost:8080',
+ 'app.appspot.com',
+ 'bugs-staging.chromium.org',
+ 'vers10n-h3x-dot-app-id.appspot.com',
+ ]
+ for hostport in test_data:
+ self.assertTrue(monorailrequest._HOSTPORT_RE.match(hostport),
+ msg='Incorrectly rejected %r' % hostport)
+
+ def testBad(self):
+ test_data = [
+ '',
+ ' ',
+ '\t',
+ '\n',
+ '\'',
+ '"',
+ 'version"cruft-dot-app-id.appspot.com',
+ '\nother header',
+ 'version&cruft-dot-app-id.appspot.com',
+ ]
+ for hostport in test_data:
+ self.assertFalse(monorailrequest._HOSTPORT_RE.match(hostport),
+ msg='Incorrectly accepted %r' % hostport)
+
+
+class MonorailApiRequestUnitTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = 'fake cnxn'
+ self.services = service_manager.Services(
+ config=fake.ConfigService(),
+ issue=fake.IssueService(),
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.project = self.services.project.TestAddProject(
+ 'proj', project_id=789)
+ self.services.user.TestAddUser('requester@example.com', 111)
+ self.issue = fake.MakeTestIssue(
+ 789, 1, 'sum', 'New', 111)
+ self.services.issue.TestAddIssue(self.issue)
+
+ self.patcher_1 = mock.patch('endpoints.get_current_user')
+ self.mock_endpoints_gcu = self.patcher_1.start()
+ self.mock_endpoints_gcu.return_value = None
+ self.patcher_2 = mock.patch('google.appengine.api.oauth.get_current_user')
+ self.mock_oauth_gcu = self.patcher_2.start()
+ self.mock_oauth_gcu.return_value = testing_helpers.Blank(
+ email=lambda: 'requester@example.com')
+
+ def tearDown(self):
+ mock.patch.stopall()
+
+ def testInit_NoProjectIssueOrViewedUser(self):
+ request = testing_helpers.Blank()
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertIsNone(mar.project)
+ self.assertIsNone(mar.issue)
+
+ def testInit_WithProject(self):
+ request = testing_helpers.Blank(projectId='proj')
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertEqual(self.project, mar.project)
+ self.assertIsNone(mar.issue)
+
+ def testInit_WithProjectAndIssue(self):
+ request = testing_helpers.Blank(
+ projectId='proj', issueId=1)
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertEqual(self.project, mar.project)
+ self.assertEqual(self.issue, mar.issue)
+
+ def testGetParam_Normal(self):
+ request = testing_helpers.Blank(q='owner:me')
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertEqual(None, mar.GetParam('unknown'))
+ self.assertEqual(100, mar.GetParam('num'))
+ self.assertEqual('owner:me', mar.GetParam('q'))
+
+ request = testing_helpers.Blank(q='owner:me', maxResults=200)
+ mar = monorailrequest.MonorailApiRequest(
+ request, self.services, cnxn=self.cnxn)
+ self.assertEqual(200, mar.GetParam('num'))
+
+
+class MonorailRequestUnitTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService(),
+ features=fake.FeaturesService())
+ self.project = self.services.project.TestAddProject('proj')
+ self.hotlist = self.services.features.TestAddHotlist(
+ 'TestHotlist', owner_ids=[111])
+ self.services.user.TestAddUser('jrobbins@example.com', 111)
+
+ self.mox = mox.Mox()
+ self.mox.StubOutWithMock(users, 'get_current_user')
+ users.get_current_user().AndReturn(None)
+ self.mox.ReplayAll()
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+
+ def testGetIntParam_ConvertsQueryParamToInt(self):
+ notice_id = 12345
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/foo?notice=%s' % notice_id)
+
+ value = mr.GetIntParam('notice')
+ self.assertTrue(isinstance(value, int))
+ self.assertEqual(notice_id, value)
+
+ def testGetIntParam_ConvertsQueryParamToLong(self):
+ notice_id = 12345678901234567890
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/foo?notice=%s' % notice_id)
+
+ value = mr.GetIntParam('notice')
+ self.assertTrue(isinstance(value, six.integer_types))
+ self.assertEqual(notice_id, value)
+
+ def testGetIntListParam_NoParam(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.ParseRequest(webapp2.Request.blank('servlet'), self.services)
+ self.assertEqual(mr.GetIntListParam('ids'), None)
+ self.assertEqual(mr.GetIntListParam('ids', default_value=['test']),
+ ['test'])
+
+ def testGetIntListParam_OneValue(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.ParseRequest(webapp2.Request.blank('servlet?ids=11'), self.services)
+ self.assertEqual(mr.GetIntListParam('ids'), [11])
+ self.assertEqual(mr.GetIntListParam('ids', default_value=['test']),
+ [11])
+
+ def testGetIntListParam_MultiValue(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.ParseRequest(
+ webapp2.Request.blank('servlet?ids=21,22,23'), self.services)
+ self.assertEqual(mr.GetIntListParam('ids'), [21, 22, 23])
+ self.assertEqual(mr.GetIntListParam('ids', default_value=['test']),
+ [21, 22, 23])
+
+ def testGetIntListParam_BogusValue(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ with self.assertRaises(exceptions.InputException):
+ mr.ParseRequest(
+ webapp2.Request.blank('servlet?ids=not_an_int'), self.services)
+
+ def testGetIntListParam_Malformed(self):
+ mr = monorailrequest.MonorailRequest(self.services)
+ with self.assertRaises(exceptions.InputException):
+ mr.ParseRequest(
+ webapp2.Request.blank('servlet?ids=31,32,,'), self.services)
+
+ def testDefaultValuesNoUrl(self):
+ """If request has no param, default param values should be used."""
+ mr = monorailrequest.MonorailRequest(self.services)
+ mr.ParseRequest(webapp2.Request.blank('servlet'), self.services)
+ self.assertEqual(mr.GetParam('r', 3), 3)
+ self.assertEqual(mr.GetIntParam('r', 3), 3)
+ self.assertEqual(mr.GetPositiveIntParam('r', 3), 3)
+ self.assertEqual(mr.GetIntListParam('r', [3, 4]), [3, 4])
+
+ def _MRWithMockRequest(
+ self, path, headers=None, *mr_args, **mr_kwargs):
+ request = webapp2.Request.blank(path, headers=headers)
+ mr = monorailrequest.MonorailRequest(self.services, *mr_args, **mr_kwargs)
+ mr.ParseRequest(request, self.services)
+ return mr
+
+ def testParseQueryParameters(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50')
+ self.assertEqual('foo OR bar', mr.query)
+ self.assertEqual(50, mr.num)
+
+ def testParseQueryParameters_ModeMissing(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50')
+ self.assertEqual('list', mr.mode)
+
+ def testParseQueryParameters_ModeList(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50&mode=')
+ self.assertEqual('list', mr.mode)
+
+ def testParseQueryParameters_ModeGrid(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50&mode=grid')
+ self.assertEqual('grid', mr.mode)
+
+ def testParseQueryParameters_ModeChart(self):
+ mr = self._MRWithMockRequest(
+ '/p/proj/issues/list?q=foo+OR+bar&num=50&mode=chart')
+ self.assertEqual('chart', mr.mode)
+
+ def testParseRequest_Scheme(self):
+ mr = self._MRWithMockRequest('/p/proj/')
+ self.assertEqual('http', mr.request.scheme)
+
+ def testParseRequest_HostportAndCurrentPageURL(self):
+ mr = self._MRWithMockRequest('/p/proj/', headers={
+ 'Host': 'example.com',
+ 'Cookie': 'asdf',
+ })
+ self.assertEqual('http', mr.request.scheme)
+ self.assertEqual('example.com', mr.request.host)
+ self.assertEqual('http://example.com/p/proj/', mr.current_page_url)
+
+ def testParseRequest_ProjectFound(self):
+ mr = self._MRWithMockRequest('/p/proj/')
+ self.assertEqual(mr.project, self.project)
+
+ def testParseRequest_ProjectNotFound(self):
+ with self.assertRaises(exceptions.NoSuchProjectException):
+ self._MRWithMockRequest('/p/no-such-proj/')
+
+ def testViewedUser_WithEmail(self):
+ mr = self._MRWithMockRequest('/u/jrobbins@example.com/')
+ self.assertEqual('jrobbins@example.com', mr.viewed_username)
+ self.assertEqual(111, mr.viewed_user_auth.user_id)
+ self.assertEqual(
+ self.services.user.GetUser('fake cnxn', 111),
+ mr.viewed_user_auth.user_pb)
+
+ def testViewedUser_WithUserID(self):
+ mr = self._MRWithMockRequest('/u/111/')
+ self.assertEqual('jrobbins@example.com', mr.viewed_username)
+ self.assertEqual(111, mr.viewed_user_auth.user_id)
+ self.assertEqual(
+ self.services.user.GetUser('fake cnxn', 111),
+ mr.viewed_user_auth.user_pb)
+
+ def testViewedUser_NoSuchEmail(self):
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self._MRWithMockRequest('/u/unknownuser@example.com/')
+ self.assertEqual(404, cm.exception.code)
+
+ def testViewedUser_NoSuchUserID(self):
+ with self.assertRaises(exceptions.NoSuchUserException):
+ self._MRWithMockRequest('/u/234521111/')
+
+ def testGetParam(self):
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/foo?syn=error!&a=a&empty=',
+ params=dict(over1='over_value1', over2='over_value2'))
+
+ # test tampering
+ self.assertRaises(exceptions.InputException, mr.GetParam, 'a',
+ antitamper_re=re.compile(r'^$'))
+ self.assertRaises(exceptions.InputException, mr.GetParam,
+ 'undefined', default_value='default',
+ antitamper_re=re.compile(r'^$'))
+
+ # test empty value
+ self.assertEqual('', mr.GetParam(
+ 'empty', default_value='default', antitamper_re=re.compile(r'^$')))
+
+ # test default
+ self.assertEqual('default', mr.GetParam(
+ 'undefined', default_value='default'))
+
+ def testComputeColSpec(self):
+ # No config passed, and nothing in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123')
+ mr.ComputeColSpec(None)
+ self.assertEqual(tracker_constants.DEFAULT_COL_SPEC, mr.col_spec)
+
+ # No config passed, but set in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123&colspec=a b C')
+ mr.ComputeColSpec(None)
+ self.assertEqual('a b C', mr.col_spec)
+
+ config = tracker_pb2.ProjectIssueConfig()
+
+ # No default in the config, and nothing in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123')
+ mr.ComputeColSpec(config)
+ self.assertEqual(tracker_constants.DEFAULT_COL_SPEC, mr.col_spec)
+
+ # No default in the config, but set in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123&colspec=a b C')
+ mr.ComputeColSpec(config)
+ self.assertEqual('a b C', mr.col_spec)
+
+ config.default_col_spec = 'd e f'
+
+ # Default in the config, and nothing in URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123')
+ mr.ComputeColSpec(config)
+ self.assertEqual('d e f', mr.col_spec)
+
+ # Default in the config, but overrided via URL
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123&colspec=a b C')
+ mr.ComputeColSpec(config)
+ self.assertEqual('a b C', mr.col_spec)
+
+ # project colspec contains hotlist columns
+ mr = testing_helpers.MakeMonorailRequest(
+ path='p/proj/issues/detail?id=123&colspec=Rank Adder Adder Owner')
+ mr.ComputeColSpec(None)
+ self.assertEqual(tracker_constants.DEFAULT_COL_SPEC, mr.col_spec)
+
+ # hotlist columns are not deleted when page is a hotlist page
+ mr = testing_helpers.MakeMonorailRequest(
+ path='u/jrobbins@example.com/hotlists/TestHotlist?colspec=Rank Adder',
+ hotlist=self.hotlist)
+ mr.ComputeColSpec(None)
+ self.assertEqual('Rank Adder', mr.col_spec)
+
+ def testComputeColSpec_XSS(self):
+ config_1 = tracker_pb2.ProjectIssueConfig()
+ config_2 = tracker_pb2.ProjectIssueConfig()
+ config_2.default_col_spec = "id '+alert(1)+'"
+ mr_1 = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/issues/detail?id=123')
+ mr_2 = testing_helpers.MakeMonorailRequest(
+ path="/p/proj/issues/detail?id=123&colspec=id '+alert(1)+'")
+
+ # Normal colspec in config but malicious request
+ self.assertRaises(
+ exceptions.InputException,
+ mr_2.ComputeColSpec, config_1)
+
+ # Malicious colspec in config but normal request
+ self.assertRaises(
+ exceptions.InputException,
+ mr_1.ComputeColSpec, config_2)
+
+ # Malicious colspec in config and malicious request
+ self.assertRaises(
+ exceptions.InputException,
+ mr_2.ComputeColSpec, config_2)
+
+
+class CalcDefaultQueryTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project = project_pb2.Project()
+ self.project.project_name = 'proj'
+ self.project.owner_ids = [111]
+ self.config = tracker_pb2.ProjectIssueConfig()
+
+ def testIssueListURL_NotDefaultCan(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 1
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_NoProject(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_NoConfig(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ mr.project = self.project
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_NotCustomized(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ mr.project = self.project
+ mr.config = self.config
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_Customized_Nonmember(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ mr.project = self.project
+ mr.config = self.config
+ mr.config.member_default_query = 'owner:me'
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ mr.auth = testing_helpers.Blank(effective_ids=set())
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ mr.auth = testing_helpers.Blank(effective_ids={999})
+ self.assertEqual('', mr._CalcDefaultQuery())
+
+ def testIssueListURL_Customized_Member(self):
+ mr = monorailrequest.MonorailRequest(None)
+ mr.query = None
+ mr.can = 2
+ mr.project = self.project
+ mr.config = self.config
+ mr.config.member_default_query = 'owner:me'
+ mr.auth = testing_helpers.Blank(effective_ids={111})
+ self.assertEqual('owner:me', mr._CalcDefaultQuery())
+
+
+class TestMonorailRequestFunctions(unittest.TestCase):
+
+ def testExtractPathIdentifiers_ProjectOnly(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/p/proj/issues/list?q=foo+OR+bar&ts=1234')
+ self.assertIsNone(username)
+ self.assertIsNone(hotlist_id)
+ self.assertIsNone(hotlist_name)
+ self.assertEqual('proj', project_name)
+
+ def testExtractPathIdentifiers_ViewedUserOnly(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/u/jrobbins@example.com/')
+ self.assertEqual('jrobbins@example.com', username)
+ self.assertIsNone(project_name)
+ self.assertIsNone(hotlist_id)
+ self.assertIsNone(hotlist_name)
+
+ def testExtractPathIdentifiers_ViewedUserURLSpace(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/u/jrobbins@example.com/updates')
+ self.assertEqual('jrobbins@example.com', username)
+ self.assertIsNone(project_name)
+ self.assertIsNone(hotlist_id)
+ self.assertIsNone(hotlist_name)
+
+ def testExtractPathIdentifiers_ViewedGroupURLSpace(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/g/user-group@example.com/updates')
+ self.assertEqual('user-group@example.com', username)
+ self.assertIsNone(project_name)
+ self.assertIsNone(hotlist_id)
+ self.assertIsNone(hotlist_name)
+
+ def testExtractPathIdentifiers_HotlistIssuesURLSpaceById(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/u/jrobbins@example.com/hotlists/13124?q=stuff&ts=more')
+ self.assertIsNone(hotlist_name)
+ self.assertIsNone(project_name)
+ self.assertEqual('jrobbins@example.com', username)
+ self.assertEqual(13124, hotlist_id)
+
+ def testExtractPathIdentifiers_HotlistIssuesURLSpaceByName(self):
+ (username, project_name, hotlist_id,
+ hotlist_name) = monorailrequest._ParsePathIdentifiers(
+ '/u/jrobbins@example.com/hotlists/testname?q=stuff&ts=more')
+ self.assertIsNone(project_name)
+ self.assertIsNone(hotlist_id)
+ self.assertEqual('jrobbins@example.com', username)
+ self.assertEqual('testname', hotlist_name)
+
+ def testParseColSpec(self):
+ parse = monorailrequest.ParseColSpec
+ self.assertEqual(['PageName', 'Summary', 'Changed', 'ChangedBy'],
+ parse(u'PageName Summary Changed ChangedBy'))
+ self.assertEqual(['Foo-Bar', 'Foo-Bar-Baz', 'Release-1.2', 'Hey', 'There'],
+ parse('Foo-Bar Foo-Bar-Baz Release-1.2 Hey!There'))
+ self.assertEqual(
+ ['\xe7\xaa\xbf\xe8\x8b\xa5\xe7\xb9\xb9'.decode('utf-8'),
+ '\xe5\x9f\xba\xe5\x9c\xb0\xe3\x81\xaf'.decode('utf-8')],
+ parse('\xe7\xaa\xbf\xe8\x8b\xa5\xe7\xb9\xb9 '
+ '\xe5\x9f\xba\xe5\x9c\xb0\xe3\x81\xaf'.decode('utf-8')))
+
+ def testParseColSpec_Dedup(self):
+ """An attacker cannot inflate response size by repeating a column."""
+ parse = monorailrequest.ParseColSpec
+ self.assertEqual([], parse(''))
+ self.assertEqual(
+ ['Aa', 'b', 'c/d'],
+ parse(u'Aa Aa AA AA AA b Aa aa c/d d c aA b aa B C/D D/aa/c'))
+ self.assertEqual(
+ ['A', 'b', 'c/d', 'e', 'f'],
+ parse(u'A b c/d e f g h i j a/k l m/c/a n/o'))
+
+ def testParseColSpec_Huge(self):
+ """An attacker cannot inflate response size with a huge column name."""
+ parse = monorailrequest.ParseColSpec
+ self.assertEqual(
+ ['Aa', 'b', 'c/d'],
+ parse(u'Aa Aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa b c/d'))
+
+ def testParseColSpec_Ignore(self):
+ """We ignore groupby and grid axes that would be useless."""
+ parse = monorailrequest.ParseColSpec
+ self.assertEqual(
+ ['Aa', 'b', 'c/d'],
+ parse(u'Aa AllLabels alllabels Id b opened/summary c/d',
+ ignore=tracker_constants.NOT_USED_IN_GRID_AXES))
+
+
+class TestPermissionLookup(unittest.TestCase):
+ OWNER_ID = 1
+ OTHER_USER_ID = 2
+
+ def setUp(self):
+ self.services = service_manager.Services(
+ project=fake.ProjectService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ self.services.user.TestAddUser('owner@gmail.com', self.OWNER_ID)
+ self.services.user.TestAddUser('user@gmail.com', self.OTHER_USER_ID)
+ self.live_project = self.services.project.TestAddProject(
+ 'live', owner_ids=[self.OWNER_ID])
+ self.archived_project = self.services.project.TestAddProject(
+ 'archived', owner_ids=[self.OWNER_ID],
+ state=project_pb2.ProjectState.ARCHIVED)
+ self.members_only_project = self.services.project.TestAddProject(
+ 'members-only', owner_ids=[self.OWNER_ID],
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY)
+
+ self.mox = mox.Mox()
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+
+ def CheckPermissions(self, perms, expect_view, expect_commit, expect_edit):
+ may_view = perms.HasPerm(permissions.VIEW, None, None)
+ self.assertEqual(expect_view, may_view)
+ may_commit = perms.HasPerm(permissions.COMMIT, None, None)
+ self.assertEqual(expect_commit, may_commit)
+ may_edit = perms.HasPerm(permissions.EDIT_PROJECT, None, None)
+ self.assertEqual(expect_edit, may_edit)
+
+ def MakeRequestAsUser(self, project_name, email):
+ self.mox.StubOutWithMock(users, 'get_current_user')
+ users.get_current_user().AndReturn(testing_helpers.Blank(
+ email=lambda: email))
+ self.mox.ReplayAll()
+
+ request = webapp2.Request.blank('/p/' + project_name)
+ mr = monorailrequest.MonorailRequest(self.services)
+ with mr.profiler.Phase('parse user info'):
+ mr.ParseRequest(request, self.services)
+ print('mr.auth is %r' % mr.auth)
+ return mr
+
+ def testOwnerPermissions_Live(self):
+ mr = self.MakeRequestAsUser('live', 'owner@gmail.com')
+ self.CheckPermissions(mr.perms, True, True, True)
+
+ def testOwnerPermissions_Archived(self):
+ mr = self.MakeRequestAsUser('archived', 'owner@gmail.com')
+ self.CheckPermissions(mr.perms, True, False, True)
+
+ def testOwnerPermissions_MembersOnly(self):
+ mr = self.MakeRequestAsUser('members-only', 'owner@gmail.com')
+ self.CheckPermissions(mr.perms, True, True, True)
+
+ def testExternalUserPermissions_Live(self):
+ mr = self.MakeRequestAsUser('live', 'user@gmail.com')
+ self.CheckPermissions(mr.perms, True, False, False)
+
+ def testExternalUserPermissions_Archived(self):
+ mr = self.MakeRequestAsUser('archived', 'user@gmail.com')
+ self.CheckPermissions(mr.perms, False, False, False)
+
+ def testExternalUserPermissions_MembersOnly(self):
+ mr = self.MakeRequestAsUser('members-only', 'user@gmail.com')
+ self.CheckPermissions(mr.perms, False, False, False)
diff --git a/framework/test/paginate_test.py b/framework/test/paginate_test.py
new file mode 100644
index 0000000..99adaa9
--- /dev/null
+++ b/framework/test/paginate_test.py
@@ -0,0 +1,145 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for pagination classes."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from google.appengine.ext import testbed
+
+from framework import exceptions
+from framework import paginate
+from testing import testing_helpers
+from proto import secrets_pb2
+
+
+class PageTokenTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def testGeneratePageToken_DiffRequests(self):
+ request_cont_1 = secrets_pb2.ListRequestContents(
+ parent='same', page_size=1, order_by='same', query='same')
+ request_cont_2 = secrets_pb2.ListRequestContents(
+ parent='same', page_size=2, order_by='same', query='same')
+ start = 10
+ self.assertNotEqual(
+ paginate.GeneratePageToken(request_cont_1, start),
+ paginate.GeneratePageToken(request_cont_2, start))
+
+ def testValidateAndParsePageToken(self):
+ request_cont_1 = secrets_pb2.ListRequestContents(
+ parent='projects/chicken', page_size=1, order_by='boks', query='hay')
+ start = 2
+ token = paginate.GeneratePageToken(request_cont_1, start)
+ self.assertEqual(
+ start,
+ paginate.ValidateAndParsePageToken(token, request_cont_1))
+
+ def testValidateAndParsePageToken_InvalidContents(self):
+ request_cont_1 = secrets_pb2.ListRequestContents(
+ parent='projects/chicken', page_size=1, order_by='boks', query='hay')
+ start = 2
+ token = paginate.GeneratePageToken(request_cont_1, start)
+
+ request_cont_diff = secrets_pb2.ListRequestContents(
+ parent='projects/goose', page_size=1, order_by='boks', query='hay')
+ with self.assertRaises(exceptions.PageTokenException):
+ paginate.ValidateAndParsePageToken(token, request_cont_diff)
+
+ def testValidateAndParsePageToken_InvalidSerializedToken(self):
+ request_cont = secrets_pb2.ListRequestContents()
+ with self.assertRaises(exceptions.PageTokenException):
+ paginate.ValidateAndParsePageToken('sldkfj87', request_cont)
+
+ def testValidateAndParsePageToken_InvalidTokenFormat(self):
+ request_cont = secrets_pb2.ListRequestContents()
+ with self.assertRaises(exceptions.PageTokenException):
+ paginate.ValidateAndParsePageToken('///sldkfj87', request_cont)
+
+
+class PaginateTest(unittest.TestCase):
+
+ def testVirtualPagination(self):
+ # Paginating 0 results on a page that can hold 100.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list')
+ total_count = 0
+ items_per_page = 100
+ start = 0
+ vp = paginate.VirtualPagination(total_count, items_per_page, start)
+ self.assertEqual(vp.num, 100)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 0)
+ self.assertFalse(vp.visible)
+
+ # Paginating 12 results on a page that can hold 100.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list')
+ vp = paginate.VirtualPagination(12, 100, 0)
+ self.assertEqual(vp.num, 100)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 12)
+ self.assertTrue(vp.visible)
+
+ # Paginating 12 results on a page that can hold 10.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list?num=10')
+ vp = paginate.VirtualPagination(12, 10, 0)
+ self.assertEqual(vp.num, 10)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 10)
+ self.assertTrue(vp.visible)
+
+ # Paginating 12 results starting at 5 on page that can hold 10.
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/issues/list?start=5&num=10')
+ vp = paginate.VirtualPagination(12, 10, 5)
+ self.assertEqual(vp.num, 10)
+ self.assertEqual(vp.start, 6)
+ self.assertEqual(vp.last, 12)
+ self.assertTrue(vp.visible)
+
+ # Paginating 123 results on a page that can hold 100.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list')
+ vp = paginate.VirtualPagination(123, 100, 0)
+ self.assertEqual(vp.num, 100)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 100)
+ self.assertTrue(vp.visible)
+
+ # Paginating 123 results on second page that can hold 100.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list?start=100')
+ vp = paginate.VirtualPagination(123, 100, 100)
+ self.assertEqual(vp.num, 100)
+ self.assertEqual(vp.start, 101)
+ self.assertEqual(vp.last, 123)
+ self.assertTrue(vp.visible)
+
+ # Paginating a huge number of objects will show at most 1000 per page.
+ mr = testing_helpers.MakeMonorailRequest(path='/issues/list?num=9999')
+ vp = paginate.VirtualPagination(12345, 9999, 0)
+ self.assertEqual(vp.num, 1000)
+ self.assertEqual(vp.start, 1)
+ self.assertEqual(vp.last, 1000)
+ self.assertTrue(vp.visible)
+
+ # Test urls for a hotlist pagination
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/u/hotlists/17?num=5&start=4')
+ mr.hotlist_id = 17
+ mr.auth.user_id = 112
+ vp = paginate.VirtualPagination(12, 5, 4,
+ list_page_url='/u/112/hotlists/17')
+ self.assertEqual(vp.num, 5)
+ self.assertEqual(vp.start, 5)
+ self.assertEqual(vp.last, 9)
+ self.assertTrue(vp.visible)
+ self.assertEqual('/u/112/hotlists/17?num=5&start=9', vp.next_url)
+ self.assertEqual('/u/112/hotlists/17?num=5&start=0', vp.prev_url)
diff --git a/framework/test/permissions_test.py b/framework/test/permissions_test.py
new file mode 100644
index 0000000..0917b53
--- /dev/null
+++ b/framework/test/permissions_test.py
@@ -0,0 +1,1860 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for permissions.py."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import unittest
+
+import mox
+
+import settings
+from framework import authdata
+from framework import framework_constants
+from framework import framework_views
+from framework import permissions
+from proto import features_pb2
+from proto import project_pb2
+from proto import site_pb2
+from proto import tracker_pb2
+from proto import user_pb2
+from proto import usergroup_pb2
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+
+
+class PermissionSetTest(unittest.TestCase):
+
+ def setUp(self):
+ self.perms = permissions.PermissionSet(['A', 'b', 'Cc'])
+ self.proj = project_pb2.Project()
+ self.proj.contributor_ids.append(111)
+ self.proj.contributor_ids.append(222)
+ self.proj.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=111, perms=['Cc', 'D', 'e', 'Ff']))
+ self.proj.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=222, perms=['G', 'H']))
+ # user 3 used to be a member and had extra perms, but no longer in project.
+ self.proj.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=333, perms=['G', 'H']))
+
+ def testGetAttr(self):
+ self.assertTrue(self.perms.a)
+ self.assertTrue(self.perms.A)
+ self.assertTrue(self.perms.b)
+ self.assertTrue(self.perms.Cc)
+ self.assertTrue(self.perms.CC)
+
+ self.assertFalse(self.perms.z)
+ self.assertFalse(self.perms.Z)
+
+ def testCanUsePerm_Anonymous(self):
+ effective_ids = set()
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+
+ def testCanUsePerm_SignedInNoGroups(self):
+ effective_ids = {111}
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'D', effective_ids, self.proj, ['Restrict-D-A']))
+ self.assertFalse(self.perms.CanUsePerm('G', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+
+ effective_ids = {222}
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm('G', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm(
+ 'Z', effective_ids, self.proj, ['Restrict-Z-A']))
+
+ def testCanUsePerm_SignedInWithGroups(self):
+ effective_ids = {111, 222, 333}
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm('G', effective_ids, self.proj, []))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'G', effective_ids, self.proj, ['Restrict-G-D']))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm(
+ 'G', effective_ids, self.proj, ['Restrict-G-Z']))
+
+ def testCanUsePerm_FormerMember(self):
+ effective_ids = {333}
+ self.assertTrue(self.perms.CanUsePerm('A', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('D', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('G', effective_ids, self.proj, []))
+ self.assertFalse(self.perms.CanUsePerm('Z', effective_ids, self.proj, []))
+
+ def testHasPerm_InPermSet(self):
+ self.assertTrue(self.perms.HasPerm('a', 0, None))
+ self.assertTrue(self.perms.HasPerm('a', 0, self.proj))
+ self.assertTrue(self.perms.HasPerm('A', 0, None))
+ self.assertTrue(self.perms.HasPerm('A', 0, self.proj))
+ self.assertFalse(self.perms.HasPerm('Z', 0, None))
+ self.assertFalse(self.perms.HasPerm('Z', 0, self.proj))
+
+ def testHasPerm_InExtraPerms(self):
+ self.assertTrue(self.perms.HasPerm('d', 111, self.proj))
+ self.assertTrue(self.perms.HasPerm('D', 111, self.proj))
+ self.assertTrue(self.perms.HasPerm('Cc', 111, self.proj))
+ self.assertTrue(self.perms.HasPerm('CC', 111, self.proj))
+ self.assertFalse(self.perms.HasPerm('Z', 111, self.proj))
+
+ self.assertFalse(self.perms.HasPerm('d', 222, self.proj))
+ self.assertFalse(self.perms.HasPerm('D', 222, self.proj))
+
+ # Only current members can have extra permissions
+ self.proj.contributor_ids = []
+ self.assertFalse(self.perms.HasPerm('d', 111, self.proj))
+
+ # TODO(jrobbins): also test consider_restrictions=False and
+ # restriction labels directly in this class.
+
+ def testHasPerm_OverrideExtraPerms(self):
+ # D is an extra perm for 111...
+ self.assertTrue(self.perms.HasPerm('d', 111, self.proj))
+ self.assertTrue(self.perms.HasPerm('D', 111, self.proj))
+ # ...unless we tell HasPerm it isn't.
+ self.assertFalse(self.perms.HasPerm('d', 111, self.proj, []))
+ self.assertFalse(self.perms.HasPerm('D', 111, self.proj, []))
+ # Perms in self.perms are still considered
+ self.assertTrue(self.perms.HasPerm('Cc', 111, self.proj, []))
+ self.assertTrue(self.perms.HasPerm('CC', 111, self.proj, []))
+ # Z is not an extra perm...
+ self.assertFalse(self.perms.HasPerm('Z', 111, self.proj))
+ # ...unless we tell HasPerm it is.
+ self.assertTrue(self.perms.HasPerm('Z', 111, self.proj, ['z']))
+
+ def testHasPerm_GrantedPerms(self):
+ self.assertTrue(self.perms.CanUsePerm(
+ 'A', {111}, self.proj, [], granted_perms=['z']))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'a', {111}, self.proj, [], granted_perms=['z']))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'a', {111}, self.proj, [], granted_perms=['a']))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'Z', {111}, self.proj, [], granted_perms=['y', 'z']))
+ self.assertTrue(self.perms.CanUsePerm(
+ 'z', {111}, self.proj, [], granted_perms=['y', 'z']))
+ self.assertFalse(self.perms.CanUsePerm(
+ 'z', {111}, self.proj, [], granted_perms=['y']))
+
+ def testDebugString(self):
+ self.assertEqual('PermissionSet()',
+ permissions.PermissionSet([]).DebugString())
+ self.assertEqual('PermissionSet(a)',
+ permissions.PermissionSet(['A']).DebugString())
+ self.assertEqual('PermissionSet(a, b, cc)', self.perms.DebugString())
+
+ def testRepr(self):
+ self.assertEqual('PermissionSet(frozenset([]))',
+ permissions.PermissionSet([]).__repr__())
+ self.assertEqual('PermissionSet(frozenset([\'a\']))',
+ permissions.PermissionSet(['A']).__repr__())
+
+
+class PermissionsTest(unittest.TestCase):
+
+ NOW = 1277762224 # Any timestamp will do, we only compare it to itself +/- 1
+ COMMITTER_USER_ID = 111
+ OWNER_USER_ID = 222
+ CONTRIB_USER_ID = 333
+ SITE_ADMIN_USER_ID = 444
+
+ def MakeProject(self, project_name, state, add_members=True, access=None):
+ args = dict(project_name=project_name, state=state)
+ if add_members:
+ args.update(owner_ids=[self.OWNER_USER_ID],
+ committer_ids=[self.COMMITTER_USER_ID],
+ contributor_ids=[self.CONTRIB_USER_ID])
+
+ if access:
+ args.update(access=access)
+
+ return fake.Project(**args)
+
+ def setUp(self):
+ self.live_project = self.MakeProject('live', project_pb2.ProjectState.LIVE)
+ self.archived_project = self.MakeProject(
+ 'archived', project_pb2.ProjectState.ARCHIVED)
+ self.other_live_project = self.MakeProject(
+ 'other_live', project_pb2.ProjectState.LIVE, add_members=False)
+ self.members_only_project = self.MakeProject(
+ 's3kr3t', project_pb2.ProjectState.LIVE,
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY)
+
+ self.nonmember = user_pb2.User()
+ self.member = user_pb2.User()
+ self.owner = user_pb2.User()
+ self.contrib = user_pb2.User()
+ self.site_admin = user_pb2.User()
+ self.site_admin.is_site_admin = True
+ self.borg_user = user_pb2.User(email=settings.borg_service_account)
+
+ self.normal_artifact = tracker_pb2.Issue()
+ self.normal_artifact.labels.extend(['hot', 'Key-Value'])
+ self.normal_artifact.reporter_id = 111
+
+ # Two PermissionSets w/ permissions outside of any project.
+ self.normal_user_perms = permissions.GetPermissions(
+ None, {111}, None)
+ self.admin_perms = permissions.PermissionSet(
+ [permissions.ADMINISTER_SITE,
+ permissions.CREATE_PROJECT])
+
+ self.mox = mox.Mox()
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+
+ def testGetPermissions_Admin(self):
+ self.assertEqual(
+ permissions.ADMIN_PERMISSIONSET,
+ permissions.GetPermissions(self.site_admin, None, None))
+
+ def testGetPermissions_BorgServiceAccount(self):
+ self.assertEqual(
+ permissions.GROUP_IMPORT_BORG_PERMISSIONSET,
+ permissions.GetPermissions(self.borg_user, None, None))
+
+ def CheckPermissions(self, perms, expected_list):
+ expect_view, expect_commit, expect_edit_project = expected_list
+ self.assertEqual(
+ expect_view, perms.HasPerm(permissions.VIEW, None, None))
+ self.assertEqual(
+ expect_commit, perms.HasPerm(permissions.COMMIT, None, None))
+ self.assertEqual(
+ expect_edit_project,
+ perms.HasPerm(permissions.EDIT_PROJECT, None, None))
+
+ def testAnonPermissions(self):
+ perms = permissions.GetPermissions(None, set(), self.live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(None, set(), self.members_only_project)
+ self.CheckPermissions(perms, [False, False, False])
+
+ def testNonmemberPermissions(self):
+ perms = permissions.GetPermissions(
+ self.nonmember, {123}, self.live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.nonmember, {123}, self.members_only_project)
+ self.CheckPermissions(perms, [False, False, False])
+
+ def testMemberPermissions(self):
+ perms = permissions.GetPermissions(
+ self.member, {self.COMMITTER_USER_ID}, self.live_project)
+ self.CheckPermissions(perms, [True, True, False])
+
+ perms = permissions.GetPermissions(
+ self.member, {self.COMMITTER_USER_ID}, self.other_live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.member, {self.COMMITTER_USER_ID}, self.members_only_project)
+ self.CheckPermissions(perms, [True, True, False])
+
+ def testOwnerPermissions(self):
+ perms = permissions.GetPermissions(
+ self.owner, {self.OWNER_USER_ID}, self.live_project)
+ self.CheckPermissions(perms, [True, True, True])
+
+ perms = permissions.GetPermissions(
+ self.owner, {self.OWNER_USER_ID}, self.other_live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.owner, {self.OWNER_USER_ID}, self.members_only_project)
+ self.CheckPermissions(perms, [True, True, True])
+
+ def testContributorPermissions(self):
+ perms = permissions.GetPermissions(
+ self.contrib, {self.CONTRIB_USER_ID}, self.live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.contrib, {self.CONTRIB_USER_ID}, self.other_live_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ perms = permissions.GetPermissions(
+ self.contrib, {self.CONTRIB_USER_ID}, self.members_only_project)
+ self.CheckPermissions(perms, [True, False, False])
+
+ def testLookupPermset_ExactMatch(self):
+ self.assertEqual(
+ permissions.USER_PERMISSIONSET,
+ permissions._LookupPermset(
+ permissions.USER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE))
+
+ def testLookupPermset_WildcardAccess(self):
+ self.assertEqual(
+ permissions.OWNER_ACTIVE_PERMISSIONSET,
+ permissions._LookupPermset(
+ permissions.OWNER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.MEMBERS_ONLY))
+
+ def testGetPermissionKey_AnonUser(self):
+ self.assertEqual(
+ (permissions.ANON_ROLE, permissions.UNDEFINED_STATUS,
+ permissions.UNDEFINED_ACCESS),
+ permissions._GetPermissionKey(None, None))
+ self.assertEqual(
+ (permissions.ANON_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(None, self.live_project))
+
+ def testGetPermissionKey_ExpiredProject(self):
+ self.archived_project.delete_time = self.NOW
+ # In an expired project, the user's committe role does not count.
+ self.assertEqual(
+ (permissions.USER_ROLE, project_pb2.ProjectState.ARCHIVED,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.COMMITTER_USER_ID, self.archived_project,
+ expired_before=self.NOW + 1))
+ # If not expired yet, the user's committe role still counts.
+ self.assertEqual(
+ (permissions.COMMITTER_ROLE, project_pb2.ProjectState.ARCHIVED,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.COMMITTER_USER_ID, self.archived_project,
+ expired_before=self.NOW - 1))
+
+ def testGetPermissionKey_DefinedRoles(self):
+ self.assertEqual(
+ (permissions.OWNER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.OWNER_USER_ID, self.live_project))
+ self.assertEqual(
+ (permissions.COMMITTER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.COMMITTER_USER_ID, self.live_project))
+ self.assertEqual(
+ (permissions.CONTRIBUTOR_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ self.CONTRIB_USER_ID, self.live_project))
+
+ def testGetPermissionKey_Nonmember(self):
+ self.assertEqual(
+ (permissions.USER_ROLE, project_pb2.ProjectState.LIVE,
+ project_pb2.ProjectAccess.ANYONE),
+ permissions._GetPermissionKey(
+ 999, self.live_project))
+
+ def testPermissionsImmutable(self):
+ self.assertTrue(isinstance(
+ permissions.EMPTY_PERMISSIONSET.perm_names, frozenset))
+ self.assertTrue(isinstance(
+ permissions.READ_ONLY_PERMISSIONSET.perm_names, frozenset))
+ self.assertTrue(isinstance(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET.perm_names, frozenset))
+ self.assertTrue(isinstance(
+ permissions.OWNER_ACTIVE_PERMISSIONSET.perm_names, frozenset))
+
+ def testGetExtraPerms(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(222)
+ # User 1 is a former member with left-over extra perms that don't count.
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=111, perms=['a', 'b', 'c']))
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=222, perms=['a', 'b', 'c']))
+
+ self.assertListEqual(
+ [],
+ permissions.GetExtraPerms(project, 111))
+ self.assertListEqual(
+ ['a', 'b', 'c'],
+ permissions.GetExtraPerms(project, 222))
+ self.assertListEqual(
+ [],
+ permissions.GetExtraPerms(project, 333))
+
+ def testCanDeleteComment_NoPermissionSet(self):
+ """Test that if no PermissionSet is given, we can't delete comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+ # If no PermissionSet is given, the user cannot delete the comment.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111, None))
+ # Same, with no user specified.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, framework_constants.NO_USER_SPECIFIED, None))
+
+ def testCanDeleteComment_AnonUsersCannotDelete(self):
+ """Test that anon users can't delete comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+ perms = permissions.PermissionSet([permissions.DELETE_ANY])
+
+ # No logged in user, even with perms from somewhere.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, framework_constants.NO_USER_SPECIFIED, perms))
+
+ # No logged in user, even if artifact was already deleted.
+ comment.deleted_by = 111
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, framework_constants.NO_USER_SPECIFIED, perms))
+
+ def testCanDeleteComment_DeleteAny(self):
+ """Test that users with DeleteAny permission can delete any comment.
+
+ Except for spam comments or comments by banned users.
+ """
+ comment = tracker_pb2.IssueComment(user_id=111)
+ commenter = user_pb2.User()
+ perms = permissions.PermissionSet([permissions.DELETE_ANY])
+
+ # Users with DeleteAny permission can delete their own comments.
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # And also comments by other users
+ comment.user_id = 999
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # As well as undelete comments they deleted.
+ comment.deleted_by = 111
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # Or that other users deleted.
+ comment.deleted_by = 222
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ def testCanDeleteComment_DeleteOwn(self):
+ """Test that users with DeleteOwn permission can delete any comment.
+
+ Except for spam comments or comments by banned users.
+ """
+ comment = tracker_pb2.IssueComment(user_id=111)
+ commenter = user_pb2.User()
+ perms = permissions.PermissionSet([permissions.DELETE_OWN])
+
+ # Users with DeleteOwn permission can delete their own comments.
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # But not comments by other users
+ comment.user_id = 999
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # They can undelete comments they deleted.
+ comment.user_id = 111
+ comment.deleted_by = 111
+ self.assertTrue(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ # But not comments that other users deleted.
+ comment.deleted_by = 222
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111, perms))
+
+ def testCanDeleteComment_CannotDeleteSpamComments(self):
+ """Test that nobody can (un)delete comments marked as spam."""
+ comment = tracker_pb2.IssueComment(user_id=111, is_spam=True)
+ commenter = user_pb2.User()
+
+ # Nobody can delete comments marked as spam.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ # Nobody can undelete comments marked as spam.
+ comment.deleted_by = 222
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ def testCanDeleteComment_CannotDeleteCommentsByBannedUser(self):
+ """Test that nobody can (un)delete comments by banned users."""
+ comment = tracker_pb2.IssueComment(user_id=111)
+ commenter = user_pb2.User(banned='Some reason')
+
+ # Nobody can delete comments by banned users.
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ # Nobody can undelete comments by banned users.
+ comment.deleted_by = 222
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+ self.assertFalse(permissions.CanDeleteComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ def testCanFlagComment_FlagSpamCanReport(self):
+ """Test that users with FlagSpam permissions can report comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([permissions.FLAG_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_FlagSpamCanUnReportOwn(self):
+ """Test that users with FlagSpam permission can un-report comments they
+ previously reported."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [111], 111,
+ permissions.PermissionSet([permissions.FLAG_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_FlagSpamCannotUnReportOthers(self):
+ """Test that users with FlagSpam permission doesn't know if other users have
+ reported a comment as spam."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [222], 111,
+ permissions.PermissionSet([permissions.FLAG_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_FlagSpamCannotUnFlag(self):
+ comment = tracker_pb2.IssueComment(is_spam=True)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [111], 111,
+ permissions.PermissionSet([permissions.FLAG_SPAM]))
+
+ self.assertFalse(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_VerdictSpamCanFlag(self):
+ """Test that users with FlagSpam permissions can flag comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([permissions.VERDICT_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_VerdictSpamCanUnFlag(self):
+ """Test that users with FlagSpam permissions can un-flag comments."""
+ comment = tracker_pb2.IssueComment(is_spam=True)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([permissions.VERDICT_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_CannotFlagNoPermission(self):
+ """Test that users without permission cannot flag comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([permissions.DELETE_ANY]))
+
+ self.assertFalse(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_CannotUnFlagNoPermission(self):
+ """Test that users without permission cannot un-flag comments."""
+ comment = tracker_pb2.IssueComment(is_spam=True)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ # Users need the VerdictSpam permission to be able to un-flag comments.
+ permissions.PermissionSet([
+ permissions.DELETE_ANY, permissions.FLAG_SPAM]))
+
+ self.assertFalse(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_CannotFlagCommentByBannedUser(self):
+ """Test that nobady can flag comments by banned users."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User(banned='Some reason')
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([
+ permissions.FLAG_SPAM, permissions.VERDICT_SPAM]))
+
+ self.assertFalse(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanFlagComment_CannotUnFlagCommentByBannedUser(self):
+ """Test that nobady can un-flag comments by banned users."""
+ comment = tracker_pb2.IssueComment(is_spam=True)
+ commenter = user_pb2.User(banned='Some reason')
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([
+ permissions.FLAG_SPAM, permissions.VERDICT_SPAM]))
+
+ self.assertFalse(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_CanUnFlagDeletedSpamComment(self):
+ """Test that we can un-flag a deleted comment that is spam."""
+ comment = tracker_pb2.IssueComment(is_spam=True, deleted_by=111)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 222,
+ permissions.PermissionSet([permissions.VERDICT_SPAM]))
+
+ self.assertTrue(can_flag)
+ self.assertTrue(is_flagged)
+
+ def testCanFlagComment_CannotFlagDeletedComment(self):
+ """Test that nobody can flag a deleted comment that is not spam."""
+ comment = tracker_pb2.IssueComment(deleted_by=111)
+ commenter = user_pb2.User()
+
+ can_flag, is_flagged = permissions.CanFlagComment(
+ comment, commenter, [], 111,
+ permissions.PermissionSet([
+ permissions.FLAG_SPAM, permissions.VERDICT_SPAM,
+ permissions.DELETE_ANY, permissions.DELETE_OWN]))
+
+ self.assertFalse(can_flag)
+ self.assertFalse(is_flagged)
+
+ def testCanViewComment_Normal(self):
+ """Test that we can view comments."""
+ comment = tracker_pb2.IssueComment()
+ commenter = user_pb2.User()
+ # We assume that CanViewIssue was already called. There are no further
+ # restrictions to view this comment.
+ self.assertTrue(permissions.CanViewComment(
+ comment, commenter, 111, None))
+
+ def testCanViewComment_CannotViewCommentsByBannedUser(self):
+ """Test that nobody can view comments by banned users."""
+ comment = tracker_pb2.IssueComment(user_id=111)
+ commenter = user_pb2.User(banned='Some reason')
+
+ # Nobody can view comments by banned users.
+ self.assertFalse(permissions.CanViewComment(
+ comment, commenter, 111, permissions.ADMIN_PERMISSIONSET))
+
+ def testCanViewComment_OnlyModeratorsCanViewSpamComments(self):
+ """Test that only users with VerdictSpam can view spam comments."""
+ comment = tracker_pb2.IssueComment(user_id=111, is_spam=True)
+ commenter = user_pb2.User()
+
+ # Users with VerdictSpam permission can view comments marked as spam.
+ self.assertTrue(permissions.CanViewComment(
+ comment, commenter, 222,
+ permissions.PermissionSet([permissions.VERDICT_SPAM])))
+
+ # Other users cannot view comments marked as spam, even if it is their own
+ # comment.
+ self.assertFalse(permissions.CanViewComment(
+ comment, commenter, 111,
+ permissions.PermissionSet([
+ permissions.FLAG_SPAM, permissions.DELETE_ANY,
+ permissions.DELETE_OWN])))
+
+ def testCanViewComment_DeletedComment(self):
+ """Test that for deleted comments, only the users that can undelete it can
+ view it.
+ """
+ comment = tracker_pb2.IssueComment(user_id=111, deleted_by=222)
+ commenter = user_pb2.User()
+
+ # Users with DeleteAny permission can view all deleted comments.
+ self.assertTrue(permissions.CanViewComment(
+ comment, commenter, 333,
+ permissions.PermissionSet([permissions.DELETE_ANY])))
+
+ # Users with DeleteOwn permissions can only see their own comments if they
+ # deleted them.
+ comment.user_id = comment.deleted_by = 333
+ self.assertTrue(permissions.CanViewComment(
+ comment, commenter, 333,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+
+ # But not comments they didn't delete.
+ comment.deleted_by = 111
+ self.assertFalse(permissions.CanViewComment(
+ comment, commenter, 333,
+ permissions.PermissionSet([permissions.DELETE_OWN])))
+
+ def testCanViewInboundMessage(self):
+ comment = tracker_pb2.IssueComment(user_id=111)
+
+ # Users can view their own inbound messages
+ self.assertTrue(permissions.CanViewInboundMessage(
+ comment, 111, permissions.EMPTY_PERMISSIONSET))
+
+ # Users with the ViewInboundMessages permissions can view inbound messages.
+ self.assertTrue(permissions.CanViewInboundMessage(
+ comment, 333,
+ permissions.PermissionSet([permissions.VIEW_INBOUND_MESSAGES])))
+
+ # Other users cannot view inbound messages.
+ self.assertFalse(permissions.CanViewInboundMessage(
+ comment, 333,
+ permissions.PermissionSet([permissions.VIEW])))
+
+ def testCanViewNormalArifact(self):
+ # Anyone can view a non-restricted artifact.
+ self.assertTrue(permissions.CanView(
+ {111}, permissions.READ_ONLY_PERMISSIONSET,
+ self.live_project, []))
+
+ def testCanCreateProject_NoPerms(self):
+ """Signed out users cannot create projects."""
+ self.assertFalse(permissions.CanCreateProject(
+ permissions.EMPTY_PERMISSIONSET))
+
+ self.assertFalse(permissions.CanCreateProject(
+ permissions.READ_ONLY_PERMISSIONSET))
+
+ def testCanCreateProject_Admin(self):
+ """Site admins can create projects."""
+ self.assertTrue(permissions.CanCreateProject(
+ permissions.ADMIN_PERMISSIONSET))
+
+ def testCanCreateProject_RegularUser(self):
+ """Signed in non-admins can create a project if settings allow ANYONE."""
+ try:
+ orig_restriction = settings.project_creation_restriction
+ ANYONE = site_pb2.UserTypeRestriction.ANYONE
+ ADMIN_ONLY = site_pb2.UserTypeRestriction.ADMIN_ONLY
+ NO_ONE = site_pb2.UserTypeRestriction.NO_ONE
+ perms = permissions.PermissionSet([permissions.CREATE_PROJECT])
+
+ settings.project_creation_restriction = ANYONE
+ self.assertTrue(permissions.CanCreateProject(perms))
+
+ settings.project_creation_restriction = ADMIN_ONLY
+ self.assertFalse(permissions.CanCreateProject(perms))
+
+ settings.project_creation_restriction = NO_ONE
+ self.assertFalse(permissions.CanCreateProject(perms))
+ self.assertFalse(permissions.CanCreateProject(
+ permissions.ADMIN_PERMISSIONSET))
+ finally:
+ settings.project_creation_restriction = orig_restriction
+
+ def testCanCreateGroup_AnyoneWithCreateGroup(self):
+ orig_setting = settings.group_creation_restriction
+ try:
+ settings.group_creation_restriction = site_pb2.UserTypeRestriction.ANYONE
+ self.assertTrue(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.CREATE_GROUP])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([])))
+ finally:
+ settings.group_creation_restriction = orig_setting
+
+ def testCanCreateGroup_AdminOnly(self):
+ orig_setting = settings.group_creation_restriction
+ try:
+ ADMIN_ONLY = site_pb2.UserTypeRestriction.ADMIN_ONLY
+ settings.group_creation_restriction = ADMIN_ONLY
+ self.assertTrue(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.ADMINISTER_SITE])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.CREATE_GROUP])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([])))
+ finally:
+ settings.group_creation_restriction = orig_setting
+
+ def testCanCreateGroup_UnspecifiedSetting(self):
+ orig_setting = settings.group_creation_restriction
+ try:
+ settings.group_creation_restriction = None
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.ADMINISTER_SITE])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([permissions.CREATE_GROUP])))
+ self.assertFalse(permissions.CanCreateGroup(
+ permissions.PermissionSet([])))
+ finally:
+ settings.group_creation_restriction = orig_setting
+
+ def testCanEditGroup_HasPerm(self):
+ self.assertTrue(permissions.CanEditGroup(
+ permissions.PermissionSet([permissions.EDIT_GROUP]), None, None))
+
+ def testCanEditGroup_IsOwner(self):
+ self.assertTrue(permissions.CanEditGroup(
+ permissions.PermissionSet([]), {111}, {111}))
+
+ def testCanEditGroup_Otherwise(self):
+ self.assertFalse(permissions.CanEditGroup(
+ permissions.PermissionSet([]), {111}, {222}))
+
+ def testCanViewGroupMembers_HasPerm(self):
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([permissions.VIEW_GROUP]),
+ None, None, None, None, None))
+
+ def testCanViewGroupMembers_IsMemberOfFriendProject(self):
+ group_settings = usergroup_pb2.MakeSettings('owners', friend_projects=[890])
+ self.assertFalse(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789}))
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789, 890}))
+
+ def testCanViewGroupMembers_VisibleToOwner(self):
+ group_settings = usergroup_pb2.MakeSettings('owners')
+ self.assertFalse(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789}))
+ self.assertFalse(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {222}, group_settings, {222}, {333}, {789}))
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {333}, group_settings, {222}, {333}, {789}))
+
+ def testCanViewGroupMembers_IsVisibleToMember(self):
+ group_settings = usergroup_pb2.MakeSettings('members')
+ self.assertFalse(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789}))
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {222}, group_settings, {222}, {333}, {789}))
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {333}, group_settings, {222}, {333}, {789}))
+
+ def testCanViewGroupMembers_AnyoneCanView(self):
+ group_settings = usergroup_pb2.MakeSettings('anyone')
+ self.assertTrue(permissions.CanViewGroupMembers(
+ permissions.PermissionSet([]),
+ {111}, group_settings, {222}, {333}, {789}))
+
+ def testIsBanned_AnonUser(self):
+ user_view = framework_views.StuffUserView(None, None, True)
+ self.assertFalse(permissions.IsBanned(None, user_view))
+
+ def testIsBanned_NormalUser(self):
+ user = user_pb2.User()
+ user_view = framework_views.StuffUserView(None, None, True)
+ self.assertFalse(permissions.IsBanned(user, user_view))
+
+ def testIsBanned_BannedUser(self):
+ user = user_pb2.User()
+ user.banned = 'spammer'
+ user_view = framework_views.StuffUserView(None, None, True)
+ self.assertTrue(permissions.IsBanned(user, user_view))
+
+ def testIsBanned_BadDomainUser(self):
+ user = user_pb2.User()
+ self.assertFalse(permissions.IsBanned(user, None))
+
+ user_view = framework_views.StuffUserView(None, None, True)
+ user_view.domain = 'spammer.com'
+ self.assertFalse(permissions.IsBanned(user, user_view))
+
+ orig_banned_user_domains = settings.banned_user_domains
+ settings.banned_user_domains = ['spammer.com', 'phisher.com']
+ self.assertTrue(permissions.IsBanned(user, user_view))
+ settings.banned_user_domains = orig_banned_user_domains
+
+ def testIsBanned_PlusAddressUser(self):
+ """We don't allow users who have + in their email address."""
+ user = user_pb2.User(email='user@example.com')
+ self.assertFalse(permissions.IsBanned(user, None))
+
+ user.email = 'user+shadystuff@example.com'
+ self.assertTrue(permissions.IsBanned(user, None))
+
+ def testCanExpungeUser_Admin(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ mr.perms = permissions.ADMIN_PERMISSIONSET
+ self.assertTrue(permissions.CanExpungeUsers(mr))
+
+ def testGetCustomPermissions(self):
+ project = project_pb2.Project()
+ self.assertListEqual([], permissions.GetCustomPermissions(project))
+
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ perms=['Core', 'Elite', 'Gold']))
+ self.assertListEqual(['Core', 'Elite', 'Gold'],
+ permissions.GetCustomPermissions(project))
+
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ perms=['Silver', 'Gold', 'Bronze']))
+ self.assertListEqual(['Bronze', 'Core', 'Elite', 'Gold', 'Silver'],
+ permissions.GetCustomPermissions(project))
+
+ # View is not returned because it is a starndard permission.
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ perms=['Bronze', permissions.VIEW]))
+ self.assertListEqual(['Bronze', 'Core', 'Elite', 'Gold', 'Silver'],
+ permissions.GetCustomPermissions(project))
+
+ def testUserCanViewProject(self):
+ self.mox.StubOutWithMock(time, 'time')
+ for _ in range(8):
+ time.time().AndReturn(self.NOW)
+ self.mox.ReplayAll()
+
+ self.assertTrue(permissions.UserCanViewProject(
+ self.member, {self.COMMITTER_USER_ID}, self.live_project))
+ self.assertTrue(permissions.UserCanViewProject(
+ None, None, self.live_project))
+
+ self.archived_project.delete_time = self.NOW + 1
+ self.assertFalse(permissions.UserCanViewProject(
+ None, None, self.archived_project))
+ self.assertTrue(permissions.UserCanViewProject(
+ self.owner, {self.OWNER_USER_ID}, self.archived_project))
+ self.assertTrue(permissions.UserCanViewProject(
+ self.site_admin, {self.SITE_ADMIN_USER_ID},
+ self.archived_project))
+
+ self.archived_project.delete_time = self.NOW - 1
+ self.assertFalse(permissions.UserCanViewProject(
+ None, None, self.archived_project))
+ self.assertFalse(permissions.UserCanViewProject(
+ self.owner, {self.OWNER_USER_ID}, self.archived_project))
+ self.assertTrue(permissions.UserCanViewProject(
+ self.site_admin, {self.SITE_ADMIN_USER_ID},
+ self.archived_project))
+
+ self.mox.VerifyAll()
+
+ def CheckExpired(self, state, expected_to_be_reapable):
+ proj = project_pb2.Project()
+ proj.state = state
+ proj.delete_time = self.NOW + 1
+ self.assertFalse(permissions.IsExpired(proj))
+
+ proj.delete_time = self.NOW - 1
+ self.assertEqual(expected_to_be_reapable, permissions.IsExpired(proj))
+
+ proj.delete_time = self.NOW - 1
+ self.assertFalse(permissions.IsExpired(proj, expired_before=self.NOW - 2))
+
+ def testIsExpired_Live(self):
+ self.CheckExpired(project_pb2.ProjectState.LIVE, False)
+
+ def testIsExpired_Archived(self):
+ self.mox.StubOutWithMock(time, 'time')
+ for _ in range(2):
+ time.time().AndReturn(self.NOW)
+ self.mox.ReplayAll()
+
+ self.CheckExpired(project_pb2.ProjectState.ARCHIVED, True)
+
+ self.mox.VerifyAll()
+
+
+class PermissionsCheckTest(unittest.TestCase):
+
+ def setUp(self):
+ self.perms = permissions.PermissionSet(['a', 'b', 'c'])
+
+ self.proj = project_pb2.Project()
+ self.proj.committer_ids.append(111)
+ self.proj.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=111, perms=['d']))
+
+ # Note: z is an example of a perm that the user does not have.
+ # Note: q is an example of an irrelevant perm that the user does not have.
+
+ def DoCanUsePerm(self, perm, project='default', user_id=None, restrict=''):
+ """Wrapper function to call CanUsePerm()."""
+ if project == 'default':
+ project = self.proj
+ return self.perms.CanUsePerm(
+ perm, {user_id or 111}, project, restrict.split())
+
+ def testHasPermNoRestrictions(self):
+ self.assertTrue(self.DoCanUsePerm('a'))
+ self.assertTrue(self.DoCanUsePerm('A'))
+ self.assertFalse(self.DoCanUsePerm('z'))
+ self.assertTrue(self.DoCanUsePerm('d'))
+ self.assertFalse(self.DoCanUsePerm('d', user_id=222))
+ self.assertFalse(self.DoCanUsePerm('d', project=project_pb2.Project()))
+
+ def testHasPermOperationRestrictions(self):
+ self.assertTrue(self.DoCanUsePerm('a', restrict='Restrict-a-b'))
+ self.assertTrue(self.DoCanUsePerm('a', restrict='Restrict-b-z'))
+ self.assertTrue(self.DoCanUsePerm('a', restrict='Restrict-a-d'))
+ self.assertTrue(self.DoCanUsePerm('d', restrict='Restrict-d-a'))
+ self.assertTrue(self.DoCanUsePerm(
+ 'd', restrict='Restrict-q-z Restrict-q-d Restrict-d-a'))
+
+ self.assertFalse(self.DoCanUsePerm('a', restrict='Restrict-a-z'))
+ self.assertFalse(self.DoCanUsePerm('d', restrict='Restrict-d-z'))
+ self.assertFalse(self.DoCanUsePerm(
+ 'd', restrict='Restrict-d-a Restrict-d-z'))
+
+ def testHasPermOutsideProjectScope(self):
+ self.assertTrue(self.DoCanUsePerm('a', project=None))
+ self.assertTrue(self.DoCanUsePerm(
+ 'a', project=None, restrict='Restrict-a-c'))
+ self.assertTrue(self.DoCanUsePerm(
+ 'a', project=None, restrict='Restrict-q-z'))
+
+ self.assertFalse(self.DoCanUsePerm('z', project=None))
+ self.assertFalse(self.DoCanUsePerm(
+ 'a', project=None, restrict='Restrict-a-d'))
+
+
+class CanViewProjectContributorListTest(unittest.TestCase):
+
+ def testCanViewProjectContributorList_NoProject(self):
+ mr = testing_helpers.MakeMonorailRequest(path='/')
+ self.assertFalse(permissions.CanViewContributorList(mr, mr.project))
+
+ def testCanViewProjectContributorList_NormalProject(self):
+ project = project_pb2.Project()
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/', project=project)
+ self.assertTrue(permissions.CanViewContributorList(mr, mr.project))
+
+ def testCanViewProjectContributorList_ProjectWithOptionSet(self):
+ project = project_pb2.Project()
+ project.only_owners_see_contributors = True
+
+ for perms in [permissions.READ_ONLY_PERMISSIONSET,
+ permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ permissions.CONTRIBUTOR_INACTIVE_PERMISSIONSET]:
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/', project=project, perms=perms)
+ self.assertFalse(permissions.CanViewContributorList(mr, mr.project))
+
+ for perms in [permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ permissions.COMMITTER_INACTIVE_PERMISSIONSET,
+ permissions.OWNER_ACTIVE_PERMISSIONSET,
+ permissions.OWNER_INACTIVE_PERMISSIONSET,
+ permissions.ADMIN_PERMISSIONSET]:
+ mr = testing_helpers.MakeMonorailRequest(
+ path='/p/proj/', project=project, perms=perms)
+ self.assertTrue(permissions.CanViewContributorList(mr, mr.project))
+
+
+class ShouldCheckForAbandonmentTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mr = testing_helpers.Blank(
+ project=project_pb2.Project(),
+ auth=authdata.AuthData())
+
+ def testOwner(self):
+ self.mr.auth.effective_ids = {111}
+ self.mr.perms = permissions.OWNER_ACTIVE_PERMISSIONSET
+ self.assertTrue(permissions.ShouldCheckForAbandonment(self.mr))
+
+ def testNonOwner(self):
+ self.mr.auth.effective_ids = {222}
+ self.mr.perms = permissions.COMMITTER_ACTIVE_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+ self.mr.perms = permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+ self.mr.perms = permissions.USER_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+ self.mr.perms = permissions.EMPTY_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+
+ def testSiteAdmin(self):
+ self.mr.auth.effective_ids = {111}
+ self.mr.perms = permissions.ADMIN_PERMISSIONSET
+ self.assertFalse(permissions.ShouldCheckForAbandonment(self.mr))
+
+
+class RestrictionLabelsTest(unittest.TestCase):
+
+ ORIG_SUMMARY = 'this is the orginal summary'
+ ORIG_LABELS = ['one', 'two']
+
+ def testIsRestrictLabel(self):
+ self.assertFalse(permissions.IsRestrictLabel('Usability'))
+ self.assertTrue(permissions.IsRestrictLabel('Restrict-View-CoreTeam'))
+ # Doing it again will test the cached results.
+ self.assertFalse(permissions.IsRestrictLabel('Usability'))
+ self.assertTrue(permissions.IsRestrictLabel('Restrict-View-CoreTeam'))
+
+ self.assertFalse(permissions.IsRestrictLabel('Usability', perm='View'))
+ self.assertTrue(permissions.IsRestrictLabel(
+ 'Restrict-View-CoreTeam', perm='View'))
+
+ # This one is a restriction label, but not the kind that we want.
+ self.assertFalse(permissions.IsRestrictLabel(
+ 'Restrict-View-CoreTeam', perm='Delete'))
+
+ def testGetRestrictions_NoIssue(self):
+ self.assertEqual([], permissions.GetRestrictions(None))
+
+ def testGetRestrictions_PermSpecified(self):
+ """We can return restiction labels related to the given perm."""
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0, labels=self.ORIG_LABELS)
+ self.assertEqual([], permissions.GetRestrictions(art, perm='view'))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['Restrict-View-Core', 'Hot',
+ 'Restrict-EditIssue-Commit', 'Restrict-EditIssue-Core'])
+ self.assertEqual(
+ ['restrict-view-core'],
+ permissions.GetRestrictions(art, perm='view'))
+ self.assertEqual(
+ ['restrict-view-core'],
+ permissions.GetRestrictions(art, perm='View'))
+ self.assertEqual(
+ ['restrict-editissue-commit', 'restrict-editissue-core'],
+ permissions.GetRestrictions(art, perm='EditIssue'))
+
+ def testGetRestrictions_NoPerm(self):
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0, labels=self.ORIG_LABELS)
+ self.assertEqual([], permissions.GetRestrictions(art))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['Restrict-MissingThirdPart', 'Hot'])
+ self.assertEqual([], permissions.GetRestrictions(art))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['Restrict-View-Core', 'Hot'])
+ self.assertEqual(['restrict-view-core'], permissions.GetRestrictions(art))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['Restrict-View-Core', 'Hot'],
+ derived_labels=['Color-Red', 'Restrict-EditIssue-GoldMembers'])
+ self.assertEqual(
+ ['restrict-view-core', 'restrict-editissue-goldmembers'],
+ permissions.GetRestrictions(art))
+
+ art = fake.MakeTestIssue(
+ 789, 1, self.ORIG_SUMMARY, 'New', 0,
+ labels=['restrict-view-core', 'hot'],
+ derived_labels=['Color-Red', 'RESTRICT-EDITISSUE-GOLDMEMBERS'])
+ self.assertEqual(
+ ['restrict-view-core', 'restrict-editissue-goldmembers'],
+ permissions.GetRestrictions(art))
+
+
+REPORTER_ID = 111
+OWNER_ID = 222
+CC_ID = 333
+OTHER_ID = 444
+APPROVER_ID = 555
+
+
+class IssuePermissionsTest(unittest.TestCase):
+
+ REGULAR_ISSUE = tracker_pb2.Issue()
+ REGULAR_ISSUE.reporter_id = REPORTER_ID
+
+ DELETED_ISSUE = tracker_pb2.Issue()
+ DELETED_ISSUE.deleted = True
+ DELETED_ISSUE.reporter_id = REPORTER_ID
+
+ RESTRICTED_ISSUE = tracker_pb2.Issue()
+ RESTRICTED_ISSUE.reporter_id = REPORTER_ID
+ RESTRICTED_ISSUE.owner_id = OWNER_ID
+ RESTRICTED_ISSUE.cc_ids.append(CC_ID)
+ RESTRICTED_ISSUE.approval_values.append(
+ tracker_pb2.ApprovalValue(approver_ids=[APPROVER_ID])
+ )
+ RESTRICTED_ISSUE.labels.append('Restrict-View-Commit')
+
+ RESTRICTED_ISSUE2 = tracker_pb2.Issue()
+ RESTRICTED_ISSUE2.reporter_id = REPORTER_ID
+ # RESTRICTED_ISSUE2 has no owner
+ RESTRICTED_ISSUE2.cc_ids.append(CC_ID)
+ RESTRICTED_ISSUE2.labels.append('Restrict-View-Commit')
+
+ RESTRICTED_ISSUE3 = tracker_pb2.Issue()
+ RESTRICTED_ISSUE3.reporter_id = REPORTER_ID
+ RESTRICTED_ISSUE3.owner_id = OWNER_ID
+ # Restrict to a permission that no one has.
+ RESTRICTED_ISSUE3.labels.append('Restrict-EditIssue-Foo')
+
+ PROJECT = project_pb2.Project()
+
+ ADMIN_PERMS = permissions.ADMIN_PERMISSIONSET
+ PERMS = permissions.EMPTY_PERMISSIONSET
+
+ def testUpdateIssuePermissions_Normal(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.REGULAR_ISSUE, {})
+
+ self.assertEqual(
+ ['addissuecomment',
+ 'commit',
+ 'createissue',
+ 'deleteown',
+ 'editissue',
+ 'flagspam',
+ 'setstar',
+ 'verdictspam',
+ 'view',
+ 'viewcontributorlist',
+ 'viewinboundmessages',
+ 'viewquota'],
+ sorted(perms.perm_names))
+
+ def testUpdateIssuePermissions_FromConfig(self):
+ config = tracker_pb2.ProjectIssueConfig(
+ field_defs=[tracker_pb2.FieldDef(field_id=123, grants_perm='Granted')])
+ issue = tracker_pb2.Issue(
+ field_values=[tracker_pb2.FieldValue(field_id=123, user_id=111)])
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, issue, {111},
+ config=config)
+ self.assertIn('granted', perms.perm_names)
+
+ def testUpdateIssuePermissions_ExtraPerms(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(999)
+ project.extra_perms.append(
+ project_pb2.Project.ExtraPerms(member_id=999, perms=['EditIssue']))
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, project,
+ self.REGULAR_ISSUE, {999})
+ self.assertIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_ExtraPermsAreSubjectToRestrictions(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(999)
+ project.extra_perms.append(
+ project_pb2.Project.ExtraPerms(member_id=999, perms=['EditIssue']))
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, project,
+ self.RESTRICTED_ISSUE3, {999})
+ self.assertNotIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_GrantedPermsAreNotSubjectToRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE3,
+ {}, granted_perms=['EditIssue'])
+ self.assertIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_RespectConsiderRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.ADMIN_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE3,
+ {})
+ self.assertIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_RestrictionsAreConsideredIndividually(self):
+ issue = tracker_pb2.Issue(
+ labels=[
+ 'Restrict-Perm1-Perm2',
+ 'Restrict-Perm2-Perm3'])
+ perms = permissions.UpdateIssuePermissions(
+ permissions.PermissionSet(['Perm1', 'Perm2', 'View']),
+ self.PROJECT, issue, {})
+ self.assertIn('perm1', perms.perm_names)
+ self.assertNotIn('perm2', perms.perm_names)
+
+ def testUpdateIssuePermissions_DeletedNoPermissions(self):
+ issue = tracker_pb2.Issue(
+ labels=['Restrict-View-Foo'],
+ deleted=True)
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT, issue, {})
+ self.assertEqual([], sorted(perms.perm_names))
+
+ def testUpdateIssuePermissions_ViewDeleted(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.DELETED_ISSUE, {})
+ self.assertEqual(['view'], sorted(perms.perm_names))
+
+ def testUpdateIssuePermissions_ViewAndDeleteDeleted(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.OWNER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.DELETED_ISSUE, {})
+ self.assertEqual(['deleteissue', 'view'], sorted(perms.perm_names))
+
+ def testUpdateIssuePermissions_ViewRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE, {})
+ self.assertNotIn('view', perms.perm_names)
+
+ def testUpdateIssuePermissions_RolesBypassViewRestrictions(self):
+ for role in {OWNER_ID, REPORTER_ID, CC_ID, APPROVER_ID}:
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE,
+ {role})
+ self.assertIn('view', perms.perm_names)
+
+ def testUpdateIssuePermissions_RolesAllowViewingDeleted(self):
+ issue = tracker_pb2.Issue(
+ reporter_id=REPORTER_ID,
+ owner_id=OWNER_ID,
+ cc_ids=[CC_ID],
+ approval_values=[tracker_pb2.ApprovalValue(approver_ids=[APPROVER_ID])],
+ labels=['Restrict-View-Foo'],
+ deleted=True)
+ for role in {OWNER_ID, REPORTER_ID, CC_ID, APPROVER_ID}:
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, issue, {role})
+ self.assertIn('view', perms.perm_names)
+
+ def testUpdateIssuePermissions_GrantedViewPermission(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.USER_PERMISSIONSET, self.PROJECT, self.RESTRICTED_ISSUE,
+ {}, ['commit'])
+ self.assertIn('view', perms.perm_names)
+
+ def testUpdateIssuePermissions_EditRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.RESTRICTED_ISSUE3, {REPORTER_ID, CC_ID, APPROVER_ID})
+ self.assertNotIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_OwnerBypassEditRestrictions(self):
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, self.PROJECT,
+ self.RESTRICTED_ISSUE3, {OWNER_ID})
+ self.assertIn('editissue', perms.perm_names)
+
+ def testUpdateIssuePermissions_CustomPermissionGrantsEditPermission(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(999)
+ project.extra_perms.append(
+ project_pb2.Project.ExtraPerms(member_id=999, perms=['Foo']))
+ perms = permissions.UpdateIssuePermissions(
+ permissions.COMMITTER_ACTIVE_PERMISSIONSET, project,
+ self.RESTRICTED_ISSUE3, {999})
+ self.assertIn('editissue', perms.perm_names)
+
+ def testCanViewIssue_Deleted(self):
+ self.assertFalse(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.DELETED_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.DELETED_ISSUE, allow_viewing_deleted=True))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ def testCanViewIssue_Regular(self):
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID},
+ permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.USER_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanViewIssue(
+ set(), permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ def testCanViewIssue_Restricted(self):
+ # Project owner can always view issue.
+ self.assertTrue(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Member can view because they have Commit perm.
+ self.assertTrue(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Contributors normally do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Non-members do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.USER_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Anon user's do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+
+ def testCanViewIssue_RestrictedParticipants(self):
+ # Reporter can always view issue
+ self.assertTrue(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Issue owner can always view issue
+ self.assertTrue(permissions.CanViewIssue(
+ {OWNER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # CC'd user can always view issue
+ self.assertTrue(permissions.CanViewIssue(
+ {CC_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Non-participants cannot view issue if they don't have the needed perm.
+ self.assertFalse(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Anon user's do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+ # Anon user's cannot match owner 0.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE2))
+ # Approvers can always view issue
+ self.assertTrue(permissions.CanViewIssue(
+ {APPROVER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE))
+
+ def testCannotViewIssueIfCannotViewProject(self):
+ """Cross-project search should not be a backdoor to viewing issues."""
+ # Reporter cannot view issue if they not long have access to the project.
+ self.assertFalse(permissions.CanViewIssue(
+ {REPORTER_ID}, permissions.EMPTY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ # Issue owner cannot always view issue
+ self.assertFalse(permissions.CanViewIssue(
+ {OWNER_ID}, permissions.EMPTY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ # CC'd user cannot always view issue
+ self.assertFalse(permissions.CanViewIssue(
+ {CC_ID}, permissions.EMPTY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ # Non-participants cannot view issue if they don't have the needed perm.
+ self.assertFalse(permissions.CanViewIssue(
+ {OTHER_ID}, permissions.EMPTY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ # Anon user's do not have Commit perm.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.EMPTY_PERMISSIONSET, self.PROJECT,
+ self.REGULAR_ISSUE))
+ # Anon user's cannot match owner 0.
+ self.assertFalse(permissions.CanViewIssue(
+ set(), permissions.EMPTY_PERMISSIONSET, self.PROJECT,
+ self.REGULAR_ISSUE))
+
+ def testCanEditIssue(self):
+ # Anon users cannot edit issues.
+ self.assertFalse(permissions.CanEditIssue(
+ {}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ # Non-members and contributors cannot edit issues,
+ # even if they reported them.
+ self.assertFalse(permissions.CanEditIssue(
+ {REPORTER_ID}, permissions.READ_ONLY_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertFalse(permissions.CanEditIssue(
+ {REPORTER_ID}, permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ # Project committers and project owners can edit issues, regardless
+ # of their role in the issue.
+ self.assertTrue(permissions.CanEditIssue(
+ {REPORTER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {REPORTER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {OWNER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {OWNER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+ self.assertTrue(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.REGULAR_ISSUE))
+
+ def testCanEditIssue_Restricted(self):
+ # Anon users cannot edit restricted issues.
+ self.assertFalse(permissions.CanEditIssue(
+ {}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE3))
+
+ # Project committers cannot edit issues with a restriction to a custom
+ # permission that they don't have.
+ self.assertFalse(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.COMMITTER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE3))
+
+ # *Issue* owners can always edit the issues that they own, even if
+ # those issues are restricted to perms that they don't have.
+ self.assertTrue(permissions.CanEditIssue(
+ {OWNER_ID}, permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE3))
+
+ # Project owners can always edit, they cannot lock themselves out.
+ self.assertTrue(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.OWNER_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE3))
+
+ # A committer with edit permission but not view permission
+ # should not be able to edit the issue.
+ self.assertFalse(permissions.CanEditIssue(
+ {OTHER_ID}, permissions.CONTRIBUTOR_ACTIVE_PERMISSIONSET,
+ self.PROJECT, self.RESTRICTED_ISSUE2))
+
+ def testCanCommentIssue_HasPerm(self):
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([permissions.ADD_ISSUE_COMMENT]),
+ None, None))
+ self.assertFalse(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, None))
+
+ def testCanCommentIssue_HasExtraPerm(self):
+ project = project_pb2.Project()
+ project.committer_ids.append(111)
+ extra_perm = project_pb2.Project.ExtraPerms(
+ member_id=111, perms=[permissions.ADD_ISSUE_COMMENT])
+ project.extra_perms.append(extra_perm)
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ project, None))
+ self.assertFalse(permissions.CanCommentIssue(
+ {222}, permissions.PermissionSet([]),
+ project, None))
+
+ def testCanCommentIssue_Restricted(self):
+ issue = tracker_pb2.Issue(labels=['Restrict-AddIssueComment-CoreTeam'])
+ # User is granted exactly the perm they need specifically in this issue.
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, issue, granted_perms=['addissuecomment']))
+ # User is granted CoreTeam, which satifies the restriction, and allows
+ # them to use the AddIssueComment permission that they have and would
+ # normally be able to use in an unrestricted issue.
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([permissions.ADD_ISSUE_COMMENT]),
+ None, issue, granted_perms=['coreteam']))
+ # User was granted CoreTeam, but never had AddIssueComment.
+ self.assertFalse(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, issue, granted_perms=['coreteam']))
+ # User has AddIssueComment, but cannot satisfy restriction.
+ self.assertFalse(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([permissions.ADD_ISSUE_COMMENT]),
+ None, issue))
+
+ def testCanCommentIssue_Granted(self):
+ self.assertTrue(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, None, granted_perms=['addissuecomment']))
+ self.assertFalse(permissions.CanCommentIssue(
+ {111}, permissions.PermissionSet([]),
+ None, None))
+
+ def testCanUpdateApprovalStatus_Approver(self):
+ # restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [222], tracker_pb2.ApprovalStatus.APPROVED))
+
+ # non-restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [222], tracker_pb2.ApprovalStatus.NEEDS_REVIEW))
+
+ def testCanUpdateApprovalStatus_SiteAdmin(self):
+ # restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {444}, permissions.PermissionSet([permissions.EDIT_ISSUE_APPROVAL]),
+ self.PROJECT, [222], tracker_pb2.ApprovalStatus.NOT_APPROVED))
+
+ # non-restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {444}, permissions.PermissionSet([permissions.EDIT_ISSUE_APPROVAL]),
+ self.PROJECT, [222], tracker_pb2.ApprovalStatus.NEEDS_REVIEW))
+
+ def testCanUpdateApprovalStatus_NonApprover(self):
+ # non-restricted status
+ self.assertTrue(permissions.CanUpdateApprovalStatus(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [333], tracker_pb2.ApprovalStatus.NEED_INFO))
+
+ # restricted status
+ self.assertFalse(permissions.CanUpdateApprovalStatus(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [333], tracker_pb2.ApprovalStatus.NA))
+
+ def testCanUpdateApprovers_Approver(self):
+ self.assertTrue(permissions.CanUpdateApprovers(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [222]))
+
+ def testCanUpdateApprovers_SiteAdmins(self):
+ self.assertTrue(permissions.CanUpdateApprovers(
+ {444}, permissions.PermissionSet([permissions.EDIT_ISSUE_APPROVAL]),
+ self.PROJECT, [222]))
+
+ def testCanUpdateApprovers_NonApprover(self):
+ self.assertFalse(permissions.CanUpdateApprovers(
+ {111, 222}, permissions.PermissionSet([]), self.PROJECT,
+ [333]))
+
+ def testCanViewComponentDef_ComponentAdmin(self):
+ cd = tracker_pb2.ComponentDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanViewComponentDef(
+ {111}, perms, None, cd))
+ self.assertFalse(permissions.CanViewComponentDef(
+ {999}, perms, None, cd))
+
+ def testCanViewComponentDef_NormalUser(self):
+ cd = tracker_pb2.ComponentDef()
+ self.assertTrue(permissions.CanViewComponentDef(
+ {111}, permissions.PermissionSet([permissions.VIEW]),
+ None, cd))
+ self.assertFalse(permissions.CanViewComponentDef(
+ {111}, permissions.PermissionSet([]),
+ None, cd))
+
+ def testCanEditComponentDef_ComponentAdmin(self):
+ cd = tracker_pb2.ComponentDef(admin_ids=[111], path='Whole')
+ sub_cd = tracker_pb2.ComponentDef(admin_ids=[222], path='Whole>Part')
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.component_defs.append(cd)
+ config.component_defs.append(sub_cd)
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditComponentDef(
+ {111}, perms, None, cd, config))
+ self.assertFalse(permissions.CanEditComponentDef(
+ {222}, perms, None, cd, config))
+ self.assertFalse(permissions.CanEditComponentDef(
+ {999}, perms, None, cd, config))
+ self.assertTrue(permissions.CanEditComponentDef(
+ {111}, perms, None, sub_cd, config))
+ self.assertTrue(permissions.CanEditComponentDef(
+ {222}, perms, None, sub_cd, config))
+ self.assertFalse(permissions.CanEditComponentDef(
+ {999}, perms, None, sub_cd, config))
+
+ def testCanEditComponentDef_ProjectOwners(self):
+ cd = tracker_pb2.ComponentDef(path='Whole')
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.component_defs.append(cd)
+ self.assertTrue(permissions.CanEditComponentDef(
+ {111}, permissions.PermissionSet([permissions.EDIT_PROJECT]),
+ None, cd, config))
+ self.assertFalse(permissions.CanEditComponentDef(
+ {111}, permissions.PermissionSet([]),
+ None, cd, config))
+
+ def testCanViewFieldDef_FieldAdmin(self):
+ fd = tracker_pb2.FieldDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanViewFieldDef(
+ {111}, perms, None, fd))
+ self.assertFalse(permissions.CanViewFieldDef(
+ {999}, perms, None, fd))
+
+ def testCanViewFieldDef_NormalUser(self):
+ fd = tracker_pb2.FieldDef()
+ self.assertTrue(permissions.CanViewFieldDef(
+ {111}, permissions.PermissionSet([permissions.VIEW]),
+ None, fd))
+ self.assertFalse(permissions.CanViewFieldDef(
+ {111}, permissions.PermissionSet([]),
+ None, fd))
+
+ def testCanEditFieldDef_FieldAdmin(self):
+ fd = tracker_pb2.FieldDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditFieldDef(
+ {111}, perms, None, fd))
+ self.assertFalse(permissions.CanEditFieldDef(
+ {999}, perms, None, fd))
+
+ def testCanEditFieldDef_ProjectOwners(self):
+ fd = tracker_pb2.FieldDef()
+ self.assertTrue(permissions.CanEditFieldDef(
+ {111}, permissions.PermissionSet([permissions.EDIT_PROJECT]),
+ None, fd))
+ self.assertFalse(permissions.CanEditFieldDef(
+ {111}, permissions.PermissionSet([]),
+ None, fd))
+
+ def testCanEditValueForFieldDef_NotRestrictedField(self):
+ fd = tracker_pb2.FieldDef()
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditValueForFieldDef({111}, perms, None, fd))
+
+ def testCanEditValueForFieldDef_RestrictedFieldEditor(self):
+ fd = tracker_pb2.FieldDef(is_restricted_field=True, editor_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditValueForFieldDef({111}, perms, None, fd))
+ self.assertFalse(
+ permissions.CanEditValueForFieldDef({999}, perms, None, fd))
+
+ def testCanEditValueForFieldDef_RestrictedFieldAdmin(self):
+ fd = tracker_pb2.FieldDef(is_restricted_field=True, admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditValueForFieldDef({111}, perms, None, fd))
+ self.assertFalse(
+ permissions.CanEditValueForFieldDef({999}, perms, None, fd))
+
+ def testCanEditValueForFieldDef_ProjectOwners(self):
+ fd = tracker_pb2.FieldDef(is_restricted_field=True)
+ self.assertTrue(
+ permissions.CanEditValueForFieldDef(
+ {111}, permissions.PermissionSet([permissions.EDIT_PROJECT]), None,
+ fd))
+ self.assertFalse(
+ permissions.CanEditValueForFieldDef(
+ {111}, permissions.PermissionSet([]), None, fd))
+
+ def testCanViewTemplate_TemplateAdmin(self):
+ td = tracker_pb2.TemplateDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanViewTemplate(
+ {111}, perms, None, td))
+ self.assertFalse(permissions.CanViewTemplate(
+ {999}, perms, None, td))
+
+ def testCanViewTemplate_MembersOnly(self):
+ td = tracker_pb2.TemplateDef(members_only=True)
+ project = project_pb2.Project(committer_ids=[111])
+ self.assertTrue(permissions.CanViewTemplate(
+ {111}, permissions.PermissionSet([]),
+ project, td))
+ self.assertFalse(permissions.CanViewTemplate(
+ {999}, permissions.PermissionSet([]),
+ project, td))
+
+ def testCanViewTemplate_AnyoneWhoCanViewProject(self):
+ td = tracker_pb2.TemplateDef()
+ self.assertTrue(permissions.CanViewTemplate(
+ {111}, permissions.PermissionSet([permissions.VIEW]),
+ None, td))
+ self.assertFalse(permissions.CanViewTemplate(
+ {111}, permissions.PermissionSet([]),
+ None, td))
+
+ def testCanEditTemplate_TemplateAdmin(self):
+ td = tracker_pb2.TemplateDef(admin_ids=[111])
+ perms = permissions.PermissionSet([])
+ self.assertTrue(permissions.CanEditTemplate(
+ {111}, perms, None, td))
+ self.assertFalse(permissions.CanEditTemplate(
+ {999}, perms, None, td))
+
+ def testCanEditTemplate_ProjectOwners(self):
+ td = tracker_pb2.TemplateDef()
+ self.assertTrue(permissions.CanEditTemplate(
+ {111}, permissions.PermissionSet([permissions.EDIT_PROJECT]),
+ None, td))
+ self.assertFalse(permissions.CanEditTemplate(
+ {111}, permissions.PermissionSet([]),
+ None, td))
+
+ def testCanViewHotlist_Private(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.is_private = True
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+
+ self.assertTrue(permissions.CanViewHotlist({222}, self.PERMS, hotlist))
+ self.assertTrue(permissions.CanViewHotlist({111, 333}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanViewHotlist({111, 333}, self.ADMIN_PERMS, hotlist))
+ self.assertFalse(
+ permissions.CanViewHotlist({333, 444}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanViewHotlist({333, 444}, self.ADMIN_PERMS, hotlist))
+
+ def testCanViewHotlist_Public(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.is_private = False
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+
+ self.assertTrue(permissions.CanViewHotlist({222}, self.PERMS, hotlist))
+ self.assertTrue(permissions.CanViewHotlist({111, 333}, self.PERMS, hotlist))
+ self.assertTrue(permissions.CanViewHotlist({333, 444}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanViewHotlist({333, 444}, self.ADMIN_PERMS, hotlist))
+
+ def testCanEditHotlist(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+
+ self.assertTrue(permissions.CanEditHotlist({222}, self.PERMS, hotlist))
+ self.assertTrue(permissions.CanEditHotlist({111, 333}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanEditHotlist({111, 333}, self.ADMIN_PERMS, hotlist))
+ self.assertFalse(
+ permissions.CanEditHotlist({333, 444}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanEditHotlist({333, 444}, self.ADMIN_PERMS, hotlist))
+
+ def testCanAdministerHotlist(self):
+ hotlist = features_pb2.Hotlist()
+ hotlist.owner_ids.append(111)
+ hotlist.editor_ids.append(222)
+
+ self.assertFalse(
+ permissions.CanAdministerHotlist({222}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanAdministerHotlist({111, 333}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanAdministerHotlist({111, 333}, self.ADMIN_PERMS, hotlist))
+ self.assertFalse(
+ permissions.CanAdministerHotlist({333, 444}, self.PERMS, hotlist))
+ self.assertTrue(
+ permissions.CanAdministerHotlist({333, 444}, self.ADMIN_PERMS, hotlist))
diff --git a/framework/test/profiler_test.py b/framework/test/profiler_test.py
new file mode 100644
index 0000000..3cc7e85
--- /dev/null
+++ b/framework/test/profiler_test.py
@@ -0,0 +1,138 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Test for monorail.framework.profiler."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import profiler
+
+
+class MockPatchResponse(object):
+ def execute(self):
+ pass
+
+
+class MockCloudTraceProjects(object):
+ def __init__(self):
+ self.patch_response = MockPatchResponse()
+ self.project_id = None
+ self.body = None
+
+ def patchTraces(self, projectId, body):
+ self.project_id = projectId
+ self.body = body
+ return self.patch_response
+
+
+class MockCloudTraceApi(object):
+ def __init__(self):
+ self.mock_projects = MockCloudTraceProjects()
+
+ def projects(self):
+ return self.mock_projects
+
+
+class ProfilerTest(unittest.TestCase):
+
+ def testTopLevelPhase(self):
+ prof = profiler.Profiler()
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.current_phase.parent, None)
+ self.assertEqual(prof.current_phase, prof.top_phase)
+ self.assertEqual(prof.next_color, 0)
+
+ def testSinglePhase(self):
+ prof = profiler.Profiler()
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ with prof.Phase('test'):
+ self.assertEqual(prof.current_phase.name, 'test')
+ self.assertEqual(prof.current_phase.parent.name, 'overall profile')
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.next_color, 1)
+
+ def testSinglePhase_SuperLongName(self):
+ prof = profiler.Profiler()
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ long_name = 'x' * 1000
+ with prof.Phase(long_name):
+ self.assertEqual(
+ 'x' * profiler.MAX_PHASE_NAME_LENGTH, prof.current_phase.name)
+
+ def testSubphaseExecption(self):
+ prof = profiler.Profiler()
+ try:
+ with prof.Phase('foo'):
+ with prof.Phase('bar'):
+ pass
+ with prof.Phase('baz'):
+ raise Exception('whoops')
+ except Exception as e:
+ self.assertEqual(e.message, 'whoops')
+ finally:
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.top_phase.subphases[0].subphases[1].name, 'baz')
+
+ def testSpanJson(self):
+ mock_trace_api = MockCloudTraceApi()
+ mock_trace_context = '1234/5678;xxxxx'
+
+ prof = profiler.Profiler(mock_trace_context, mock_trace_api)
+ with prof.Phase('foo'):
+ with prof.Phase('bar'):
+ pass
+ with prof.Phase('baz'):
+ pass
+
+ # Shouldn't this be automatic?
+ prof.current_phase.End()
+
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.top_phase.subphases[0].subphases[1].name, 'baz')
+ span_json = prof.top_phase.SpanJson()
+ self.assertEqual(len(span_json), 4)
+
+ for span in span_json:
+ self.assertTrue(span['endTime'] > span['startTime'])
+
+ # pylint: disable=unbalanced-tuple-unpacking
+ span1, span2, span3, span4 = span_json
+
+ self.assertEqual(span1['name'], 'overall profile')
+ self.assertEqual(span2['name'], 'foo')
+ self.assertEqual(span3['name'], 'bar')
+ self.assertEqual(span4['name'], 'baz')
+
+ self.assertTrue(span1['startTime'] < span2['startTime'])
+ self.assertTrue(span1['startTime'] < span3['startTime'])
+ self.assertTrue(span1['startTime'] < span4['startTime'])
+
+ self.assertTrue(span1['endTime'] > span2['endTime'])
+ self.assertTrue(span1['endTime'] > span3['endTime'])
+ self.assertTrue(span1['endTime'] > span4['endTime'])
+
+
+ def testReportCloudTrace(self):
+ mock_trace_api = MockCloudTraceApi()
+ mock_trace_context = '1234/5678;xxxxx'
+
+ prof = profiler.Profiler(mock_trace_context, mock_trace_api)
+ with prof.Phase('foo'):
+ with prof.Phase('bar'):
+ pass
+ with prof.Phase('baz'):
+ pass
+
+ # Shouldn't this be automatic?
+ prof.current_phase.End()
+
+ self.assertEqual(prof.current_phase.name, 'overall profile')
+ self.assertEqual(prof.top_phase.subphases[0].subphases[1].name, 'baz')
+
+ prof.ReportTrace()
+ self.assertEqual(mock_trace_api.mock_projects.project_id, 'testing-app')
diff --git a/framework/test/ratelimiter_test.py b/framework/test/ratelimiter_test.py
new file mode 100644
index 0000000..b351f8c
--- /dev/null
+++ b/framework/test/ratelimiter_test.py
@@ -0,0 +1,398 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for RateLimiter.
+"""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import unittest
+
+from google.appengine.api import memcache
+from google.appengine.ext import testbed
+
+import mox
+import os
+import settings
+
+from framework import ratelimiter
+from services import service_manager
+from services import client_config_svc
+from testing import fake
+from testing import testing_helpers
+
+
+class RateLimiterTest(unittest.TestCase):
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_user_stub()
+
+ self.mox = mox.Mox()
+ self.services = service_manager.Services(
+ config=fake.ConfigService(),
+ issue=fake.IssueService(),
+ user=fake.UserService(),
+ project=fake.ProjectService(),
+ )
+ self.project = self.services.project.TestAddProject('proj', project_id=987)
+
+ self.ratelimiter = ratelimiter.RateLimiter()
+ ratelimiter.COUNTRY_LIMITS = {}
+ os.environ['USER_EMAIL'] = ''
+ settings.ratelimiting_enabled = True
+ ratelimiter.DEFAULT_LIMIT = 10
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+ # settings.ratelimiting_enabled = True
+
+ def testCheckStart_pass(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ self.ratelimiter.CheckStart(request)
+ # Should not throw an exception.
+
+ def testCheckStart_fail(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ now = 0.0
+ cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
+ values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
+ cachekeys in cachekeysets]
+ for value in values:
+ memcache.add_multi(value)
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ self.ratelimiter.CheckStart(request, now)
+
+ def testCheckStart_expiredEntries(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ now = 0.0
+ cachekeysets, _, _, _ = ratelimiter._CacheKeys(request, now)
+ values = [{key: ratelimiter.DEFAULT_LIMIT for key in cachekeys} for
+ cachekeys in cachekeysets]
+ for value in values:
+ memcache.add_multi(value)
+
+ now = now + 2 * ratelimiter.EXPIRE_AFTER_SECS
+ self.ratelimiter.CheckStart(request, now)
+ # Should not throw an exception.
+
+ def testCheckStart_repeatedCalls(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ now = 0.0
+
+ # Call CheckStart once every minute. Should be ok.
+ for _ in range(ratelimiter.N_MINUTES):
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 120.0
+
+ # Call CheckStart more than DEFAULT_LIMIT times in the same minute.
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for _ in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
+ now = now + 0.001
+ self.ratelimiter.CheckStart(request, now)
+
+ def testCheckStart_differentIPs(self):
+ now = 0.0
+
+ ratelimiter.COUNTRY_LIMITS = {}
+ # Exceed DEFAULT_LIMIT calls, but vary remote_addr so different
+ # remote addresses aren't ratelimited together.
+ for m in range(ratelimiter.DEFAULT_LIMIT * 2):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.%d' % (m % 16)
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ # Exceed the limit, but only for one IP address. The
+ # others should be fine.
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ # Now proceed to make requests for all of the other IP
+ # addresses besides .0.
+ for m in range(ratelimiter.DEFAULT_LIMIT * 2):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ # Skip .0 since it's already exceeded the limit.
+ request.remote_addr = '192.168.1.%d' % (m + 1)
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ def testCheckStart_sameIPDifferentUserIDs(self):
+ # Behind a NAT, e.g.
+ now = 0.0
+
+ # Exceed DEFAULT_LIMIT calls, but vary user_id so different
+ # users behind the same IP aren't ratelimited together.
+ for m in range(ratelimiter.DEFAULT_LIMIT * 2):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.remote_addr = '192.168.1.0'
+ os.environ['USER_EMAIL'] = '%s@example.com' % m
+ request.headers['X-AppEngine-Country'] = 'US'
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ # Exceed the limit, but only for one userID+IP address. The
+ # others should be fine.
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ os.environ['USER_EMAIL'] = '42@example.com'
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ # Now proceed to make requests for other user IDs
+ # besides 42.
+ for m in range(ratelimiter.DEFAULT_LIMIT * 2):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ # Skip .0 since it's already exceeded the limit.
+ request.remote_addr = '192.168.1.0'
+ os.environ['USER_EMAIL'] = '%s@example.com' % (43 + m)
+ ratelimiter._CacheKeys(request, now)
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ def testCheckStart_ratelimitingDisabled(self):
+ settings.ratelimiting_enabled = False
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers['X-AppEngine-Country'] = 'US'
+ request.remote_addr = '192.168.1.0'
+ now = 0.0
+
+ # Call CheckStart a lot. Should be ok.
+ for _ in range(ratelimiter.DEFAULT_LIMIT):
+ self.ratelimiter.CheckStart(request, now)
+ now = now + 0.001
+
+ def testCheckStart_perCountryLoggedOutLimit(self):
+ ratelimiter.COUNTRY_LIMITS['US'] = 10
+
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
+ request.remote_addr = '192.168.1.1'
+ now = 0.0
+
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for m in range(ratelimiter.DEFAULT_LIMIT + 2): # pragma: no branch
+ self.ratelimiter.CheckStart(request, now)
+ # Vary remote address to make sure the limit covers
+ # the whole country, regardless of IP.
+ request.remote_addr = '192.168.1.%d' % m
+ now = now + 0.001
+
+ # CheckStart for a country that isn't covered by a country-specific limit.
+ request.headers['X-AppEngine-Country'] = 'UK'
+ for m in range(11):
+ self.ratelimiter.CheckStart(request, now)
+ # Vary remote address to make sure the limit covers
+ # the whole country, regardless of IP.
+ request.remote_addr = '192.168.1.%d' % m
+ now = now + 0.001
+
+ # And regular rate limits work per-IP.
+ request.remote_addr = '192.168.1.1'
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ for m in range(ratelimiter.DEFAULT_LIMIT): # pragma: no branch
+ self.ratelimiter.CheckStart(request, now)
+ # Vary remote address to make sure the limit covers
+ # the whole country, regardless of IP.
+ now = now + 0.001
+
+ def testCheckEnd_SlowRequest(self):
+ """We count one request for each 1000ms."""
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers[ratelimiter.COUNTRY_HEADER] = 'US'
+ request.remote_addr = '192.168.1.1'
+ start_time = 0.0
+
+ # Send some requests, all under the limit.
+ for _ in range(ratelimiter.DEFAULT_LIMIT-1):
+ start_time = start_time + 0.001
+ self.ratelimiter.CheckStart(request, start_time)
+ now = start_time + 0.010
+ self.ratelimiter.CheckEnd(request, now, start_time)
+
+ # Now issue some more request, this time taking long
+ # enough to get the cost threshold penalty.
+ # Fast forward enough to impact a later bucket than the
+ # previous requests.
+ start_time = now + 120.0
+ self.ratelimiter.CheckStart(request, start_time)
+
+ # Take longer than the threshold to process the request.
+ elapsed_ms = settings.ratelimiting_ms_per_count * 2
+ now = start_time + elapsed_ms / 1000
+
+ # The request finished, taking long enough to count as two.
+ self.ratelimiter.CheckEnd(request, now, start_time)
+
+ with self.assertRaises(ratelimiter.RateLimitExceeded):
+ # One more request after the expensive query should
+ # throw an excpetion.
+ self.ratelimiter.CheckStart(request, start_time)
+
+ def testCheckEnd_FastRequest(self):
+ request, _ = testing_helpers.GetRequestObjects(
+ project=self.project)
+ request.headers[ratelimiter.COUNTRY_HEADER] = 'asdasd'
+ request.remote_addr = '192.168.1.1'
+ start_time = 0.0
+
+ # Send some requests, all under the limit.
+ for _ in range(ratelimiter.DEFAULT_LIMIT):
+ self.ratelimiter.CheckStart(request, start_time)
+ now = start_time + 0.01
+ self.ratelimiter.CheckEnd(request, now, start_time)
+ start_time = now + 0.01
+
+
+class ApiRateLimiterTest(unittest.TestCase):
+
+ def setUp(self):
+ settings.ratelimiting_enabled = True
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.services = service_manager.Services(
+ config=fake.ConfigService(),
+ issue=fake.IssueService(),
+ user=fake.UserService(),
+ project=fake.ProjectService(),
+ )
+
+ self.client_id = '123456789'
+ self.client_email = 'test@example.com'
+
+ self.ratelimiter = ratelimiter.ApiRateLimiter()
+ settings.api_ratelimiting_enabled = True
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testCheckStart_Allowed(self):
+ now = 0.0
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+ self.ratelimiter.CheckStart(self.client_id, None, now)
+ self.ratelimiter.CheckStart(None, None, now)
+ self.ratelimiter.CheckStart('anonymous', None, now)
+
+ def testCheckStart_Rejected(self):
+ now = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, now)
+ values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
+ keyset in keysets]
+ for value in values:
+ memcache.add_multi(value)
+ with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+
+ def testCheckStart_Allowed_HigherQPMSpecified(self):
+ """Client goes over the default, but has a higher QPM set."""
+ now = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, now)
+ qpm_dict = client_config_svc.GetQPMDict()
+ qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM + 10
+ # The client used 1 request more than the default limit in each of the
+ # 5 minutes in our 5 minute sample window, so 5 over to the total.
+ values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
+ keyset in keysets]
+ for value in values:
+ memcache.add_multi(value)
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+ del qpm_dict[self.client_email]
+
+ def testCheckStart_Allowed_LowQPMIgnored(self):
+ """Client specifies a QPM lower than the default and default is used."""
+ now = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, now)
+ qpm_dict = client_config_svc.GetQPMDict()
+ qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
+ values = [{key: ratelimiter.DEFAULT_API_QPM for key in keyset} for
+ keyset in keysets]
+ for value in values:
+ memcache.add_multi(value)
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+ del qpm_dict[self.client_email]
+
+ def testCheckStart_Rejected_LowQPMIgnored(self):
+ """Client specifies a QPM lower than the default and default is used."""
+ now = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, now)
+ qpm_dict = client_config_svc.GetQPMDict()
+ qpm_dict[self.client_email] = ratelimiter.DEFAULT_API_QPM - 10
+ values = [{key: ratelimiter.DEFAULT_API_QPM + 1 for key in keyset} for
+ keyset in keysets]
+ for value in values:
+ memcache.add_multi(value)
+ with self.assertRaises(ratelimiter.ApiRateLimitExceeded):
+ self.ratelimiter.CheckStart(self.client_id, self.client_email, now)
+ del qpm_dict[self.client_email]
+
+ def testCheckEnd(self):
+ start_time = 0.0
+ keysets = ratelimiter._CreateApiCacheKeys(
+ self.client_id, self.client_email, start_time)
+
+ now = 0.1
+ self.ratelimiter.CheckEnd(
+ self.client_id, self.client_email, now, start_time)
+ counters = memcache.get_multi(keysets[0])
+ count = sum(counters.values())
+ # No extra cost charged
+ self.assertEqual(0, count)
+
+ elapsed_ms = settings.ratelimiting_ms_per_count * 2
+ now = start_time + elapsed_ms / 1000
+ self.ratelimiter.CheckEnd(
+ self.client_id, self.client_email, now, start_time)
+ counters = memcache.get_multi(keysets[0])
+ count = sum(counters.values())
+ # Extra cost charged
+ self.assertEqual(1, count)
diff --git a/framework/test/reap_test.py b/framework/test/reap_test.py
new file mode 100644
index 0000000..f1a907d
--- /dev/null
+++ b/framework/test/reap_test.py
@@ -0,0 +1,131 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the reap module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import mock
+import mox
+
+from mock import Mock
+
+from framework import reap
+from framework import sql
+from proto import project_pb2
+from services import service_manager
+from services import template_svc
+from testing import fake
+from testing import testing_helpers
+
+
+class ReapTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project_service = fake.ProjectService()
+ self.issue_service = fake.IssueService()
+ self.issue_star_service = fake.IssueStarService()
+ self.config_service = fake.ConfigService()
+ self.features_service = fake.FeaturesService()
+ self.project_star_service = fake.ProjectStarService()
+ self.services = service_manager.Services(
+ project=self.project_service,
+ issue=self.issue_service,
+ issue_star=self.issue_star_service,
+ config=self.config_service,
+ features=self.features_service,
+ project_star=self.project_star_service,
+ template=Mock(spec=template_svc.TemplateService),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+
+ self.proj1_id = 1001
+ self.proj1_issue_id = 111
+ self.proj1 = self.project_service.TestAddProject(
+ name='proj1', project_id=self.proj1_id)
+ self.proj2_id = 1002
+ self.proj2_issue_id = 112
+ self.proj2 = self.project_service.TestAddProject(
+ name='proj2', project_id=self.proj2_id)
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.project_service.project_tbl = self.mox.CreateMock(sql.SQLTableManager)
+ self.issue_service.issue_tbl = self.mox.CreateMock(sql.SQLTableManager)
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def setUpMarkDoomedProjects(self):
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'], limit=1000, state='archived',
+ where=mox.IgnoreArg()).AndReturn([[self.proj1_id]])
+
+ def testMarkDoomedProjects(self):
+ self.setUpMarkDoomedProjects()
+ reaper = reap.Reap('req', 'resp', services=self.services)
+
+ self.mox.ReplayAll()
+ doomed_project_ids = reaper._MarkDoomedProjects(self.cnxn)
+ self.mox.VerifyAll()
+
+ self.assertEqual([self.proj1_id], doomed_project_ids)
+ self.assertEqual(project_pb2.ProjectState.DELETABLE, self.proj1.state)
+ self.assertEqual('DELETABLE_%s' % self.proj1_id, self.proj1.project_name)
+
+ def setUpExpungeParts(self):
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'], limit=100,
+ state='deletable').AndReturn([[self.proj1_id], [self.proj2_id]])
+ self.issue_service.issue_tbl.Select(
+ self.cnxn, cols=['id'], limit=1000,
+ project_id=self.proj1_id).AndReturn([[self.proj1_issue_id]])
+ self.issue_service.issue_tbl.Select(
+ self.cnxn, cols=['id'], limit=1000,
+ project_id=self.proj2_id).AndReturn([[self.proj2_issue_id]])
+
+ def testExpungeDeletableProjects(self):
+ self.setUpExpungeParts()
+ reaper = reap.Reap('req', 'resp', services=self.services)
+
+ self.mox.ReplayAll()
+ expunged_project_ids = reaper._ExpungeDeletableProjects(self.cnxn)
+ self.mox.VerifyAll()
+
+ self.assertEqual([self.proj1_id, self.proj2_id], expunged_project_ids)
+ # Verify all expected expunge methods were called.
+ self.assertEqual(
+ [self.proj1_issue_id, self.proj2_issue_id],
+ self.services.issue_star.expunged_item_ids)
+ self.assertEqual(
+ [self.proj1_issue_id, self.proj2_issue_id],
+ self.services.issue.expunged_issues)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id], self.services.config.expunged_configs)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.features.expunged_saved_queries)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.features.expunged_filter_rules)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.issue.expunged_former_locations)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id], self.services.issue.expunged_local_ids)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.features.expunged_quick_edit)
+ self.assertEqual(
+ [self.proj1_id, self.proj2_id],
+ self.services.project_star.expunged_item_ids)
+ self.assertEqual(0, len(self.services.project.test_projects))
+ self.services.template.ExpungeProjectTemplates.assert_has_calls([
+ mock.call(self.cnxn, 1001),
+ mock.call(self.cnxn, 1002)])
diff --git a/framework/test/redis_utils_test.py b/framework/test/redis_utils_test.py
new file mode 100644
index 0000000..a4128ce
--- /dev/null
+++ b/framework/test/redis_utils_test.py
@@ -0,0 +1,64 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""Tests for the Redis utility module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import fakeredis
+import unittest
+
+from framework import redis_utils
+from proto import features_pb2
+
+
+class RedisHelperTest(unittest.TestCase):
+
+ def testFormatRedisKey(self):
+ redis_key = redis_utils.FormatRedisKey(111)
+ self.assertEqual('111', redis_key)
+ redis_key = redis_utils.FormatRedisKey(222, prefix='foo:')
+ self.assertEqual('foo:222', redis_key)
+ redis_key = redis_utils.FormatRedisKey(333, prefix='bar')
+ self.assertEqual('bar:333', redis_key)
+
+ def testCreateRedisClient(self):
+ self.assertIsNone(redis_utils.connection_pool)
+ redis_client_1 = redis_utils.CreateRedisClient()
+ self.assertIsNotNone(redis_client_1)
+ self.assertIsNotNone(redis_utils.connection_pool)
+ redis_client_2 = redis_utils.CreateRedisClient()
+ self.assertIsNotNone(redis_client_2)
+ self.assertIsNot(redis_client_1, redis_client_2)
+
+ def testConnectionVerification(self):
+ server = fakeredis.FakeServer()
+ client = None
+ self.assertFalse(redis_utils.VerifyRedisConnection(client))
+ server.connected = True
+ client = fakeredis.FakeRedis(server=server)
+ self.assertTrue(redis_utils.VerifyRedisConnection(client))
+ server.connected = False
+ self.assertFalse(redis_utils.VerifyRedisConnection(client))
+
+ def testSerializeDeserializeInt(self):
+ serialized_int = redis_utils.SerializeValue(123)
+ self.assertEqual('123', serialized_int)
+ self.assertEquals(123, redis_utils.DeserializeValue(serialized_int))
+
+ def testSerializeDeserializeStr(self):
+ serialized = redis_utils.SerializeValue('123')
+ self.assertEqual('"123"', serialized)
+ self.assertEquals('123', redis_utils.DeserializeValue(serialized))
+
+ def testSerializeDeserializePB(self):
+ features = features_pb2.Hotlist.HotlistItem(
+ issue_id=7949, rank=0, adder_id=333, date_added=1525)
+ serialized = redis_utils.SerializeValue(
+ features, pb_class=features_pb2.Hotlist.HotlistItem)
+ self.assertIsInstance(serialized, str)
+ deserialized = redis_utils.DeserializeValue(
+ serialized, pb_class=features_pb2.Hotlist.HotlistItem)
+ self.assertIsInstance(deserialized, features_pb2.Hotlist.HotlistItem)
+ self.assertEquals(deserialized, features)
diff --git a/framework/test/registerpages_helpers_test.py b/framework/test/registerpages_helpers_test.py
new file mode 100644
index 0000000..61c489e
--- /dev/null
+++ b/framework/test/registerpages_helpers_test.py
@@ -0,0 +1,59 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for URL handler registration helper functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import webapp2
+
+from framework import registerpages_helpers
+
+
+class SendRedirectInScopeTest(unittest.TestCase):
+
+ def testMakeRedirectInScope_Error(self):
+ self.assertRaises(
+ AssertionError,
+ registerpages_helpers.MakeRedirectInScope, 'no/initial/slash', 'p')
+ self.assertRaises(
+ AssertionError,
+ registerpages_helpers.MakeRedirectInScope, '', 'p')
+
+ def testMakeRedirectInScope_Normal(self):
+ factory = registerpages_helpers.MakeRedirectInScope('/', 'p')
+ # Non-dasher, normal case
+ request = webapp2.Request.blank(
+ path='/p/foo', headers={'Host': 'example.com'})
+ response = webapp2.Response()
+ redirector = factory(request, response)
+ redirector.get()
+ self.assertEqual(response.location, '//example.com/p/foo/')
+ self.assertEqual(response.status, '301 Moved Permanently')
+
+ def testMakeRedirectInScope_Temporary(self):
+ factory = registerpages_helpers.MakeRedirectInScope(
+ '/', 'p', permanent=False)
+ request = webapp2.Request.blank(
+ path='/p/foo', headers={'Host': 'example.com'})
+ response = webapp2.Response()
+ redirector = factory(request, response)
+ redirector.get()
+ self.assertEqual(response.location, '//example.com/p/foo/')
+ self.assertEqual(response.status, '302 Moved Temporarily')
+
+ def testMakeRedirectInScope_KeepQueryString(self):
+ factory = registerpages_helpers.MakeRedirectInScope(
+ '/', 'p', keep_qs=True)
+ request = webapp2.Request.blank(
+ path='/p/foo?q=1', headers={'Host': 'example.com'})
+ response = webapp2.Response()
+ redirector = factory(request, response)
+ redirector.get()
+ self.assertEqual(response.location, '//example.com/p/foo/?q=1')
+ self.assertEqual(response.status, '302 Moved Temporarily')
diff --git a/framework/test/servlet_helpers_test.py b/framework/test/servlet_helpers_test.py
new file mode 100644
index 0000000..a2fe687
--- /dev/null
+++ b/framework/test/servlet_helpers_test.py
@@ -0,0 +1,168 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for servlet base class helper functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from google.appengine.ext import testbed
+
+
+from framework import permissions
+from framework import servlet_helpers
+from proto import project_pb2
+from proto import tracker_pb2
+from testing import testing_helpers
+
+
+class EztDataTest(unittest.TestCase):
+
+ def testGetBannerTime(self):
+ """Tests GetBannerTime method."""
+ timestamp = [2019, 6, 13, 18, 30]
+
+ banner_time = servlet_helpers.GetBannerTime(timestamp)
+ self.assertEqual(1560450600, banner_time)
+
+
+class AssertBasePermissionTest(unittest.TestCase):
+
+ def testAccessGranted(self):
+ _, mr = testing_helpers.GetRequestObjects(path='/hosting')
+ # No exceptions should be raised.
+ servlet_helpers.AssertBasePermission(mr)
+
+ mr.auth.user_id = 123
+ # No exceptions should be raised.
+ servlet_helpers.AssertBasePermission(mr)
+ servlet_helpers.AssertBasePermissionForUser(
+ mr.auth.user_pb, mr.auth.user_view)
+
+ def testBanned(self):
+ _, mr = testing_helpers.GetRequestObjects(path='/hosting')
+ mr.auth.user_pb.banned = 'spammer'
+ self.assertRaises(
+ permissions.BannedUserException,
+ servlet_helpers.AssertBasePermissionForUser,
+ mr.auth.user_pb, mr.auth.user_view)
+ self.assertRaises(
+ permissions.BannedUserException,
+ servlet_helpers.AssertBasePermission, mr)
+
+ def testPlusAddressAccount(self):
+ _, mr = testing_helpers.GetRequestObjects(path='/hosting')
+ mr.auth.user_pb.email = 'mailinglist+spammer@chromium.org'
+ self.assertRaises(
+ permissions.BannedUserException,
+ servlet_helpers.AssertBasePermissionForUser,
+ mr.auth.user_pb, mr.auth.user_view)
+ self.assertRaises(
+ permissions.BannedUserException,
+ servlet_helpers.AssertBasePermission, mr)
+
+ def testNoAccessToProject(self):
+ project = project_pb2.Project()
+ project.project_name = 'proj'
+ project.access = project_pb2.ProjectAccess.MEMBERS_ONLY
+ _, mr = testing_helpers.GetRequestObjects(path='/p/proj/', project=project)
+ mr.perms = permissions.EMPTY_PERMISSIONSET
+ self.assertRaises(
+ permissions.PermissionException,
+ servlet_helpers.AssertBasePermission, mr)
+
+
+FORM_URL = 'http://example.com/issues/form.php'
+
+
+class ComputeIssueEntryURLTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project = project_pb2.Project()
+ self.project.project_name = 'proj'
+ self.config = tracker_pb2.ProjectIssueConfig()
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_user_stub()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testComputeIssueEntryURL_Normal(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues/detail?id=123&q=term',
+ project=self.project)
+
+ url = servlet_helpers.ComputeIssueEntryURL(mr, self.config)
+ self.assertEqual('/p/proj/issues/entry', url)
+
+ def testComputeIssueEntryURL_Customized(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues/detail?id=123&q=term',
+ project=self.project)
+ mr.auth.user_id = 111
+ self.config.custom_issue_entry_url = FORM_URL
+
+ url = servlet_helpers.ComputeIssueEntryURL(mr, self.config)
+ self.assertTrue(url.startswith(FORM_URL))
+ self.assertIn('token=', url)
+ self.assertIn('role=', url)
+ self.assertIn('continue=', url)
+
+class IssueListURLTest(unittest.TestCase):
+
+ def setUp(self):
+ self.project = project_pb2.Project()
+ self.project.project_name = 'proj'
+ self.project.owner_ids = [111]
+ self.config = tracker_pb2.ProjectIssueConfig()
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_user_stub()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testIssueListURL_NotCustomized(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues', project=self.project)
+
+ url = servlet_helpers.IssueListURL(mr, self.config)
+ self.assertEqual('/p/proj/issues/list', url)
+
+ def testIssueListURL_Customized_Nonmember(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues', project=self.project)
+ self.config.member_default_query = 'owner:me'
+
+ url = servlet_helpers.IssueListURL(mr, self.config)
+ self.assertEqual('/p/proj/issues/list', url)
+
+ def testIssueListURL_Customized_Member(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues', project=self.project,
+ user_info={'effective_ids': {111}})
+ self.config.member_default_query = 'owner:me'
+
+ url = servlet_helpers.IssueListURL(mr, self.config)
+ self.assertEqual('/p/proj/issues/list?q=owner%3Ame', url)
+
+ def testIssueListURL_Customized_RetainQS(self):
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/issues', project=self.project,
+ user_info={'effective_ids': {111}})
+ self.config.member_default_query = 'owner:me'
+
+ url = servlet_helpers.IssueListURL(mr, self.config, query_string='')
+ self.assertEqual('/p/proj/issues/list?q=owner%3Ame', url)
+
+ url = servlet_helpers.IssueListURL(mr, self.config, query_string='q=Pri=1')
+ self.assertEqual('/p/proj/issues/list?q=Pri=1', url)
diff --git a/framework/test/servlet_test.py b/framework/test/servlet_test.py
new file mode 100644
index 0000000..40d5ed2
--- /dev/null
+++ b/framework/test/servlet_test.py
@@ -0,0 +1,474 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for servlet base class module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import mock
+import unittest
+
+from google.appengine.api import app_identity
+from google.appengine.ext import testbed
+
+import webapp2
+
+from framework import framework_constants
+from framework import servlet
+from framework import xsrf
+from proto import project_pb2
+from proto import tracker_pb2
+from proto import user_pb2
+from services import service_manager
+from testing import fake
+from testing import testing_helpers
+
+
+class TestableServlet(servlet.Servlet):
+ """A tiny concrete subclass of abstract class Servlet."""
+
+ def __init__(self, request, response, services=None, do_post_redirect=True):
+ super(TestableServlet, self).__init__(request, response, services=services)
+ self.do_post_redirect = do_post_redirect
+ self.seen_post_data = None
+
+ def ProcessFormData(self, _mr, post_data):
+ self.seen_post_data = post_data
+ if self.do_post_redirect:
+ return '/This/Is?The=Next#Page'
+ else:
+ self.response.write('sending raw data to browser')
+
+
+class ServletTest(unittest.TestCase):
+
+ def setUp(self):
+ services = service_manager.Services(
+ project=fake.ProjectService(),
+ project_star=fake.ProjectStarService(),
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService())
+ services.user.TestAddUser('user@example.com', 111)
+ self.page_class = TestableServlet(
+ webapp2.Request.blank('/'), webapp2.Response(), services=services)
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_user_stub()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testDefaultValues(self):
+ self.assertEqual(None, self.page_class._MAIN_TAB_MODE)
+ self.assertTrue(self.page_class._TEMPLATE_PATH.endswith('/templates/'))
+ self.assertEqual(None, self.page_class._PAGE_TEMPLATE)
+
+ def testGatherBaseData(self):
+ project = self.page_class.services.project.TestAddProject(
+ 'testproj', state=project_pb2.ProjectState.LIVE)
+ project.cached_content_timestamp = 12345
+
+ (_request, mr) = testing_helpers.GetRequestObjects(
+ path='/p/testproj/feeds', project=project)
+ nonce = '1a2b3c4d5e6f7g'
+
+ base_data = self.page_class.GatherBaseData(mr, nonce)
+
+ self.assertEqual(base_data['nonce'], nonce)
+ self.assertEqual(base_data['projectname'], 'testproj')
+ self.assertEqual(base_data['project'].cached_content_timestamp, 12345)
+ self.assertEqual(base_data['project_alert'], None)
+
+ self.assertTrue(base_data['currentPageURL'].endswith('/p/testproj/feeds'))
+ self.assertTrue(
+ base_data['currentPageURLEncoded'].endswith('%2Fp%2Ftestproj%2Ffeeds'))
+
+ def testFormHandlerURL(self):
+ self.assertEqual('/edit.do', self.page_class._FormHandlerURL('/'))
+ self.assertEqual(
+ '/something/edit.do',
+ self.page_class._FormHandlerURL('/something/'))
+ self.assertEqual(
+ '/something/edit.do',
+ self.page_class._FormHandlerURL('/something/edit.do'))
+ self.assertEqual(
+ '/something/detail_ezt.do',
+ self.page_class._FormHandlerURL('/something/detail_ezt'))
+
+ def testProcessForm_BadToken(self):
+ user_id = 111
+ token = 'no soup for you'
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ },
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ self.assertRaises(
+ xsrf.TokenIncorrect, self.page_class._DoFormProcessing, request, mr)
+ self.assertEqual(None, self.page_class.seen_post_data)
+
+ def testProcessForm_XhrAllowed_BadToken(self):
+ user_id = 111
+ token = 'no soup for you'
+
+ self.page_class.ALLOW_XHR = True
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ },
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ self.assertRaises(
+ xsrf.TokenIncorrect, self.page_class._DoFormProcessing, request, mr)
+ self.assertEqual(None, self.page_class.seen_post_data)
+
+ def testProcessForm_XhrAllowed_AcceptsPathToken(self):
+ user_id = 111
+ token = xsrf.GenerateToken(user_id, '/we/we/we')
+
+ self.page_class.ALLOW_XHR = True
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ },
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._DoFormProcessing(request, mr)
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+
+ self.assertDictEqual(
+ {
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ }, dict(self.page_class.seen_post_data))
+
+ def testProcessForm_XhrAllowed_AcceptsXhrToken(self):
+ user_id = 111
+ token = xsrf.GenerateToken(user_id, 'xhr')
+
+ self.page_class.ALLOW_XHR = True
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={'yesterday': 'thursday', 'today': 'friday', 'token': token},
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._DoFormProcessing(request, mr)
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+
+ self.assertDictEqual(
+ {
+ 'yesterday': 'thursday',
+ 'today': 'friday',
+ 'token': token
+ }, dict(self.page_class.seen_post_data))
+
+ def testProcessForm_RawResponse(self):
+ user_id = 111
+ token = xsrf.GenerateToken(user_id, '/we/we/we')
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={'yesterday': 'thursday', 'today': 'friday', 'token': token},
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ self.page_class.do_post_redirect = False
+ self.page_class._DoFormProcessing(request, mr)
+ self.assertEqual(
+ 'sending raw data to browser',
+ self.page_class.response.body)
+
+ def testProcessForm_Normal(self):
+ user_id = 111
+ token = xsrf.GenerateToken(user_id, '/we/we/we')
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/we/we/we?so=excited',
+ params={'yesterday': 'thursday', 'today': 'friday', 'token': token},
+ user_info={'user_id': user_id},
+ method='POST',
+ )
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._DoFormProcessing(request, mr)
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+
+ self.assertDictEqual(
+ {'yesterday': 'thursday', 'today': 'friday', 'token': token},
+ dict(self.page_class.seen_post_data))
+
+ def testCalcProjectAlert(self):
+ project = fake.Project(
+ project_name='alerttest', state=project_pb2.ProjectState.LIVE)
+
+ project_alert = servlet._CalcProjectAlert(project)
+ self.assertEqual(project_alert, None)
+
+ project.state = project_pb2.ProjectState.ARCHIVED
+ project_alert = servlet._CalcProjectAlert(project)
+ self.assertEqual(
+ project_alert,
+ 'Project is archived: read-only by members only.')
+
+ delete_time = int(time.time() + framework_constants.SECS_PER_DAY * 1.5)
+ project.delete_time = delete_time
+ project_alert = servlet._CalcProjectAlert(project)
+ self.assertEqual(project_alert, 'Scheduled for deletion in 1 day.')
+
+ delete_time = int(time.time() + framework_constants.SECS_PER_DAY * 2.5)
+ project.delete_time = delete_time
+ project_alert = servlet._CalcProjectAlert(project)
+ self.assertEqual(project_alert, 'Scheduled for deletion in 2 days.')
+
+ def testCheckForMovedProject_NoRedirect(self):
+ project = fake.Project(
+ project_name='proj', state=project_pb2.ProjectState.LIVE)
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/source/browse/p/adminAdvanced', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ def testCheckForMovedProject_Redirect(self):
+ project = fake.Project(project_name='proj', moved_to='http://example.com')
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._CheckForMovedProject(mr, request)
+ self.assertEqual(302, cm.exception.code) # redirect because project moved
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/source/browse/p/adminAdvanced', project=project)
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._CheckForMovedProject(mr, request)
+ self.assertEqual(302, cm.exception.code) # redirect because project moved
+
+ def testCheckForMovedProject_AdminAdvanced(self):
+ """We do not redirect away from the page that edits project state."""
+ project = fake.Project(project_name='proj', moved_to='http://example.com')
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/adminAdvanced', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/adminAdvanced?ts=123234', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/adminAdvanced.do', project=project)
+ self.page_class._CheckForMovedProject(mr, request)
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_RedirBrandedProject(self):
+ """We redirect for a branded project if the user typed a different host."""
+ project = fake.Project(project_name='proj')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path', project=project)
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+ self.assertEqual('https://branded.example.com/p/proj/path?redir=1',
+ cm.exception.location)
+
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path?query', project=project)
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+ self.assertEqual('https://branded.example.com/p/proj/path?query&redir=1',
+ cm.exception.location)
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_AvoidRedirLoops(self):
+ """Don't redirect for a branded project if already redirected."""
+ project = fake.Project(project_name='proj')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path?redir=1', project=project)
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_NonProjectPage(self):
+ """Don't redirect for a branded project if not in any project."""
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/u/user@example.com')
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, None)
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_AlreadyOnBrandedHost(self):
+ """Don't redirect for a branded project if already on branded domain."""
+ project = fake.Project(project_name='proj')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path', project=project)
+ request.host = 'branded.example.com'
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_Localhost(self):
+ """Don't redirect for a branded project on localhost."""
+ project = fake.Project(project_name='proj')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/proj/path', project=project)
+ request.host = 'localhost:8080'
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+
+ request.host = '0.0.0.0:8080'
+ # No redirect happens.
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'proj')
+
+ @mock.patch('settings.branded_domains',
+ {'proj': 'branded.example.com', '*': 'bugs.chromium.org'})
+ def testMaybeRedirectToBrandedDomain_NotBranded(self):
+ """Don't redirect for a non-branded project."""
+ project = fake.Project(project_name='other')
+ request, _mr = testing_helpers.GetRequestObjects(
+ path='/p/other/path?query', project=project)
+ request.host = 'branded.example.com' # But other project is unbranded.
+
+ with self.assertRaises(webapp2.HTTPException) as cm:
+ self.page_class._MaybeRedirectToBrandedDomain(request, 'other')
+ self.assertEqual(302, cm.exception.code) # forms redirect on success
+ self.assertEqual('https://bugs.chromium.org/p/other/path?query&redir=1',
+ cm.exception.location)
+
+ def testGatherHelpData_Normal(self):
+ project = fake.Project(project_name='proj')
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual(None, help_data['cue'])
+ self.assertEqual(None, help_data['account_cue'])
+
+ def testGatherHelpData_VacationReminder(self):
+ project = fake.Project(project_name='proj')
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ mr.auth.user_id = 111
+ mr.auth.user_pb.vacation_message = 'Gone skiing'
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual('you_are_on_vacation', help_data['cue'])
+
+ self.page_class.services.user.SetUserPrefs(
+ 'cnxn', 111,
+ [user_pb2.UserPrefValue(name='you_are_on_vacation', value='true')])
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual(None, help_data['cue'])
+ self.assertEqual(None, help_data['account_cue'])
+
+ def testGatherHelpData_YouAreBouncing(self):
+ project = fake.Project(project_name='proj')
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ mr.auth.user_id = 111
+ mr.auth.user_pb.email_bounce_timestamp = 1497647529
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual('your_email_bounced', help_data['cue'])
+
+ self.page_class.services.user.SetUserPrefs(
+ 'cnxn', 111,
+ [user_pb2.UserPrefValue(name='your_email_bounced', value='true')])
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual(None, help_data['cue'])
+ self.assertEqual(None, help_data['account_cue'])
+
+ def testGatherHelpData_ChildAccount(self):
+ """Display a warning when user is signed in to a child account."""
+ project = fake.Project(project_name='proj')
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/proj', project=project)
+ mr.auth.user_pb.linked_parent_id = 111
+ help_data = self.page_class.GatherHelpData(mr, {})
+ self.assertEqual(None, help_data['cue'])
+ self.assertEqual('switch_to_parent_account', help_data['account_cue'])
+ self.assertEqual('user@example.com', help_data['parent_email'])
+
+ def testGatherDebugData_Visibility(self):
+ project = fake.Project(
+ project_name='testtest', state=project_pb2.ProjectState.LIVE)
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/foo/servlet_path', project=project)
+ debug_data = self.page_class.GatherDebugData(mr, {})
+ self.assertEqual('off', debug_data['dbg'])
+
+ _request, mr = testing_helpers.GetRequestObjects(
+ path='/p/foo/servlet_path?debug=1', project=project)
+ debug_data = self.page_class.GatherDebugData(mr, {})
+ self.assertEqual('on', debug_data['dbg'])
+
+
+class ProjectIsRestrictedTest(unittest.TestCase):
+
+ def testNonRestrictedProject(self):
+ proj = project_pb2.Project()
+ mr = testing_helpers.MakeMonorailRequest()
+ mr.project = proj
+
+ proj.access = project_pb2.ProjectAccess.ANYONE
+ proj.state = project_pb2.ProjectState.LIVE
+ self.assertFalse(servlet._ProjectIsRestricted(mr))
+
+ proj.state = project_pb2.ProjectState.ARCHIVED
+ self.assertFalse(servlet._ProjectIsRestricted(mr))
+
+ def testRestrictedProject(self):
+ proj = project_pb2.Project()
+ mr = testing_helpers.MakeMonorailRequest()
+ mr.project = proj
+
+ proj.state = project_pb2.ProjectState.LIVE
+ proj.access = project_pb2.ProjectAccess.MEMBERS_ONLY
+ self.assertTrue(servlet._ProjectIsRestricted(mr))
+
+class VersionBaseTest(unittest.TestCase):
+
+ @mock.patch('settings.local_mode', True)
+ def testLocalhost(self):
+ request = webapp2.Request.blank('/', base_url='http://localhost:8080')
+ actual = servlet._VersionBaseURL(request)
+ expected = 'http://localhost:8080'
+ self.assertEqual(expected, actual)
+
+ @mock.patch('settings.local_mode', False)
+ @mock.patch('google.appengine.api.app_identity.get_default_version_hostname')
+ def testProd(self, mock_gdvh):
+ mock_gdvh.return_value = 'monorail-prod.appspot.com'
+ request = webapp2.Request.blank('/', base_url='https://bugs.chromium.org')
+ actual = servlet._VersionBaseURL(request)
+ expected = 'https://test-dot-monorail-prod.appspot.com'
+ self.assertEqual(expected, actual)
diff --git a/framework/test/sorting_test.py b/framework/test/sorting_test.py
new file mode 100644
index 0000000..4b1feb3
--- /dev/null
+++ b/framework/test/sorting_test.py
@@ -0,0 +1,360 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for sorting.py functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+# For convenient debugging
+import logging
+
+import mox
+
+from framework import sorting
+from framework import framework_views
+from proto import tracker_pb2
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+
+
+def MakeDescending(accessor):
+ return sorting._MaybeMakeDescending(accessor, True)
+
+
+class DescendingValueTest(unittest.TestCase):
+
+ def testMinString(self):
+ """When sorting desc, a min string will sort last instead of first."""
+ actual = sorting.DescendingValue.MakeDescendingValue(sorting.MIN_STRING)
+ self.assertEqual(sorting.MAX_STRING, actual)
+
+ def testMaxString(self):
+ """When sorting desc, a max string will sort first instead of last."""
+ actual = sorting.DescendingValue.MakeDescendingValue(sorting.MAX_STRING)
+ self.assertEqual(sorting.MIN_STRING, actual)
+
+ def testDescValues(self):
+ """The point of DescendingValue is to reverse the sort order."""
+ anti_a = sorting.DescendingValue.MakeDescendingValue('a')
+ anti_b = sorting.DescendingValue.MakeDescendingValue('b')
+ self.assertTrue(anti_a > anti_b)
+
+ def testMaybeMakeDescending(self):
+ """It returns an accessor that makes DescendingValue iff arg is True."""
+ asc_accessor = sorting._MaybeMakeDescending(lambda issue: 'a', False)
+ asc_value = asc_accessor('fake issue')
+ self.assertTrue(asc_value is 'a')
+
+ desc_accessor = sorting._MaybeMakeDescending(lambda issue: 'a', True)
+ print(desc_accessor)
+ desc_value = desc_accessor('fake issue')
+ self.assertTrue(isinstance(desc_value, sorting.DescendingValue))
+
+
+class SortingTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.default_cols = 'a b c'
+ self.builtin_cols = 'a b x y z'
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ self.config.component_defs.append(tracker_bizobj.MakeComponentDef(
+ 11, 789, 'Database', 'doc', False, [], [], 0, 0))
+ self.config.component_defs.append(tracker_bizobj.MakeComponentDef(
+ 22, 789, 'User Interface', 'doc', True, [], [], 0, 0))
+ self.config.component_defs.append(tracker_bizobj.MakeComponentDef(
+ 33, 789, 'Installer', 'doc', False, [], [], 0, 0))
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testMakeSingleSortKeyAccessor_Status(self):
+ """Sorting by status should create an accessor for that column."""
+ self.mox.StubOutWithMock(sorting, '_IndexOrLexical')
+ status_names = [wks.status for wks in self.config.well_known_statuses]
+ sorting._IndexOrLexical(status_names, 'status accessor')
+ self.mox.ReplayAll()
+
+ sorting._MakeSingleSortKeyAccessor(
+ 'status', self.config, {'status': 'status accessor'}, [], {}, [])
+ self.mox.VerifyAll()
+
+ def testMakeSingleSortKeyAccessor_Component(self):
+ """Sorting by component should create an accessor for that column."""
+ self.mox.StubOutWithMock(sorting, '_IndexListAccessor')
+ component_ids = [11, 33, 22]
+ sorting._IndexListAccessor(component_ids, 'component accessor')
+ self.mox.ReplayAll()
+
+ sorting._MakeSingleSortKeyAccessor(
+ 'component', self.config, {'component': 'component accessor'}, [], {}, [])
+ self.mox.VerifyAll()
+
+ def testMakeSingleSortKeyAccessor_OtherBuiltInColunms(self):
+ """Sorting a built-in column should create an accessor for that column."""
+ accessor = sorting._MakeSingleSortKeyAccessor(
+ 'buildincol', self.config, {'buildincol': 'accessor'}, [], {}, [])
+ self.assertEqual('accessor', accessor)
+
+ def testMakeSingleSortKeyAccessor_WithPostProcessor(self):
+ """Sorting a built-in user column should create a user accessor."""
+ self.mox.StubOutWithMock(sorting, '_MakeAccessorWithPostProcessor')
+ users_by_id = {111: 'fake user'}
+ sorting._MakeAccessorWithPostProcessor(
+ users_by_id, 'mock owner accessor', 'mock postprocessor')
+ self.mox.ReplayAll()
+
+ sorting._MakeSingleSortKeyAccessor(
+ 'owner', self.config, {'owner': 'mock owner accessor'},
+ {'owner': 'mock postprocessor'}, users_by_id, [])
+ self.mox.VerifyAll()
+
+ def testIndexOrLexical(self):
+ well_known_values = ['x-a', 'x-b', 'x-c', 'x-d']
+ art = 'this is a fake artifact'
+
+ # Case 1: accessor generates no values.
+ base_accessor = lambda art: None
+ accessor = sorting._IndexOrLexical(well_known_values, base_accessor)
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.DescendingValue(sorting.MAX_STRING),
+ neg_accessor(art))
+
+ # Case 2: accessor generates a value, but it is an empty value.
+ base_accessor = lambda art: ''
+ accessor = sorting._IndexOrLexical(well_known_values, base_accessor)
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.DescendingValue(sorting.MAX_STRING),
+ neg_accessor(art))
+
+ # Case 3: A single well-known value
+ base_accessor = lambda art: 'x-c'
+ accessor = sorting._IndexOrLexical(well_known_values, base_accessor)
+ self.assertEqual(2, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(-2, neg_accessor(art))
+
+ # Case 4: A single odd-ball value
+ base_accessor = lambda art: 'x-zzz'
+ accessor = sorting._IndexOrLexical(well_known_values, base_accessor)
+ self.assertEqual('x-zzz', accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(
+ sorting.DescendingValue('x-zzz'), neg_accessor(art))
+
+ def testIndexListAccessor_SomeWellKnownValues(self):
+ """Values sort according to their position in the well-known list."""
+ well_known_values = [11, 33, 22] # These represent component IDs.
+ art = fake.MakeTestIssue(789, 1, 'sum 1', 'New', 111)
+ base_accessor = lambda issue: issue.component_ids
+ accessor = sorting._IndexListAccessor(well_known_values, base_accessor)
+
+ # Case 1: accessor generates no values.
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.MAX_STRING, neg_accessor(art))
+
+ # Case 2: A single well-known value
+ art.component_ids = [33]
+ self.assertEqual([1], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([-1], neg_accessor(art))
+
+ # Case 3: Multiple well-known and odd-ball values
+ art.component_ids = [33, 11, 99]
+ self.assertEqual([0, 1, sorting.MAX_STRING], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.MAX_STRING, -1, 0],
+ neg_accessor(art))
+
+ def testIndexListAccessor_NoWellKnownValues(self):
+ """When there are no well-known values, all values sort last."""
+ well_known_values = [] # Nothing pre-defined, so everything is oddball
+ art = fake.MakeTestIssue(789, 1, 'sum 1', 'New', 111)
+ base_accessor = lambda issue: issue.component_ids
+ accessor = sorting._IndexListAccessor(well_known_values, base_accessor)
+
+ # Case 1: accessor generates no values.
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.MAX_STRING, neg_accessor(art))
+
+ # Case 2: A single oddball value
+ art.component_ids = [33]
+ self.assertEqual([sorting.MAX_STRING], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.MAX_STRING], neg_accessor(art))
+
+ # Case 3: Multiple odd-ball values
+ art.component_ids = [33, 11, 99]
+ self.assertEqual(
+ [sorting.MAX_STRING, sorting.MAX_STRING, sorting.MAX_STRING],
+ accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(
+ [sorting.MAX_STRING, sorting.MAX_STRING, sorting.MAX_STRING],
+ neg_accessor(art))
+
+ def testIndexOrLexicalList(self):
+ well_known_values = ['Pri-High', 'Pri-Med', 'Pri-Low']
+ art = fake.MakeTestIssue(789, 1, 'sum 1', 'New', 111, merged_into=200001)
+
+ # Case 1: accessor generates no values.
+ accessor = sorting._IndexOrLexicalList(well_known_values, [], 'pri', {})
+ self.assertEqual(sorting.MAX_STRING, accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(sorting.MAX_STRING, neg_accessor(art))
+
+ # Case 2: A single well-known value
+ art.labels = ['Pri-Med']
+ accessor = sorting._IndexOrLexicalList(well_known_values, [], 'pri', {})
+ self.assertEqual([1], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([-1], neg_accessor(art))
+
+ # Case 3: Multiple well-known and odd-ball values
+ art.labels = ['Pri-zzz', 'Pri-Med', 'yyy', 'Pri-High']
+ accessor = sorting._IndexOrLexicalList(well_known_values, [], 'pri', {})
+ self.assertEqual([0, 1, 'zzz'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.DescendingValue('zzz'), -1, 0],
+ neg_accessor(art))
+
+ # Case 4: Multi-part prefix.
+ well_known_values.extend(['X-Y-Header', 'X-Y-Footer'])
+ art.labels = ['X-Y-Footer', 'X-Y-Zone', 'X-Y-Header', 'X-Y-Area']
+ accessor = sorting._IndexOrLexicalList(well_known_values, [], 'x-y', {})
+ self.assertEqual([3, 4, 'area', 'zone'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.DescendingValue('zone'),
+ sorting.DescendingValue('area'), -4, -3],
+ neg_accessor(art))
+
+ def testIndexOrLexicalList_CustomFields(self):
+ art = fake.MakeTestIssue(789, 1, 'sum 2', 'New', 111)
+ art.labels = ['samename-value1']
+ art.field_values = [tracker_bizobj.MakeFieldValue(
+ 3, 6078, None, None, None, None, False)]
+
+ all_field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'samename', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'samename', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'notsamename', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'should get filtered out', False)
+ ]
+
+ accessor = sorting._IndexOrLexicalList([], all_field_defs, 'samename', {})
+ self.assertEqual([6078, 'value1'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(
+ [sorting.DescendingValue('value1'), -6078], neg_accessor(art))
+
+ def testIndexOrLexicalList_PhaseCustomFields(self):
+ art = fake.MakeTestIssue(789, 1, 'sum 2', 'New', 111)
+ art.labels = ['summer.goats-value1']
+ art.field_values = [
+ tracker_bizobj.MakeFieldValue(
+ 3, 33, None, None, None, None, False, phase_id=77),
+ tracker_bizobj.MakeFieldValue(
+ 3, 34, None, None, None, None, False, phase_id=77),
+ tracker_bizobj.MakeFieldValue(
+ 3, 1000, None, None, None, None, False, phase_id=78)]
+ art.phases = [tracker_pb2.Phase(phase_id=77, name='summer'),
+ tracker_pb2.Phase(phase_id=78, name='winter')]
+
+ all_field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'goats', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, True, None, None, None, False, None,
+ None, None, None, 'goats love mineral', False, is_phase_field=True),
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'boo', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'ahh', False),
+ ]
+
+ accessor = sorting._IndexOrLexicalList(
+ [], all_field_defs, 'summer.goats', {})
+ self.assertEqual([33, 34, 'value1'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual(
+ [sorting.DescendingValue('value1'), -34, -33], neg_accessor(art))
+
+ def testIndexOrLexicalList_ApprovalStatus(self):
+ art = fake.MakeTestIssue(789, 1, 'sum 2', 'New', 111)
+ art.labels = ['samename-value1']
+ art.approval_values = [tracker_pb2.ApprovalValue(approval_id=4)]
+
+ all_field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'samename', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'samename', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False)
+ ]
+
+ accessor = sorting._IndexOrLexicalList([], all_field_defs, 'samename', {})
+ self.assertEqual([0, 'value1'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.DescendingValue('value1'),
+ sorting.DescendingValue(0)],
+ neg_accessor(art))
+
+ def testIndexOrLexicalList_ApprovalApprover(self):
+ art = art = fake.MakeTestIssue(789, 1, 'sum 2', 'New', 111)
+ art.labels = ['samename-approver-value1']
+ art.approval_values = [
+ tracker_pb2.ApprovalValue(approval_id=4, approver_ids=[333])]
+
+ all_field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 4, 788, 'samename', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow spots', False)
+ ]
+ users_by_id = {333: framework_views.StuffUserView(333, 'a@test.com', True)}
+
+ accessor = sorting._IndexOrLexicalList(
+ [], all_field_defs, 'samename-approver', users_by_id)
+ self.assertEqual(['a@test.com', 'value1'], accessor(art))
+ neg_accessor = MakeDescending(accessor)
+ self.assertEqual([sorting.DescendingValue('value1'),
+ sorting.DescendingValue('a@test.com')],
+ neg_accessor(art))
+
+ def testComputeSortDirectives(self):
+ config = tracker_pb2.ProjectIssueConfig()
+ self.assertEqual(
+ ['project', 'id'], sorting.ComputeSortDirectives(config, '', ''))
+
+ self.assertEqual(
+ ['a', 'b', 'c', 'project', 'id'],
+ sorting.ComputeSortDirectives(config, '', 'a b C'))
+
+ config.default_sort_spec = 'id -reporter Owner'
+ self.assertEqual(
+ ['id', '-reporter', 'owner', 'project'],
+ sorting.ComputeSortDirectives(config, '', ''))
+
+ self.assertEqual(
+ ['x', '-b', 'a', 'c', '-owner', 'id', '-reporter', 'project'],
+ sorting.ComputeSortDirectives(config, 'x -b', 'A -b c -owner'))
diff --git a/framework/test/sql_test.py b/framework/test/sql_test.py
new file mode 100644
index 0000000..f073e24
--- /dev/null
+++ b/framework/test/sql_test.py
@@ -0,0 +1,681 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for the sql module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import mock
+import time
+import unittest
+
+import settings
+from framework import exceptions
+from framework import sql
+
+
+class MockSQLCnxn(object):
+ """This class mocks the connection and cursor classes."""
+
+ def __init__(self, instance, database):
+ self.instance = instance
+ self.database = database
+ self.last_executed = None
+ self.last_executed_args = None
+ self.result_rows = None
+ self.rowcount = 0
+ self.lastrowid = None
+ self.pool_key = instance + '/' + database
+ self.is_bad = False
+ self.has_uncommitted = False
+
+ def execute(self, stmt_str, args=None):
+ self.last_executed = stmt_str % tuple(args or [])
+ if not stmt_str.startswith(('SET', 'SELECT')):
+ self.has_uncommitted = True
+
+ def executemany(self, stmt_str, args):
+ # We cannot format the string because args has many values for each %s.
+ self.last_executed = stmt_str
+ self.last_executed_args = tuple(args)
+
+ # sql.py only calls executemany() for INSERT.
+ assert stmt_str.startswith('INSERT')
+ self.lastrowid = 123
+
+ def fetchall(self):
+ return self.result_rows
+
+ def cursor(self):
+ return self
+
+ def commit(self):
+ self.has_uncommitted = False
+
+ def close(self):
+ assert not self.has_uncommitted
+
+ def rollback(self):
+ self.has_uncommitted = False
+
+ def ping(self):
+ if self.is_bad:
+ raise BaseException('connection error!')
+
+
+sql.cnxn_ctor = MockSQLCnxn
+
+
+class ConnectionPoolingTest(unittest.TestCase):
+
+ def testGet(self):
+ pool_size = 2
+ num_dbs = 2
+ p = sql.ConnectionPool(pool_size)
+
+ for i in range(num_dbs):
+ for _ in range(pool_size):
+ c = p.get('test', 'db%d' % i)
+ self.assertIsNotNone(c)
+ p.release(c)
+
+ cnxn1 = p.get('test', 'db0')
+ q = p.queues[cnxn1.pool_key]
+ self.assertIs(q.qsize(), 0)
+
+ p.release(cnxn1)
+ self.assertIs(q.qsize(), pool_size - 1)
+ self.assertIs(q.full(), False)
+ self.assertIs(q.empty(), False)
+
+ cnxn2 = p.get('test', 'db0')
+ q = p.queues[cnxn2.pool_key]
+ self.assertIs(q.qsize(), 0)
+ self.assertIs(q.full(), False)
+ self.assertIs(q.empty(), True)
+
+ def testGetAndReturnPooledCnxn(self):
+ p = sql.ConnectionPool(2)
+
+ cnxn1 = p.get('test', 'db1')
+ self.assertIs(len(p.queues), 1)
+
+ cnxn2 = p.get('test', 'db2')
+ self.assertIs(len(p.queues), 2)
+
+ # Should use the existing pool.
+ cnxn3 = p.get('test', 'db1')
+ self.assertIs(len(p.queues), 2)
+
+ p.release(cnxn3)
+ p.release(cnxn2)
+
+ cnxn1.is_bad = True
+ p.release(cnxn1)
+ # cnxn1 should not be returned from the pool if we
+ # ask for a connection to its database.
+
+ cnxn4 = p.get('test', 'db1')
+
+ self.assertIsNot(cnxn1, cnxn4)
+ self.assertIs(len(p.queues), 2)
+ self.assertIs(cnxn4.is_bad, False)
+
+ def testGetAndReturnPooledCnxn_badCnxn(self):
+ p = sql.ConnectionPool(2)
+
+ cnxn1 = p.get('test', 'db1')
+ cnxn2 = p.get('test', 'db2')
+ cnxn3 = p.get('test', 'db1')
+
+ cnxn3.is_bad = True
+
+ p.release(cnxn3)
+ q = p.queues[cnxn3.pool_key]
+ self.assertIs(q.qsize(), 1)
+
+ with self.assertRaises(BaseException):
+ cnxn3 = p.get('test', 'db1')
+
+ q = p.queues[cnxn2.pool_key]
+ self.assertIs(q.qsize(), 0)
+ p.release(cnxn2)
+ self.assertIs(q.qsize(), 1)
+
+ p.release(cnxn1)
+ q = p.queues[cnxn1.pool_key]
+ self.assertIs(q.qsize(), 1)
+
+
+class MonorailConnectionTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = sql.MonorailConnection()
+ self.orig_local_mode = settings.local_mode
+ self.orig_num_logical_shards = settings.num_logical_shards
+ settings.local_mode = False
+
+ def tearDown(self):
+ settings.local_mode = self.orig_local_mode
+ settings.num_logical_shards = self.orig_num_logical_shards
+
+ def testGetPrimaryConnection(self):
+ sql_cnxn = self.cnxn.GetPrimaryConnection()
+ self.assertEqual(settings.db_instance, sql_cnxn.instance)
+ self.assertEqual(settings.db_database_name, sql_cnxn.database)
+
+ sql_cnxn2 = self.cnxn.GetPrimaryConnection()
+ self.assertIs(sql_cnxn2, sql_cnxn)
+
+ def testGetConnectionForShard(self):
+ sql_cnxn = self.cnxn.GetConnectionForShard(1)
+ replica_name = settings.db_replica_names[
+ 1 % len(settings.db_replica_names)]
+ self.assertEqual(settings.physical_db_name_format % replica_name,
+ sql_cnxn.instance)
+ self.assertEqual(settings.db_database_name, sql_cnxn.database)
+
+ sql_cnxn2 = self.cnxn.GetConnectionForShard(1)
+ self.assertIs(sql_cnxn2, sql_cnxn)
+
+ def testClose(self):
+ sql_cnxn = self.cnxn.GetPrimaryConnection()
+ self.cnxn.Close()
+ self.assertFalse(sql_cnxn.has_uncommitted)
+
+ def testExecute_Primary(self):
+ """Execute() with no shard passes the statement to the primary sql cnxn."""
+ sql_cnxn = self.cnxn.GetPrimaryConnection()
+ with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = self.cnxn.Execute('statement', [])
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(sql_cnxn, 'statement', [], commit=True)
+
+ def testExecute_Shard(self):
+ """Execute() with a shard passes the statement to the shard sql cnxn."""
+ shard_id = 1
+ sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
+ with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
+
+ def testExecute_Shard_Unavailable(self):
+ """If a shard is unavailable, we try the next one."""
+ shard_id = 1
+ sql_cnxn_1 = self.cnxn.GetConnectionForShard(shard_id)
+ sql_cnxn_2 = self.cnxn.GetConnectionForShard(shard_id + 1)
+
+ # Simulate a recent failure on shard 1.
+ self.cnxn.unavailable_shards[1] = int(time.time()) - 3
+
+ with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(sql_cnxn_2, 'statement', [], commit=True)
+
+ # Even a new MonorailConnection instance shares the same state.
+ other_cnxn = sql.MonorailConnection()
+ other_sql_cnxn_2 = other_cnxn.GetConnectionForShard(shard_id + 1)
+
+ with mock.patch.object(other_cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = other_cnxn.Execute('statement', [], shard_id=shard_id)
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(
+ other_sql_cnxn_2, 'statement', [], commit=True)
+
+ # Simulate an old failure on shard 1, allowing us to try using it again.
+ self.cnxn.unavailable_shards[1] = (
+ int(time.time()) - sql.BAD_SHARD_AVOIDANCE_SEC - 2)
+
+ with mock.patch.object(self.cnxn, '_ExecuteWithSQLConnection') as ewsc:
+ ewsc.return_value = 'db result'
+ actual_result = self.cnxn.Execute('statement', [], shard_id=shard_id)
+ self.assertEqual('db result', actual_result)
+ ewsc.assert_called_once_with(sql_cnxn_1, 'statement', [], commit=True)
+
+
+class TableManagerTest(unittest.TestCase):
+
+ def setUp(self):
+ self.emp_tbl = sql.SQLTableManager('Employee')
+ self.cnxn = sql.MonorailConnection()
+ self.primary_cnxn = self.cnxn.GetPrimaryConnection()
+
+ def testSelect_Trivial(self):
+ self.primary_cnxn.result_rows = [(111, True), (222, False)]
+ rows = self.emp_tbl.Select(self.cnxn)
+ self.assertEqual('SELECT * FROM Employee', self.primary_cnxn.last_executed)
+ self.assertEqual([(111, True), (222, False)], rows)
+
+ def testSelect_Conditions(self):
+ self.primary_cnxn.result_rows = [(111,)]
+ rows = self.emp_tbl.Select(
+ self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
+ self.assertEqual(
+ 'SELECT emp_id FROM Employee'
+ '\nWHERE dept_id IN (10,20)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual([(111,)], rows)
+
+ def testSelectRow(self):
+ self.primary_cnxn.result_rows = [(111,)]
+ row = self.emp_tbl.SelectRow(
+ self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[10, 20])
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id FROM Employee'
+ '\nWHERE dept_id IN (10,20)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual((111,), row)
+
+ def testSelectRow_NoMatches(self):
+ self.primary_cnxn.result_rows = []
+ row = self.emp_tbl.SelectRow(
+ self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99])
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id FROM Employee'
+ '\nWHERE dept_id IN (99)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual(None, row)
+
+ row = self.emp_tbl.SelectRow(
+ self.cnxn, cols=['emp_id'], fulltime=True, dept_id=[99],
+ default=(-1,))
+ self.assertEqual((-1,), row)
+
+ def testSelectValue(self):
+ self.primary_cnxn.result_rows = [(111,)]
+ val = self.emp_tbl.SelectValue(
+ self.cnxn, 'emp_id', fulltime=True, dept_id=[10, 20])
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id FROM Employee'
+ '\nWHERE dept_id IN (10,20)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual(111, val)
+
+ def testSelectValue_NoMatches(self):
+ self.primary_cnxn.result_rows = []
+ val = self.emp_tbl.SelectValue(
+ self.cnxn, 'emp_id', fulltime=True, dept_id=[99])
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id FROM Employee'
+ '\nWHERE dept_id IN (99)'
+ '\n AND fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual(None, val)
+
+ val = self.emp_tbl.SelectValue(
+ self.cnxn, 'emp_id', fulltime=True, dept_id=[99],
+ default=-1)
+ self.assertEqual(-1, val)
+
+ def testInsertRow(self):
+ self.primary_cnxn.rowcount = 1
+ generated_id = self.emp_tbl.InsertRow(self.cnxn, emp_id=111, fulltime=True)
+ self.assertEqual(
+ 'INSERT INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
+ self.assertEqual(([111, 1],), self.primary_cnxn.last_executed_args)
+ self.assertEqual(123, generated_id)
+
+ def testInsertRows_Empty(self):
+ generated_id = self.emp_tbl.InsertRows(
+ self.cnxn, ['emp_id', 'fulltime'], [])
+ self.assertIsNone(self.primary_cnxn.last_executed)
+ self.assertIsNone(self.primary_cnxn.last_executed_args)
+ self.assertEqual(None, generated_id)
+
+ def testInsertRows(self):
+ self.primary_cnxn.rowcount = 2
+ generated_ids = self.emp_tbl.InsertRows(
+ self.cnxn, ['emp_id', 'fulltime'], [(111, True), (222, False)])
+ self.assertEqual(
+ 'INSERT INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)', self.primary_cnxn.last_executed)
+ self.assertEqual(([111, 1], [222, 0]), self.primary_cnxn.last_executed_args)
+ self.assertEqual([], generated_ids)
+
+ def testUpdate(self):
+ self.primary_cnxn.rowcount = 2
+ rowcount = self.emp_tbl.Update(
+ self.cnxn, {'fulltime': True}, emp_id=[111, 222])
+ self.assertEqual(
+ 'UPDATE Employee SET fulltime=1'
+ '\nWHERE emp_id IN (111,222)', self.primary_cnxn.last_executed)
+ self.assertEqual(2, rowcount)
+
+ def testUpdate_Limit(self):
+ self.emp_tbl.Update(
+ self.cnxn, {'fulltime': True}, limit=8, emp_id=[111, 222])
+ self.assertEqual(
+ 'UPDATE Employee SET fulltime=1'
+ '\nWHERE emp_id IN (111,222)'
+ '\nLIMIT 8', self.primary_cnxn.last_executed)
+
+ def testIncrementCounterValue(self):
+ self.primary_cnxn.rowcount = 1
+ self.primary_cnxn.lastrowid = 9
+ new_counter_val = self.emp_tbl.IncrementCounterValue(
+ self.cnxn, 'years_worked', emp_id=111)
+ self.assertEqual(
+ 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + 1)'
+ '\nWHERE emp_id = 111', self.primary_cnxn.last_executed)
+ self.assertEqual(9, new_counter_val)
+
+ def testDelete(self):
+ self.primary_cnxn.rowcount = 1
+ rowcount = self.emp_tbl.Delete(self.cnxn, fulltime=True)
+ self.assertEqual(
+ 'DELETE FROM Employee'
+ '\nWHERE fulltime = 1', self.primary_cnxn.last_executed)
+ self.assertEqual(1, rowcount)
+
+ def testDelete_Limit(self):
+ self.emp_tbl.Delete(self.cnxn, fulltime=True, limit=3)
+ self.assertEqual(
+ 'DELETE FROM Employee'
+ '\nWHERE fulltime = 1'
+ '\nLIMIT 3', self.primary_cnxn.last_executed)
+
+
+class StatementTest(unittest.TestCase):
+
+ def testMakeSelect(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ stmt = sql.Statement.MakeSelect(
+ 'Employee', ['emp_id', 'fulltime'], distinct=True)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT DISTINCT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testMakeInsert(self):
+ stmt = sql.Statement.MakeInsert(
+ 'Employee', ['emp_id', 'fulltime'], [(111, True), (222, False)])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'INSERT INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)',
+ stmt_str)
+ self.assertEqual([[111, 1], [222, 0]], args)
+
+ stmt = sql.Statement.MakeInsert(
+ 'Employee', ['emp_id', 'fulltime'], [(111, False)], replace=True)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'INSERT INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)'
+ '\nON DUPLICATE KEY UPDATE '
+ 'emp_id=VALUES(emp_id), fulltime=VALUES(fulltime)',
+ stmt_str)
+ self.assertEqual([[111, 0]], args)
+
+ stmt = sql.Statement.MakeInsert(
+ 'Employee', ['emp_id', 'fulltime'], [(111, False)], ignore=True)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'INSERT IGNORE INTO Employee (emp_id, fulltime)'
+ '\nVALUES (%s,%s)',
+ stmt_str)
+ self.assertEqual([[111, 0]], args)
+
+ def testMakeInsert_InvalidString(self):
+ with self.assertRaises(exceptions.InputException):
+ sql.Statement.MakeInsert(
+ 'Employee', ['emp_id', 'name'], [(111, 'First \x00 Last')])
+
+ def testMakeUpdate(self):
+ stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'UPDATE Employee SET fulltime=%s',
+ stmt_str)
+ self.assertEqual([1], args)
+
+ def testMakeUpdate_InvalidString(self):
+ with self.assertRaises(exceptions.InputException):
+ sql.Statement.MakeUpdate('Employee', {'name': 'First \x00 Last'})
+
+ def testMakeIncrement(self):
+ stmt = sql.Statement.MakeIncrement('Employee', 'years_worked')
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
+ stmt_str)
+ self.assertEqual([1], args)
+
+ stmt = sql.Statement.MakeIncrement('Employee', 'years_worked', step=5)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'UPDATE Employee SET years_worked = LAST_INSERT_ID(years_worked + %s)',
+ stmt_str)
+ self.assertEqual([5], args)
+
+ def testMakeDelete(self):
+ stmt = sql.Statement.MakeDelete('Employee')
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'DELETE FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddUseClause(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddUseClause('USE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)')
+ stmt.AddOrderByTerms([('emp_id', [])])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nUSE INDEX (emp_id) USE INDEX FOR ORDER BY (emp_id)'
+ '\nORDER BY emp_id',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddJoinClause_Empty(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddJoinClauses([])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddJoinClause(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddJoinClauses([('CorporateHoliday', [])])
+ stmt.AddJoinClauses(
+ [('Product ON Project.inventor_id = emp_id', [])], left=True)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\n JOIN CorporateHoliday'
+ '\n LEFT JOIN Product ON Project.inventor_id = emp_id',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddGroupByTerms_Empty(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddGroupByTerms([])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddGroupByTerms(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddGroupByTerms(['dept_id', 'location_id'])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nGROUP BY dept_id, location_id',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddOrderByTerms_Empty(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddOrderByTerms([])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddOrderByTerms(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddOrderByTerms([('dept_id', []), ('emp_id DESC', [])])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nORDER BY dept_id, emp_id DESC',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testSetLimitAndOffset(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.SetLimitAndOffset(100, 0)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nLIMIT 100',
+ stmt_str)
+ self.assertEqual([], args)
+
+ stmt.SetLimitAndOffset(100, 500)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nLIMIT 100 OFFSET 500',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddWhereTerms_Select(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddWhereTerms([], emp_id=[111, 222])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nWHERE emp_id IN (%s,%s)',
+ stmt_str)
+ self.assertEqual([111, 222], args)
+
+ def testAddWhereTerms_Update(self):
+ stmt = sql.Statement.MakeUpdate('Employee', {'fulltime': True})
+ stmt.AddWhereTerms([], emp_id=[111, 222])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'UPDATE Employee SET fulltime=%s'
+ '\nWHERE emp_id IN (%s,%s)',
+ stmt_str)
+ self.assertEqual([1, 111, 222], args)
+
+ def testAddWhereTerms_Delete(self):
+ stmt = sql.Statement.MakeDelete('Employee')
+ stmt.AddWhereTerms([], emp_id=[111, 222])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'DELETE FROM Employee'
+ '\nWHERE emp_id IN (%s,%s)',
+ stmt_str)
+ self.assertEqual([111, 222], args)
+
+ def testAddWhereTerms_Empty(self):
+ """Add empty terms should have no effect."""
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddWhereTerms([])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee',
+ stmt_str)
+ self.assertEqual([], args)
+
+ def testAddWhereTerms_UpdateEmptyArray(self):
+ """Add empty array should throw an exception."""
+ stmt = sql.Statement.MakeUpdate('SpamVerdict', {'user_id': 1})
+ # See https://crbug.com/monorail/6735.
+ with self.assertRaises(exceptions.InputException):
+ stmt.AddWhereTerms([], user_id=[])
+ mock_log.assert_called_once_with('Invalid update DB value %r', 'user_id')
+
+ def testAddWhereTerms_MulitpleTerms(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddWhereTerms(
+ [('emp_id %% %s = %s', [2, 0])], fulltime=True, emp_id_not=222)
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nWHERE emp_id %% %s = %s'
+ '\n AND emp_id != %s'
+ '\n AND fulltime = %s',
+ stmt_str)
+ self.assertEqual([2, 0, 222, 1], args)
+
+ def testAddHavingTerms_NoGroupBy(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
+ self.assertRaises(AssertionError, stmt.Generate)
+
+ def testAddHavingTerms_WithGroupBy(self):
+ stmt = sql.Statement.MakeSelect('Employee', ['emp_id', 'fulltime'])
+ stmt.AddGroupByTerms(['dept_id', 'location_id'])
+ stmt.AddHavingTerms([('COUNT(*) > %s', [10])])
+ stmt_str, args = stmt.Generate()
+ self.assertEqual(
+ 'SELECT emp_id, fulltime FROM Employee'
+ '\nGROUP BY dept_id, location_id'
+ '\nHAVING COUNT(*) > %s',
+ stmt_str)
+ self.assertEqual([10], args)
+
+
+class FunctionsTest(unittest.TestCase):
+
+ def testIsValidDBValue_NonString(self):
+ self.assertTrue(sql._IsValidDBValue(12))
+ self.assertTrue(sql._IsValidDBValue(True))
+ self.assertTrue(sql._IsValidDBValue(False))
+ self.assertTrue(sql._IsValidDBValue(None))
+
+ def testIsValidDBValue_String(self):
+ self.assertTrue(sql._IsValidDBValue(''))
+ self.assertTrue(sql._IsValidDBValue('hello'))
+ self.assertTrue(sql._IsValidDBValue(u'hello'))
+ self.assertFalse(sql._IsValidDBValue('null \x00 byte'))
+
+ def testBoolsToInts_NoChanges(self):
+ self.assertEqual(['hello'], sql._BoolsToInts(['hello']))
+ self.assertEqual([['hello']], sql._BoolsToInts([['hello']]))
+ self.assertEqual([['hello']], sql._BoolsToInts([('hello',)]))
+ self.assertEqual([12], sql._BoolsToInts([12]))
+ self.assertEqual([[12]], sql._BoolsToInts([[12]]))
+ self.assertEqual([[12]], sql._BoolsToInts([(12,)]))
+ self.assertEqual(
+ [12, 13, 'hi', [99, 'yo']],
+ sql._BoolsToInts([12, 13, 'hi', [99, 'yo']]))
+
+ def testBoolsToInts_WithChanges(self):
+ self.assertEqual([1, 0], sql._BoolsToInts([True, False]))
+ self.assertEqual([[1, 0]], sql._BoolsToInts([[True, False]]))
+ self.assertEqual([[1, 0]], sql._BoolsToInts([(True, False)]))
+ self.assertEqual(
+ [12, 1, 'hi', [0, 'yo']],
+ sql._BoolsToInts([12, True, 'hi', [False, 'yo']]))
+
+ def testRandomShardID(self):
+ """A random shard ID must always be a valid shard ID."""
+ shard_id = sql.RandomShardID()
+ self.assertTrue(0 <= shard_id < settings.num_logical_shards)
diff --git a/framework/test/table_view_helpers_test.py b/framework/test/table_view_helpers_test.py
new file mode 100644
index 0000000..0260308
--- /dev/null
+++ b/framework/test/table_view_helpers_test.py
@@ -0,0 +1,753 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for table_view_helpers classes and functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import unittest
+import logging
+
+from framework import framework_views
+from framework import table_view_helpers
+from proto import tracker_pb2
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+
+
+EMPTY_SEARCH_RESULTS = []
+
+SEARCH_RESULTS_WITH_LABELS = [
+ fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111, labels='Priority-High Mstone-1',
+ merged_into=200001, star_count=1),
+ fake.MakeTestIssue(
+ 789, 2, 'sum 2', 'New', 111, labels='Priority-High Mstone-1',
+ merged_into=1, star_count=1),
+ fake.MakeTestIssue(
+ 789, 3, 'sum 3', 'New', 111, labels='Priority-Low Mstone-1.1',
+ merged_into=1, star_count=1),
+ # 'Visibility-Super-High' tests that only first dash counts
+ fake.MakeTestIssue(
+ 789, 4, 'sum 4', 'New', 111, labels='Visibility-Super-High',
+ star_count=1),
+ ]
+
+
+def MakeTestIssue(local_id, issue_id, summary):
+ issue = tracker_pb2.Issue()
+ issue.local_id = local_id
+ issue.issue_id = issue_id
+ issue.summary = summary
+ return issue
+
+
+class TableCellTest(unittest.TestCase):
+
+ USERS_BY_ID = {}
+
+ def setUp(self):
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ self.config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 1, 789, 'Goats', tracker_pb2.FieldTypes.INT_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'Num of Goats in the season', False, is_phase_field=True),
+ tracker_bizobj.MakeFieldDef(
+ 2, 789, 'DogNames', tracker_pb2.FieldTypes.STR_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'good dog names', False),
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'Approval', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'Tracks review from cows', False)
+ ]
+ self.config.approval_defs = [tracker_pb2.ApprovalDef(approval_id=3)]
+ self.issue = MakeTestIssue(
+ local_id=1, issue_id=100001, summary='One')
+ self.issue.field_values = [
+ tracker_bizobj.MakeFieldValue(
+ 1, 34, None, None, None, None, False, phase_id=23),
+ tracker_bizobj.MakeFieldValue(
+ 1, 35, None, None, None, None, False, phase_id=24),
+ tracker_bizobj.MakeFieldValue(
+ 2, None, 'Waffles', None, None, None, False),
+ ]
+ self.issue.phases = [
+ tracker_pb2.Phase(phase_id=23, name='winter'),
+ tracker_pb2.Phase(phase_id=24, name='summer')]
+ self.issue.approval_values = [
+ tracker_pb2.ApprovalValue(
+ approval_id=3, approver_ids=[111, 222, 333])]
+ self.users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', False),
+ 222: framework_views.StuffUserView(222, 'foo2@example.com', True),
+ }
+
+ self.summary_table_cell_kws = {
+ 'col': None,
+ 'users_by_id': {},
+ 'non_col_labels': [('lab', False)],
+ 'label_values': {},
+ 'related_issues': {},
+ 'config': 'fake_config',
+ }
+
+ def testTableCellSummary(self):
+ """TableCellSummary stores the data given to it."""
+ cell = table_view_helpers.TableCellSummary(
+ MakeTestIssue(4, 4, 'Lame default summary.'),
+ **self.summary_table_cell_kws)
+ self.assertEqual(cell.type, table_view_helpers.CELL_TYPE_SUMMARY)
+ self.assertEqual(cell.values[0].item, 'Lame default summary.')
+ self.assertEqual(cell.non_column_labels[0].value, 'lab')
+
+ def testTableCellSummary_NoPythonEscaping(self):
+ """TableCellSummary stores the summary without escaping it in python."""
+ cell = table_view_helpers.TableCellSummary(
+ MakeTestIssue(4, 4, '<b>bold</b> "summary".'),
+ **self.summary_table_cell_kws)
+ self.assertEqual(cell.values[0].item,'<b>bold</b> "summary".')
+
+ def testTableCellCustom_normal(self):
+ """TableCellCustom stores the value of a custom FieldValue."""
+ cell_dognames = table_view_helpers.TableCellCustom(
+ self.issue, col='dognames', config=self.config)
+ self.assertEqual(cell_dognames.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(cell_dognames.values[0].item, 'Waffles')
+
+ def testTableCellCustom_phasefields(self):
+ """TableCellCustom stores the value of a custom FieldValue."""
+ cell_winter = table_view_helpers.TableCellCustom(
+ self.issue, col='winter.goats', config=self.config)
+ self.assertEqual(cell_winter.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(cell_winter.values[0].item, 34)
+
+ cell_summer = table_view_helpers.TableCellCustom(
+ self.issue, col='summer.goats', config=self.config)
+ self.assertEqual(cell_summer.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(cell_summer.values[0].item, 35)
+
+ def testTableCellApprovalStatus(self):
+ """TableCellApprovalStatus stores the status of an ApprovalValue."""
+ cell = table_view_helpers.TableCellApprovalStatus(
+ self.issue, col='Approval', config=self.config)
+ self.assertEqual(cell.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(cell.values[0].item, 'NOT_SET')
+
+ def testTableCellApprovalApprover(self):
+ """TableCellApprovalApprover stores the approvers of an ApprovalValue."""
+ cell = table_view_helpers.TableCellApprovalApprover(
+ self.issue, col='Approval-approver', config=self.config,
+ users_by_id=self.users_by_id)
+ self.assertEqual(cell.type, table_view_helpers.CELL_TYPE_ATTR)
+ self.assertEqual(len(cell.values), 2)
+ self.assertItemsEqual([cell.values[0].item, cell.values[1].item],
+ ['foo@example.com', 'f...@example.com'])
+
+ # TODO(jrobbins): TableCellProject, TableCellStars
+
+
+
+class TableViewHelpersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.default_cols = 'a b c'
+ self.builtin_cols = ['a', 'b', 'x', 'y', 'z']
+ self.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ def testComputeUnshownColumns_CommonCase(self):
+ shown_cols = ['a', 'b', 'c']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(
+ unshown, ['Mstone', 'Priority', 'Visibility', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_MoreBuiltins(self):
+ shown_cols = ['a', 'b', 'c', 'x', 'y']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['Mstone', 'Priority', 'Visibility', 'z'])
+
+ def testComputeUnshownColumns_NotAllDefaults(self):
+ shown_cols = ['a', 'b']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['c', 'x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(
+ unshown, ['Mstone', 'Priority', 'Visibility', 'c', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_ExtraNonDefaults(self):
+ shown_cols = ['a', 'b', 'c', 'd', 'e', 'f']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(
+ unshown, ['Mstone', 'Priority', 'Visibility', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_UserColumnsShown(self):
+ shown_cols = ['a', 'b', 'c', 'Priority']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['Mstone', 'Visibility', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_EverythingShown(self):
+ shown_cols = [
+ 'a', 'b', 'c', 'x', 'y', 'z', 'Priority', 'Mstone', 'Visibility']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, [])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, [])
+
+ def testComputeUnshownColumns_NothingShown(self):
+ shown_cols = []
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = self.default_cols
+ config.well_known_labels = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(unshown, ['a', 'b', 'c', 'x', 'y', 'z'])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, self.builtin_cols)
+ self.assertEqual(
+ unshown,
+ ['Mstone', 'Priority', 'Visibility', 'a', 'b', 'c', 'x', 'y', 'z'])
+
+ def testComputeUnshownColumns_NoBuiltins(self):
+ shown_cols = ['a', 'b', 'c']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = ''
+ config.well_known_labels = []
+ builtin_cols = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ EMPTY_SEARCH_RESULTS, shown_cols, config, builtin_cols)
+ self.assertEqual(unshown, [])
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ SEARCH_RESULTS_WITH_LABELS, shown_cols, config, builtin_cols)
+ self.assertEqual(unshown, ['Mstone', 'Priority', 'Visibility'])
+
+ def testComputeUnshownColumns_FieldDefs(self):
+ search_results = [
+ fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111,
+ field_values=[
+ tracker_bizobj.MakeFieldValue(
+ 5, 74, None, None, None, None, False, phase_id=4),
+ tracker_bizobj.MakeFieldValue(
+ 6, 78, None, None, None, None, False, phase_id=5)],
+ phases=[
+ tracker_pb2.Phase(phase_id=4, name='goats'),
+ tracker_pb2.Phase(phase_id=5, name='sheep')]),
+ fake.MakeTestIssue(
+ 789, 2, 'sum 2', 'New', 111,
+ field_values=[
+ tracker_bizobj.MakeFieldValue(
+ 5, 74, None, None, None, None, False, phase_id=3),
+ tracker_bizobj.MakeFieldValue(
+ 6, 77, None, None, None, None, False, phase_id=3)],
+ phases=[
+ tracker_pb2.Phase(phase_id=3, name='Goats'),
+ tracker_pb2.Phase(phase_id=3, name='Goats-Exp')]),
+ ]
+
+ shown_cols = ['a', 'b', 'a1', 'a2-approver', 'f3', 'goats.g1', 'sheep.g2']
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.default_col_spec = ''
+ config.well_known_labels = []
+ config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 1, 789, 'a1', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'Tracks review from cows', False),
+ tracker_bizobj.MakeFieldDef(
+ 2, 789, 'a2', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'Tracks review from chickens', False),
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'f3', tracker_pb2.FieldTypes.STR_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'cow names', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 789, 'f4', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'chicken gobbles', False),
+ tracker_bizobj.MakeFieldDef(
+ 5, 789, 'g1', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'fluff', False, is_phase_field=True),
+ tracker_bizobj.MakeFieldDef(
+ 6, 789, 'g2', tracker_pb2.FieldTypes.INT_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'poof', False, is_phase_field=True),
+ ]
+ builtin_cols = []
+
+ unshown = table_view_helpers.ComputeUnshownColumns(
+ search_results, shown_cols, config, builtin_cols)
+ self.assertEqual(unshown, [
+ 'a1-approver', 'a2', 'f4',
+ 'goats-exp.g1', 'goats-exp.g2', 'goats.g2', 'sheep.g1'])
+
+ def testExtractUniqueValues_NoColumns(self):
+ column_values = table_view_helpers.ExtractUniqueValues(
+ [], SEARCH_RESULTS_WITH_LABELS, {}, self.config, {})
+ self.assertEqual([], column_values)
+
+ def testExtractUniqueValues_NoResults(self):
+ cols = ['type', 'priority', 'owner', 'status', 'stars', 'attachments']
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, EMPTY_SEARCH_RESULTS, {}, self.config, {})
+ self.assertEqual(6, len(column_values))
+ for index, col in enumerate(cols):
+ self.assertEqual(index, column_values[index].col_index)
+ self.assertEqual(col, column_values[index].column_name)
+ self.assertEqual([], column_values[index].filter_values)
+
+ def testExtractUniqueValues_ExplicitResults(self):
+ cols = ['priority', 'owner', 'status', 'stars', 'mstone', 'foo']
+ users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', True),
+ }
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, SEARCH_RESULTS_WITH_LABELS, users_by_id, self.config, {})
+ self.assertEqual(len(cols), len(column_values))
+
+ self.assertEqual('priority', column_values[0].column_name)
+ self.assertEqual(['High', 'Low'], column_values[0].filter_values)
+
+ self.assertEqual('owner', column_values[1].column_name)
+ self.assertEqual(['f...@example.com'], column_values[1].filter_values)
+
+ self.assertEqual('status', column_values[2].column_name)
+ self.assertEqual(['New'], column_values[2].filter_values)
+
+ self.assertEqual('stars', column_values[3].column_name)
+ self.assertEqual([1], column_values[3].filter_values)
+
+ self.assertEqual('mstone', column_values[4].column_name)
+ self.assertEqual(['1', '1.1'], column_values[4].filter_values)
+
+ self.assertEqual('foo', column_values[5].column_name)
+ self.assertEqual([], column_values[5].filter_values)
+
+ # self.assertEquals('mergedinto', column_values[6].column_name)
+ # self.assertEquals(
+ # ['1', 'other-project:1'], column_values[6].filter_values)
+
+ def testExtractUniqueValues_CombinedColumns(self):
+ cols = ['priority/pri', 'owner', 'status', 'stars', 'mstone/milestone']
+ users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', True),
+ }
+ issue = fake.MakeTestIssue(
+ 789, 5, 'sum 5', 'New', 111, merged_into=200001,
+ labels='Priority-High Pri-0 Milestone-1.0 mstone-1',
+ star_count=15)
+
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, SEARCH_RESULTS_WITH_LABELS + [issue], users_by_id,
+ self.config, {})
+ self.assertEqual(5, len(column_values))
+
+ self.assertEqual('priority/pri', column_values[0].column_name)
+ self.assertEqual(['0', 'High', 'Low'], column_values[0].filter_values)
+
+ self.assertEqual('owner', column_values[1].column_name)
+ self.assertEqual(['f...@example.com'], column_values[1].filter_values)
+
+ self.assertEqual('status', column_values[2].column_name)
+ self.assertEqual(['New'], column_values[2].filter_values)
+
+ self.assertEqual('stars', column_values[3].column_name)
+ self.assertEqual([1, 15], column_values[3].filter_values)
+
+ self.assertEqual('mstone/milestone', column_values[4].column_name)
+ self.assertEqual(['1', '1.0', '1.1'], column_values[4].filter_values)
+
+ def testExtractUniqueValues_DerivedValues(self):
+ cols = ['priority', 'milestone', 'owner', 'status']
+ users_by_id = {
+ 111: framework_views.StuffUserView(111, 'foo@example.com', True),
+ 222: framework_views.StuffUserView(222, 'bar@example.com', True),
+ 333: framework_views.StuffUserView(333, 'lol@example.com', True),
+ }
+ search_results = [
+ fake.MakeTestIssue(
+ 789, 1, 'sum 1', '', 111, labels='Priority-High Milestone-1.0',
+ derived_labels='Milestone-2.0 Foo', derived_status='Started'),
+ fake.MakeTestIssue(
+ 789, 2, 'sum 2', 'New', 111, labels='Priority-High Milestone-1.0',
+ derived_owner_id=333), # Not seen because of owner_id
+ fake.MakeTestIssue(
+ 789, 3, 'sum 3', 'New', 0, labels='Priority-Low Milestone-1.1',
+ derived_owner_id=222),
+ ]
+
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, search_results, users_by_id, self.config, {})
+ self.assertEqual(4, len(column_values))
+
+ self.assertEqual('priority', column_values[0].column_name)
+ self.assertEqual(['High', 'Low'], column_values[0].filter_values)
+
+ self.assertEqual('milestone', column_values[1].column_name)
+ self.assertEqual(['1.0', '1.1', '2.0'], column_values[1].filter_values)
+
+ self.assertEqual('owner', column_values[2].column_name)
+ self.assertEqual(
+ ['b...@example.com', 'f...@example.com'],
+ column_values[2].filter_values)
+
+ self.assertEqual('status', column_values[3].column_name)
+ self.assertEqual(['New', 'Started'], column_values[3].filter_values)
+
+ def testExtractUniqueValues_ColumnsRobustness(self):
+ cols = ['reporter', 'cc', 'owner', 'status', 'attachments']
+ search_results = [
+ tracker_pb2.Issue(),
+ ]
+ column_values = table_view_helpers.ExtractUniqueValues(
+ cols, search_results, {}, self.config, {})
+
+ self.assertEqual(5, len(column_values))
+ for col_val in column_values:
+ if col_val.column_name == 'attachments':
+ self.assertEqual([0], col_val.filter_values)
+ else:
+ self.assertEqual([], col_val.filter_values)
+
+ def testMakeTableData_Empty(self):
+ visible_results = []
+ lower_columns = []
+ cell_factories = {}
+ table_data = table_view_helpers.MakeTableData(
+ visible_results, [], lower_columns, lower_columns,
+ cell_factories, [], 'unused function', {}, set(), self.config)
+ self.assertEqual([], table_data)
+
+ lower_columns = ['type', 'priority', 'summary', 'stars']
+ cell_factories = {
+ 'summary': table_view_helpers.TableCellSummary,
+ 'stars': table_view_helpers.TableCellStars,
+ }
+
+ table_data = table_view_helpers.MakeTableData(
+ visible_results, [], lower_columns, [], {},
+ cell_factories, 'unused function', {}, set(), self.config)
+ self.assertEqual([], table_data)
+
+ def testMakeTableData_Normal(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111, labels='Type-Defect Priority-Medium')
+ visible_results = [art]
+ lower_columns = ['type', 'priority', 'summary', 'stars']
+ cell_factories = {
+ 'summary': table_view_helpers.TableCellSummary,
+ 'stars': table_view_helpers.TableCellStars,
+ }
+
+ table_data = table_view_helpers.MakeTableData(
+ visible_results, [], lower_columns, lower_columns, {},
+ cell_factories, lambda art: 'id', {}, set(), self.config)
+ self.assertEqual(1, len(table_data))
+ row = table_data[0]
+ self.assertEqual(4, len(row.cells))
+ self.assertEqual('Defect', row.cells[0].values[0].item)
+
+ def testMakeTableData_Groups(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111, labels='Type-Defect Priority-Medium')
+ visible_results = [art]
+ lower_columns = ['type', 'priority', 'summary', 'stars']
+ lower_group_by = ['priority']
+ cell_factories = {
+ 'summary': table_view_helpers.TableCellSummary,
+ 'stars': table_view_helpers.TableCellStars,
+ }
+
+ table_data = table_view_helpers.MakeTableData(
+ visible_results, [], lower_columns, lower_group_by, {},
+ cell_factories, lambda art: 'id', {}, set(), self.config)
+ self.assertEqual(1, len(table_data))
+ row = table_data[0]
+ self.assertEqual(1, len(row.group.cells))
+ self.assertEqual('Medium', row.group.cells[0].values[0].item)
+
+ def testMakeRowData(self):
+ art = fake.MakeTestIssue(
+ 789, 1, 'sum 1', 'New', 111, labels='Type-Defect Priority-Medium',
+ star_count=1)
+ columns = ['type', 'priority', 'summary', 'stars']
+
+ cell_factories = [table_view_helpers.TableCellKeyLabels,
+ table_view_helpers.TableCellKeyLabels,
+ table_view_helpers.TableCellSummary,
+ table_view_helpers.TableCellStars]
+
+ # a result is an table_view_helpers.TableRow object with a "cells" field
+ # containing a list of table_view_helpers.TableCell objects.
+ result = table_view_helpers.MakeRowData(
+ art, columns, {}, cell_factories, {}, set(), self.config, {})
+
+ self.assertEqual(len(columns), len(result.cells))
+
+ for i in range(len(columns)):
+ cell = result.cells[i]
+ self.assertEqual(i, cell.col_index)
+
+ self.assertEqual(table_view_helpers.CELL_TYPE_ATTR, result.cells[0].type)
+ self.assertEqual('Defect', result.cells[0].values[0].item)
+ self.assertFalse(result.cells[0].values[0].is_derived)
+
+ self.assertEqual(table_view_helpers.CELL_TYPE_ATTR, result.cells[1].type)
+ self.assertEqual('Medium', result.cells[1].values[0].item)
+ self.assertFalse(result.cells[1].values[0].is_derived)
+
+ self.assertEqual(
+ table_view_helpers.CELL_TYPE_SUMMARY, result.cells[2].type)
+ self.assertEqual('sum 1', result.cells[2].values[0].item)
+ self.assertFalse(result.cells[2].values[0].is_derived)
+
+ self.assertEqual(table_view_helpers.CELL_TYPE_ATTR, result.cells[3].type)
+ self.assertEqual(1, result.cells[3].values[0].item)
+ self.assertFalse(result.cells[3].values[0].is_derived)
+
+ def testAccumulateLabelValues_Empty(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ [], [], label_values, non_col_labels)
+ self.assertEqual({}, label_values)
+ self.assertEqual([], non_col_labels)
+
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ [], ['Type', 'Priority'], label_values, non_col_labels)
+ self.assertEqual({}, label_values)
+ self.assertEqual([], non_col_labels)
+
+ def testAccumulateLabelValues_OneWordLabels(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['HelloThere'], [], label_values, non_col_labels)
+ self.assertEqual({}, label_values)
+ self.assertEqual([('HelloThere', False)], non_col_labels)
+
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['HelloThere'], [], label_values, non_col_labels, is_derived=True)
+ self.assertEqual({}, label_values)
+ self.assertEqual([('HelloThere', True)], non_col_labels)
+
+ def testAccumulateLabelValues_KeyValueLabels(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['Type-Defect', 'Milestone-Soon'], ['type', 'milestone'],
+ label_values, non_col_labels)
+ self.assertEqual(
+ {'type': [('Defect', False)],
+ 'milestone': [('Soon', False)]},
+ label_values)
+ self.assertEqual([], non_col_labels)
+
+ def testAccumulateLabelValues_MultiValueLabels(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['OS-Mac', 'OS-Linux'], ['os', 'arch'],
+ label_values, non_col_labels)
+ self.assertEqual(
+ {'os': [('Mac', False), ('Linux', False)]},
+ label_values)
+ self.assertEqual([], non_col_labels)
+
+ def testAccumulateLabelValues_MultiPartLabels(self):
+ label_values, non_col_labels = collections.defaultdict(list), []
+ table_view_helpers._AccumulateLabelValues(
+ ['OS-Mac-Server', 'OS-Mac-Laptop'], ['os', 'os-mac'],
+ label_values, non_col_labels)
+ self.assertEqual(
+ {'os': [('Mac-Server', False), ('Mac-Laptop', False)],
+ 'os-mac': [('Server', False), ('Laptop', False)],
+ },
+ label_values)
+ self.assertEqual([], non_col_labels)
+
+ def testChooseCellFactory(self):
+ """We choose the right kind of table cell for the specified column."""
+ cell_factories = {
+ 'summary': table_view_helpers.TableCellSummary,
+ 'stars': table_view_helpers.TableCellStars,
+ }
+ os_fd = tracker_bizobj.MakeFieldDef(
+ 1, 789, 'os', tracker_pb2.FieldTypes.ENUM_TYPE, None, None, False,
+ False, False, None, None, None, False, None, None, None, None,
+ 'Operating system', False)
+ deadline_fd = tracker_bizobj.MakeFieldDef(
+ 2, 789, 'deadline', tracker_pb2.FieldTypes.DATE_TYPE, None, None, False,
+ False, False, None, None, None, False, None, None, None, None,
+ 'Deadline to resolve issue', False)
+ approval_fd = tracker_bizobj.MakeFieldDef(
+ 3, 789, 'CowApproval', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False,
+ False, False, None, None, None, False, None, None, None, None,
+ 'Tracks reviews from cows', False)
+ goats_fd = tracker_bizobj.MakeFieldDef(
+ 4, 789, 'goats', tracker_pb2.FieldTypes.INT_TYPE, None, None, False,
+ False, False, None, None, None, False, None, None, None, None,
+ 'Num goats in each phase', False, is_phase_field=True)
+ self.config.field_defs = [os_fd, deadline_fd, approval_fd, goats_fd]
+
+ # The column is defined in cell_factories.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'summary', cell_factories, self.config)
+ self.assertEqual(table_view_helpers.TableCellSummary, actual)
+
+ # The column is a composite column.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'summary/stars', cell_factories, self.config)
+ self.assertEqual('FactoryClass', actual.__name__)
+
+ # The column is a enum custom field, so it is treated like a label.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'os', cell_factories, self.config)
+ self.assertEqual(table_view_helpers.TableCellKeyLabels, actual)
+
+ # The column is a non-enum custom field.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'deadline', cell_factories, self.config)
+ self.assertEqual(
+ [(table_view_helpers.TableCellCustom, 'deadline'),
+ (table_view_helpers.TableCellKeyLabels, 'deadline')],
+ actual.factory_col_list)
+
+ # The column is an approval custom field.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'CowApproval', cell_factories, self.config)
+ self.assertEqual(
+ [(table_view_helpers.TableCellApprovalStatus, 'CowApproval'),
+ (table_view_helpers.TableCellKeyLabels, 'CowApproval')],
+ actual.factory_col_list)
+
+ # The column is an approval custom field with '-approver'.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'CowApproval-approver', cell_factories, self.config)
+ self.assertEqual(
+ [(table_view_helpers.TableCellApprovalApprover, 'CowApproval-approver'),
+ (table_view_helpers.TableCellKeyLabels, 'CowApproval-approver')],
+ actual.factory_col_list)
+
+ # The column specifies a phase custom field.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'winter.goats', cell_factories, self.config)
+ self.assertEqual(
+ [(table_view_helpers.TableCellCustom, 'winter.goats'),
+ (table_view_helpers.TableCellKeyLabels, 'winter.goats')],
+ actual.factory_col_list)
+
+
+ # Column that don't match one of the other cases is assumed to be a label.
+ actual = table_view_helpers.ChooseCellFactory(
+ 'reward', cell_factories, self.config)
+ self.assertEqual(table_view_helpers.TableCellKeyLabels, actual)
+
+ def testCompositeFactoryTableCell_Empty(self):
+ """If we made a composite of zero columns, it would have no values."""
+ composite = table_view_helpers.CompositeFactoryTableCell([])
+ cell = composite('artifact')
+ self.assertEqual([], cell.values)
+
+ def testCompositeFactoryTableCell_Normal(self):
+ """If we make a composite, it has values from each of the sub cells."""
+ composite = table_view_helpers.CompositeFactoryTableCell(
+ [(sub_factory_1, 'col1'),
+ (sub_factory_2, 'col2')])
+
+ cell = composite('artifact')
+ self.assertEqual(
+ ['sub_cell_1_col1',
+ 'sub_cell_2_col2'],
+ cell.values)
+
+ def testCompositeColTableCell_Empty(self):
+ """If we made a composite of zero columns, it would have no values."""
+ composite = table_view_helpers.CompositeColTableCell([], {}, self.config)
+ cell = composite('artifact')
+ self.assertEqual([], cell.values)
+
+
+ def testCompositeColTableCell_Normal(self):
+ """If we make a composite, it has values from each of the sub cells."""
+ composite = table_view_helpers.CompositeColTableCell(
+ ['col1', 'col2'],
+ {'col1': sub_factory_1, 'col2': sub_factory_2},
+ self.config)
+ cell = composite('artifact')
+ self.assertEqual(
+ ['sub_cell_1_col1',
+ 'sub_cell_2_col2'],
+ cell.values)
+
+
+def sub_factory_1(_art, **kw):
+ return testing_helpers.Blank(
+ values=['sub_cell_1_%s' % kw['col']],
+ non_column_labels=[])
+
+
+def sub_factory_2(_art, **kw):
+ return testing_helpers.Blank(
+ values=['sub_cell_2_%s' % kw['col']],
+ non_column_labels=[])
diff --git a/framework/test/template_helpers_test.py b/framework/test/template_helpers_test.py
new file mode 100644
index 0000000..85296fa
--- /dev/null
+++ b/framework/test/template_helpers_test.py
@@ -0,0 +1,216 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for template_helpers module."""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import unittest
+
+from framework import pbproxy_test_pb2
+from framework import template_helpers
+
+
+class HelpersUnitTest(unittest.TestCase):
+
+ def testDictionaryProxy(self):
+
+ # basic in 'n out test
+ item = template_helpers.EZTItem(label='foo', group_name='bar')
+
+ self.assertEqual('foo', item.label)
+ self.assertEqual('bar', item.group_name)
+
+ # be sure the __str__ returns the fields
+ self.assertEqual(
+ "EZTItem({'group_name': 'bar', 'label': 'foo'})", str(item))
+
+ def testPBProxy(self):
+ """Checks that PBProxy wraps protobuf objects as expected."""
+ # check that protobuf fields are accessible in ".attribute" form
+ pbe = pbproxy_test_pb2.PBProxyExample()
+ pbe.nickname = 'foo'
+ pbe.invited = False
+ pbep = template_helpers.PBProxy(pbe)
+ self.assertEqual(pbep.nickname, 'foo')
+ # _bool suffix converts protobuf field 'bar' to None (EZT boolean false)
+ self.assertEqual(pbep.invited_bool, None)
+
+ # check that a new field can be added to the PBProxy
+ pbep.baz = 'bif'
+ self.assertEqual(pbep.baz, 'bif')
+
+ # check that a PBProxy-local field can hide a protobuf field
+ pbep.nickname = 'local foo'
+ self.assertEqual(pbep.nickname, 'local foo')
+
+ # check that a nested protobuf is recursively wrapped with a PBProxy
+ pbn = pbproxy_test_pb2.PBProxyNested()
+ pbn.nested = pbproxy_test_pb2.PBProxyExample()
+ pbn.nested.nickname = 'bar'
+ pbn.nested.invited = True
+ pbnp = template_helpers.PBProxy(pbn)
+ self.assertEqual(pbnp.nested.nickname, 'bar')
+ # _bool suffix converts protobuf field 'bar' to 'yes' (EZT boolean true)
+ self.assertEqual(pbnp.nested.invited_bool, 'yes')
+
+ # check that 'repeated' lists of items produce a list of strings
+ pbn.multiple_strings.append('1')
+ pbn.multiple_strings.append('2')
+ self.assertEqual(pbnp.multiple_strings, ['1', '2'])
+
+ # check that 'repeated' messages produce lists of PBProxy instances
+ pbe1 = pbproxy_test_pb2.PBProxyExample()
+ pbn.multiple_pbes.append(pbe1)
+ pbe1.nickname = '1'
+ pbe1.invited = True
+ pbe2 = pbproxy_test_pb2.PBProxyExample()
+ pbn.multiple_pbes.append(pbe2)
+ pbe2.nickname = '2'
+ pbe2.invited = False
+ self.assertEqual(pbnp.multiple_pbes[0].nickname, '1')
+ self.assertEqual(pbnp.multiple_pbes[0].invited_bool, 'yes')
+ self.assertEqual(pbnp.multiple_pbes[1].nickname, '2')
+ self.assertEqual(pbnp.multiple_pbes[1].invited_bool, None)
+
+ def testFitTextMethods(self):
+ """Tests both FitUnsafeText with an eye on i18n."""
+ # pylint: disable=anomalous-unicode-escape-in-string
+ test_data = (
+ u'This is a short string.',
+
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. '
+ u'This is a much longer string. ',
+
+ # This is a short escaped i18n string
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab'.decode('utf-8'),
+
+ # This is a longer i18n string
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab '
+ '\xe6\x88\x91\xe8\x83\xbd\xe5\x90\x9e '
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab '
+ '\xe6\x88\x91\xe8\x83\xbd\xe5\x90\x9e '
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab '
+ '\xe6\x88\x91\xe8\x83\xbd\xe5\x90\x9e '
+ '\xd5\xa1\xd5\xba\xd5\xa1\xd5\xaf\xd5\xab '
+ '\xe6\x88\x91\xe8\x83\xbd\xe5\x90\x9e '.decode('utf-8'),
+
+ # This is a longer i18n string that was causing trouble.
+ '\u041d\u0430 \u0431\u0435\u0440\u0435\u0433\u0443'
+ ' \u043f\u0443\u0441\u0442\u044b\u043d\u043d\u044b\u0445'
+ ' \u0432\u043e\u043b\u043d \u0421\u0442\u043e\u044f\u043b'
+ ' \u043e\u043d, \u0434\u0443\u043c'
+ ' \u0432\u0435\u043b\u0438\u043a\u0438\u0445'
+ ' \u043f\u043e\u043b\u043d, \u0418'
+ ' \u0432\u0434\u0430\u043b\u044c'
+ ' \u0433\u043b\u044f\u0434\u0435\u043b.'
+ ' \u041f\u0440\u0435\u0434 \u043d\u0438\u043c'
+ ' \u0448\u0438\u0440\u043e\u043a\u043e'
+ ' \u0420\u0435\u043a\u0430'
+ ' \u043d\u0435\u0441\u043b\u0430\u0441\u044f;'
+ ' \u0431\u0435\u0434\u043d\u044b\u0439'
+ ' \u0447\u0451\u043b\u043d \u041f\u043e'
+ ' \u043d\u0435\u0439'
+ ' \u0441\u0442\u0440\u0435\u043c\u0438\u043b\u0441\u044f'
+ ' \u043e\u0434\u0438\u043d\u043e\u043a\u043e.'
+ ' \u041f\u043e \u043c\u0448\u0438\u0441\u0442\u044b\u043c,'
+ ' \u0442\u043e\u043f\u043a\u0438\u043c'
+ ' \u0431\u0435\u0440\u0435\u0433\u0430\u043c'
+ ' \u0427\u0435\u0440\u043d\u0435\u043b\u0438'
+ ' \u0438\u0437\u0431\u044b \u0437\u0434\u0435\u0441\u044c'
+ ' \u0438 \u0442\u0430\u043c, \u041f\u0440\u0438\u044e\u0442'
+ ' \u0443\u0431\u043e\u0433\u043e\u0433\u043e'
+ ' \u0447\u0443\u0445\u043e\u043d\u0446\u0430;'
+ ' \u0418 \u043b\u0435\u0441,'
+ ' \u043d\u0435\u0432\u0435\u0434\u043e\u043c\u044b\u0439'
+ ' \u043b\u0443\u0447\u0430\u043c \u0412'
+ ' \u0442\u0443\u043c\u0430\u043d\u0435'
+ ' \u0441\u043f\u0440\u044f\u0442\u0430\u043d\u043d\u043e'
+ '\u0433\u043e \u0441\u043e\u043b\u043d\u0446\u0430,'
+ ' \u041a\u0440\u0443\u0433\u043e\u043c'
+ ' \u0448\u0443\u043c\u0435\u043b.'.decode('utf-8'))
+
+ for unicode_s in test_data:
+ # Get the length in characters, not bytes.
+ length = len(unicode_s)
+
+ # Test the FitUnsafeText method at the length boundary.
+ fitted_unsafe_text = template_helpers.FitUnsafeText(unicode_s, length)
+ self.assertEqual(fitted_unsafe_text, unicode_s)
+
+ # Set some values that test FitString well.
+ available_space = length // 2
+ max_trailing = length // 4
+ # Break the string at various places - symmetric range around 0
+ for i in range(1-max_trailing, max_trailing):
+ # Test the FitUnsafeText method.
+ fitted_unsafe_text = template_helpers.FitUnsafeText(
+ unicode_s, available_space - i)
+ self.assertEqual(fitted_unsafe_text[:available_space - i],
+ unicode_s[:available_space - i])
+
+ # Test a string that is already unicode
+ u_string = u'This is already unicode'
+ fitted_unsafe_text = template_helpers.FitUnsafeText(u_string, 100)
+ self.assertEqual(u_string, fitted_unsafe_text)
+
+ # Test a string that is already unicode, and has non-ascii in it.
+ u_string = u'This is already unicode este\\u0301tico'
+ fitted_unsafe_text = template_helpers.FitUnsafeText(u_string, 100)
+ self.assertEqual(u_string, fitted_unsafe_text)
+
+ def testEZTError(self):
+ errors = template_helpers.EZTError()
+ self.assertFalse(errors.AnyErrors())
+
+ errors.error_a = 'A'
+ self.assertTrue(errors.AnyErrors())
+ self.assertEqual('A', errors.error_a)
+
+ errors.SetError('error_b', 'B')
+ self.assertTrue(errors.AnyErrors())
+ self.assertEqual('A', errors.error_a)
+ self.assertEqual('B', errors.error_b)
+
+ def testBytesKbOrMb(self):
+ self.assertEqual('1023 bytes', template_helpers.BytesKbOrMb(1023))
+ self.assertEqual('1.0 KB', template_helpers.BytesKbOrMb(1024))
+ self.assertEqual('1023 KB', template_helpers.BytesKbOrMb(1024 * 1023))
+ self.assertEqual('1.0 MB', template_helpers.BytesKbOrMb(1024 * 1024))
+ self.assertEqual('98.0 MB', template_helpers.BytesKbOrMb(98 * 1024 * 1024))
+ self.assertEqual('99 MB', template_helpers.BytesKbOrMb(99 * 1024 * 1024))
+
+
+class TextRunTest(unittest.TestCase):
+
+ def testLink(self):
+ run = template_helpers.TextRun(
+ 'content', tag='a', href='http://example.com')
+ expected = '<a href="http://example.com">content</a>'
+ self.assertEqual(expected, run.FormatForHTMLEmail())
+
+ run = template_helpers.TextRun(
+ 'con<tent>', tag='a', href='http://exa"mple.com')
+ expected = '<a href="http://exa"mple.com">con<tent></a>'
+ self.assertEqual(expected, run.FormatForHTMLEmail())
+
+ def testText(self):
+ run = template_helpers.TextRun('content')
+ expected = 'content'
+ self.assertEqual(expected, run.FormatForHTMLEmail())
+
+ run = template_helpers.TextRun('con<tent>')
+ expected = 'con<tent>'
+ self.assertEqual(expected, run.FormatForHTMLEmail())
diff --git a/framework/test/timestr_test.py b/framework/test/timestr_test.py
new file mode 100644
index 0000000..ad11249
--- /dev/null
+++ b/framework/test/timestr_test.py
@@ -0,0 +1,95 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unittest for timestr module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import calendar
+import datetime
+import time
+import unittest
+
+from framework import timestr
+
+
+class TimeStrTest(unittest.TestCase):
+ """Unit tests for timestr routines."""
+
+ def testFormatAbsoluteDate(self):
+ now = datetime.datetime(2008, 1, 1)
+
+ def GetDate(*args):
+ date = datetime.datetime(*args)
+ return timestr.FormatAbsoluteDate(
+ calendar.timegm(date.utctimetuple()), clock=lambda: now)
+
+ self.assertEqual(GetDate(2008, 1, 1), 'Today')
+ self.assertEqual(GetDate(2007, 12, 31), 'Yesterday')
+ self.assertEqual(GetDate(2007, 12, 30), 'Dec 30')
+ self.assertEqual(GetDate(2007, 1, 1), 'Jan 2007')
+ self.assertEqual(GetDate(2007, 1, 2), 'Jan 2007')
+ self.assertEqual(GetDate(2007, 12, 31), 'Yesterday')
+ self.assertEqual(GetDate(2006, 12, 31), 'Dec 2006')
+ self.assertEqual(GetDate(2007, 7, 1), 'Jul 1')
+ self.assertEqual(GetDate(2007, 6, 30), 'Jun 2007')
+ self.assertEqual(GetDate(2008, 1, 3), 'Jan 2008')
+
+ # Leap year fun
+ now = datetime.datetime(2008, 3, 1)
+ self.assertEqual(GetDate(2008, 2, 29), 'Yesterday')
+
+ # Clock skew
+ now = datetime.datetime(2008, 1, 1, 23, 59, 59)
+ self.assertEqual(GetDate(2008, 1, 2), 'Today')
+ now = datetime.datetime(2007, 12, 31, 23, 59, 59)
+ self.assertEqual(GetDate(2008, 1, 1), 'Today')
+ self.assertEqual(GetDate(2008, 1, 2), 'Jan 2008')
+
+ def testFormatRelativeDate(self):
+ now = time.mktime(datetime.datetime(2008, 1, 1).timetuple())
+
+ def TestSecsAgo(secs_ago, expected, expected_days_only):
+ test_time = now - secs_ago
+ actual = timestr.FormatRelativeDate(
+ test_time, clock=lambda: now)
+ self.assertEqual(actual, expected)
+ actual_days_only = timestr.FormatRelativeDate(
+ test_time, clock=lambda: now, days_only=True)
+ self.assertEqual(actual_days_only, expected_days_only)
+
+ TestSecsAgo(10 * 24 * 60 * 60, '', '10 days ago')
+ TestSecsAgo(5 * 24 * 60 * 60 - 1, '4 days ago', '4 days ago')
+ TestSecsAgo(5 * 60 * 60 - 1, '4 hours ago', '')
+ TestSecsAgo(5 * 60 - 1, '4 minutes ago', '')
+ TestSecsAgo(2 * 60 - 1, '1 minute ago', '')
+ TestSecsAgo(60 - 1, 'moments ago', '')
+ TestSecsAgo(0, 'moments ago', '')
+ TestSecsAgo(-10, 'moments ago', '')
+ TestSecsAgo(-100, '', '')
+
+ def testGetHumanScaleDate(self):
+ """Tests GetHumanScaleDate()."""
+ now = time.mktime(datetime.datetime(2008, 4, 10, 20, 50, 30).timetuple())
+
+ def GetDate(*args):
+ date = datetime.datetime(*args)
+ timestamp = time.mktime(date.timetuple())
+ return timestr.GetHumanScaleDate(timestamp, now=now)
+
+ self.assertEqual(GetDate(2008, 4, 10, 15), ('Today', '5 hours ago'))
+ self.assertEqual(GetDate(2008, 4, 10, 19, 55), ('Today', '55 min ago'))
+ self.assertEqual(GetDate(2008, 4, 10, 20, 48, 35), ('Today', '1 min ago'))
+ self.assertEqual(GetDate(2008, 4, 10, 20, 49, 35), ('Today', 'moments ago'))
+ self.assertEqual(GetDate(2008, 4, 10, 20, 50, 55), ('Today', 'moments ago'))
+ self.assertEqual(GetDate(2008, 4, 9, 15), ('Yesterday', '29 hours ago'))
+ self.assertEqual(GetDate(2008, 4, 5, 15), ('Last 7 days', 'Apr 05, 2008'))
+ self.assertEqual(GetDate(2008, 3, 22, 15), ('Last 30 days', 'Mar 22, 2008'))
+ self.assertEqual(
+ GetDate(2008, 1, 2, 15), ('Earlier this year', 'Jan 02, 2008'))
+ self.assertEqual(
+ GetDate(2007, 12, 31, 15), ('Before this year', 'Dec 31, 2007'))
+ self.assertEqual(GetDate(2008, 4, 11, 20, 49, 35), ('Future', 'Later'))
diff --git a/framework/test/ts_mon_js_test.py b/framework/test/ts_mon_js_test.py
new file mode 100644
index 0000000..bcd4060
--- /dev/null
+++ b/framework/test/ts_mon_js_test.py
@@ -0,0 +1,73 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for MonorailTSMonJSHandler."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import json
+import unittest
+from mock import patch
+
+import webapp2
+from google.appengine.ext import testbed
+
+from framework.ts_mon_js import MonorailTSMonJSHandler
+from services import service_manager
+
+
+class MonorailTSMonJSHandlerTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_user_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ @patch('framework.xsrf.ValidateToken')
+ @patch('time.time')
+ def testSubmitMetrics(self, _mockTime, _mockValidateToken):
+ """Test normal case POSTing metrics."""
+ _mockTime.return_value = 1537821859
+ req = webapp2.Request.blank('/_/ts_mon_js')
+ req.body = json.dumps({
+ 'metrics': [{
+ 'MetricInfo': {
+ 'Name': 'monorail/frontend/issue_update_latency',
+ 'ValueType': 2,
+ },
+ 'Cells': [{
+ 'value': {
+ 'sum': 1234,
+ 'count': 4321,
+ 'buckets': {
+ 0: 123,
+ 1: 321,
+ 2: 213,
+ },
+ },
+ 'fields': {
+ 'client_id': '789',
+ 'host_name': 'rutabaga',
+ 'document_visible': True,
+ },
+ 'start_time': 1537821859 - 60,
+ }],
+ }],
+ })
+ res = webapp2.Response()
+ ts_mon_handler = MonorailTSMonJSHandler(request=req, response=res)
+ class MockApp(object):
+ def __init__(self):
+ self.config = {'services': service_manager.Services()}
+ ts_mon_handler.app = MockApp()
+
+ ts_mon_handler.post()
+
+ self.assertEqual(res.status_int, 201)
+ self.assertEqual(res.body, 'Ok.')
diff --git a/framework/test/validate_test.py b/framework/test/validate_test.py
new file mode 100644
index 0000000..9ea17fe
--- /dev/null
+++ b/framework/test/validate_test.py
@@ -0,0 +1,128 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""This file provides unit tests for Validate functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from framework import validate
+
+
+class ValidateUnitTest(unittest.TestCase):
+ """Set of unit tests for validation functions."""
+
+ GOOD_EMAIL_ADDRESSES = [
+ 'user@example.com',
+ 'user@e.com',
+ 'user+tag@example.com',
+ 'u.ser@example.com',
+ 'us.er@example.com',
+ 'u.s.e.r@example.com',
+ 'user@ex-ample.com',
+ 'user@ex.ample.com',
+ 'user@e.x.ample.com',
+ 'user@exampl.e.com',
+ 'user@e-x-ample.com',
+ 'user@e-x-a-m-p-l-e.com',
+ 'user@e-x.am-ple.com',
+ 'user@e--xample.com',
+ ]
+
+ BAD_EMAIL_ADDRESSES = [
+ ' leading.whitespace@example.com',
+ 'trailing.whitespace@example.com ',
+ '(paren.quoted@example.com)',
+ '<angle.quoted@example.com>',
+ 'trailing.@example.com',
+ 'trailing.dot.@example.com',
+ '.leading@example.com',
+ '.leading.dot@example.com',
+ 'user@example.com.',
+ 'us..er@example.com',
+ 'user@ex..ample.com',
+ 'user@example..com',
+ 'user@ex-.ample.com',
+ 'user@-example.com',
+ 'user@.example.com',
+ 'user@example-.com',
+ 'user@example',
+ 'user@example.',
+ 'user@example.c',
+ 'user@example.comcomcomc',
+ 'user@example.co-m',
+ 'user@exa_mple.com',
+ 'user@exa-_mple.com',
+ 'user@example.c0m',
+ ]
+
+ def testIsValidEmail(self):
+ """Tests the Email validator class."""
+ for email in self.GOOD_EMAIL_ADDRESSES:
+ self.assertTrue(validate.IsValidEmail(email), msg='Rejected:%r' % email)
+
+ for email in self.BAD_EMAIL_ADDRESSES:
+ self.assertFalse(validate.IsValidEmail(email), msg='Accepted:%r' % email)
+
+ def testIsValidMailTo(self):
+ for email in self.GOOD_EMAIL_ADDRESSES:
+ self.assertTrue(
+ validate.IsValidMailTo('mailto:' + email),
+ msg='Rejected:%r' % ('mailto:' + email))
+
+ for email in self.BAD_EMAIL_ADDRESSES:
+ self.assertFalse(
+ validate.IsValidMailTo('mailto:' + email),
+ msg='Accepted:%r' % ('mailto:' + email))
+
+ GOOD_URLS = [
+ 'http://google.com',
+ 'http://maps.google.com/',
+ 'https://secure.protocol.com',
+ 'https://dash-domain.com',
+ 'http://www.google.com/search?q=foo&hl=en',
+ 'https://a.very.long.domain.name.net/with/a/long/path/inf0/too',
+ 'http://funny.ws/',
+ 'http://we.love.anchors.info/page.html#anchor',
+ 'http://redundant-slashes.com//in/path//info',
+ 'http://trailingslashe.com/in/path/info/',
+ 'http://domain.with.port.com:8080',
+ 'http://domain.with.port.com:8080/path/info',
+ 'ftp://ftp.gnu.org',
+ 'ftp://some.server.some.place.com',
+ 'http://b/123456',
+ 'http://cl/123456/',
+ ]
+
+ BAD_URLS = [
+ ' http://leading.whitespace.com',
+ 'http://trailing.domain.whitespace.com ',
+ 'http://trailing.whitespace.com/after/path/info ',
+ 'http://underscore_domain.com/',
+ 'http://space in domain.com',
+ 'http://user@example.com', # standard, but we purposely don't accept it.
+ 'http://user:pass@ex.com', # standard, but we purposely don't accept it.
+ 'http://:password@ex.com', # standard, but we purposely don't accept it.
+ 'missing-http.com',
+ 'http:missing-slashes.com',
+ 'http:/only-one-slash.com',
+ 'http://trailing.dot.',
+ 'mailto:bad.scheme',
+ 'javascript:attempt-to-inject',
+ 'http://short-with-no-final-slash',
+ 'http:///',
+ 'http:///no.host.name',
+ 'http://:8080/',
+ 'http://badport.com:808a0/ ',
+ ]
+
+ def testURL(self):
+ for url in self.GOOD_URLS:
+ self.assertTrue(validate.IsValidURL(url), msg='Rejected:%r' % url)
+
+ for url in self.BAD_URLS:
+ self.assertFalse(validate.IsValidURL(url), msg='Accepted:%r' % url)
diff --git a/framework/test/warmup_test.py b/framework/test/warmup_test.py
new file mode 100644
index 0000000..d8ddb65
--- /dev/null
+++ b/framework/test/warmup_test.py
@@ -0,0 +1,36 @@
+# Copyright 2017 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the warmup servlet."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from testing import testing_helpers
+
+from framework import sql
+from framework import warmup
+from services import service_manager
+
+
+class WarmupTest(unittest.TestCase):
+
+ def setUp(self):
+ #self.cache_manager = cachemanager_svc.CacheManager()
+ #self.services = service_manager.Services(
+ # cache_manager=self.cache_manager)
+ self.services = service_manager.Services()
+ self.servlet = warmup.Warmup(
+ 'req', 'res', services=self.services)
+
+
+ def testHandleRequest_NothingToDo(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ actual_json_data = self.servlet.HandleRequest(mr)
+ self.assertEqual(
+ {'success': 1},
+ actual_json_data)
diff --git a/framework/test/xsrf_test.py b/framework/test/xsrf_test.py
new file mode 100644
index 0000000..aa04570
--- /dev/null
+++ b/framework/test/xsrf_test.py
@@ -0,0 +1,113 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for XSRF utility functions."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import unittest
+
+from mock import patch
+
+from google.appengine.ext import testbed
+
+import settings
+from framework import xsrf
+
+
+class XsrfTest(unittest.TestCase):
+ """Set of unit tests for blocking XSRF attacks."""
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+ self.testbed.init_datastore_v3_stub()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testGenerateToken_AnonUserGetsAToken(self):
+ self.assertNotEqual('', xsrf.GenerateToken(0, '/path'))
+
+ def testGenerateToken_DifferentUsersGetDifferentTokens(self):
+ self.assertNotEqual(
+ xsrf.GenerateToken(111, '/path'),
+ xsrf.GenerateToken(222, '/path'))
+
+ self.assertNotEqual(
+ xsrf.GenerateToken(111, '/path'),
+ xsrf.GenerateToken(0, '/path'))
+
+ def testGenerateToken_DifferentPathsGetDifferentTokens(self):
+ self.assertNotEqual(
+ xsrf.GenerateToken(111, '/path/one'),
+ xsrf.GenerateToken(111, '/path/two'))
+
+ def testValidToken(self):
+ token = xsrf.GenerateToken(111, '/path')
+ xsrf.ValidateToken(token, 111, '/path') # no exception raised
+
+ def testMalformedToken(self):
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, 'bad', 111, '/path')
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, '', 111, '/path')
+
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, '098a08fe08b08c08a05e:9721973123', 111, '/path')
+
+ def testWrongUser(self):
+ token = xsrf.GenerateToken(111, '/path')
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, token, 222, '/path')
+
+ def testWrongPath(self):
+ token = xsrf.GenerateToken(111, '/path/one')
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, token, 111, '/path/two')
+
+ @patch('time.time')
+ def testValidateToken_Expiration(self, mockTime):
+ test_time = 1526671379
+ mockTime.return_value = test_time
+ token = xsrf.GenerateToken(111, '/path')
+ xsrf.ValidateToken(token, 111, '/path')
+
+ mockTime.return_value = test_time + 1
+ xsrf.ValidateToken(token, 111, '/path')
+
+ mockTime.return_value = test_time + xsrf.TOKEN_TIMEOUT_SEC
+ xsrf.ValidateToken(token, 111, '/path')
+
+ mockTime.return_value = test_time + xsrf.TOKEN_TIMEOUT_SEC + 1
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, token, 11, '/path')
+
+ @patch('time.time')
+ def testValidateToken_Future(self, mockTime):
+ """We reject tokens from the future."""
+ test_time = 1526671379
+ mockTime.return_value = test_time
+ token = xsrf.GenerateToken(111, '/path')
+ xsrf.ValidateToken(token, 111, '/path')
+
+ # The clock of the GAE instance doing the checking might be slightly slow.
+ mockTime.return_value = test_time - 1
+ xsrf.ValidateToken(token, 111, '/path')
+
+ # But, if the difference is too much, someone is trying to fake a token.
+ mockTime.return_value = test_time - xsrf.CLOCK_SKEW_SEC - 1
+ self.assertRaises(
+ xsrf.TokenIncorrect,
+ xsrf.ValidateToken, token, 111, '/path')
diff --git a/framework/timestr.py b/framework/timestr.py
new file mode 100644
index 0000000..2b32e8c
--- /dev/null
+++ b/framework/timestr.py
@@ -0,0 +1,188 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Time-to-string and time-from-string routines."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import calendar
+import datetime
+import time
+
+
+class Error(Exception):
+ """Exception used to indicate problems with time routines."""
+ pass
+
+
+HTML_TIME_FMT = '%a, %d %b %Y %H:%M:%S GMT'
+HTML_DATE_WIDGET_FORMAT = '%Y-%m-%d'
+
+MONTH_YEAR_FMT = '%b %Y'
+MONTH_DAY_FMT = '%b %d'
+MONTH_DAY_YEAR_FMT = '%b %d %Y'
+
+# We assume that all server clocks are synchronized within this amount.
+MAX_CLOCK_SKEW_SEC = 30
+
+
+def TimeForHTMLHeader(when=None):
+ """Return the given time (or now) in HTML header format."""
+ if when is None:
+ when = int(time.time())
+ return time.strftime(HTML_TIME_FMT, time.gmtime(when))
+
+
+def TimestampToDateWidgetStr(when):
+ """Format a timestamp int for use by HTML <input type="date">."""
+ return time.strftime(HTML_DATE_WIDGET_FORMAT, time.gmtime(when))
+
+
+def DateWidgetStrToTimestamp(val_str):
+ """Parse the HTML <input type="date"> string into a timestamp int."""
+ return int(calendar.timegm(time.strptime(val_str, HTML_DATE_WIDGET_FORMAT)))
+
+
+def FormatAbsoluteDate(
+ timestamp, clock=datetime.datetime.utcnow,
+ recent_format=MONTH_DAY_FMT, old_format=MONTH_YEAR_FMT):
+ """Format timestamp like 'Sep 5', or 'Yesterday', or 'Today'.
+
+ Args:
+ timestamp: Seconds since the epoch in UTC.
+ clock: callable that returns a datetime.datetime object when called with no
+ arguments, giving the current time to use when computing what to display.
+ recent_format: Format string to pass to strftime to present dates between
+ six months ago and yesterday.
+ old_format: Format string to pass to strftime to present dates older than
+ six months or more than skew_tolerance in the future.
+
+ Returns:
+ If timestamp's date is today, "Today". If timestamp's date is yesterday,
+ "Yesterday". If timestamp is within six months before today, return the
+ time as formatted by recent_format. Otherwise, return the time as formatted
+ by old_format.
+ """
+ ts = datetime.datetime.utcfromtimestamp(timestamp)
+ now = clock()
+ month_delta = 12 * now.year + now.month - (12 * ts.year + ts.month)
+ delta = now - ts
+
+ if ts > now:
+ # If the time is slightly in the future due to clock skew, treat as today.
+ skew_tolerance = datetime.timedelta(seconds=MAX_CLOCK_SKEW_SEC)
+ if -delta <= skew_tolerance:
+ return 'Today'
+ # Otherwise treat it like an old date.
+ else:
+ fmt = old_format
+ elif month_delta > 6 or delta.days >= 365:
+ fmt = old_format
+ elif delta.days == 1:
+ return 'Yesterday'
+ elif delta.days == 0:
+ return 'Today'
+ else:
+ fmt = recent_format
+
+ return time.strftime(fmt, time.gmtime(timestamp)).replace(' 0', ' ')
+
+
+def FormatRelativeDate(timestamp, days_only=False, clock=None):
+ """Return a short string that makes timestamp more meaningful to the user.
+
+ Describe the timestamp relative to the current time, e.g., '4
+ hours ago'. In cases where the timestamp is more than 6 days ago,
+ we return '' so that an alternative display can be used instead.
+
+ Args:
+ timestamp: Seconds since the epoch in UTC.
+ days_only: If True, return 'N days ago' even for more than 6 days.
+ clock: optional function to return an int time, like int(time.time()).
+
+ Returns:
+ String describing relative time.
+ """
+ if clock:
+ now = clock()
+ else:
+ now = int(time.time())
+
+ # TODO(jrobbins): i18n of date strings
+ delta = int(now - timestamp)
+ d_minutes = delta // 60
+ d_hours = d_minutes // 60
+ d_days = d_hours // 24
+ if days_only:
+ if d_days > 1:
+ return '%s days ago' % d_days
+ else:
+ return ''
+
+ if d_days > 6:
+ return ''
+ if d_days > 1:
+ return '%s days ago' % d_days # starts at 2 days
+ if d_hours > 1:
+ return '%s hours ago' % d_hours # starts at 2 hours
+ if d_minutes > 1:
+ return '%s minutes ago' % d_minutes
+ if d_minutes > 0:
+ return '1 minute ago'
+ if delta > -MAX_CLOCK_SKEW_SEC:
+ return 'moments ago'
+ return ''
+
+
+def GetHumanScaleDate(timestamp, now=None):
+ """Formats a timestamp to a course-grained and fine-grained time phrase.
+
+ Args:
+ timestamp: Seconds since the epoch in UTC.
+ now: Current time in seconds since the epoch in UTC.
+
+ Returns:
+ A pair (course_grain, fine_grain) where course_grain is a string
+ such as 'Today', 'Yesterday', etc.; and fine_grained is a string describing
+ relative hours for Today and Yesterday, or an exact date for longer ago.
+ """
+ if now is None:
+ now = int(time.time())
+
+ now_year = datetime.datetime.fromtimestamp(now).year
+ then_year = datetime.datetime.fromtimestamp(timestamp).year
+ delta = int(now - timestamp)
+ delta_minutes = delta // 60
+ delta_hours = delta_minutes // 60
+ delta_days = delta_hours // 24
+
+ if 0 <= delta_hours < 24:
+ if delta_hours > 1:
+ return 'Today', '%s hours ago' % delta_hours
+ if delta_minutes > 1:
+ return 'Today', '%s min ago' % delta_minutes
+ if delta_minutes > 0:
+ return 'Today', '1 min ago'
+ if delta > 0:
+ return 'Today', 'moments ago'
+ if 0 <= delta_hours < 48:
+ return 'Yesterday', '%s hours ago' % delta_hours
+ if 0 <= delta_days < 7:
+ return 'Last 7 days', time.strftime(
+ '%b %d, %Y', (time.localtime(timestamp)))
+ if 0 <= delta_days < 30:
+ return 'Last 30 days', time.strftime(
+ '%b %d, %Y', (time.localtime(timestamp)))
+ if delta > 0:
+ if now_year == then_year:
+ return 'Earlier this year', time.strftime(
+ '%b %d, %Y', (time.localtime(timestamp)))
+ return ('Before this year',
+ time.strftime('%b %d, %Y', (time.localtime(timestamp))))
+ if delta > -MAX_CLOCK_SKEW_SEC:
+ return 'Today', 'moments ago'
+ # Only say something is in the future if it is more than just clock skew.
+ return 'Future', 'Later'
diff --git a/framework/trimvisitedpages.py b/framework/trimvisitedpages.py
new file mode 100644
index 0000000..8d6ec23
--- /dev/null
+++ b/framework/trimvisitedpages.py
@@ -0,0 +1,19 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Classes to handle cron requests to trim users' hotlists/issues visited."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+from framework import jsonfeed
+
+class TrimVisitedPages(jsonfeed.InternalTask):
+
+ """Look for users with more than 10 visited hotlists and deletes extras."""
+
+ def HandleRequest(self, mr):
+ """Delete old RecentHotlist2User rows when there are too many"""
+ self.services.user.TrimUserVisitedHotlists(mr.cnxn)
diff --git a/framework/ts_mon_js.py b/framework/ts_mon_js.py
new file mode 100644
index 0000000..61be1a8
--- /dev/null
+++ b/framework/ts_mon_js.py
@@ -0,0 +1,110 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""ts_mon JavaScript proxy handler."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+from framework import authdata
+from framework import sql
+from framework import xsrf
+
+from gae_ts_mon.handlers import TSMonJSHandler
+
+from google.appengine.api import users
+
+from infra_libs import ts_mon
+
+
+STANDARD_FIELDS = [
+ ts_mon.StringField('client_id'),
+ ts_mon.StringField('host_name'),
+ ts_mon.BooleanField('document_visible'),
+]
+
+
+# User action metrics.
+ISSUE_CREATE_LATENCY_METRIC = ts_mon.CumulativeDistributionMetric(
+ 'monorail/frontend/issue_create_latency', (
+ 'Latency between Issue Entry form submission and page load of '
+ 'the subsequent issue page.'
+ ), field_spec=STANDARD_FIELDS,
+ units=ts_mon.MetricsDataUnits.MILLISECONDS)
+ISSUE_UPDATE_LATENCY_METRIC = ts_mon.CumulativeDistributionMetric(
+ 'monorail/frontend/issue_update_latency', (
+ 'Latency between Issue Update form submission and page load of '
+ 'the subsequent issue page.'
+ ), field_spec=STANDARD_FIELDS,
+ units=ts_mon.MetricsDataUnits.MILLISECONDS)
+AUTOCOMPLETE_POPULATE_LATENCY_METRIC = ts_mon.CumulativeDistributionMetric(
+ 'monorail/frontend/autocomplete_populate_latency', (
+ 'Latency between page load and autocomplete options loading.'
+ ), field_spec=STANDARD_FIELDS,
+ units=ts_mon.MetricsDataUnits.MILLISECONDS)
+CHARTS_SWITCH_DATE_RANGE_METRIC = ts_mon.CounterMetric(
+ 'monorail/frontend/charts/switch_date_range', (
+ 'Number of times user clicks frequency button.'
+ ), field_spec=STANDARD_FIELDS + [ts_mon.IntegerField('date_range')])
+
+# Page load metrics.
+ISSUE_COMMENTS_LOAD_EXTRA_FIELDS = [
+ ts_mon.StringField('template_name'),
+ ts_mon.BooleanField('full_app_load'),
+]
+ISSUE_COMMENTS_LOAD_LATENCY_METRIC = ts_mon.CumulativeDistributionMetric(
+ 'monorail/frontend/issue_comments_load_latency', (
+ 'Time from navigation or click to issue comments loaded.'
+ ), field_spec=STANDARD_FIELDS + ISSUE_COMMENTS_LOAD_EXTRA_FIELDS,
+ units=ts_mon.MetricsDataUnits.MILLISECONDS)
+DOM_CONTENT_LOADED_EXTRA_FIELDS = [
+ ts_mon.StringField('template_name')]
+DOM_CONTENT_LOADED_METRIC = ts_mon.CumulativeDistributionMetric(
+ 'frontend/dom_content_loaded', (
+ 'domContentLoaded performance timing.'
+ ), field_spec=STANDARD_FIELDS + DOM_CONTENT_LOADED_EXTRA_FIELDS,
+ units=ts_mon.MetricsDataUnits.MILLISECONDS)
+
+
+ISSUE_LIST_LOAD_EXTRA_FIELDS = [
+ ts_mon.StringField('template_name'),
+ ts_mon.BooleanField('full_app_load'),
+]
+ISSUE_LIST_LOAD_LATENCY_METRIC = ts_mon.CumulativeDistributionMetric(
+ 'monorail/frontend/issue_list_load_latency', (
+ 'Time from navigation or click to search issues list loaded.'
+ ), field_spec=STANDARD_FIELDS + ISSUE_LIST_LOAD_EXTRA_FIELDS,
+ units=ts_mon.MetricsDataUnits.MILLISECONDS)
+
+
+class MonorailTSMonJSHandler(TSMonJSHandler):
+
+ def __init__(self, request=None, response=None):
+ super(MonorailTSMonJSHandler, self).__init__(request, response)
+ self.register_metrics([
+ ISSUE_CREATE_LATENCY_METRIC,
+ ISSUE_UPDATE_LATENCY_METRIC,
+ AUTOCOMPLETE_POPULATE_LATENCY_METRIC,
+ CHARTS_SWITCH_DATE_RANGE_METRIC,
+ ISSUE_COMMENTS_LOAD_LATENCY_METRIC,
+ DOM_CONTENT_LOADED_METRIC,
+ ISSUE_LIST_LOAD_LATENCY_METRIC])
+
+ def xsrf_is_valid(self, body):
+ """This method expects the body dictionary to include two fields:
+ `token` and `user_id`.
+ """
+ cnxn = sql.MonorailConnection()
+ token = body.get('token')
+ user = users.get_current_user()
+ email = user.email() if user else None
+
+ services = self.app.config.get('services')
+ auth = authdata.AuthData.FromEmail(cnxn, email, services, autocreate=False)
+ try:
+ xsrf.ValidateToken(token, auth.user_id, xsrf.XHR_SERVLET_PATH)
+ return True
+ except xsrf.TokenIncorrect:
+ return False
diff --git a/framework/urls.py b/framework/urls.py
new file mode 100644
index 0000000..d7e5e3a
--- /dev/null
+++ b/framework/urls.py
@@ -0,0 +1,157 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Constants that define the Monorail URL space."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+# URLs of site-wide Monorail pages
+HOSTING_HOME = '/hosting_old/'
+PROJECT_CREATE = '/hosting/createProject'
+USER_SETTINGS = '/hosting/settings'
+PROJECT_MOVED = '/hosting/moved'
+GROUP_LIST = '/g/'
+GROUP_CREATE = '/hosting/createGroup'
+GROUP_DELETE = '/hosting/deleteGroup'
+
+# URLs of project pages
+SUMMARY = '/' # Now just a redirect to /issues/list
+UPDATES_LIST = '/updates/list'
+PEOPLE_LIST = '/people/list'
+PEOPLE_DETAIL = '/people/detail'
+ADMIN_META = '/admin'
+ADMIN_ADVANCED = '/adminAdvanced'
+
+# URLs of user pages, relative to either /u/userid or /u/username
+# TODO(jrobbins): Add /u/userid as the canonical URL in metadata.
+USER_PROFILE = '/'
+USER_PROFILE_POLYMER = '/polymer'
+USER_CLEAR_BOUNCING = '/clearBouncing'
+BAN_USER = '/ban'
+BAN_SPAMMER = '/banSpammer'
+
+# URLs for User Updates pages
+USER_UPDATES_PROJECTS = '/updates/projects'
+USER_UPDATES_DEVELOPERS = '/updates/developers'
+USER_UPDATES_MINE = '/updates'
+
+# URLs of user group pages, relative to /g/groupname.
+GROUP_DETAIL = '/'
+GROUP_ADMIN = '/groupadmin'
+
+# URLs of issue tracker backend request handlers. Called from the frontends.
+BACKEND_SEARCH = '/_backend/search'
+BACKEND_NONVIEWABLE = '/_backend/nonviewable'
+
+# URLs of task queue request handlers. Called asynchronously from frontends.
+RECOMPUTE_DERIVED_FIELDS_TASK = '/_task/recomputeDerivedFields'
+NOTIFY_ISSUE_CHANGE_TASK = '/_task/notifyIssueChange'
+NOTIFY_BLOCKING_CHANGE_TASK = '/_task/notifyBlockingChange'
+NOTIFY_BULK_CHANGE_TASK = '/_task/notifyBulkEdit'
+NOTIFY_APPROVAL_CHANGE_TASK = '/_task/notifyApprovalChange'
+NOTIFY_RULES_DELETED_TASK = '/_task/notifyRulesDeleted'
+OUTBOUND_EMAIL_TASK = '/_task/outboundEmail'
+SPAM_DATA_EXPORT_TASK = '/_task/spamDataExport'
+BAN_SPAMMER_TASK = '/_task/banSpammer'
+ISSUE_DATE_ACTION_TASK = '/_task/issueDateAction'
+COMPONENT_DATA_EXPORT_TASK = '/_task/componentDataExportTask'
+SEND_WIPEOUT_USER_LISTS_TASK = '/_task/sendWipeoutUserListsTask'
+DELETE_WIPEOUT_USERS_TASK = '/_task/deleteWipeoutUsersTask'
+DELETE_USERS_TASK = '/_task/deleteUsersTask'
+
+# URL for publishing issue changes to a pubsub topic.
+PUBLISH_PUBSUB_ISSUE_CHANGE_TASK = '/_task/publishPubsubIssueChange'
+
+# URL for manually triggered FLT launch issue conversion job.
+FLT_ISSUE_CONVERSION_TASK = '/_task/fltConversionTask'
+
+# URLs of cron job request handlers. Called from GAE via cron.yaml.
+REINDEX_QUEUE_CRON = '/_cron/reindexQueue'
+RAMCACHE_CONSOLIDATE_CRON = '/_cron/ramCacheConsolidate'
+REAP_CRON = '/_cron/reap'
+SPAM_DATA_EXPORT_CRON = '/_cron/spamDataExport'
+LOAD_API_CLIENT_CONFIGS_CRON = '/_cron/loadApiClientConfigs'
+TRIM_VISITED_PAGES_CRON = '/_cron/trimVisitedPages'
+DATE_ACTION_CRON = '/_cron/dateAction'
+SPAM_TRAINING_CRON = '/_cron/spamTraining'
+COMPONENT_DATA_EXPORT_CRON = '/_cron/componentDataExport'
+WIPEOUT_SYNC_CRON = '/_cron/wipeoutSync'
+
+# URLs of handlers needed for GAE instance management.
+WARMUP = '/_ah/warmup'
+START = '/_ah/start'
+STOP = '/_ah/stop'
+
+# URLs of User pages
+SAVED_QUERIES = '/queries'
+DASHBOARD = '/dashboard'
+HOTLISTS = '/hotlists'
+
+# URLS of User hotlist pages
+HOTLIST_ISSUES = ''
+HOTLIST_ISSUES_CSV = '/csv'
+HOTLIST_PEOPLE = '/people'
+HOTLIST_DETAIL = '/details'
+HOTLIST_RERANK_JSON = '/rerank'
+
+# URLs of issue tracker project pages
+ISSUE_APPROVAL = '/issues/approval'
+ISSUE_LIST = '/issues/list'
+ISSUE_LIST_NEW_TEMP = '/issues/list_new'
+ISSUE_DETAIL = '/issues/detail'
+ISSUE_DETAIL_LEGACY = '/issues/detail_ezt'
+ISSUE_DETAIL_FLIPPER_NEXT = '/issues/detail/next'
+ISSUE_DETAIL_FLIPPER_PREV = '/issues/detail/previous'
+ISSUE_DETAIL_FLIPPER_LIST = '/issues/detail/list'
+ISSUE_DETAIL_FLIPPER_INDEX = '/issues/detail/flipper'
+ISSUE_WIZARD = '/issues/wizard'
+ISSUE_ENTRY = '/issues/entry'
+ISSUE_ENTRY_NEW = '/issues/entry_new'
+ISSUE_ENTRY_AFTER_LOGIN = '/issues/entryafterlogin'
+ISSUE_BULK_EDIT = '/issues/bulkedit'
+ISSUE_ADVSEARCH = '/issues/advsearch'
+ISSUE_TIPS = '/issues/searchtips'
+ISSUE_ATTACHMENT = '/issues/attachment'
+ISSUE_ATTACHMENT_TEXT = '/issues/attachmentText'
+ISSUE_LIST_CSV = '/issues/csv'
+COMPONENT_CREATE = '/components/create'
+COMPONENT_DETAIL = '/components/detail'
+FIELD_CREATE = '/fields/create'
+FIELD_DETAIL = '/fields/detail'
+TEMPLATE_CREATE ='/templates/create'
+TEMPLATE_DETAIL = '/templates/detail'
+WIKI_LIST = '/w/list' # Wiki urls are just redirects to project.docs_url
+WIKI_PAGE = '/wiki/<wiki_page:.*>'
+SOURCE_PAGE = '/source/<source_page:.*>'
+ADMIN_INTRO = '/adminIntro'
+# TODO(jrobbins): move some editing from /admin to /adminIntro.
+ADMIN_COMPONENTS = '/adminComponents'
+ADMIN_LABELS = '/adminLabels'
+ADMIN_RULES = '/adminRules'
+ADMIN_TEMPLATES = '/adminTemplates'
+ADMIN_STATUSES = '/adminStatuses'
+ADMIN_VIEWS = '/adminViews'
+ADMIN_EXPORT = '/projectExport'
+ADMIN_EXPORT_JSON = '/projectExport/json'
+ISSUE_ORIGINAL = '/issues/original'
+ISSUE_REINDEX = '/issues/reindex'
+ISSUE_EXPORT = '/issues/export'
+ISSUE_EXPORT_JSON = '/issues/export/json'
+ISSUE_IMPORT = '/issues/import'
+
+# URLs for hotlist features
+HOTLIST_CREATE = '/hosting/createHotlist'
+
+# URLs of site-wide pages referenced from the framework directory.
+CAPTCHA_QUESTION = '/hosting/captcha'
+EXCESSIVE_ACTIVITY = '/hosting/excessiveActivity'
+BANNED = '/hosting/noAccess'
+CLIENT_MON = '/_/clientmon'
+TS_MON_JS = '/_/jstsmon'
+
+CSP_REPORT = '/csp'
+
+SPAM_MODERATION_QUEUE = '/spamqueue'
diff --git a/framework/validate.py b/framework/validate.py
new file mode 100644
index 0000000..ee26396
--- /dev/null
+++ b/framework/validate.py
@@ -0,0 +1,112 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of Python input field validators."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import re
+
+# RFC 2821-compliant email address regex
+#
+# Please see sections "4.1.2 Command Argument Syntax" and
+# "4.1.3 Address Literals" of: http://www.faqs.org/rfcs/rfc2821.html
+#
+# The following implementation is still a subset of RFC 2821. Fully
+# double-quoted <user> parts are not supported (since the RFC discourages
+# their use anyway), and using the backslash to escape other characters
+# that are normally invalid, such as commas, is not supported.
+#
+# The groups in this regular expression are:
+#
+# <user>: all of the valid non-quoted portion of the email address before
+# the @ sign (not including the @ sign)
+#
+# <domain>: all of the domain name between the @ sign (but not including it)
+# and the dot before the TLD (but not including that final dot)
+#
+# <tld>: the top-level domain after the last dot (but not including that
+# final dot)
+#
+_RFC_2821_EMAIL_REGEX = r"""(?x)
+ (?P<user>
+ # Part of the username that comes before any dots that may occur in it.
+ # At least one of the listed non-dot characters is required before the
+ # first dot.
+ [-a-zA-Z0-9!#$%&'*+/=?^_`{|}~]+
+
+ # Remaining part of the username that starts with the dot and
+ # which may have other dots, if such a part exists. Only one dot
+ # is permitted between each "Atom", and a trailing dot is not permitted.
+ (?:[.][-a-zA-Z0-9!#$%&'*+/=?^_`{|}~]+)*
+ )
+
+ # Domain name, where subdomains are allowed. Also, dashes are allowed
+ # given that they are preceded and followed by at least one character.
+ @(?P<domain>
+ (?:[0-9a-zA-Z] # at least one non-dash
+ (?:[-]* # plus zero or more dashes
+ [0-9a-zA-Z]+ # plus at least one non-dash
+ )* # zero or more of dashes followed by non-dashes
+ ) # one required domain part (may be a sub-domain)
+
+ (?:\. # dot separator before additional sub-domain part
+ [0-9a-zA-Z] # at least one non-dash
+ (?:[-]* # plus zero or more dashes
+ [0-9a-zA-Z]+ # plus at least one non-dash
+ )* # zero or more of dashes followed by non-dashes
+ )* # at least one sub-domain part and a dot
+ )
+ \. # dot separator before TLD
+
+ # TLD, the part after 'usernames@domain.' which can consist of 2-9
+ # letters.
+ (?P<tld>[a-zA-Z]{2,9})
+ """
+
+# object used with <re>.search() or <re>.sub() to find email addresses
+# within a string (or with <re>.match() to find email addresses at the
+# beginning of a string that may be followed by trailing characters,
+# since <re>.match() implicitly anchors at the beginning of the string)
+RE_EMAIL_SEARCH = re.compile(_RFC_2821_EMAIL_REGEX)
+
+# object used with <re>.match to find strings that contain *only* a single
+# email address (by adding the end-of-string anchor $)
+RE_EMAIL_ONLY = re.compile('^%s$' % _RFC_2821_EMAIL_REGEX)
+
+_SCHEME_PATTERN = r'(?:https?|ftp)://'
+_SHORT_HOST_PATTERN = (
+ r'(?=[a-zA-Z])[-a-zA-Z0-9]*[a-zA-Z0-9](:[0-9]+)?'
+ r'/' # Slash is manditory for short host names.
+ r'[^\s]*'
+ )
+_DOTTED_HOST_PATTERN = (
+ r'[-a-zA-Z0-9.]+\.[a-zA-Z]{2,9}(:[0-9]+)?'
+ r'(/[^\s]*)?'
+ )
+_URL_REGEX = r'%s(%s|%s)' % (
+ _SCHEME_PATTERN, _SHORT_HOST_PATTERN, _DOTTED_HOST_PATTERN)
+
+# A more complete URL regular expression based on a combination of the
+# existing _URL_REGEX and the pattern found for URI regular expressions
+# found in the URL RFC document. It's detailed here:
+# http://www.ietf.org/rfc/rfc2396.txt
+RE_COMPLEX_URL = re.compile(r'^%s(\?([^# ]*))?(#(.*))?$' % _URL_REGEX)
+
+
+def IsValidEmail(s):
+ """Return true iff the string is a properly formatted email address."""
+ return RE_EMAIL_ONLY.match(s)
+
+
+def IsValidMailTo(s):
+ """Return true iff the string is a properly formatted mailto:."""
+ return s.startswith('mailto:') and RE_EMAIL_ONLY.match(s[7:])
+
+
+def IsValidURL(s):
+ """Return true iff the string is a properly formatted web or ftp URL."""
+ return RE_COMPLEX_URL.match(s)
diff --git a/framework/warmup.py b/framework/warmup.py
new file mode 100644
index 0000000..ef8a53d
--- /dev/null
+++ b/framework/warmup.py
@@ -0,0 +1,51 @@
+# Copyright 2017 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A class to handle the initial warmup request from AppEngine."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+
+from framework import jsonfeed
+
+
+class Warmup(jsonfeed.InternalTask):
+ """Placeholder for warmup work. Used only to enable min_idle_instances."""
+
+ def HandleRequest(self, _mr):
+ """Don't do anything that could cause a jam when many instances start."""
+ logging.info('/_ah/startup does nothing in Monorail.')
+ logging.info('However it is needed for min_idle_instances in app.yaml.')
+
+ return {
+ 'success': 1,
+ }
+
+class Start(jsonfeed.InternalTask):
+ """Placeholder for start work. Used only to enable manual_scaling."""
+
+ def HandleRequest(self, _mr):
+ """Don't do anything that could cause a jam when many instances start."""
+ logging.info('/_ah/start does nothing in Monorail.')
+ logging.info('However it is needed for manual_scaling in app.yaml.')
+
+ return {
+ 'success': 1,
+ }
+
+
+class Stop(jsonfeed.InternalTask):
+ """Placeholder for stop work. Used only to enable manual_scaling."""
+
+ def HandleRequest(self, _mr):
+ """Don't do anything that could cause a jam when many instances start."""
+ logging.info('/_ah/stop does nothing in Monorail.')
+ logging.info('However it is needed for manual_scaling in app.yaml.')
+
+ return {
+ 'success': 1,
+ }
diff --git a/framework/xsrf.py b/framework/xsrf.py
new file mode 100644
index 0000000..75581ef
--- /dev/null
+++ b/framework/xsrf.py
@@ -0,0 +1,138 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Utility routines for avoiding cross-site-request-forgery."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import base64
+import hmac
+import logging
+import time
+
+# This is a file in the top-level directory that you must edit before deploying
+import settings
+from framework import framework_constants
+from services import secrets_svc
+
+# This is how long tokens are valid.
+TOKEN_TIMEOUT_SEC = 2 * framework_constants.SECS_PER_HOUR
+
+# The token refresh servlet accepts old tokens to generate new ones, but
+# we still impose a limit on how old they can be.
+REFRESH_TOKEN_TIMEOUT_SEC = 10 * framework_constants.SECS_PER_DAY
+
+# When the JS on a page decides whether or not it needs to refresh the
+# XSRF token before submitting a form, there could be some clock skew,
+# so we subtract a little time to avoid having the JS use an existing
+# token that the server might consider expired already.
+TOKEN_TIMEOUT_MARGIN_SEC = 5 * framework_constants.SECS_PER_MINUTE
+
+# When checking that the token is not from the future, allow a little
+# margin for the possibliity that the clock of the GAE instance that
+# generated the token could be a little ahead of the one checking.
+CLOCK_SKEW_SEC = 5
+
+# Form tokens and issue stars are limited to only work with the specific
+# servlet path for the servlet that processes them. There are several
+# XHR handlers that mainly read data without making changes, so we just
+# use 'xhr' with all of them.
+XHR_SERVLET_PATH = 'xhr'
+
+
+DELIMITER = ':'
+
+
+def GenerateToken(user_id, servlet_path, token_time=None):
+ """Return a security token specifically for the given user.
+
+ Args:
+ user_id: int user ID of the user viewing an HTML form.
+ servlet_path: string URI path to limit the use of the token.
+ token_time: Time at which the token is generated in seconds since the epoch.
+
+ Returns:
+ A url-safe security token. The token is a string with the digest
+ the user_id and time, followed by plain-text copy of the time that is
+ used in validation.
+
+ Raises:
+ ValueError: if the XSRF secret was not configured.
+ """
+ token_time = token_time or int(time.time())
+ digester = hmac.new(secrets_svc.GetXSRFKey())
+ digester.update(str(user_id))
+ digester.update(DELIMITER)
+ digester.update(servlet_path)
+ digester.update(DELIMITER)
+ digester.update(str(token_time))
+ digest = digester.digest()
+
+ token = base64.urlsafe_b64encode('%s%s%d' % (digest, DELIMITER, token_time))
+ return token
+
+
+def ValidateToken(
+ token, user_id, servlet_path, timeout=TOKEN_TIMEOUT_SEC):
+ """Return True if the given token is valid for the given scope.
+
+ Args:
+ token: String token that was presented by the user.
+ user_id: int user ID.
+ servlet_path: string URI path to limit the use of the token.
+
+ Raises:
+ TokenIncorrect: if the token is missing or invalid.
+ """
+ if not token:
+ raise TokenIncorrect('missing token')
+
+ try:
+ decoded = base64.urlsafe_b64decode(str(token))
+ token_time = int(decoded.split(DELIMITER)[-1])
+ except (TypeError, ValueError):
+ raise TokenIncorrect('could not decode token')
+ now = int(time.time())
+
+ # The given token should match the generated one with the same time.
+ expected_token = GenerateToken(user_id, servlet_path, token_time=token_time)
+ if len(token) != len(expected_token):
+ raise TokenIncorrect('presented token is wrong size')
+
+ # Perform constant time comparison to avoid timing attacks
+ different = 0
+ for x, y in zip(token, expected_token):
+ different |= ord(x) ^ ord(y)
+ if different:
+ raise TokenIncorrect(
+ 'presented token does not match expected token: %r != %r' % (
+ token, expected_token))
+
+ # We reject tokens from the future.
+ if token_time > now + CLOCK_SKEW_SEC:
+ raise TokenIncorrect('token is from future')
+
+ # We check expiration last so that we only raise the expriration error
+ # if the token would have otherwise been valid.
+ if now - token_time > timeout:
+ raise TokenIncorrect('token has expired')
+
+
+def TokenExpiresSec():
+ """Return timestamp when current tokens will expire, minus a safety margin."""
+ now = int(time.time())
+ return now + TOKEN_TIMEOUT_SEC - TOKEN_TIMEOUT_MARGIN_SEC
+
+
+class Error(Exception):
+ """Base class for errors from this module."""
+ pass
+
+
+# Caught separately in servlet.py
+class TokenIncorrect(Error):
+ """The POST body has an incorrect URL Command Attack token."""
+ pass