Project import generated by Copybara.
GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/services/__init__.py b/services/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/services/__init__.py
@@ -0,0 +1 @@
+
diff --git a/services/api_pb2_v1_helpers.py b/services/api_pb2_v1_helpers.py
new file mode 100644
index 0000000..dcdea66
--- /dev/null
+++ b/services/api_pb2_v1_helpers.py
@@ -0,0 +1,628 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Convert Monorail PB objects to API PB objects"""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import datetime
+import logging
+import time
+
+from six import string_types
+
+from businesslogic import work_env
+from framework import exceptions
+from framework import framework_constants
+from framework import framework_helpers
+from framework import framework_views
+from framework import permissions
+from framework import timestr
+from proto import api_pb2_v1
+from proto import project_pb2
+from proto import tracker_pb2
+from services import project_svc
+from tracker import field_helpers
+from tracker import tracker_bizobj
+from tracker import tracker_helpers
+
+
+def convert_project(project, config, role, templates):
+ """Convert Monorail Project PB to API ProjectWrapper PB."""
+
+ return api_pb2_v1.ProjectWrapper(
+ kind='monorail#project',
+ name=project.project_name,
+ externalId=project.project_name,
+ htmlLink='/p/%s/' % project.project_name,
+ summary=project.summary,
+ description=project.description,
+ role=role,
+ issuesConfig=convert_project_config(config, templates))
+
+
+def convert_project_config(config, templates):
+ """Convert Monorail ProjectIssueConfig PB to API ProjectIssueConfig PB."""
+
+ return api_pb2_v1.ProjectIssueConfig(
+ kind='monorail#projectIssueConfig',
+ restrictToKnown=config.restrict_to_known,
+ defaultColumns=config.default_col_spec.split(),
+ defaultSorting=config.default_sort_spec.split(),
+ statuses=[convert_status(s) for s in config.well_known_statuses],
+ labels=[convert_label(l) for l in config.well_known_labels],
+ prompts=[convert_template(t) for t in templates],
+ defaultPromptForMembers=config.default_template_for_developers,
+ defaultPromptForNonMembers=config.default_template_for_users)
+
+
+def convert_status(status):
+ """Convert Monorail StatusDef PB to API Status PB."""
+
+ return api_pb2_v1.Status(
+ status=status.status,
+ meansOpen=status.means_open,
+ description=status.status_docstring)
+
+
+def convert_label(label):
+ """Convert Monorail LabelDef PB to API Label PB."""
+
+ return api_pb2_v1.Label(
+ label=label.label,
+ description=label.label_docstring)
+
+
+def convert_template(template):
+ """Convert Monorail TemplateDef PB to API Prompt PB."""
+
+ return api_pb2_v1.Prompt(
+ name=template.name,
+ title=template.summary,
+ description=template.content,
+ titleMustBeEdited=template.summary_must_be_edited,
+ status=template.status,
+ labels=template.labels,
+ membersOnly=template.members_only,
+ defaultToMember=template.owner_defaults_to_member,
+ componentRequired=template.component_required)
+
+
+def convert_person(user_id, cnxn, services, trap_exception=False):
+ """Convert user id to API AtomPerson PB or None if user_id is None."""
+
+ if not user_id:
+ # convert_person should handle 'converting' optional user values,
+ # like issue.owner, where user_id may be None.
+ return None
+ if user_id == framework_constants.DELETED_USER_ID:
+ return api_pb2_v1.AtomPerson(
+ kind='monorail#issuePerson',
+ name=framework_constants.DELETED_USER_NAME)
+ try:
+ user = services.user.GetUser(cnxn, user_id)
+ except exceptions.NoSuchUserException as ex:
+ if trap_exception:
+ logging.warning(str(ex))
+ return None
+ else:
+ raise ex
+
+ days_ago = None
+ if user.last_visit_timestamp:
+ secs_ago = int(time.time()) - user.last_visit_timestamp
+ days_ago = secs_ago // framework_constants.SECS_PER_DAY
+ return api_pb2_v1.AtomPerson(
+ kind='monorail#issuePerson',
+ name=user.email,
+ htmlLink='https://%s/u/%d' % (framework_helpers.GetHostPort(), user_id),
+ last_visit_days_ago=days_ago,
+ email_bouncing=bool(user.email_bounce_timestamp),
+ vacation_message=user.vacation_message)
+
+
+def convert_issue_ids(issue_ids, mar, services):
+ """Convert global issue ids to API IssueRef PB."""
+
+ # missed issue ids are filtered out.
+ issues = services.issue.GetIssues(mar.cnxn, issue_ids)
+ result = []
+ for issue in issues:
+ issue_ref = api_pb2_v1.IssueRef(
+ issueId=issue.local_id,
+ projectId=issue.project_name,
+ kind='monorail#issueRef')
+ result.append(issue_ref)
+ return result
+
+
+def convert_issueref_pbs(issueref_pbs, mar, services):
+ """Convert API IssueRef PBs to global issue ids."""
+
+ if issueref_pbs:
+ result = []
+ for ir in issueref_pbs:
+ project_id = mar.project_id
+ if ir.projectId:
+ project = services.project.GetProjectByName(
+ mar.cnxn, ir.projectId)
+ if project:
+ project_id = project.project_id
+ try:
+ issue = services.issue.GetIssueByLocalID(
+ mar.cnxn, project_id, ir.issueId)
+ result.append(issue.issue_id)
+ except exceptions.NoSuchIssueException:
+ logging.warning(
+ 'Issue (%s:%d) does not exist.' % (ir.projectId, ir.issueId))
+ return result
+ else:
+ return None
+
+
+def convert_approvals(cnxn, approval_values, services, config, phases):
+ """Convert an Issue's Monorail ApprovalValue PBs to API Approval"""
+ fds_by_id = {fd.field_id: fd for fd in config.field_defs}
+ phases_by_id = {phase.phase_id: phase for phase in phases}
+ approvals = []
+ for av in approval_values:
+ approval_fd = fds_by_id.get(av.approval_id)
+ if approval_fd is None:
+ logging.warning(
+ 'Approval (%d) does not exist' % av.approval_id)
+ continue
+ if approval_fd.field_type is not tracker_pb2.FieldTypes.APPROVAL_TYPE:
+ logging.warning(
+ 'field %s has unexpected field_type: %s' % (
+ approval_fd.field_name, approval_fd.field_type.name))
+ continue
+
+ approval = api_pb2_v1.Approval()
+ approval.approvalName = approval_fd.field_name
+ approvers = [convert_person(approver_id, cnxn, services)
+ for approver_id in av.approver_ids]
+ approval.approvers = [approver for approver in approvers if approver]
+
+ approval.status = api_pb2_v1.ApprovalStatus(av.status.number)
+ if av.setter_id:
+ approval.setter = convert_person(av.setter_id, cnxn, services)
+ if av.set_on:
+ approval.setOn = datetime.datetime.fromtimestamp(av.set_on)
+ if av.phase_id:
+ try:
+ approval.phaseName = phases_by_id[av.phase_id].name
+ except KeyError:
+ logging.warning('phase %d not found in given phases list' % av.phase_id)
+ approvals.append(approval)
+ return approvals
+
+
+def convert_phases(phases):
+ """Convert an Issue's Monorail Phase PBs to API Phase."""
+ converted_phases = []
+ for idx, phase in enumerate(phases):
+ if not phase.name:
+ try:
+ logging.warning(
+ 'Phase %d has no name, skipping conversion.' % phase.phase_id)
+ except TypeError:
+ logging.warning(
+ 'Phase #%d (%s) has no name or id, skipping conversion.' % (
+ idx, phase))
+ continue
+ converted = api_pb2_v1.Phase(phaseName=phase.name, rank=phase.rank)
+ converted_phases.append(converted)
+ return converted_phases
+
+
+def convert_issue(cls, issue, mar, services):
+ """Convert Monorail Issue PB to API IssuesGetInsertResponse."""
+
+ config = services.config.GetProjectConfig(mar.cnxn, issue.project_id)
+ granted_perms = tracker_bizobj.GetGrantedPerms(
+ issue, mar.auth.effective_ids, config)
+ issue_project = services.project.GetProject(mar.cnxn, issue.project_id)
+ component_list = []
+ for cd in config.component_defs:
+ cid = cd.component_id
+ if cid in issue.component_ids:
+ component_list.append(cd.path)
+ cc_list = [convert_person(p, mar.cnxn, services) for p in issue.cc_ids]
+ cc_list = [p for p in cc_list if p is not None]
+ field_values_list = []
+ fds_by_id = {
+ fd.field_id: fd for fd in config.field_defs}
+ phases_by_id = {phase.phase_id: phase for phase in issue.phases}
+ for fv in issue.field_values:
+ fd = fds_by_id.get(fv.field_id)
+ if not fd:
+ logging.warning('Custom field %d of project %s does not exist',
+ fv.field_id, issue_project.project_name)
+ continue
+ val = None
+ if fv.user_id:
+ val = _get_user_email(
+ services.user, mar.cnxn, fv.user_id)
+ else:
+ val = tracker_bizobj.GetFieldValue(fv, {})
+ if not isinstance(val, string_types):
+ val = str(val)
+ new_fv = api_pb2_v1.FieldValue(
+ fieldName=fd.field_name,
+ fieldValue=val,
+ derived=fv.derived)
+ if fd.approval_id: # Attach parent approval name
+ approval_fd = fds_by_id.get(fd.approval_id)
+ if not approval_fd:
+ logging.warning('Parent approval field %d of field %s does not exist',
+ fd.approval_id, fd.field_name)
+ else:
+ new_fv.approvalName = approval_fd.field_name
+ elif fv.phase_id: # Attach phase name
+ phase = phases_by_id.get(fv.phase_id)
+ if not phase:
+ logging.warning('Phase %d for field %s does not exist',
+ fv.phase_id, fd.field_name)
+ else:
+ new_fv.phaseName = phase.name
+ field_values_list.append(new_fv)
+ approval_values_list = convert_approvals(
+ mar.cnxn, issue.approval_values, services, config, issue.phases)
+ phases_list = convert_phases(issue.phases)
+ with work_env.WorkEnv(mar, services) as we:
+ starred = we.IsIssueStarred(issue)
+ resp = cls(
+ kind='monorail#issue',
+ id=issue.local_id,
+ title=issue.summary,
+ summary=issue.summary,
+ projectId=issue_project.project_name,
+ stars=issue.star_count,
+ starred=starred,
+ status=issue.status,
+ state=(api_pb2_v1.IssueState.open if
+ tracker_helpers.MeansOpenInProject(
+ tracker_bizobj.GetStatus(issue), config)
+ else api_pb2_v1.IssueState.closed),
+ labels=issue.labels,
+ components=component_list,
+ author=convert_person(issue.reporter_id, mar.cnxn, services),
+ owner=convert_person(issue.owner_id, mar.cnxn, services),
+ cc=cc_list,
+ updated=datetime.datetime.fromtimestamp(issue.modified_timestamp),
+ published=datetime.datetime.fromtimestamp(issue.opened_timestamp),
+ blockedOn=convert_issue_ids(issue.blocked_on_iids, mar, services),
+ blocking=convert_issue_ids(issue.blocking_iids, mar, services),
+ canComment=permissions.CanCommentIssue(
+ mar.auth.effective_ids, mar.perms, issue_project, issue,
+ granted_perms=granted_perms),
+ canEdit=permissions.CanEditIssue(
+ mar.auth.effective_ids, mar.perms, issue_project, issue,
+ granted_perms=granted_perms),
+ fieldValues=field_values_list,
+ approvalValues=approval_values_list,
+ phases=phases_list
+ )
+ if issue.closed_timestamp > 0:
+ resp.closed = datetime.datetime.fromtimestamp(issue.closed_timestamp)
+ if issue.merged_into:
+ resp.mergedInto=convert_issue_ids([issue.merged_into], mar, services)[0]
+ if issue.owner_modified_timestamp:
+ resp.owner_modified = datetime.datetime.fromtimestamp(
+ issue.owner_modified_timestamp)
+ if issue.status_modified_timestamp:
+ resp.status_modified = datetime.datetime.fromtimestamp(
+ issue.status_modified_timestamp)
+ if issue.component_modified_timestamp:
+ resp.component_modified = datetime.datetime.fromtimestamp(
+ issue.component_modified_timestamp)
+ return resp
+
+
+def convert_comment(issue, comment, mar, services, granted_perms):
+ """Convert Monorail IssueComment PB to API IssueCommentWrapper."""
+
+ perms = permissions.UpdateIssuePermissions(
+ mar.perms, mar.project, issue, mar.auth.effective_ids,
+ granted_perms=granted_perms)
+ commenter = services.user.GetUser(mar.cnxn, comment.user_id)
+ can_delete = permissions.CanDeleteComment(
+ comment, commenter, mar.auth.user_id, perms)
+
+ return api_pb2_v1.IssueCommentWrapper(
+ attachments=[convert_attachment(a) for a in comment.attachments],
+ author=convert_person(comment.user_id, mar.cnxn, services,
+ trap_exception=True),
+ canDelete=can_delete,
+ content=comment.content,
+ deletedBy=convert_person(comment.deleted_by, mar.cnxn, services,
+ trap_exception=True),
+ id=comment.sequence,
+ published=datetime.datetime.fromtimestamp(comment.timestamp),
+ updates=convert_amendments(issue, comment.amendments, mar, services),
+ kind='monorail#issueComment',
+ is_description=comment.is_description)
+
+def convert_approval_comment(issue, comment, mar, services, granted_perms):
+ perms = permissions.UpdateIssuePermissions(
+ mar.perms, mar.project, issue, mar.auth.effective_ids,
+ granted_perms=granted_perms)
+ commenter = services.user.GetUser(mar.cnxn, comment.user_id)
+ can_delete = permissions.CanDeleteComment(
+ comment, commenter, mar.auth.user_id, perms)
+
+ return api_pb2_v1.ApprovalCommentWrapper(
+ attachments=[convert_attachment(a) for a in comment.attachments],
+ author=convert_person(
+ comment.user_id, mar.cnxn, services, trap_exception=True),
+ canDelete=can_delete,
+ content=comment.content,
+ deletedBy=convert_person(comment.deleted_by, mar.cnxn, services,
+ trap_exception=True),
+ id=comment.sequence,
+ published=datetime.datetime.fromtimestamp(comment.timestamp),
+ approvalUpdates=convert_approval_amendments(
+ comment.amendments, mar, services),
+ kind='monorail#approvalComment',
+ is_description=comment.is_description)
+
+
+def convert_attachment(attachment):
+ """Convert Monorail Attachment PB to API Attachment."""
+
+ return api_pb2_v1.Attachment(
+ attachmentId=attachment.attachment_id,
+ fileName=attachment.filename,
+ fileSize=attachment.filesize,
+ mimetype=attachment.mimetype,
+ isDeleted=attachment.deleted)
+
+
+def convert_amendments(issue, amendments, mar, services):
+ """Convert a list of Monorail Amendment PBs to API Update."""
+ amendments_user_ids = tracker_bizobj.UsersInvolvedInAmendments(amendments)
+ users_by_id = framework_views.MakeAllUserViews(
+ mar.cnxn, services.user, amendments_user_ids)
+ framework_views.RevealAllEmailsToMembers(
+ mar.cnxn, services, mar.auth, users_by_id, mar.project)
+
+ result = api_pb2_v1.Update(kind='monorail#issueCommentUpdate')
+ for amendment in amendments:
+ if amendment.field == tracker_pb2.FieldID.SUMMARY:
+ result.summary = amendment.newvalue
+ elif amendment.field == tracker_pb2.FieldID.STATUS:
+ result.status = amendment.newvalue
+ elif amendment.field == tracker_pb2.FieldID.OWNER:
+ if len(amendment.added_user_ids) == 0:
+ result.owner = framework_constants.NO_USER_NAME
+ else:
+ result.owner = _get_user_email(
+ services.user, mar.cnxn, amendment.added_user_ids[0])
+ elif amendment.field == tracker_pb2.FieldID.LABELS:
+ result.labels = amendment.newvalue.split()
+ elif amendment.field == tracker_pb2.FieldID.CC:
+ for user_id in amendment.added_user_ids:
+ user_email = _get_user_email(
+ services.user, mar.cnxn, user_id)
+ result.cc.append(user_email)
+ for user_id in amendment.removed_user_ids:
+ user_email = _get_user_email(
+ services.user, mar.cnxn, user_id)
+ result.cc.append('-%s' % user_email)
+ elif amendment.field == tracker_pb2.FieldID.BLOCKEDON:
+ result.blockedOn = _append_project(
+ amendment.newvalue, issue.project_name)
+ elif amendment.field == tracker_pb2.FieldID.BLOCKING:
+ result.blocking = _append_project(
+ amendment.newvalue, issue.project_name)
+ elif amendment.field == tracker_pb2.FieldID.MERGEDINTO:
+ result.mergedInto = amendment.newvalue
+ elif amendment.field == tracker_pb2.FieldID.COMPONENTS:
+ result.components = amendment.newvalue.split()
+ elif amendment.field == tracker_pb2.FieldID.CUSTOM:
+ fv = api_pb2_v1.FieldValue()
+ fv.fieldName = amendment.custom_field_name
+ fv.fieldValue = tracker_bizobj.AmendmentString(amendment, users_by_id)
+ result.fieldValues.append(fv)
+
+ return result
+
+
+def convert_approval_amendments(amendments, mar, services):
+ """Convert a list of Monorail Amendment PBs API ApprovalUpdate."""
+ amendments_user_ids = tracker_bizobj.UsersInvolvedInAmendments(amendments)
+ users_by_id = framework_views.MakeAllUserViews(
+ mar.cnxn, services.user, amendments_user_ids)
+ framework_views.RevealAllEmailsToMembers(
+ mar.cnxn, services, mar.auth, users_by_id, mar.project)
+
+ result = api_pb2_v1.ApprovalUpdate(kind='monorail#approvalCommentUpdate')
+ for amendment in amendments:
+ if amendment.field == tracker_pb2.FieldID.CUSTOM:
+ if amendment.custom_field_name == 'Status':
+ status_number = tracker_pb2.ApprovalStatus(
+ amendment.newvalue.upper()).number
+ result.status = api_pb2_v1.ApprovalStatus(status_number).name
+ elif amendment.custom_field_name == 'Approvers':
+ for user_id in amendment.added_user_ids:
+ user_email = _get_user_email(
+ services.user, mar.cnxn, user_id)
+ result.approvers.append(user_email)
+ for user_id in amendment.removed_user_ids:
+ user_email = _get_user_email(
+ services.user, mar.cnxn, user_id)
+ result.approvers.append('-%s' % user_email)
+ else:
+ fv = api_pb2_v1.FieldValue()
+ fv.fieldName = amendment.custom_field_name
+ fv.fieldValue = tracker_bizobj.AmendmentString(amendment, users_by_id)
+ # TODO(jojwang): monorail:4229, add approvalName field to FieldValue
+ result.fieldValues.append(fv)
+
+ return result
+
+
+def _get_user_email(user_service, cnxn, user_id):
+ """Get user email."""
+
+ if user_id == framework_constants.DELETED_USER_ID:
+ return framework_constants.DELETED_USER_NAME
+ if not user_id:
+ # _get_user_email should handle getting emails for optional user values,
+ # like issue.owner where user_id may be None.
+ return framework_constants.NO_USER_NAME
+ try:
+ user_email = user_service.LookupUserEmail(
+ cnxn, user_id)
+ except exceptions.NoSuchUserException:
+ user_email = framework_constants.USER_NOT_FOUND_NAME
+ return user_email
+
+
+def _append_project(issue_ids, project_name):
+ """Append project name to convert <id> to <project>:<id> format."""
+
+ result = []
+ id_list = issue_ids.split()
+ for id_str in id_list:
+ if ':' in id_str:
+ result.append(id_str)
+ # '-' means this issue is being removed
+ elif id_str.startswith('-'):
+ result.append('-%s:%s' % (project_name, id_str[1:]))
+ else:
+ result.append('%s:%s' % (project_name, id_str))
+ return result
+
+
+def split_remove_add(item_list):
+ """Split one list of items into two: items to add and items to remove."""
+
+ list_to_add = []
+ list_to_remove = []
+
+ for item in item_list:
+ if item.startswith('-'):
+ list_to_remove.append(item[1:])
+ else:
+ list_to_add.append(item)
+
+ return list_to_add, list_to_remove
+
+
+# TODO(sheyang): batch the SQL queries to fetch projects/issues.
+def issue_global_ids(project_local_id_pairs, project_id, mar, services):
+ """Find global issues ids given <project_name>:<issue_local_id> pairs."""
+
+ result = []
+ for pair in project_local_id_pairs:
+ issue_project_id = None
+ local_id = None
+ if ':' in pair:
+ pair_ary = pair.split(':')
+ project_name = pair_ary[0]
+ local_id = int(pair_ary[1])
+ project = services.project.GetProjectByName(mar.cnxn, project_name)
+ if not project:
+ raise exceptions.NoSuchProjectException(
+ 'Project %s does not exist' % project_name)
+ issue_project_id = project.project_id
+ else:
+ issue_project_id = project_id
+ local_id = int(pair)
+ result.append(
+ services.issue.LookupIssueID(mar.cnxn, issue_project_id, local_id))
+
+ return result
+
+
+def convert_group_settings(group_name, setting):
+ """Convert UserGroupSettings to UserGroupSettingsWrapper."""
+ return api_pb2_v1.UserGroupSettingsWrapper(
+ groupName=group_name,
+ who_can_view_members=setting.who_can_view_members,
+ ext_group_type=setting.ext_group_type,
+ last_sync_time=setting.last_sync_time)
+
+
+def convert_component_def(cd, mar, services):
+ """Convert ComponentDef PB to Component PB."""
+ project_name = services.project.LookupProjectNames(
+ mar.cnxn, [cd.project_id])[cd.project_id]
+ user_ids = set()
+ user_ids.update(
+ cd.admin_ids + cd.cc_ids + [cd.creator_id] + [cd.modifier_id])
+ user_names_dict = services.user.LookupUserEmails(mar.cnxn, list(user_ids))
+ component = api_pb2_v1.Component(
+ componentId=cd.component_id,
+ projectName=project_name,
+ componentPath=cd.path,
+ description=cd.docstring,
+ admin=sorted([user_names_dict[uid] for uid in cd.admin_ids]),
+ cc=sorted([user_names_dict[uid] for uid in cd.cc_ids]),
+ deprecated=cd.deprecated)
+ if cd.created:
+ component.created = datetime.datetime.fromtimestamp(cd.created)
+ component.creator = user_names_dict[cd.creator_id]
+ if cd.modified:
+ component.modified = datetime.datetime.fromtimestamp(cd.modified)
+ component.modifier = user_names_dict[cd.modifier_id]
+ return component
+
+
+def convert_component_ids(config, component_names):
+ """Convert a list of component names to ids."""
+ component_names_lower = [name.lower() for name in component_names]
+ result = []
+ for cd in config.component_defs:
+ cpath = cd.path
+ if cpath.lower() in component_names_lower:
+ result.append(cd.component_id)
+ return result
+
+
+def convert_field_values(field_values, mar, services):
+ """Convert user passed in field value list to FieldValue PB, or labels."""
+ fv_list_add = []
+ fv_list_remove = []
+ fv_list_clear = []
+ label_list_add = []
+ label_list_remove = []
+ field_name_dict = {
+ fd.field_name: fd for fd in mar.config.field_defs}
+
+ for fv in field_values:
+ field_def = field_name_dict.get(fv.fieldName)
+ if not field_def:
+ logging.warning('Custom field %s of does not exist', fv.fieldName)
+ continue
+
+ if fv.operator == api_pb2_v1.FieldValueOperator.clear:
+ fv_list_clear.append(field_def.field_id)
+ continue
+
+ # Enum fields are stored as labels
+ if field_def.field_type == tracker_pb2.FieldTypes.ENUM_TYPE:
+ raw_val = '%s-%s' % (fv.fieldName, fv.fieldValue)
+ if fv.operator == api_pb2_v1.FieldValueOperator.remove:
+ label_list_remove.append(raw_val)
+ elif fv.operator == api_pb2_v1.FieldValueOperator.add:
+ label_list_add.append(raw_val)
+ else: # pragma: no cover
+ logging.warning('Unsupported field value operater %s', fv.operator)
+ else:
+ new_fv = field_helpers.ParseOneFieldValue(
+ mar.cnxn, services.user, field_def, fv.fieldValue)
+ if fv.operator == api_pb2_v1.FieldValueOperator.remove:
+ fv_list_remove.append(new_fv)
+ elif fv.operator == api_pb2_v1.FieldValueOperator.add:
+ fv_list_add.append(new_fv)
+ else: # pragma: no cover
+ logging.warning('Unsupported field value operater %s', fv.operator)
+
+ return (fv_list_add, fv_list_remove, fv_list_clear,
+ label_list_add, label_list_remove)
diff --git a/services/api_svc_v1.py b/services/api_svc_v1.py
new file mode 100644
index 0000000..20a9c8b
--- /dev/null
+++ b/services/api_svc_v1.py
@@ -0,0 +1,1511 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""API service.
+
+To manually test this API locally, use the following steps:
+1. Start the development server via 'make serve'.
+2. Start a new Chrome session via the command-line:
+ PATH_TO_CHROME --user-data-dir=/tmp/test \
+ --unsafely-treat-insecure-origin-as-secure=http://localhost:8080
+3. Visit http://localhost:8080/_ah/api/explorer
+4. Click shield icon in the omnibar and allow unsafe scripts.
+5. Click on the "Services" menu item in the API Explorer.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import calendar
+import datetime
+import endpoints
+import functools
+import logging
+import re
+import time
+from google.appengine.api import oauth
+from protorpc import message_types
+from protorpc import protojson
+from protorpc import remote
+
+import settings
+from businesslogic import work_env
+from features import filterrules_helpers
+from features import send_notifications
+from framework import authdata
+from framework import exceptions
+from framework import framework_constants
+from framework import framework_helpers
+from framework import framework_views
+from framework import monitoring
+from framework import monorailrequest
+from framework import permissions
+from framework import ratelimiter
+from framework import sql
+from project import project_helpers
+from proto import api_pb2_v1
+from proto import project_pb2
+from proto import tracker_pb2
+from search import frontendsearchpipeline
+from services import api_pb2_v1_helpers
+from services import client_config_svc
+from services import service_manager
+from services import tracker_fulltext
+from sitewide import sitewide_helpers
+from tracker import field_helpers
+from tracker import issuedetailezt
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+from tracker import tracker_helpers
+
+from infra_libs import ts_mon
+
+
+ENDPOINTS_API_NAME = 'monorail'
+DOC_URL = (
+ 'https://chromium.googlesource.com/infra/infra/+/main/'
+ 'appengine/monorail/doc/api.md')
+
+
+def monorail_api_method(
+ request_message, response_message, **kwargs):
+ """Extends endpoints.method by performing base checks."""
+ time_fn = kwargs.pop('time_fn', time.time)
+ method_name = kwargs.get('name', '')
+ method_path = kwargs.get('path', '')
+ http_method = kwargs.get('http_method', '')
+ def new_decorator(func):
+ @endpoints.method(request_message, response_message, **kwargs)
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ start_time = time_fn()
+ approximate_http_status = 200
+ request = args[0]
+ ret = None
+ c_id = None
+ c_email = None
+ mar = None
+ try:
+ if settings.read_only and http_method.lower() != 'get':
+ raise permissions.PermissionException(
+ 'This request is not allowed in read-only mode')
+ requester = endpoints.get_current_user()
+ logging.info('requester is %r', requester)
+ logging.info('args is %r', args)
+ logging.info('kwargs is %r', kwargs)
+ auth_client_ids, auth_emails = (
+ client_config_svc.GetClientConfigSvc().GetClientIDEmails())
+ if settings.local_mode:
+ auth_client_ids.append(endpoints.API_EXPLORER_CLIENT_ID)
+ if self._services is None:
+ self._set_services(service_manager.set_up_services())
+ cnxn = sql.MonorailConnection()
+ c_id, c_email = api_base_checks(
+ request, requester, self._services, cnxn,
+ auth_client_ids, auth_emails)
+ mar = self.mar_factory(request, cnxn)
+ self.ratelimiter.CheckStart(c_id, c_email, start_time)
+ monitoring.IncrementAPIRequestsCount(
+ 'endpoints', c_id, client_email=c_email)
+ ret = func(self, mar, *args, **kwargs)
+ except exceptions.NoSuchUserException as e:
+ approximate_http_status = 404
+ raise endpoints.NotFoundException(
+ 'The user does not exist: %s' % str(e))
+ except (exceptions.NoSuchProjectException,
+ exceptions.NoSuchIssueException,
+ exceptions.NoSuchComponentException) as e:
+ approximate_http_status = 404
+ raise endpoints.NotFoundException(str(e))
+ except (permissions.BannedUserException,
+ permissions.PermissionException) as e:
+ approximate_http_status = 403
+ logging.info('Allowlist ID %r email %r', auth_client_ids, auth_emails)
+ raise endpoints.ForbiddenException(str(e))
+ except endpoints.BadRequestException:
+ approximate_http_status = 400
+ raise
+ except endpoints.UnauthorizedException:
+ approximate_http_status = 401
+ # Client will refresh token and retry.
+ raise
+ except oauth.InvalidOAuthTokenError:
+ approximate_http_status = 401
+ # Client will refresh token and retry.
+ raise endpoints.UnauthorizedException(
+ 'Auth error: InvalidOAuthTokenError')
+ except (exceptions.GroupExistsException,
+ exceptions.InvalidComponentNameException,
+ ratelimiter.ApiRateLimitExceeded) as e:
+ approximate_http_status = 400
+ raise endpoints.BadRequestException(str(e))
+ except Exception as e:
+ approximate_http_status = 500
+ logging.exception('Unexpected error in monorail API')
+ raise
+ finally:
+ if mar:
+ mar.CleanUp()
+ now = time_fn()
+ if c_id and c_email:
+ self.ratelimiter.CheckEnd(c_id, c_email, now, start_time)
+ _RecordMonitoringStats(
+ start_time, request, ret, (method_name or func.__name__),
+ (method_path or func.__name__), approximate_http_status, now)
+
+ return ret
+
+ return wrapper
+ return new_decorator
+
+
+def _RecordMonitoringStats(
+ start_time,
+ request,
+ response,
+ method_name,
+ method_path,
+ approximate_http_status,
+ now=None):
+ now = now or time.time()
+ elapsed_ms = int((now - start_time) * 1000)
+ # Use the api name, not the request path, to prevent an explosion in
+ # possible field values.
+ method_identifier = (
+ ENDPOINTS_API_NAME + '.endpoints.' + method_name + '/' + method_path)
+
+ # Endpoints APIs don't return the full set of http status values.
+ fields = monitoring.GetCommonFields(
+ approximate_http_status, method_identifier)
+
+ monitoring.AddServerDurations(elapsed_ms, fields)
+ monitoring.IncrementServerResponseStatusCount(fields)
+ request_length = len(protojson.encode_message(request))
+ monitoring.AddServerRequesteBytes(request_length, fields)
+ response_length = 0
+ if response:
+ response_length = len(protojson.encode_message(response))
+ monitoring.AddServerResponseBytes(response_length, fields)
+
+
+def _is_requester_in_allowed_domains(requester):
+ if requester.email().endswith(settings.api_allowed_email_domains):
+ if framework_constants.MONORAIL_SCOPE in oauth.get_authorized_scopes(
+ framework_constants.MONORAIL_SCOPE):
+ return True
+ else:
+ logging.info("User is not authenticated with monorail scope")
+ return False
+
+def api_base_checks(request, requester, services, cnxn,
+ auth_client_ids, auth_emails):
+ """Base checks for API users.
+
+ Args:
+ request: The HTTP request from Cloud Endpoints.
+ requester: The user who sends the request.
+ services: Services object.
+ cnxn: connection to the SQL database.
+ auth_client_ids: authorized client ids.
+ auth_emails: authorized emails when client is anonymous.
+
+ Returns:
+ Client ID and client email.
+
+ Raises:
+ endpoints.UnauthorizedException: If the requester is anonymous.
+ exceptions.NoSuchUserException: If the requester does not exist in Monorail.
+ NoSuchProjectException: If the project does not exist in Monorail.
+ permissions.BannedUserException: If the requester is banned.
+ permissions.PermissionException: If the requester does not have
+ permisssion to view.
+ """
+ valid_user = False
+ auth_err = ''
+ client_id = None
+
+ try:
+ client_id = oauth.get_client_id(framework_constants.OAUTH_SCOPE)
+ logging.info('Oauth client ID %s', client_id)
+ except oauth.Error as ex:
+ auth_err = 'oauth.Error: %s' % ex
+
+ if not requester:
+ try:
+ requester = oauth.get_current_user(framework_constants.OAUTH_SCOPE)
+ logging.info('Oauth requester %s', requester.email())
+ except oauth.Error as ex:
+ logging.info('Got oauth error: %r', ex)
+ auth_err = 'oauth.Error: %s' % ex
+
+ if client_id and requester:
+ if client_id in auth_client_ids:
+ # A allowlisted client app can make requests for any user or anon.
+ logging.info('Client ID %r is allowlisted', client_id)
+ valid_user = True
+ elif requester.email() in auth_emails:
+ # A allowlisted user account can make requests via any client app.
+ logging.info('Client email %r is allowlisted', requester.email())
+ valid_user = True
+ elif _is_requester_in_allowed_domains(requester):
+ # A user with an allowed-domain email and authenticated with the
+ # monorail scope is allowed to make requests via any client app.
+ logging.info(
+ 'User email %r is within the allowed domains', requester.email())
+ valid_user = True
+ else:
+ auth_err = (
+ 'Neither client ID %r nor email %r is allowlisted' %
+ (client_id, requester.email()))
+
+ if not valid_user:
+ raise endpoints.UnauthorizedException('Auth error: %s' % auth_err)
+ else:
+ logging.info('API request from user %s:%s', client_id, requester.email())
+
+ project_name = None
+ if hasattr(request, 'projectId'):
+ project_name = request.projectId
+ issue_local_id = None
+ if hasattr(request, 'issueId'):
+ issue_local_id = request.issueId
+ # This could raise exceptions.NoSuchUserException
+ requester_id = services.user.LookupUserID(cnxn, requester.email())
+ auth = authdata.AuthData.FromUserID(cnxn, requester_id, services)
+ if permissions.IsBanned(auth.user_pb, auth.user_view):
+ raise permissions.BannedUserException(
+ 'The user %s has been banned from using Monorail' %
+ requester.email())
+ if project_name:
+ project = services.project.GetProjectByName(
+ cnxn, project_name)
+ if not project:
+ raise exceptions.NoSuchProjectException(
+ 'Project %s does not exist' % project_name)
+ if project.state != project_pb2.ProjectState.LIVE:
+ raise permissions.PermissionException(
+ 'API may not access project %s because it is not live'
+ % project_name)
+ if not permissions.UserCanViewProject(
+ auth.user_pb, auth.effective_ids, project):
+ raise permissions.PermissionException(
+ 'The user %s has no permission for project %s' %
+ (requester.email(), project_name))
+ if issue_local_id:
+ # This may raise a NoSuchIssueException.
+ issue = services.issue.GetIssueByLocalID(
+ cnxn, project.project_id, issue_local_id)
+ perms = permissions.GetPermissions(
+ auth.user_pb, auth.effective_ids, project)
+ config = services.config.GetProjectConfig(cnxn, project.project_id)
+ granted_perms = tracker_bizobj.GetGrantedPerms(
+ issue, auth.effective_ids, config)
+ if not permissions.CanViewIssue(
+ auth.effective_ids, perms, project, issue,
+ granted_perms=granted_perms):
+ raise permissions.PermissionException(
+ 'User is not allowed to view this issue %s:%d' %
+ (project_name, issue_local_id))
+
+ return client_id, requester.email()
+
+
+@endpoints.api(name=ENDPOINTS_API_NAME, version='v1',
+ description='Monorail API to manage issues.',
+ auth_level=endpoints.AUTH_LEVEL.NONE,
+ allowed_client_ids=endpoints.SKIP_CLIENT_ID_CHECK,
+ documentation=DOC_URL)
+class MonorailApi(remote.Service):
+
+ # Class variables. Handy to mock.
+ _services = None
+ _mar = None
+
+ ratelimiter = ratelimiter.ApiRateLimiter()
+
+ @classmethod
+ def _set_services(cls, services):
+ cls._services = services
+
+ def mar_factory(self, request, cnxn):
+ if not self._mar:
+ self._mar = monorailrequest.MonorailApiRequest(
+ request, self._services, cnxn=cnxn)
+ return self._mar
+
+ def aux_delete_comment(self, mar, request, delete=True):
+ action_name = 'delete' if delete else 'undelete'
+
+ with work_env.WorkEnv(mar, self._services) as we:
+ issue = we.GetIssueByLocalID(
+ mar.project_id, request.issueId, use_cache=False)
+ all_comments = we.ListIssueComments(issue)
+ try:
+ issue_comment = all_comments[request.commentId]
+ except IndexError:
+ raise exceptions.NoSuchIssueException(
+ 'The issue %s:%d does not have comment %d.' %
+ (mar.project_name, request.issueId, request.commentId))
+
+ issue_perms = permissions.UpdateIssuePermissions(
+ mar.perms, mar.project, issue, mar.auth.effective_ids,
+ granted_perms=mar.granted_perms)
+ commenter = we.GetUser(issue_comment.user_id)
+
+ if not permissions.CanDeleteComment(
+ issue_comment, commenter, mar.auth.user_id, issue_perms):
+ raise permissions.PermissionException(
+ 'User is not allowed to %s the comment %d of issue %s:%d' %
+ (action_name, request.commentId, mar.project_name,
+ request.issueId))
+
+ we.DeleteComment(issue, issue_comment, delete=delete)
+ return api_pb2_v1.IssuesCommentsDeleteResponse()
+
+ @monorail_api_method(
+ api_pb2_v1.ISSUES_COMMENTS_DELETE_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.IssuesCommentsDeleteResponse,
+ path='projects/{projectId}/issues/{issueId}/comments/{commentId}',
+ http_method='DELETE',
+ name='issues.comments.delete')
+ def issues_comments_delete(self, mar, request):
+ """Delete a comment."""
+ return self.aux_delete_comment(mar, request, True)
+
+ def parse_imported_reporter(self, mar, request):
+ """Handle the case where an API client is importing issues for users.
+
+ Args:
+ mar: monorail API request object including auth and perms.
+ request: A request PB that defines author and published fields.
+
+ Returns:
+ A pair (reporter_id, timestamp) with the user ID of the user to
+ attribute the comment to and timestamp of the original comment.
+ If the author field is not set, this is not an import request
+ and the comment is attributed to the API client as per normal.
+ An API client that is attempting to post on behalf of other
+ users must have the ImportComment permission in the current
+ project.
+ """
+ reporter_id = mar.auth.user_id
+ timestamp = None
+ if (request.author and request.author.name and
+ request.author.name != mar.auth.email):
+ if not mar.perms.HasPerm(
+ permissions.IMPORT_COMMENT, mar.auth.user_id, mar.project):
+ logging.info('name is %r', request.author.name)
+ raise permissions.PermissionException(
+ 'User is not allowed to attribue comments to others')
+ reporter_id = self._services.user.LookupUserID(
+ mar.cnxn, request.author.name, autocreate=True)
+ logging.info('Importing issue or comment.')
+ if request.published:
+ timestamp = calendar.timegm(request.published.utctimetuple())
+
+ return reporter_id, timestamp
+
+ @monorail_api_method(
+ api_pb2_v1.ISSUES_COMMENTS_INSERT_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.IssuesCommentsInsertResponse,
+ path='projects/{projectId}/issues/{issueId}/comments',
+ http_method='POST',
+ name='issues.comments.insert')
+ def issues_comments_insert(self, mar, request):
+ # type (...) -> proto.api_pb2_v1.IssuesCommentsInsertResponse
+ """Add a comment."""
+ # Because we will modify issues, load from DB rather than cache.
+ issue = self._services.issue.GetIssueByLocalID(
+ mar.cnxn, mar.project_id, request.issueId, use_cache=False)
+ old_owner_id = tracker_bizobj.GetOwnerId(issue)
+ if not permissions.CanCommentIssue(
+ mar.auth.effective_ids, mar.perms, mar.project, issue,
+ mar.granted_perms):
+ raise permissions.PermissionException(
+ 'User is not allowed to comment this issue (%s, %d)' %
+ (request.projectId, request.issueId))
+
+ # Temporary block on updating approval subfields.
+ if request.updates and request.updates.fieldValues:
+ fds_by_name = {fd.field_name.lower():fd for fd in mar.config.field_defs}
+ for fv in request.updates.fieldValues:
+ # Checking for fv.approvalName is unreliable since it can be removed.
+ fd = fds_by_name.get(fv.fieldName.lower())
+ if fd and fd.approval_id:
+ raise exceptions.ActionNotSupported(
+ 'No API support for approval field changes: (approval %s owns %s)'
+ % (fd.approval_id, fd.field_name))
+ # if fd was None, that gets dealt with later.
+
+ if request.content and len(
+ request.content) > tracker_constants.MAX_COMMENT_CHARS:
+ raise endpoints.BadRequestException(
+ 'Comment is too long on this issue (%s, %d' %
+ (request.projectId, request.issueId))
+
+ updates_dict = {}
+ move_to_project = None
+ if request.updates:
+ if not permissions.CanEditIssue(
+ mar.auth.effective_ids, mar.perms, mar.project, issue,
+ mar.granted_perms):
+ raise permissions.PermissionException(
+ 'User is not allowed to edit this issue (%s, %d)' %
+ (request.projectId, request.issueId))
+ if request.updates.moveToProject:
+ move_to = request.updates.moveToProject.lower()
+ move_to_project = issuedetailezt.CheckMoveIssueRequest(
+ self._services, mar, issue, True, move_to, mar.errors)
+ if mar.errors.AnyErrors():
+ raise endpoints.BadRequestException(mar.errors.move_to)
+
+ updates_dict['summary'] = request.updates.summary
+ updates_dict['status'] = request.updates.status
+ updates_dict['is_description'] = request.updates.is_description
+ if request.updates.owner:
+ # A current issue owner can be removed via the API with a
+ # NO_USER_NAME('----') input.
+ if request.updates.owner == framework_constants.NO_USER_NAME:
+ updates_dict['owner'] = framework_constants.NO_USER_SPECIFIED
+ else:
+ new_owner_id = self._services.user.LookupUserID(
+ mar.cnxn, request.updates.owner)
+ valid, msg = tracker_helpers.IsValidIssueOwner(
+ mar.cnxn, mar.project, new_owner_id, self._services)
+ if not valid:
+ raise endpoints.BadRequestException(msg)
+ updates_dict['owner'] = new_owner_id
+ updates_dict['cc_add'], updates_dict['cc_remove'] = (
+ api_pb2_v1_helpers.split_remove_add(request.updates.cc))
+ updates_dict['cc_add'] = list(self._services.user.LookupUserIDs(
+ mar.cnxn, updates_dict['cc_add'], autocreate=True).values())
+ updates_dict['cc_remove'] = list(self._services.user.LookupUserIDs(
+ mar.cnxn, updates_dict['cc_remove']).values())
+ updates_dict['labels_add'], updates_dict['labels_remove'] = (
+ api_pb2_v1_helpers.split_remove_add(request.updates.labels))
+ blocked_on_add_strs, blocked_on_remove_strs = (
+ api_pb2_v1_helpers.split_remove_add(request.updates.blockedOn))
+ updates_dict['blocked_on_add'] = api_pb2_v1_helpers.issue_global_ids(
+ blocked_on_add_strs, issue.project_id, mar,
+ self._services)
+ updates_dict['blocked_on_remove'] = api_pb2_v1_helpers.issue_global_ids(
+ blocked_on_remove_strs, issue.project_id, mar,
+ self._services)
+ blocking_add_strs, blocking_remove_strs = (
+ api_pb2_v1_helpers.split_remove_add(request.updates.blocking))
+ updates_dict['blocking_add'] = api_pb2_v1_helpers.issue_global_ids(
+ blocking_add_strs, issue.project_id, mar,
+ self._services)
+ updates_dict['blocking_remove'] = api_pb2_v1_helpers.issue_global_ids(
+ blocking_remove_strs, issue.project_id, mar,
+ self._services)
+ components_add_strs, components_remove_strs = (
+ api_pb2_v1_helpers.split_remove_add(request.updates.components))
+ updates_dict['components_add'] = (
+ api_pb2_v1_helpers.convert_component_ids(
+ mar.config, components_add_strs))
+ updates_dict['components_remove'] = (
+ api_pb2_v1_helpers.convert_component_ids(
+ mar.config, components_remove_strs))
+ if request.updates.mergedInto:
+ merge_project_name, merge_local_id = tracker_bizobj.ParseIssueRef(
+ request.updates.mergedInto)
+ merge_into_project = self._services.project.GetProjectByName(
+ mar.cnxn, merge_project_name or issue.project_name)
+ # Because we will modify issues, load from DB rather than cache.
+ merge_into_issue = self._services.issue.GetIssueByLocalID(
+ mar.cnxn, merge_into_project.project_id, merge_local_id,
+ use_cache=False)
+ merge_allowed = tracker_helpers.IsMergeAllowed(
+ merge_into_issue, mar, self._services)
+ if not merge_allowed:
+ raise permissions.PermissionException(
+ 'User is not allowed to merge into issue %s:%s' %
+ (merge_into_issue.project_name, merge_into_issue.local_id))
+ updates_dict['merged_into'] = merge_into_issue.issue_id
+ (updates_dict['field_vals_add'], updates_dict['field_vals_remove'],
+ updates_dict['fields_clear'], updates_dict['fields_labels_add'],
+ updates_dict['fields_labels_remove']) = (
+ api_pb2_v1_helpers.convert_field_values(
+ request.updates.fieldValues, mar, self._services))
+
+ field_helpers.ValidateCustomFields(
+ mar.cnxn, self._services,
+ (updates_dict.get('field_vals_add', []) +
+ updates_dict.get('field_vals_remove', [])),
+ mar.config, mar.project, ezt_errors=mar.errors)
+ if mar.errors.AnyErrors():
+ raise endpoints.BadRequestException(
+ 'Invalid field values: %s' % mar.errors.custom_fields)
+
+ updates_dict['labels_add'] = (
+ updates_dict.get('labels_add', []) +
+ updates_dict.get('fields_labels_add', []))
+ updates_dict['labels_remove'] = (
+ updates_dict.get('labels_remove', []) +
+ updates_dict.get('fields_labels_remove', []))
+
+ # TODO(jrobbins): Stop using updates_dict in the first place.
+ delta = tracker_bizobj.MakeIssueDelta(
+ updates_dict.get('status'),
+ updates_dict.get('owner'),
+ updates_dict.get('cc_add', []),
+ updates_dict.get('cc_remove', []),
+ updates_dict.get('components_add', []),
+ updates_dict.get('components_remove', []),
+ (updates_dict.get('labels_add', []) +
+ updates_dict.get('fields_labels_add', [])),
+ (updates_dict.get('labels_remove', []) +
+ updates_dict.get('fields_labels_remove', [])),
+ updates_dict.get('field_vals_add', []),
+ updates_dict.get('field_vals_remove', []),
+ updates_dict.get('fields_clear', []),
+ updates_dict.get('blocked_on_add', []),
+ updates_dict.get('blocked_on_remove', []),
+ updates_dict.get('blocking_add', []),
+ updates_dict.get('blocking_remove', []),
+ updates_dict.get('merged_into'),
+ updates_dict.get('summary'))
+
+ importer_id = None
+ reporter_id, timestamp = self.parse_imported_reporter(mar, request)
+ if reporter_id != mar.auth.user_id:
+ importer_id = mar.auth.user_id
+
+ # TODO(jrobbins): Finish refactoring to make everything go through work_env.
+ _, comment = self._services.issue.DeltaUpdateIssue(
+ cnxn=mar.cnxn, services=self._services,
+ reporter_id=reporter_id, project_id=mar.project_id, config=mar.config,
+ issue=issue, delta=delta, index_now=False, comment=request.content,
+ is_description=updates_dict.get('is_description'),
+ timestamp=timestamp, importer_id=importer_id)
+
+ move_comment = None
+ if move_to_project:
+ old_text_ref = 'issue %s:%s' % (issue.project_name, issue.local_id)
+ tracker_fulltext.UnindexIssues([issue.issue_id])
+ moved_back_iids = self._services.issue.MoveIssues(
+ mar.cnxn, move_to_project, [issue], self._services.user)
+ new_text_ref = 'issue %s:%s' % (issue.project_name, issue.local_id)
+ if issue.issue_id in moved_back_iids:
+ content = 'Moved %s back to %s again.' % (old_text_ref, new_text_ref)
+ else:
+ content = 'Moved %s to now be %s.' % (old_text_ref, new_text_ref)
+ move_comment = self._services.issue.CreateIssueComment(
+ mar.cnxn, issue, mar.auth.user_id, content, amendments=[
+ tracker_bizobj.MakeProjectAmendment(move_to_project.project_name)])
+
+ if 'merged_into' in updates_dict:
+ new_starrers = tracker_helpers.GetNewIssueStarrers(
+ mar.cnxn, self._services, [issue.issue_id], merge_into_issue.issue_id)
+ tracker_helpers.AddIssueStarrers(
+ mar.cnxn, self._services, mar,
+ merge_into_issue.issue_id, merge_into_project, new_starrers)
+ # Load target issue again to get the updated star count.
+ merge_into_issue = self._services.issue.GetIssue(
+ mar.cnxn, merge_into_issue.issue_id, use_cache=False)
+ merge_comment_pb = tracker_helpers.MergeCCsAndAddComment(
+ self._services, mar, issue, merge_into_issue)
+ hostport = framework_helpers.GetHostPort(
+ project_name=merge_into_issue.project_name)
+ send_notifications.PrepareAndSendIssueChangeNotification(
+ merge_into_issue.issue_id, hostport,
+ mar.auth.user_id, send_email=True, comment_id=merge_comment_pb.id)
+
+ tracker_fulltext.IndexIssues(
+ mar.cnxn, [issue], self._services.user, self._services.issue,
+ self._services.config)
+
+ comment = comment or move_comment
+ if comment is None:
+ return api_pb2_v1.IssuesCommentsInsertResponse()
+
+ cmnts = self._services.issue.GetCommentsForIssue(mar.cnxn, issue.issue_id)
+ seq = len(cmnts) - 1
+
+ if request.sendEmail:
+ hostport = framework_helpers.GetHostPort(project_name=issue.project_name)
+ send_notifications.PrepareAndSendIssueChangeNotification(
+ issue.issue_id, hostport, comment.user_id, send_email=True,
+ old_owner_id=old_owner_id, comment_id=comment.id)
+
+ issue_perms = permissions.UpdateIssuePermissions(
+ mar.perms, mar.project, issue, mar.auth.effective_ids,
+ granted_perms=mar.granted_perms)
+ commenter = self._services.user.GetUser(mar.cnxn, comment.user_id)
+ can_delete = permissions.CanDeleteComment(
+ comment, commenter, mar.auth.user_id, issue_perms)
+ return api_pb2_v1.IssuesCommentsInsertResponse(
+ id=seq,
+ kind='monorail#issueComment',
+ author=api_pb2_v1_helpers.convert_person(
+ comment.user_id, mar.cnxn, self._services),
+ content=comment.content,
+ published=datetime.datetime.fromtimestamp(comment.timestamp),
+ updates=api_pb2_v1_helpers.convert_amendments(
+ issue, comment.amendments, mar, self._services),
+ canDelete=can_delete)
+
+ @monorail_api_method(
+ api_pb2_v1.ISSUES_COMMENTS_LIST_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.IssuesCommentsListResponse,
+ path='projects/{projectId}/issues/{issueId}/comments',
+ http_method='GET',
+ name='issues.comments.list')
+ def issues_comments_list(self, mar, request):
+ """List all comments for an issue."""
+ issue = self._services.issue.GetIssueByLocalID(
+ mar.cnxn, mar.project_id, request.issueId)
+ comments = self._services.issue.GetCommentsForIssue(
+ mar.cnxn, issue.issue_id)
+ comments = [comment for comment in comments if not comment.approval_id]
+ visible_comments = []
+ for comment in comments[
+ request.startIndex:(request.startIndex + request.maxResults)]:
+ visible_comments.append(
+ api_pb2_v1_helpers.convert_comment(
+ issue, comment, mar, self._services, mar.granted_perms))
+
+ return api_pb2_v1.IssuesCommentsListResponse(
+ kind='monorail#issueCommentList',
+ totalResults=len(comments),
+ items=visible_comments)
+
+ @monorail_api_method(
+ api_pb2_v1.ISSUES_COMMENTS_DELETE_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.IssuesCommentsDeleteResponse,
+ path='projects/{projectId}/issues/{issueId}/comments/{commentId}',
+ http_method='POST',
+ name='issues.comments.undelete')
+ def issues_comments_undelete(self, mar, request):
+ """Restore a deleted comment."""
+ return self.aux_delete_comment(mar, request, False)
+
+ @monorail_api_method(
+ api_pb2_v1.APPROVALS_COMMENTS_LIST_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.ApprovalsCommentsListResponse,
+ path='projects/{projectId}/issues/{issueId}/'
+ 'approvals/{approvalName}/comments',
+ http_method='GET',
+ name='approvals.comments.list')
+ def approvals_comments_list(self, mar, request):
+ """List all comments for an issue approval."""
+ issue = self._services.issue.GetIssueByLocalID(
+ mar.cnxn, mar.project_id, request.issueId)
+ if not permissions.CanViewIssue(
+ mar.auth.effective_ids, mar.perms, mar.project, issue,
+ mar.granted_perms):
+ raise permissions.PermissionException(
+ 'User is not allowed to view this issue (%s, %d)' %
+ (request.projectId, request.issueId))
+ config = self._services.config.GetProjectConfig(mar.cnxn, issue.project_id)
+ approval_fd = tracker_bizobj.FindFieldDef(request.approvalName, config)
+ if not approval_fd:
+ raise endpoints.BadRequestException(
+ 'Field definition for %s not found in project config' %
+ request.approvalName)
+ comments = self._services.issue.GetCommentsForIssue(
+ mar.cnxn, issue.issue_id)
+ comments = [comment for comment in comments
+ if comment.approval_id == approval_fd.field_id]
+ visible_comments = []
+ for comment in comments[
+ request.startIndex:(request.startIndex + request.maxResults)]:
+ visible_comments.append(
+ api_pb2_v1_helpers.convert_approval_comment(
+ issue, comment, mar, self._services, mar.granted_perms))
+
+ return api_pb2_v1.ApprovalsCommentsListResponse(
+ kind='monorail#approvalCommentList',
+ totalResults=len(comments),
+ items=visible_comments)
+
+ @monorail_api_method(
+ api_pb2_v1.APPROVALS_COMMENTS_INSERT_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.ApprovalsCommentsInsertResponse,
+ path=("projects/{projectId}/issues/{issueId}/"
+ "approvals/{approvalName}/comments"),
+ http_method='POST',
+ name='approvals.comments.insert')
+ def approvals_comments_insert(self, mar, request):
+ # type (...) -> proto.api_pb2_v1.ApprovalsCommentsInsertResponse
+ """Add an approval comment."""
+ approval_fd = tracker_bizobj.FindFieldDef(
+ request.approvalName, mar.config)
+ if not approval_fd or (
+ approval_fd.field_type != tracker_pb2.FieldTypes.APPROVAL_TYPE):
+ raise endpoints.BadRequestException(
+ 'Field definition for %s not found in project config' %
+ request.approvalName)
+ try:
+ issue = self._services.issue.GetIssueByLocalID(
+ mar.cnxn, mar.project_id, request.issueId)
+ except exceptions.NoSuchIssueException:
+ raise endpoints.BadRequestException(
+ 'Issue %s:%s not found' % (request.projectId, request.issueId))
+ approval = tracker_bizobj.FindApprovalValueByID(
+ approval_fd.field_id, issue.approval_values)
+ if not approval:
+ raise endpoints.BadRequestException(
+ 'Approval %s not found in issue.' % request.approvalName)
+
+ if not permissions.CanCommentIssue(
+ mar.auth.effective_ids, mar.perms, mar.project, issue,
+ mar.granted_perms):
+ raise permissions.PermissionException(
+ 'User is not allowed to comment on this issue (%s, %d)' %
+ (request.projectId, request.issueId))
+
+ if request.content and len(
+ request.content) > tracker_constants.MAX_COMMENT_CHARS:
+ raise endpoints.BadRequestException(
+ 'Comment is too long on this issue (%s, %d' %
+ (request.projectId, request.issueId))
+
+ updates_dict = {}
+ if request.approvalUpdates:
+ if request.approvalUpdates.fieldValues:
+ # Block updating field values that don't belong to the approval.
+ approvals_fds_by_name = {
+ fd.field_name.lower():fd for fd in mar.config.field_defs
+ if fd.approval_id == approval_fd.field_id}
+ for fv in request.approvalUpdates.fieldValues:
+ if approvals_fds_by_name.get(fv.fieldName.lower()) is None:
+ raise endpoints.BadRequestException(
+ 'Field defition for %s not found in %s subfields.' %
+ (fv.fieldName, request.approvalName))
+ (updates_dict['field_vals_add'], updates_dict['field_vals_remove'],
+ updates_dict['fields_clear'], updates_dict['fields_labels_add'],
+ updates_dict['fields_labels_remove']) = (
+ api_pb2_v1_helpers.convert_field_values(
+ request.approvalUpdates.fieldValues, mar, self._services))
+ if request.approvalUpdates.approvers:
+ if not permissions.CanUpdateApprovers(
+ mar.auth.effective_ids, mar.perms, mar.project,
+ approval.approver_ids):
+ raise permissions.PermissionException(
+ 'User is not allowed to update approvers')
+ approvers_add, approvers_remove = api_pb2_v1_helpers.split_remove_add(
+ request.approvalUpdates.approvers)
+ updates_dict['approver_ids_add'] = list(
+ self._services.user.LookupUserIDs(mar.cnxn, approvers_add,
+ autocreate=True).values())
+ updates_dict['approver_ids_remove'] = list(
+ self._services.user.LookupUserIDs(mar.cnxn, approvers_remove,
+ autocreate=True).values())
+ if request.approvalUpdates.status:
+ status = tracker_pb2.ApprovalStatus(
+ api_pb2_v1.ApprovalStatus(request.approvalUpdates.status).number)
+ if not permissions.CanUpdateApprovalStatus(
+ mar.auth.effective_ids, mar.perms, mar.project,
+ approval.approver_ids, status):
+ raise permissions.PermissionException(
+ 'User is not allowed to make this status change')
+ updates_dict['status'] = status
+ logging.info(time.time)
+ approval_delta = tracker_bizobj.MakeApprovalDelta(
+ updates_dict.get('status'), mar.auth.user_id,
+ updates_dict.get('approver_ids_add', []),
+ updates_dict.get('approver_ids_remove', []),
+ updates_dict.get('field_vals_add', []),
+ updates_dict.get('field_vals_remove', []),
+ updates_dict.get('fields_clear', []),
+ updates_dict.get('fields_labels_add', []),
+ updates_dict.get('fields_labels_remove', []))
+ comment = self._services.issue.DeltaUpdateIssueApproval(
+ mar.cnxn, mar.auth.user_id, mar.config, issue, approval, approval_delta,
+ comment_content=request.content,
+ is_description=request.is_description)
+
+ cmnts = self._services.issue.GetCommentsForIssue(mar.cnxn, issue.issue_id)
+ seq = len(cmnts) - 1
+
+ if request.sendEmail:
+ hostport = framework_helpers.GetHostPort(project_name=issue.project_name)
+ send_notifications.PrepareAndSendApprovalChangeNotification(
+ issue.issue_id, approval.approval_id,
+ hostport, comment.id, send_email=True)
+
+ issue_perms = permissions.UpdateIssuePermissions(
+ mar.perms, mar.project, issue, mar.auth.effective_ids,
+ granted_perms=mar.granted_perms)
+ commenter = self._services.user.GetUser(mar.cnxn, comment.user_id)
+ can_delete = permissions.CanDeleteComment(
+ comment, commenter, mar.auth.user_id, issue_perms)
+ return api_pb2_v1.ApprovalsCommentsInsertResponse(
+ id=seq,
+ kind='monorail#approvalComment',
+ author=api_pb2_v1_helpers.convert_person(
+ comment.user_id, mar.cnxn, self._services),
+ content=comment.content,
+ published=datetime.datetime.fromtimestamp(comment.timestamp),
+ approvalUpdates=api_pb2_v1_helpers.convert_approval_amendments(
+ comment.amendments, mar, self._services),
+ canDelete=can_delete)
+
+ @monorail_api_method(
+ api_pb2_v1.USERS_GET_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.UsersGetResponse,
+ path='users/{userId}',
+ http_method='GET',
+ name='users.get')
+ def users_get(self, mar, request):
+ """Get a user."""
+ owner_project_only = request.ownerProjectsOnly
+ with work_env.WorkEnv(mar, self._services) as we:
+ (visible_ownership, visible_deleted, visible_membership,
+ visible_contrib) = we.GetUserProjects(
+ mar.viewed_user_auth.effective_ids)
+
+ project_list = []
+ for proj in (visible_ownership + visible_deleted):
+ config = self._services.config.GetProjectConfig(
+ mar.cnxn, proj.project_id)
+ templates = self._services.template.GetProjectTemplates(
+ mar.cnxn, config.project_id)
+ proj_result = api_pb2_v1_helpers.convert_project(
+ proj, config, api_pb2_v1.Role.owner, templates)
+ project_list.append(proj_result)
+ if not owner_project_only:
+ for proj in visible_membership:
+ config = self._services.config.GetProjectConfig(
+ mar.cnxn, proj.project_id)
+ templates = self._services.template.GetProjectTemplates(
+ mar.cnxn, config.project_id)
+ proj_result = api_pb2_v1_helpers.convert_project(
+ proj, config, api_pb2_v1.Role.member, templates)
+ project_list.append(proj_result)
+ for proj in visible_contrib:
+ config = self._services.config.GetProjectConfig(
+ mar.cnxn, proj.project_id)
+ templates = self._services.template.GetProjectTemplates(
+ mar.cnxn, config.project_id)
+ proj_result = api_pb2_v1_helpers.convert_project(
+ proj, config, api_pb2_v1.Role.contributor, templates)
+ project_list.append(proj_result)
+
+ return api_pb2_v1.UsersGetResponse(
+ id=str(mar.viewed_user_auth.user_id),
+ kind='monorail#user',
+ projects=project_list,
+ )
+
+ @monorail_api_method(
+ api_pb2_v1.ISSUES_GET_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.IssuesGetInsertResponse,
+ path='projects/{projectId}/issues/{issueId}',
+ http_method='GET',
+ name='issues.get')
+ def issues_get(self, mar, request):
+ """Get an issue."""
+ issue = self._services.issue.GetIssueByLocalID(
+ mar.cnxn, mar.project_id, request.issueId)
+
+ return api_pb2_v1_helpers.convert_issue(
+ api_pb2_v1.IssuesGetInsertResponse, issue, mar, self._services)
+
+ @monorail_api_method(
+ api_pb2_v1.ISSUES_INSERT_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.IssuesGetInsertResponse,
+ path='projects/{projectId}/issues',
+ http_method='POST',
+ name='issues.insert')
+ def issues_insert(self, mar, request):
+ """Add a new issue."""
+ if not mar.perms.CanUsePerm(
+ permissions.CREATE_ISSUE, mar.auth.effective_ids, mar.project, []):
+ raise permissions.PermissionException(
+ 'The requester %s is not allowed to create issues for project %s.' %
+ (mar.auth.email, mar.project_name))
+
+ with work_env.WorkEnv(mar, self._services) as we:
+ owner_id = framework_constants.NO_USER_SPECIFIED
+ if request.owner and request.owner.name:
+ try:
+ owner_id = self._services.user.LookupUserID(
+ mar.cnxn, request.owner.name)
+ except exceptions.NoSuchUserException:
+ raise endpoints.BadRequestException(
+ 'The specified owner %s does not exist.' % request.owner.name)
+
+ cc_ids = []
+ request.cc = [cc for cc in request.cc if cc]
+ if request.cc:
+ cc_ids = list(self._services.user.LookupUserIDs(
+ mar.cnxn, [ap.name for ap in request.cc],
+ autocreate=True).values())
+ comp_ids = api_pb2_v1_helpers.convert_component_ids(
+ mar.config, request.components)
+ fields_add, _, _, fields_labels, _ = (
+ api_pb2_v1_helpers.convert_field_values(
+ request.fieldValues, mar, self._services))
+ field_helpers.ValidateCustomFields(
+ mar.cnxn, self._services, fields_add, mar.config, mar.project,
+ ezt_errors=mar.errors)
+ if mar.errors.AnyErrors():
+ raise endpoints.BadRequestException(
+ 'Invalid field values: %s' % mar.errors.custom_fields)
+
+ logging.info('request.author is %r', request.author)
+ reporter_id, timestamp = self.parse_imported_reporter(mar, request)
+ # To preserve previous behavior, do not raise filter rule errors.
+ try:
+ new_issue, _ = we.CreateIssue(
+ mar.project_id,
+ request.summary,
+ request.status,
+ owner_id,
+ cc_ids,
+ request.labels + fields_labels,
+ fields_add,
+ comp_ids,
+ request.description,
+ blocked_on=api_pb2_v1_helpers.convert_issueref_pbs(
+ request.blockedOn, mar, self._services),
+ blocking=api_pb2_v1_helpers.convert_issueref_pbs(
+ request.blocking, mar, self._services),
+ reporter_id=reporter_id,
+ timestamp=timestamp,
+ send_email=request.sendEmail,
+ raise_filter_errors=False)
+ we.StarIssue(new_issue, True)
+ except exceptions.InputException as e:
+ raise endpoints.BadRequestException(str(e))
+
+ return api_pb2_v1_helpers.convert_issue(
+ api_pb2_v1.IssuesGetInsertResponse, new_issue, mar, self._services)
+
+ @monorail_api_method(
+ api_pb2_v1.ISSUES_LIST_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.IssuesListResponse,
+ path='projects/{projectId}/issues',
+ http_method='GET',
+ name='issues.list')
+ def issues_list(self, mar, request):
+ """List issues for projects."""
+ if request.additionalProject:
+ for project_name in request.additionalProject:
+ project = self._services.project.GetProjectByName(
+ mar.cnxn, project_name)
+ if project and not permissions.UserCanViewProject(
+ mar.auth.user_pb, mar.auth.effective_ids, project):
+ raise permissions.PermissionException(
+ 'The user %s has no permission for project %s' %
+ (mar.auth.email, project_name))
+ # TODO(jrobbins): This should go through work_env.
+ pipeline = frontendsearchpipeline.FrontendSearchPipeline(
+ mar.cnxn,
+ self._services,
+ mar.auth, [mar.me_user_id],
+ mar.query,
+ mar.query_project_names,
+ mar.num,
+ mar.start,
+ mar.can,
+ mar.group_by_spec,
+ mar.sort_spec,
+ mar.warnings,
+ mar.errors,
+ mar.use_cached_searches,
+ mar.profiler,
+ project=mar.project)
+ if not mar.errors.AnyErrors():
+ pipeline.SearchForIIDs()
+ pipeline.MergeAndSortIssues()
+ pipeline.Paginate()
+ else:
+ raise endpoints.BadRequestException(mar.errors.query)
+
+ issue_list = [
+ api_pb2_v1_helpers.convert_issue(
+ api_pb2_v1.IssueWrapper, r, mar, self._services)
+ for r in pipeline.visible_results]
+ return api_pb2_v1.IssuesListResponse(
+ kind='monorail#issueList',
+ totalResults=pipeline.total_count,
+ items=issue_list)
+
+ @monorail_api_method(
+ api_pb2_v1.GROUPS_SETTINGS_LIST_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.GroupsSettingsListResponse,
+ path='groupsettings',
+ http_method='GET',
+ name='groups.settings.list')
+ def groups_settings_list(self, mar, request):
+ """List all group settings."""
+ all_groups = self._services.usergroup.GetAllUserGroupsInfo(mar.cnxn)
+ group_settings = []
+ for g in all_groups:
+ setting = g[2]
+ wrapper = api_pb2_v1_helpers.convert_group_settings(g[0], setting)
+ if not request.importedGroupsOnly or wrapper.ext_group_type:
+ group_settings.append(wrapper)
+ return api_pb2_v1.GroupsSettingsListResponse(
+ groupSettings=group_settings)
+
+ @monorail_api_method(
+ api_pb2_v1.GROUPS_CREATE_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.GroupsCreateResponse,
+ path='groups',
+ http_method='POST',
+ name='groups.create')
+ def groups_create(self, mar, request):
+ """Create a new user group."""
+ if not permissions.CanCreateGroup(mar.perms):
+ raise permissions.PermissionException(
+ 'The user is not allowed to create groups.')
+
+ user_dict = self._services.user.LookupExistingUserIDs(
+ mar.cnxn, [request.groupName])
+ if request.groupName.lower() in user_dict:
+ raise exceptions.GroupExistsException(
+ 'group %s already exists' % request.groupName)
+
+ if request.ext_group_type:
+ ext_group_type = str(request.ext_group_type).lower()
+ else:
+ ext_group_type = None
+ group_id = self._services.usergroup.CreateGroup(
+ mar.cnxn, self._services, request.groupName,
+ str(request.who_can_view_members).lower(),
+ ext_group_type)
+
+ return api_pb2_v1.GroupsCreateResponse(
+ groupID=group_id)
+
+ @monorail_api_method(
+ api_pb2_v1.GROUPS_GET_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.GroupsGetResponse,
+ path='groups/{groupName}',
+ http_method='GET',
+ name='groups.get')
+ def groups_get(self, mar, request):
+ """Get a group's settings and users."""
+ if not mar.viewed_user_auth:
+ raise exceptions.NoSuchUserException(request.groupName)
+ group_id = mar.viewed_user_auth.user_id
+ group_settings = self._services.usergroup.GetGroupSettings(
+ mar.cnxn, group_id)
+ member_ids, owner_ids = self._services.usergroup.LookupAllMembers(
+ mar.cnxn, [group_id])
+ (owned_project_ids, membered_project_ids,
+ contrib_project_ids) = self._services.project.GetUserRolesInAllProjects(
+ mar.cnxn, mar.auth.effective_ids)
+ project_ids = owned_project_ids.union(
+ membered_project_ids).union(contrib_project_ids)
+ if not permissions.CanViewGroupMembers(
+ mar.perms, mar.auth.effective_ids, group_settings, member_ids[group_id],
+ owner_ids[group_id], project_ids):
+ raise permissions.PermissionException(
+ 'The user is not allowed to view this group.')
+
+ member_ids, owner_ids = self._services.usergroup.LookupMembers(
+ mar.cnxn, [group_id])
+
+ member_emails = list(self._services.user.LookupUserEmails(
+ mar.cnxn, member_ids[group_id]).values())
+ owner_emails = list(self._services.user.LookupUserEmails(
+ mar.cnxn, owner_ids[group_id]).values())
+
+ return api_pb2_v1.GroupsGetResponse(
+ groupID=group_id,
+ groupSettings=api_pb2_v1_helpers.convert_group_settings(
+ request.groupName, group_settings),
+ groupOwners=owner_emails,
+ groupMembers=member_emails)
+
+ @monorail_api_method(
+ api_pb2_v1.GROUPS_UPDATE_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.GroupsUpdateResponse,
+ path='groups/{groupName}',
+ http_method='POST',
+ name='groups.update')
+ def groups_update(self, mar, request):
+ """Update a group's settings and users."""
+ group_id = mar.viewed_user_auth.user_id
+ member_ids_dict, owner_ids_dict = self._services.usergroup.LookupMembers(
+ mar.cnxn, [group_id])
+ owner_ids = owner_ids_dict.get(group_id, [])
+ member_ids = member_ids_dict.get(group_id, [])
+ if not permissions.CanEditGroup(
+ mar.perms, mar.auth.effective_ids, owner_ids):
+ raise permissions.PermissionException(
+ 'The user is not allowed to edit this group.')
+
+ group_settings = self._services.usergroup.GetGroupSettings(
+ mar.cnxn, group_id)
+ if (request.who_can_view_members or request.ext_group_type
+ or request.last_sync_time or request.friend_projects):
+ group_settings.who_can_view_members = (
+ request.who_can_view_members or group_settings.who_can_view_members)
+ group_settings.ext_group_type = (
+ request.ext_group_type or group_settings.ext_group_type)
+ group_settings.last_sync_time = (
+ request.last_sync_time or group_settings.last_sync_time)
+ if framework_constants.NO_VALUES in request.friend_projects:
+ group_settings.friend_projects = []
+ else:
+ id_dict = self._services.project.LookupProjectIDs(
+ mar.cnxn, request.friend_projects)
+ group_settings.friend_projects = (
+ list(id_dict.values()) or group_settings.friend_projects)
+ self._services.usergroup.UpdateSettings(
+ mar.cnxn, group_id, group_settings)
+
+ if request.groupOwners or request.groupMembers:
+ self._services.usergroup.RemoveMembers(
+ mar.cnxn, group_id, owner_ids + member_ids)
+ owners_dict = self._services.user.LookupUserIDs(
+ mar.cnxn, request.groupOwners, autocreate=True)
+ self._services.usergroup.UpdateMembers(
+ mar.cnxn, group_id, list(owners_dict.values()), 'owner')
+ members_dict = self._services.user.LookupUserIDs(
+ mar.cnxn, request.groupMembers, autocreate=True)
+ self._services.usergroup.UpdateMembers(
+ mar.cnxn, group_id, list(members_dict.values()), 'member')
+
+ return api_pb2_v1.GroupsUpdateResponse()
+
+ @monorail_api_method(
+ api_pb2_v1.COMPONENTS_LIST_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.ComponentsListResponse,
+ path='projects/{projectId}/components',
+ http_method='GET',
+ name='components.list')
+ def components_list(self, mar, _request):
+ """List all components of a given project."""
+ config = self._services.config.GetProjectConfig(mar.cnxn, mar.project_id)
+ components = [api_pb2_v1_helpers.convert_component_def(
+ cd, mar, self._services) for cd in config.component_defs]
+ return api_pb2_v1.ComponentsListResponse(
+ components=components)
+
+ @monorail_api_method(
+ api_pb2_v1.COMPONENTS_CREATE_REQUEST_RESOURCE_CONTAINER,
+ api_pb2_v1.Component,
+ path='projects/{projectId}/components',
+ http_method='POST',
+ name='components.create')
+ def components_create(self, mar, request):
+ """Create a component."""
+ if not mar.perms.CanUsePerm(
+ permissions.EDIT_PROJECT, mar.auth.effective_ids, mar.project, []):
+ raise permissions.PermissionException(
+ 'User is not allowed to create components for this project')
+
+ config = self._services.config.GetProjectConfig(mar.cnxn, mar.project_id)
+ leaf_name = request.componentName
+ if not tracker_constants.COMPONENT_NAME_RE.match(leaf_name):
+ raise exceptions.InvalidComponentNameException(
+ 'The component name %s is invalid.' % leaf_name)
+
+ parent_path = request.parentPath
+ if parent_path:
+ parent_def = tracker_bizobj.FindComponentDef(parent_path, config)
+ if not parent_def:
+ raise exceptions.NoSuchComponentException(
+ 'Parent component %s does not exist.' % parent_path)
+ if not permissions.CanEditComponentDef(
+ mar.auth.effective_ids, mar.perms, mar.project, parent_def, config):
+ raise permissions.PermissionException(
+ 'User is not allowed to add a subcomponent to component %s' %
+ parent_path)
+
+ path = '%s>%s' % (parent_path, leaf_name)
+ else:
+ path = leaf_name
+
+ if tracker_bizobj.FindComponentDef(path, config):
+ raise exceptions.InvalidComponentNameException(
+ 'The name %s is already in use.' % path)
+
+ created = int(time.time())
+ user_emails = set()
+ user_emails.update([mar.auth.email] + request.admin + request.cc)
+ user_ids_dict = self._services.user.LookupUserIDs(
+ mar.cnxn, list(user_emails), autocreate=False)
+ request.admin = [admin for admin in request.admin if admin]
+ admin_ids = [user_ids_dict[uname] for uname in request.admin]
+ request.cc = [cc for cc in request.cc if cc]
+ cc_ids = [user_ids_dict[uname] for uname in request.cc]
+ label_ids = [] # TODO(jrobbins): allow API clients to specify this too.
+
+ component_id = self._services.config.CreateComponentDef(
+ mar.cnxn, mar.project_id, path, request.description, request.deprecated,
+ admin_ids, cc_ids, created, user_ids_dict[mar.auth.email], label_ids)
+
+ return api_pb2_v1.Component(
+ componentId=component_id,
+ projectName=request.projectId,
+ componentPath=path,
+ description=request.description,
+ admin=request.admin,
+ cc=request.cc,
+ deprecated=request.deprecated,
+ created=datetime.datetime.fromtimestamp(created),
+ creator=mar.auth.email)
+
+ @monorail_api_method(
+ api_pb2_v1.COMPONENTS_DELETE_REQUEST_RESOURCE_CONTAINER,
+ message_types.VoidMessage,
+ path='projects/{projectId}/components/{componentPath}',
+ http_method='DELETE',
+ name='components.delete')
+ def components_delete(self, mar, request):
+ """Delete a component."""
+ config = self._services.config.GetProjectConfig(mar.cnxn, mar.project_id)
+ component_path = request.componentPath
+ component_def = tracker_bizobj.FindComponentDef(
+ component_path, config)
+ if not component_def:
+ raise exceptions.NoSuchComponentException(
+ 'The component %s does not exist.' % component_path)
+ if not permissions.CanViewComponentDef(
+ mar.auth.effective_ids, mar.perms, mar.project, component_def):
+ raise permissions.PermissionException(
+ 'User is not allowed to view this component %s' % component_path)
+ if not permissions.CanEditComponentDef(
+ mar.auth.effective_ids, mar.perms, mar.project, component_def, config):
+ raise permissions.PermissionException(
+ 'User is not allowed to delete this component %s' % component_path)
+
+ allow_delete = not tracker_bizobj.FindDescendantComponents(
+ config, component_def)
+ if not allow_delete:
+ raise permissions.PermissionException(
+ 'User tried to delete component that had subcomponents')
+
+ self._services.issue.DeleteComponentReferences(
+ mar.cnxn, component_def.component_id)
+ self._services.config.DeleteComponentDef(
+ mar.cnxn, mar.project_id, component_def.component_id)
+ return message_types.VoidMessage()
+
+ @monorail_api_method(
+ api_pb2_v1.COMPONENTS_UPDATE_REQUEST_RESOURCE_CONTAINER,
+ message_types.VoidMessage,
+ path='projects/{projectId}/components/{componentPath}',
+ http_method='POST',
+ name='components.update')
+ def components_update(self, mar, request):
+ """Update a component."""
+ config = self._services.config.GetProjectConfig(mar.cnxn, mar.project_id)
+ component_path = request.componentPath
+ component_def = tracker_bizobj.FindComponentDef(
+ component_path, config)
+ if not component_def:
+ raise exceptions.NoSuchComponentException(
+ 'The component %s does not exist.' % component_path)
+ if not permissions.CanViewComponentDef(
+ mar.auth.effective_ids, mar.perms, mar.project, component_def):
+ raise permissions.PermissionException(
+ 'User is not allowed to view this component %s' % component_path)
+ if not permissions.CanEditComponentDef(
+ mar.auth.effective_ids, mar.perms, mar.project, component_def, config):
+ raise permissions.PermissionException(
+ 'User is not allowed to edit this component %s' % component_path)
+
+ original_path = component_def.path
+ new_path = component_def.path
+ new_docstring = component_def.docstring
+ new_deprecated = component_def.deprecated
+ new_admin_ids = component_def.admin_ids
+ new_cc_ids = component_def.cc_ids
+ update_filterrule = False
+ for update in request.updates:
+ if update.field == api_pb2_v1.ComponentUpdateFieldID.LEAF_NAME:
+ leaf_name = update.leafName
+ if not tracker_constants.COMPONENT_NAME_RE.match(leaf_name):
+ raise exceptions.InvalidComponentNameException(
+ 'The component name %s is invalid.' % leaf_name)
+
+ if '>' in original_path:
+ parent_path = original_path[:original_path.rindex('>')]
+ new_path = '%s>%s' % (parent_path, leaf_name)
+ else:
+ new_path = leaf_name
+
+ conflict = tracker_bizobj.FindComponentDef(new_path, config)
+ if conflict and conflict.component_id != component_def.component_id:
+ raise exceptions.InvalidComponentNameException(
+ 'The name %s is already in use.' % new_path)
+ update_filterrule = True
+ elif update.field == api_pb2_v1.ComponentUpdateFieldID.DESCRIPTION:
+ new_docstring = update.description
+ elif update.field == api_pb2_v1.ComponentUpdateFieldID.ADMIN:
+ user_ids_dict = self._services.user.LookupUserIDs(
+ mar.cnxn, list(update.admin), autocreate=True)
+ new_admin_ids = list(set(user_ids_dict.values()))
+ elif update.field == api_pb2_v1.ComponentUpdateFieldID.CC:
+ user_ids_dict = self._services.user.LookupUserIDs(
+ mar.cnxn, list(update.cc), autocreate=True)
+ new_cc_ids = list(set(user_ids_dict.values()))
+ update_filterrule = True
+ elif update.field == api_pb2_v1.ComponentUpdateFieldID.DEPRECATED:
+ new_deprecated = update.deprecated
+ else:
+ logging.error('Unknown component field %r', update.field)
+
+ new_modified = int(time.time())
+ new_modifier_id = self._services.user.LookupUserID(
+ mar.cnxn, mar.auth.email, autocreate=False)
+ logging.info(
+ 'Updating component id %d: path-%s, docstring-%s, deprecated-%s,'
+ ' admin_ids-%s, cc_ids-%s modified by %s', component_def.component_id,
+ new_path, new_docstring, new_deprecated, new_admin_ids, new_cc_ids,
+ new_modifier_id)
+ self._services.config.UpdateComponentDef(
+ mar.cnxn, mar.project_id, component_def.component_id,
+ path=new_path, docstring=new_docstring, deprecated=new_deprecated,
+ admin_ids=new_admin_ids, cc_ids=new_cc_ids, modified=new_modified,
+ modifier_id=new_modifier_id)
+
+ # TODO(sheyang): reuse the code in componentdetails
+ if original_path != new_path:
+ # If the name changed then update all of its subcomponents as well.
+ subcomponent_ids = tracker_bizobj.FindMatchingComponentIDs(
+ original_path, config, exact=False)
+ for subcomponent_id in subcomponent_ids:
+ if subcomponent_id == component_def.component_id:
+ continue
+ subcomponent_def = tracker_bizobj.FindComponentDefByID(
+ subcomponent_id, config)
+ subcomponent_new_path = subcomponent_def.path.replace(
+ original_path, new_path, 1)
+ self._services.config.UpdateComponentDef(
+ mar.cnxn, mar.project_id, subcomponent_def.component_id,
+ path=subcomponent_new_path)
+
+ if update_filterrule:
+ filterrules_helpers.RecomputeAllDerivedFields(
+ mar.cnxn, self._services, mar.project, config)
+
+ return message_types.VoidMessage()
+
+
+@endpoints.api(name='monorail_client_configs', version='v1',
+ description='Monorail API client configs.')
+class ClientConfigApi(remote.Service):
+
+ # Class variables. Handy to mock.
+ _services = None
+ _mar = None
+
+ @classmethod
+ def _set_services(cls, services):
+ cls._services = services
+
+ def mar_factory(self, request, cnxn):
+ if not self._mar:
+ self._mar = monorailrequest.MonorailApiRequest(
+ request, self._services, cnxn=cnxn)
+ return self._mar
+
+ @endpoints.method(
+ message_types.VoidMessage,
+ message_types.VoidMessage,
+ path='client_configs',
+ http_method='POST',
+ name='client_configs.update')
+ def client_configs_update(self, request):
+ if self._services is None:
+ self._set_services(service_manager.set_up_services())
+ mar = self.mar_factory(request, sql.MonorailConnection())
+ if not mar.perms.HasPerm(permissions.ADMINISTER_SITE, None, None):
+ raise permissions.PermissionException(
+ 'The requester %s is not allowed to update client configs.' %
+ mar.auth.email)
+
+ ROLE_DICT = {
+ 1: permissions.COMMITTER_ROLE,
+ 2: permissions.CONTRIBUTOR_ROLE,
+ }
+
+ client_config = client_config_svc.GetClientConfigSvc()
+
+ cfg = client_config.GetConfigs()
+ if not cfg:
+ msg = 'Failed to fetch client configs.'
+ logging.error(msg)
+ raise endpoints.InternalServerErrorException(msg)
+
+ for client in cfg.clients:
+ if not client.client_email:
+ continue
+ # 1: create the user if non-existent
+ user_id = self._services.user.LookupUserID(
+ mar.cnxn, client.client_email, autocreate=True)
+ user_pb = self._services.user.GetUser(mar.cnxn, user_id)
+
+ logging.info('User ID %d for email %s', user_id, client.client_email)
+
+ # 2: set period and lifetime limit
+ # new_soft_limit, new_hard_limit, new_lifetime_limit
+ new_limit_tuple = (
+ client.period_limit, client.period_limit, client.lifetime_limit)
+ action_limit_updates = {'api_request': new_limit_tuple}
+ self._services.user.UpdateUserSettings(
+ mar.cnxn, user_id, user_pb, action_limit_updates=action_limit_updates)
+
+ logging.info('Updated api request limit %r', new_limit_tuple)
+
+ # 3: Update project role and extra perms
+ projects_dict = self._services.project.GetAllProjects(mar.cnxn)
+ project_name_to_ids = {
+ p.project_name: p.project_id for p in projects_dict.values()}
+
+ # Set project role and extra perms
+ for perm in client.project_permissions:
+ project_ids = self._GetProjectIDs(perm.project, project_name_to_ids)
+ logging.info('Matching projects %r for name %s',
+ project_ids, perm.project)
+
+ role = ROLE_DICT[perm.role]
+ for p_id in project_ids:
+ project = projects_dict[p_id]
+ people_list = []
+ if role == 'owner':
+ people_list = project.owner_ids
+ elif role == 'committer':
+ people_list = project.committer_ids
+ elif role == 'contributor':
+ people_list = project.contributor_ids
+ # Onlu update role/extra perms iff changed
+ if not user_id in people_list:
+ logging.info('Update project %s role %s for user %s',
+ project.project_name, role, client.client_email)
+ owner_ids, committer_ids, contributor_ids = (
+ project_helpers.MembersWithGivenIDs(project, {user_id}, role))
+ self._services.project.UpdateProjectRoles(
+ mar.cnxn, p_id, owner_ids, committer_ids,
+ contributor_ids)
+ if perm.extra_permissions:
+ logging.info('Update project %s extra perm %s for user %s',
+ project.project_name, perm.extra_permissions,
+ client.client_email)
+ self._services.project.UpdateExtraPerms(
+ mar.cnxn, p_id, user_id, list(perm.extra_permissions))
+
+ mar.CleanUp()
+ return message_types.VoidMessage()
+
+ def _GetProjectIDs(self, project_str, project_name_to_ids):
+ result = []
+ if any(ch in project_str for ch in ['*', '+', '?', '.']):
+ pattern = re.compile(project_str)
+ for p_name in project_name_to_ids.keys():
+ if pattern.match(p_name):
+ project_id = project_name_to_ids.get(p_name)
+ if project_id:
+ result.append(project_id)
+ else:
+ project_id = project_name_to_ids.get(project_str)
+ if project_id:
+ result.append(project_id)
+
+ if not result:
+ logging.warning('Cannot find projects for specified name %s',
+ project_str)
+ return result
diff --git a/services/cachemanager_svc.py b/services/cachemanager_svc.py
new file mode 100644
index 0000000..8dc5753
--- /dev/null
+++ b/services/cachemanager_svc.py
@@ -0,0 +1,166 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A simple in-RAM cache with distributed invalidation.
+
+Here's how it works:
+ + Each frontend or backend job has one CacheManager which
+ owns a set of RamCache objects, which are basically dictionaries.
+ + Each job can put objects in its own local cache, and retrieve them.
+ + When an item is modified, the item at the corresponding cache key
+ is invalidated, which means two things: (a) it is dropped from the
+ local RAM cache, and (b) the key is written to the Invalidate table.
+ + On each incoming request, the job checks the Invalidate table for
+ any entries added since the last time that it checked. If it finds
+ any, it drops all RamCache entries for the corresponding key.
+ + There is also a cron task that truncates old Invalidate entries
+ when the table is too large. If a frontend job sees more than the
+ max Invalidate rows, it will drop everything from all caches,
+ because it does not know what it missed due to truncation.
+ + The special key 0 means to drop all cache entries.
+
+This approach makes jobs use cached values that are not stale at the
+time that processing of each request begins. There is no guarantee that
+an item will not be modified by some other job and that the cached entry
+could become stale during the lifetime of that same request.
+
+TODO(jrobbins): Listener hook so that client code can register its own
+handler for invalidation events. E.g., the sorting code has a cache that
+is correctly invalidated on each issue change, but needs to be completely
+dropped when a config is modified.
+
+TODO(jrobbins): If this part of the system becomes a bottleneck, consider
+some optimizations: (a) splitting the table into multiple tables by
+kind, or (b) sharding the table by cache_key. Or, maybe leverage memcache
+to avoid even hitting the DB in the frequent case where nothing has changed.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import logging
+
+from framework import jsonfeed
+from framework import sql
+
+
+INVALIDATE_TABLE_NAME = 'Invalidate'
+INVALIDATE_COLS = ['timestep', 'kind', 'cache_key']
+# Note: *_id invalidations should happen only when there's a change
+# in one of the values used to look up the internal ID number.
+# E.g. hotlist_id_2lc should only be invalidated when the hotlist
+# name or owner changes.
+INVALIDATE_KIND_VALUES = [
+ 'user', 'usergroup', 'project', 'project_id', 'issue', 'issue_id',
+ 'hotlist', 'hotlist_id', 'comment', 'template'
+]
+INVALIDATE_ALL_KEYS = 0
+MAX_INVALIDATE_ROWS_TO_CONSIDER = 1000
+
+
+class CacheManager(object):
+ """Service class to manage RAM caches and shared Invalidate table."""
+
+ def __init__(self):
+ self.cache_registry = collections.defaultdict(list)
+ self.processed_invalidations_up_to = 0
+ self.invalidate_tbl = sql.SQLTableManager(INVALIDATE_TABLE_NAME)
+
+ def RegisterCache(self, cache, kind):
+ """Register a cache to be notified of future invalidations."""
+ assert kind in INVALIDATE_KIND_VALUES
+ self.cache_registry[kind].append(cache)
+
+ def _InvalidateAllCaches(self):
+ """Invalidate all cache entries."""
+ for cache_list in self.cache_registry.values():
+ for cache in cache_list:
+ cache.LocalInvalidateAll()
+
+ def _ProcessInvalidationRows(self, rows):
+ """Invalidate cache entries indicated by database rows."""
+ already_done = set()
+ for timestep, kind, key in rows:
+ self.processed_invalidations_up_to = max(
+ self.processed_invalidations_up_to, timestep)
+ if (kind, key) in already_done:
+ continue
+ already_done.add((kind, key))
+ for cache in self.cache_registry[kind]:
+ if key == INVALIDATE_ALL_KEYS:
+ cache.LocalInvalidateAll()
+ else:
+ cache.LocalInvalidate(key)
+
+ def DoDistributedInvalidation(self, cnxn):
+ """Drop any cache entries that were invalidated by other jobs."""
+ # Only consider a reasonable number of rows so that we can never
+ # get bogged down on this step. If there are too many rows to
+ # process, just invalidate all caches, and process the last group
+ # of rows to update processed_invalidations_up_to.
+ rows = self.invalidate_tbl.Select(
+ cnxn, cols=INVALIDATE_COLS,
+ where=[('timestep > %s', [self.processed_invalidations_up_to])],
+ order_by=[('timestep DESC', [])],
+ limit=MAX_INVALIDATE_ROWS_TO_CONSIDER)
+
+ cnxn.Commit()
+
+ if len(rows) == MAX_INVALIDATE_ROWS_TO_CONSIDER:
+ logging.info('Invaliditing all caches: there are too many invalidations')
+ self._InvalidateAllCaches()
+
+ logging.info('Saw %d invalidation rows', len(rows))
+ self._ProcessInvalidationRows(rows)
+
+ def StoreInvalidateRows(self, cnxn, kind, keys):
+ """Store rows to let all jobs know to invalidate the given keys."""
+ assert kind in INVALIDATE_KIND_VALUES
+ self.invalidate_tbl.InsertRows(
+ cnxn, ['kind', 'cache_key'], [(kind, key) for key in keys])
+
+ def StoreInvalidateAll(self, cnxn, kind):
+ """Store a value to tell all jobs to invalidate all items of this kind."""
+ last_timestep = self.invalidate_tbl.InsertRow(
+ cnxn, kind=kind, cache_key=INVALIDATE_ALL_KEYS)
+ self.invalidate_tbl.Delete(
+ cnxn, kind=kind, where=[('timestep < %s', [last_timestep])])
+
+
+class RamCacheConsolidate(jsonfeed.InternalTask):
+ """Drop old Invalidate rows when there are too many of them."""
+
+ def HandleRequest(self, mr):
+ """Drop excessive rows in the Invalidate table and return some stats.
+
+ Args:
+ mr: common information parsed from the HTTP request.
+
+ Returns:
+ Results dictionary in JSON format. The stats are just for debugging,
+ they are not used by any other part of the system.
+ """
+ tbl = self.services.cache_manager.invalidate_tbl
+ old_count = tbl.SelectValue(mr.cnxn, 'COUNT(*)')
+
+ # Delete anything other than the last 1000 rows because we won't
+ # look at them anyway. If a job gets a request and sees 1000 new
+ # rows, it will drop all caches of all types, so it is as if there
+ # were INVALIDATE_ALL_KEYS entries.
+ if old_count > MAX_INVALIDATE_ROWS_TO_CONSIDER:
+ kept_timesteps = tbl.Select(
+ mr.cnxn, ['timestep'],
+ order_by=[('timestep DESC', [])],
+ limit=MAX_INVALIDATE_ROWS_TO_CONSIDER)
+ earliest_kept = kept_timesteps[-1][0]
+ tbl.Delete(mr.cnxn, where=[('timestep < %s', [earliest_kept])])
+
+ new_count = tbl.SelectValue(mr.cnxn, 'COUNT(*)')
+
+ return {
+ 'old_count': old_count,
+ 'new_count': new_count,
+ }
diff --git a/services/caches.py b/services/caches.py
new file mode 100644
index 0000000..07702bf
--- /dev/null
+++ b/services/caches.py
@@ -0,0 +1,514 @@
+# Copyright 2020 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+"""Classes to manage cached values.
+
+Monorail makes full use of the RAM of GAE frontends to reduce latency
+and load on the database.
+
+Even though these caches do invalidation, there are rare race conditions
+that can cause a somewhat stale object to be retrieved from memcache and
+then put into a RAM cache and used by a given GAE instance for some time.
+So, we only use these caches for operations that can tolerate somewhat
+stale data. For example, displaying issues in a list or displaying brief
+info about related issues. We never use the cache to load objects as
+part of a read-modify-save sequence because that could cause stored data
+to revert to a previous state.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import redis
+
+from protorpc import protobuf
+
+from google.appengine.api import memcache
+
+import settings
+from framework import framework_constants
+from framework import redis_utils
+from proto import tracker_pb2
+
+
+DEFAULT_MAX_SIZE = 10000
+
+
+class RamCache(object):
+ """An in-RAM cache with distributed invalidation."""
+
+ def __init__(self, cache_manager, kind, max_size=None):
+ self.cache_manager = cache_manager
+ self.kind = kind
+ self.cache = {}
+ self.max_size = max_size or DEFAULT_MAX_SIZE
+ cache_manager.RegisterCache(self, kind)
+
+ def CacheItem(self, key, item):
+ """Store item at key in this cache, discarding a random item if needed."""
+ if len(self.cache) >= self.max_size:
+ self.cache.popitem()
+
+ self.cache[key] = item
+
+ def CacheAll(self, new_item_dict):
+ """Cache all items in the given dict, dropping old items if needed."""
+ if len(new_item_dict) >= self.max_size:
+ logging.warn('Dumping the entire cache! %s', self.kind)
+ self.cache = {}
+ else:
+ while len(self.cache) + len(new_item_dict) > self.max_size:
+ self.cache.popitem()
+
+ self.cache.update(new_item_dict)
+
+ def GetItem(self, key):
+ """Return the cached item if present, otherwise None."""
+ return self.cache.get(key)
+
+ def HasItem(self, key):
+ """Return True if there is a value cached at the given key."""
+ return key in self.cache
+
+ def GetAll(self, keys):
+ """Look up the given keys.
+
+ Args:
+ keys: a list of cache keys to look up.
+
+ Returns:
+ A pair: (hits_dict, misses_list) where hits_dict is a dictionary of
+ all the given keys and the values that were found in the cache, and
+ misses_list is a list of given keys that were not in the cache.
+ """
+ hits, misses = {}, []
+ for key in keys:
+ try:
+ hits[key] = self.cache[key]
+ except KeyError:
+ misses.append(key)
+
+ return hits, misses
+
+ def LocalInvalidate(self, key):
+ """Drop the given key from this cache, without distributed notification."""
+ if key in self.cache:
+ logging.info('Locally invalidating %r in kind=%r', key, self.kind)
+ self.cache.pop(key, None)
+
+ def Invalidate(self, cnxn, key):
+ """Drop key locally, and append it to the Invalidate DB table."""
+ self.InvalidateKeys(cnxn, [key])
+
+ def InvalidateKeys(self, cnxn, keys):
+ """Drop keys locally, and append them to the Invalidate DB table."""
+ for key in keys:
+ self.LocalInvalidate(key)
+ if self.cache_manager:
+ self.cache_manager.StoreInvalidateRows(cnxn, self.kind, keys)
+
+ def LocalInvalidateAll(self):
+ """Invalidate all keys locally: just start over with an empty dict."""
+ logging.info('Locally invalidating all in kind=%r', self.kind)
+ self.cache = {}
+
+ def InvalidateAll(self, cnxn):
+ """Invalidate all keys in this cache."""
+ self.LocalInvalidateAll()
+ if self.cache_manager:
+ self.cache_manager.StoreInvalidateAll(cnxn, self.kind)
+
+
+class ShardedRamCache(RamCache):
+ """Specialized version of RamCache that stores values in parts.
+
+ Instead of the cache keys being simple integers, they are pairs, e.g.,
+ (project_id, shard_id). Invalidation will invalidate all shards for
+ a given main key, e.g, invalidating project_id 16 will drop keys
+ (16, 0), (16, 1), (16, 2), ... (16, 9).
+ """
+
+ def __init__(self, cache_manager, kind, max_size=None, num_shards=10):
+ super(ShardedRamCache, self).__init__(
+ cache_manager, kind, max_size=max_size)
+ self.num_shards = num_shards
+
+ def LocalInvalidate(self, key):
+ """Use the specified value to drop entries from the local cache."""
+ logging.info('About to invalidate shared RAM keys %r',
+ [(key, shard_id) for shard_id in range(self.num_shards)
+ if (key, shard_id) in self.cache])
+ for shard_id in range(self.num_shards):
+ self.cache.pop((key, shard_id), None)
+
+
+class ValueCentricRamCache(RamCache):
+ """Specialized version of RamCache that stores values in InvalidateTable.
+
+ This is useful for caches that have non integer keys.
+ """
+
+ def LocalInvalidate(self, value):
+ """Use the specified value to drop entries from the local cache."""
+ keys_to_drop = []
+ # Loop through and collect all keys with the specified value.
+ for k, v in self.cache.items():
+ if v == value:
+ keys_to_drop.append(k)
+ for k in keys_to_drop:
+ self.cache.pop(k, None)
+
+ def InvalidateKeys(self, cnxn, keys):
+ """Drop keys locally, and append their values to the Invalidate DB table."""
+ # Find values to invalidate.
+ values = [self.cache[key] for key in keys if self.cache.has_key(key)]
+ if len(values) == len(keys):
+ for value in values:
+ self.LocalInvalidate(value)
+ if self.cache_manager:
+ self.cache_manager.StoreInvalidateRows(cnxn, self.kind, values)
+ else:
+ # If a value is not found in the cache then invalidate the whole cache.
+ # This is done to ensure that we are not in an inconsistent state or in a
+ # race condition.
+ self.InvalidateAll(cnxn)
+
+
+class AbstractTwoLevelCache(object):
+ """A class to manage both RAM and secondary-caching layer to retrieve objects.
+
+ Subclasses must implement the FetchItems() method to get objects from
+ the database when both caches miss.
+ """
+
+ # When loading a huge number of issues from the database, do it in chunks
+ # so as to avoid timeouts.
+ _FETCH_BATCH_SIZE = 10000
+
+ def __init__(
+ self,
+ cache_manager,
+ kind,
+ prefix,
+ pb_class,
+ max_size=None,
+ use_redis=False,
+ redis_client=None):
+
+ self.cache = self._MakeCache(cache_manager, kind, max_size=max_size)
+ self.prefix = prefix
+ self.pb_class = pb_class
+
+ if use_redis:
+ self.redis_client = redis_client or redis_utils.CreateRedisClient()
+ self.use_redis = redis_utils.VerifyRedisConnection(
+ self.redis_client, msg=kind)
+ else:
+ self.redis_client = None
+ self.use_redis = False
+
+ def _MakeCache(self, cache_manager, kind, max_size=None):
+ """Make the RAM cache and register it with the cache_manager."""
+ return RamCache(cache_manager, kind, max_size=max_size)
+
+ def CacheItem(self, key, value):
+ """Add the given key-value pair to RAM and L2 cache."""
+ self.cache.CacheItem(key, value)
+ self._WriteToCache({key: value})
+
+ def HasItem(self, key):
+ """Return True if the given key is in the RAM cache."""
+ return self.cache.HasItem(key)
+
+ def GetAnyOnHandItem(self, keys, start=None, end=None):
+ """Try to find one of the specified items in RAM."""
+ if start is None:
+ start = 0
+ if end is None:
+ end = len(keys)
+ for i in range(start, end):
+ key = keys[i]
+ if self.cache.HasItem(key):
+ return self.cache.GetItem(key)
+
+ # Note: We could check L2 here too, but the round-trips to L2
+ # are kind of slow. And, getting too many hits from L2 actually
+ # fills our RAM cache too quickly and could lead to thrashing.
+
+ return None
+
+ def GetAll(self, cnxn, keys, use_cache=True, **kwargs):
+ """Get values for the given keys from RAM, the L2 cache, or the DB.
+
+ Args:
+ cnxn: connection to the database.
+ keys: list of integer keys to look up.
+ use_cache: set to False to always hit the database.
+ **kwargs: any additional keywords are passed to FetchItems().
+
+ Returns:
+ A pair: hits, misses. Where hits is {key: value} and misses is
+ a list of any keys that were not found anywhere.
+ """
+ if use_cache:
+ result_dict, missed_keys = self.cache.GetAll(keys)
+ else:
+ result_dict, missed_keys = {}, list(keys)
+
+ if missed_keys:
+ if use_cache:
+ cache_hits, missed_keys = self._ReadFromCache(missed_keys)
+ result_dict.update(cache_hits)
+ self.cache.CacheAll(cache_hits)
+
+ while missed_keys:
+ missed_batch = missed_keys[:self._FETCH_BATCH_SIZE]
+ missed_keys = missed_keys[self._FETCH_BATCH_SIZE:]
+ retrieved_dict = self.FetchItems(cnxn, missed_batch, **kwargs)
+ result_dict.update(retrieved_dict)
+ if use_cache:
+ self.cache.CacheAll(retrieved_dict)
+ self._WriteToCache(retrieved_dict)
+
+ still_missing_keys = [key for key in keys if key not in result_dict]
+ return result_dict, still_missing_keys
+
+ def LocalInvalidateAll(self):
+ self.cache.LocalInvalidateAll()
+
+ def LocalInvalidate(self, key):
+ self.cache.LocalInvalidate(key)
+
+ def InvalidateKeys(self, cnxn, keys):
+ """Drop the given keys from both RAM and L2 cache."""
+ self.cache.InvalidateKeys(cnxn, keys)
+ self._DeleteFromCache(keys)
+
+ def InvalidateAllKeys(self, cnxn, keys):
+ """Drop the given keys from L2 cache and invalidate all keys in RAM.
+
+ Useful for avoiding inserting many rows into the Invalidate table when
+ invalidating a large group of keys all at once. Only use when necessary.
+ """
+ self.cache.InvalidateAll(cnxn)
+ self._DeleteFromCache(keys)
+
+ def GetAllAlreadyInRam(self, keys):
+ """Look only in RAM to return {key: values}, missed_keys."""
+ result_dict, missed_keys = self.cache.GetAll(keys)
+ return result_dict, missed_keys
+
+ def InvalidateAllRamEntries(self, cnxn):
+ """Drop all RAM cache entries. It will refill as needed from L2 cache."""
+ self.cache.InvalidateAll(cnxn)
+
+ def FetchItems(self, cnxn, keys, **kwargs):
+ """On RAM and L2 cache miss, hit the database."""
+ raise NotImplementedError()
+
+ def _ReadFromCache(self, keys):
+ # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int]
+ """Reads a list of keys from secondary caching service.
+
+ Redis will be used if Redis is enabled and connection is valid;
+ otherwise, memcache will be used.
+
+ Args:
+ keys: List of integer keys to look up in L2 cache.
+
+ Returns:
+ A pair: hits, misses. Where hits is {key: value} and misses is
+ a list of any keys that were not found anywhere.
+ """
+ if self.use_redis:
+ return self._ReadFromRedis(keys)
+ else:
+ return self._ReadFromMemcache(keys)
+
+ def _WriteToCache(self, retrieved_dict):
+ # type: (Mapping[int, Any]) -> None
+ """Writes a set of key-value pairs to secondary caching service.
+
+ Redis will be used if Redis is enabled and connection is valid;
+ otherwise, memcache will be used.
+
+ Args:
+ retrieved_dict: Dictionary contains pairs of key-values to write to cache.
+ """
+ if self.use_redis:
+ return self._WriteToRedis(retrieved_dict)
+ else:
+ return self._WriteToMemcache(retrieved_dict)
+
+ def _DeleteFromCache(self, keys):
+ # type: (Sequence[int]) -> None
+ """Selects which cache to delete from.
+
+ Redis will be used if Redis is enabled and connection is valid;
+ otherwise, memcache will be used.
+
+ Args:
+ keys: List of integer keys to delete from cache.
+ """
+ if self.use_redis:
+ return self._DeleteFromRedis(keys)
+ else:
+ return self._DeleteFromMemcache(keys)
+
+ def _ReadFromMemcache(self, keys):
+ # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int]
+ """Read the given keys from memcache, return {key: value}, missing_keys."""
+ cache_hits = {}
+ cached_dict = memcache.get_multi(
+ [self._KeyToStr(key) for key in keys],
+ key_prefix=self.prefix,
+ namespace=settings.memcache_namespace)
+
+ for key_str, serialized_value in cached_dict.items():
+ value = self._StrToValue(serialized_value)
+ key = self._StrToKey(key_str)
+ cache_hits[key] = value
+ self.cache.CacheItem(key, value)
+
+ still_missing_keys = [key for key in keys if key not in cache_hits]
+ return cache_hits, still_missing_keys
+
+ def _WriteToMemcache(self, retrieved_dict):
+ # type: (Mapping[int, int]) -> None
+ """Write entries for each key-value pair to memcache. Encode PBs."""
+ strs_to_cache = {
+ self._KeyToStr(key): self._ValueToStr(value)
+ for key, value in retrieved_dict.items()}
+
+ try:
+ memcache.add_multi(
+ strs_to_cache,
+ key_prefix=self.prefix,
+ time=framework_constants.CACHE_EXPIRATION,
+ namespace=settings.memcache_namespace)
+ except ValueError as identifier:
+ # If memcache does not accept the values, ensure that no stale
+ # values are left, then bail out.
+ logging.error('Got memcache error: %r', identifier)
+ self._DeleteFromMemcache(list(strs_to_cache.keys()))
+ return
+
+ def _DeleteFromMemcache(self, keys):
+ # type: (Sequence[str]) -> None
+ """Delete key-values from memcache. """
+ memcache.delete_multi(
+ [self._KeyToStr(key) for key in keys],
+ seconds=5,
+ key_prefix=self.prefix,
+ namespace=settings.memcache_namespace)
+
+ def _WriteToRedis(self, retrieved_dict):
+ # type: (Mapping[int, Any]) -> None
+ """Write entries for each key-value pair to Redis. Encode PBs.
+
+ Args:
+ retrieved_dict: Dictionary of key-value pairs to write to Redis.
+ """
+ try:
+ for key, value in retrieved_dict.items():
+ redis_key = redis_utils.FormatRedisKey(key, prefix=self.prefix)
+ redis_value = self._ValueToStr(value)
+
+ self.redis_client.setex(
+ redis_key, framework_constants.CACHE_EXPIRATION, redis_value)
+ except redis.RedisError as identifier:
+ logging.error(
+ 'Redis error occurred during write operation: %s', identifier)
+ self._DeleteFromRedis(list(retrieved_dict.keys()))
+ return
+ logging.info(
+ 'cached batch of %d values in redis %s', len(retrieved_dict),
+ self.prefix)
+
+ def _ReadFromRedis(self, keys):
+ # type: (Sequence[int]) -> Mapping[str, Any], Sequence[int]
+ """Read the given keys from Redis, return {key: value}, missing keys.
+
+ Args:
+ keys: List of integer keys to read from Redis.
+
+ Returns:
+ A pair: hits, misses. Where hits is {key: value} and misses is
+ a list of any keys that were not found anywhere.
+ """
+ cache_hits = {}
+ missing_keys = []
+ try:
+ values_list = self.redis_client.mget(
+ [redis_utils.FormatRedisKey(key, prefix=self.prefix) for key in keys])
+ except redis.RedisError as identifier:
+ logging.error(
+ 'Redis error occurred during read operation: %s', identifier)
+ values_list = [None] * len(keys)
+
+ for key, serialized_value in zip(keys, values_list):
+ if serialized_value:
+ value = self._StrToValue(serialized_value)
+ cache_hits[key] = value
+ self.cache.CacheItem(key, value)
+ else:
+ missing_keys.append(key)
+ logging.info(
+ 'decoded %d values from redis %s, missing %d', len(cache_hits),
+ self.prefix, len(missing_keys))
+ return cache_hits, missing_keys
+
+ def _DeleteFromRedis(self, keys):
+ # type: (Sequence[int]) -> None
+ """Delete key-values from redis.
+
+ Args:
+ keys: List of integer keys to delete.
+ """
+ try:
+ self.redis_client.delete(
+ *[
+ redis_utils.FormatRedisKey(key, prefix=self.prefix)
+ for key in keys
+ ])
+ except redis.RedisError as identifier:
+ logging.error(
+ 'Redis error occurred during delete operation %s', identifier)
+
+ def _KeyToStr(self, key):
+ # type: (int) -> str
+ """Convert our int IDs to strings for use as memcache keys."""
+ return str(key)
+
+ def _StrToKey(self, key_str):
+ # type: (str) -> int
+ """Convert memcache keys back to the ints that we use as IDs."""
+ return int(key_str)
+
+ def _ValueToStr(self, value):
+ # type: (Any) -> str
+ """Serialize an application object so that it can be stored in L2 cache."""
+ if self.use_redis:
+ return redis_utils.SerializeValue(value, pb_class=self.pb_class)
+ else:
+ if not self.pb_class:
+ return value
+ elif self.pb_class == int:
+ return str(value)
+ else:
+ return protobuf.encode_message(value)
+
+ def _StrToValue(self, serialized_value):
+ # type: (str) -> Any
+ """Deserialize L2 cache string into an application object."""
+ if self.use_redis:
+ return redis_utils.DeserializeValue(
+ serialized_value, pb_class=self.pb_class)
+ else:
+ if not self.pb_class:
+ return serialized_value
+ elif self.pb_class == int:
+ return int(serialized_value)
+ else:
+ return protobuf.decode_message(self.pb_class, serialized_value)
diff --git a/services/chart_svc.py b/services/chart_svc.py
new file mode 100644
index 0000000..49ccb51
--- /dev/null
+++ b/services/chart_svc.py
@@ -0,0 +1,411 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A service for querying data for charts.
+
+Functions for querying the IssueSnapshot table and associated join tables.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import settings
+import time
+
+from features import hotlist_helpers
+from framework import framework_helpers
+from framework import sql
+from search import search_helpers
+from tracker import tracker_bizobj
+from tracker import tracker_helpers
+from search import query2ast
+from search import ast2select
+from search import ast2ast
+
+
+ISSUESNAPSHOT_TABLE_NAME = 'IssueSnapshot'
+ISSUESNAPSHOT2CC_TABLE_NAME = 'IssueSnapshot2Cc'
+ISSUESNAPSHOT2COMPONENT_TABLE_NAME = 'IssueSnapshot2Component'
+ISSUESNAPSHOT2LABEL_TABLE_NAME = 'IssueSnapshot2Label'
+
+ISSUESNAPSHOT_COLS = ['id', 'issue_id', 'shard', 'project_id', 'local_id',
+ 'reporter_id', 'owner_id', 'status_id', 'period_start', 'period_end',
+ 'is_open']
+ISSUESNAPSHOT2CC_COLS = ['issuesnapshot_id', 'cc_id']
+ISSUESNAPSHOT2COMPONENT_COLS = ['issuesnapshot_id', 'component_id']
+ISSUESNAPSHOT2LABEL_COLS = ['issuesnapshot_id', 'label_id']
+
+
+class ChartService(object):
+ """Class for querying chart data."""
+
+ def __init__(self, config_service):
+ """Constructor for ChartService.
+
+ Args:
+ config_service (ConfigService): An instance of ConfigService.
+ """
+ self.config_service = config_service
+
+ # Set up SQL table objects.
+ self.issuesnapshot_tbl = sql.SQLTableManager(ISSUESNAPSHOT_TABLE_NAME)
+ self.issuesnapshot2cc_tbl = sql.SQLTableManager(
+ ISSUESNAPSHOT2CC_TABLE_NAME)
+ self.issuesnapshot2component_tbl = sql.SQLTableManager(
+ ISSUESNAPSHOT2COMPONENT_TABLE_NAME)
+ self.issuesnapshot2label_tbl = sql.SQLTableManager(
+ ISSUESNAPSHOT2LABEL_TABLE_NAME)
+
+ def QueryIssueSnapshots(self, cnxn, services, unixtime, effective_ids,
+ project, perms, group_by=None, label_prefix=None,
+ query=None, canned_query=None, hotlist=None):
+ """Queries historical issue counts grouped by label or component.
+
+ Args:
+ cnxn: A MonorailConnection instance.
+ services: A Services instance.
+ unixtime: An integer representing the Unix time in seconds.
+ effective_ids: The effective User IDs associated with the current user.
+ project: A project object representing the current project.
+ perms: A permissions object associated with the current user.
+ group_by (str, optional): Which dimension to group by. Values can
+ be 'label', 'component', or None, in which case no grouping will
+ be applied.
+ label_prefix: Required when group_by is 'label.' Will limit the query to
+ only labels with the specified prefix (for example 'Pri').
+ query (str, optional): A query string from the request to apply to
+ the snapshot query.
+ canned_query (str, optional): Parsed canned query applied to the query
+ scope.
+ hotlist (Hotlist, optional): Hotlist to search under (in lieu of project).
+
+ Returns:
+ 1. A dict of {'2nd dimension or "total"': number of occurences}.
+ 2. A list of any unsupported query conditions in query.
+ 3. A boolean that is true if any results were capped.
+ """
+ if hotlist:
+ # TODO(jeffcarp): Get project_ids in a more efficient manner. We can
+ # query for "SELECT DISTINCT(project_id)" for all issues in hotlist.
+ issues_list = services.issue.GetIssues(cnxn,
+ [hotlist_issue.issue_id for hotlist_issue in hotlist.items])
+ hotlist_issues_project_ids = hotlist_helpers.GetAllProjectsOfIssues(
+ [issue for issue in issues_list])
+ config_list = hotlist_helpers.GetAllConfigsOfProjects(
+ cnxn, hotlist_issues_project_ids, services)
+ project_config = tracker_bizobj.HarmonizeConfigs(config_list)
+ else:
+ project_config = services.config.GetProjectConfig(cnxn,
+ project.project_id)
+
+ if project:
+ project_ids = [project.project_id]
+ else:
+ project_ids = hotlist_issues_project_ids
+
+ try:
+ query_left_joins, query_where, unsupported_conds = self._QueryToWhere(
+ cnxn, services, project_config, query, canned_query, project_ids)
+ except ast2select.NoPossibleResults:
+ return {}, ['Invalid query.'], False
+
+ restricted_label_ids = search_helpers.GetPersonalAtRiskLabelIDs(
+ cnxn, None, self.config_service, effective_ids, project, perms)
+
+ left_joins = [
+ ('Issue ON IssueSnapshot.issue_id = Issue.id', []),
+ ]
+
+ if restricted_label_ids:
+ left_joins.append(
+ (('Issue2Label AS Forbidden_label'
+ ' ON Issue.id = Forbidden_label.issue_id'
+ ' AND Forbidden_label.label_id IN (%s)' % (
+ sql.PlaceHolders(restricted_label_ids)
+ )), restricted_label_ids))
+
+ if effective_ids:
+ left_joins.append(
+ ('Issue2Cc AS I2cc'
+ ' ON Issue.id = I2cc.issue_id'
+ ' AND I2cc.cc_id IN (%s)' % sql.PlaceHolders(effective_ids),
+ effective_ids))
+
+ # TODO(jeffcarp): Handle case where there are issues with no labels.
+ where = [
+ ('IssueSnapshot.period_start <= %s', [unixtime]),
+ ('IssueSnapshot.period_end > %s', [unixtime]),
+ ('Issue.is_spam = %s', [False]),
+ ('Issue.deleted = %s', [False]),
+ ]
+ if project_ids:
+ where.append(
+ ('IssueSnapshot.project_id IN (%s)' % sql.PlaceHolders(project_ids),
+ project_ids))
+
+ forbidden_label_clause = 'Forbidden_label.label_id IS NULL'
+ if effective_ids:
+ if restricted_label_ids:
+ forbidden_label_clause = ' OR %s' % forbidden_label_clause
+ else:
+ forbidden_label_clause = ''
+
+ where.append(
+ ((
+ '(Issue.reporter_id IN (%s)'
+ ' OR Issue.owner_id IN (%s)'
+ ' OR I2cc.cc_id IS NOT NULL'
+ '%s)'
+ ) % (
+ sql.PlaceHolders(effective_ids), sql.PlaceHolders(effective_ids),
+ forbidden_label_clause
+ ),
+ list(effective_ids) + list(effective_ids)
+ ))
+ else:
+ where.append((forbidden_label_clause, []))
+
+ if group_by == 'component':
+ cols = ['Comp.path', 'COUNT(IssueSnapshot.issue_id)']
+ left_joins.extend([
+ (('IssueSnapshot2Component AS Is2c ON'
+ ' Is2c.issuesnapshot_id = IssueSnapshot.id'), []),
+ ('ComponentDef AS Comp ON Comp.id = Is2c.component_id', []),
+ ])
+ group_by = ['Comp.path']
+ elif group_by == 'label':
+ cols = ['Lab.label', 'COUNT(IssueSnapshot.issue_id)']
+ left_joins.extend([
+ (('IssueSnapshot2Label AS Is2l'
+ ' ON Is2l.issuesnapshot_id = IssueSnapshot.id'), []),
+ ('LabelDef AS Lab ON Lab.id = Is2l.label_id', []),
+ ])
+
+ if not label_prefix:
+ raise ValueError('`label_prefix` required when grouping by label.')
+
+ # TODO(jeffcarp): If LookupIDsOfLabelsMatching() is called on output,
+ # ensure regex is case-insensitive.
+ where.append(('LOWER(Lab.label) LIKE %s', [label_prefix.lower() + '-%']))
+ group_by = ['Lab.label']
+ elif group_by == 'open':
+ cols = ['IssueSnapshot.is_open',
+ 'COUNT(IssueSnapshot.issue_id) AS issue_count']
+ group_by = ['IssueSnapshot.is_open']
+ elif group_by == 'status':
+ left_joins.append(('StatusDef AS Stats ON ' \
+ 'Stats.id = IssueSnapshot.status_id', []))
+ cols = ['Stats.status', 'COUNT(IssueSnapshot.issue_id)']
+ group_by = ['Stats.status']
+ elif group_by == 'owner':
+ cols = ['IssueSnapshot.owner_id', 'COUNT(IssueSnapshot.issue_id)']
+ group_by = ['IssueSnapshot.owner_id']
+ elif not group_by:
+ cols = ['IssueSnapshot.issue_id']
+ else:
+ raise ValueError('`group_by` must be label, component, ' \
+ 'open, status, owner or None.')
+
+ if query_left_joins:
+ left_joins.extend(query_left_joins)
+
+ if query_where:
+ where.extend(query_where)
+
+ if hotlist:
+ left_joins.extend([
+ (('IssueSnapshot2Hotlist AS Is2h'
+ ' ON Is2h.issuesnapshot_id = IssueSnapshot.id'
+ ' AND Is2h.hotlist_id = %s'), [hotlist.hotlist_id]),
+ ])
+ where.append(
+ ('Is2h.hotlist_id = %s', [hotlist.hotlist_id]))
+
+ promises = []
+
+ for shard_id in range(settings.num_logical_shards):
+ count_stmt, stmt_args = self._BuildSnapshotQuery(cols=cols,
+ where=where, joins=left_joins, group_by=group_by,
+ shard_id=shard_id)
+ promises.append(framework_helpers.Promise(cnxn.Execute,
+ count_stmt, stmt_args, shard_id=shard_id))
+
+ shard_values_dict = {}
+
+ search_limit_reached = False
+
+ for promise in promises:
+ # Wait for each query to complete and add it to the dict.
+ shard_values = list(promise.WaitAndGetValue())
+
+ if not shard_values:
+ continue
+ if group_by:
+ for name, count in shard_values:
+ if count >= settings.chart_query_max_rows:
+ search_limit_reached = True
+
+ shard_values_dict.setdefault(name, 0)
+ shard_values_dict[name] += count
+ else:
+ if shard_values[0][0] >= settings.chart_query_max_rows:
+ search_limit_reached = True
+
+ shard_values_dict.setdefault('total', 0)
+ shard_values_dict['total'] += shard_values[0][0]
+
+ unsupported_field_names = list(set([
+ field.field_name
+ for cond in unsupported_conds
+ for field in cond.field_defs
+ ]))
+
+ return shard_values_dict, unsupported_field_names, search_limit_reached
+
+ def StoreIssueSnapshots(self, cnxn, issues, commit=True):
+ """Adds an IssueSnapshot and updates the previous one for each issue."""
+ for issue in issues:
+ right_now = self._currentTime()
+
+ # Update previous snapshot of current issue's end time to right now.
+ self.issuesnapshot_tbl.Update(cnxn,
+ delta={'period_end': right_now},
+ where=[('IssueSnapshot.issue_id = %s', [issue.issue_id]),
+ ('IssueSnapshot.period_end = %s',
+ [settings.maximum_snapshot_period_end])],
+ commit=commit)
+
+ config = self.config_service.GetProjectConfig(cnxn, issue.project_id)
+ period_end = settings.maximum_snapshot_period_end
+ is_open = tracker_helpers.MeansOpenInProject(
+ tracker_bizobj.GetStatus(issue), config)
+ shard = issue.issue_id % settings.num_logical_shards
+ status = tracker_bizobj.GetStatus(issue)
+ status_id = self.config_service.LookupStatusID(
+ cnxn, issue.project_id, status) or None
+ owner_id = tracker_bizobj.GetOwnerId(issue) or None
+
+ issuesnapshot_rows = [(issue.issue_id, shard, issue.project_id,
+ issue.local_id, issue.reporter_id, owner_id, status_id, right_now,
+ period_end, is_open)]
+
+ ids = self.issuesnapshot_tbl.InsertRows(
+ cnxn, ISSUESNAPSHOT_COLS[1:],
+ issuesnapshot_rows,
+ replace=True, commit=commit,
+ return_generated_ids=True)
+ issuesnapshot_id = ids[0]
+
+ # Add all labels to IssueSnapshot2Label.
+ label_rows = [
+ (issuesnapshot_id,
+ self.config_service.LookupLabelID(cnxn, issue.project_id, label))
+ for label in tracker_bizobj.GetLabels(issue)
+ ]
+ self.issuesnapshot2label_tbl.InsertRows(
+ cnxn, ISSUESNAPSHOT2LABEL_COLS,
+ label_rows, replace=True, commit=commit)
+
+ # Add all CCs to IssueSnapshot2Cc.
+ cc_rows = [
+ (issuesnapshot_id, cc_id)
+ for cc_id in tracker_bizobj.GetCcIds(issue)
+ ]
+ self.issuesnapshot2cc_tbl.InsertRows(
+ cnxn, ISSUESNAPSHOT2CC_COLS,
+ cc_rows,
+ replace=True, commit=commit)
+
+ # Add all components to IssueSnapshot2Component.
+ component_rows = [
+ (issuesnapshot_id, component_id)
+ for component_id in issue.component_ids
+ ]
+ self.issuesnapshot2component_tbl.InsertRows(
+ cnxn, ISSUESNAPSHOT2COMPONENT_COLS,
+ component_rows,
+ replace=True, commit=commit)
+
+ # Add all components to IssueSnapshot2Hotlist.
+ # This is raw SQL to obviate passing FeaturesService down through
+ # the call stack wherever this function is called.
+ # TODO(jrobbins): sort out dependencies between service classes.
+ cnxn.Execute('''
+ INSERT INTO IssueSnapshot2Hotlist (issuesnapshot_id, hotlist_id)
+ SELECT %s, hotlist_id FROM Hotlist2Issue WHERE issue_id = %s
+ ''', [issuesnapshot_id, issue.issue_id])
+
+ def ExpungeHotlistsFromIssueSnapshots(self, cnxn, hotlist_ids, commit=True):
+ """Expunge the existence of hotlists from issue snapshots.
+
+ Args:
+ cnxn: connection to SQL database.
+ hotlist_ids: list of hotlist_ids for hotlists we want to delete.
+ commit: set to False to skip the DB commit and do it in a caller.
+ """
+ vals_ph = sql.PlaceHolders(hotlist_ids)
+ cnxn.Execute(
+ 'DELETE FROM IssueSnapshot2Hotlist '
+ 'WHERE hotlist_id IN ({vals_ph})'.format(vals_ph=vals_ph),
+ hotlist_ids,
+ commit=commit)
+
+ def _currentTime(self):
+ """This is a separate method so it can be mocked by tests."""
+ return time.time()
+
+ def _QueryToWhere(self, cnxn, services, project_config, query, canned_query,
+ project_ids):
+ """Parses a query string into LEFT JOIN and WHERE conditions.
+
+ Args:
+ cnxn: A MonorailConnection instance.
+ services: A Services instance.
+ project_config: The configuration for the given project.
+ query (string): The query to parse.
+ canned_query (string): The supplied canned query.
+ project_ids: The current project ID(s).
+
+ Returns:
+ 1. A list of LEFT JOIN clauses for the SQL query.
+ 2. A list of WHERE clases for the SQL query.
+ 3. A list of query conditions that are unsupported with snapshots.
+ """
+ if not (query or canned_query):
+ return [], [], []
+
+ query = query or ''
+ scope = canned_query or ''
+
+ query_ast = query2ast.ParseUserQuery(query, scope,
+ query2ast.BUILTIN_ISSUE_FIELDS, project_config)
+ query_ast = ast2ast.PreprocessAST(cnxn, query_ast, project_ids,
+ services, project_config)
+ left_joins, where, unsupported = ast2select.BuildSQLQuery(query_ast,
+ snapshot_mode=True)
+
+ return left_joins, where, unsupported
+
+ def _BuildSnapshotQuery(self, cols, where, joins, group_by, shard_id):
+ """Given SQL arguments, executes a snapshot COUNT query."""
+ stmt = sql.Statement.MakeSelect('IssueSnapshot', cols, distinct=True)
+ stmt.AddJoinClauses(joins, left=True)
+ stmt.AddWhereTerms(where + [('IssueSnapshot.shard = %s', [shard_id])])
+ if group_by:
+ stmt.AddGroupByTerms(group_by)
+ stmt.SetLimitAndOffset(limit=settings.chart_query_max_rows, offset=0)
+ stmt_str, stmt_args = stmt.Generate()
+ if group_by:
+ if group_by[0] == 'IssueSnapshot.is_open':
+ count_stmt = ('SELECT IF(results.is_open = 1, "Opened", "Closed") ' \
+ 'AS bool_open, results.issue_count ' \
+ 'FROM (%s) AS results' % stmt_str)
+ else:
+ count_stmt = stmt_str
+ else:
+ count_stmt = 'SELECT COUNT(results.issue_id) FROM (%s) AS results' % (
+ stmt_str)
+ return count_stmt, stmt_args
diff --git a/services/client_config_svc.py b/services/client_config_svc.py
new file mode 100644
index 0000000..c0acf03
--- /dev/null
+++ b/services/client_config_svc.py
@@ -0,0 +1,236 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import base64
+import json
+import logging
+import os
+import time
+import urllib
+import webapp2
+
+from google.appengine.api import app_identity
+from google.appengine.api import urlfetch
+from google.appengine.ext import db
+from google.protobuf import text_format
+
+from infra_libs import ts_mon
+
+import settings
+from framework import framework_constants
+from proto import api_clients_config_pb2
+
+
+CONFIG_FILE_PATH = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
+ 'testing', 'api_clients.cfg')
+LUCI_CONFIG_URL = (
+ 'https://luci-config.appspot.com/_ah/api/config/v1/config_sets'
+ '/services/monorail-prod/config/api_clients.cfg')
+
+
+client_config_svc = None
+service_account_map = None
+qpm_dict = None
+allowed_origins_set = None
+
+
+class ClientConfig(db.Model):
+ configs = db.TextProperty()
+
+
+# Note: The cron job must have hit the servlet before this will work.
+class LoadApiClientConfigs(webapp2.RequestHandler):
+
+ config_loads = ts_mon.CounterMetric(
+ 'monorail/client_config_svc/loads',
+ 'Results of fetches from luci-config.',
+ [ts_mon.BooleanField('success'), ts_mon.StringField('type')])
+
+ def get(self):
+ global service_account_map
+ global qpm_dict
+ authorization_token, _ = app_identity.get_access_token(
+ framework_constants.OAUTH_SCOPE)
+ response = urlfetch.fetch(
+ LUCI_CONFIG_URL,
+ method=urlfetch.GET,
+ follow_redirects=False,
+ headers={'Content-Type': 'application/json; charset=UTF-8',
+ 'Authorization': 'Bearer ' + authorization_token})
+
+ if response.status_code != 200:
+ logging.error('Invalid response from luci-config: %r', response)
+ self.config_loads.increment({'success': False, 'type': 'luci-cfg-error'})
+ self.abort(500, 'Invalid response from luci-config')
+
+ try:
+ content_text = self._process_response(response)
+ except Exception as e:
+ self.abort(500, str(e))
+
+ logging.info('luci-config content decoded: %r.', content_text)
+ configs = ClientConfig(configs=content_text,
+ key_name='api_client_configs')
+ configs.put()
+ service_account_map = None
+ qpm_dict = None
+ self.config_loads.increment({'success': True, 'type': 'success'})
+
+ def _process_response(self, response):
+ try:
+ content = json.loads(response.content)
+ except ValueError:
+ logging.error('Response was not JSON: %r', response.content)
+ self.config_loads.increment({'success': False, 'type': 'json-load-error'})
+ raise
+
+ try:
+ config_content = content['content']
+ except KeyError:
+ logging.error('JSON contained no content: %r', content)
+ self.config_loads.increment({'success': False, 'type': 'json-key-error'})
+ raise
+
+ try:
+ content_text = base64.b64decode(config_content)
+ except TypeError:
+ logging.error('Content was not b64: %r', config_content)
+ self.config_loads.increment({'success': False,
+ 'type': 'b64-decode-error'})
+ raise
+
+ try:
+ cfg = api_clients_config_pb2.ClientCfg()
+ text_format.Merge(content_text, cfg)
+ except:
+ logging.error('Content was not a valid ClientCfg proto: %r', content_text)
+ self.config_loads.increment({'success': False,
+ 'type': 'proto-load-error'})
+ raise
+
+ return content_text
+
+
+class ClientConfigService(object):
+ """The persistence layer for client config data."""
+
+ # Reload no more than once every 15 minutes.
+ # Different GAE instances can load it at different times,
+ # so clients may get inconsistence responses shortly after allowlisting.
+ EXPIRES_IN = 15 * framework_constants.SECS_PER_MINUTE
+
+ def __init__(self):
+ self.client_configs = None
+ self.load_time = 0
+
+ def GetConfigs(self, use_cache=True, cur_time=None):
+ """Read client configs."""
+
+ cur_time = cur_time or int(time.time())
+ force_load = False
+ if not self.client_configs:
+ force_load = True
+ elif not use_cache:
+ force_load = True
+ elif cur_time - self.load_time > self.EXPIRES_IN:
+ force_load = True
+
+ if force_load:
+ if settings.local_mode or settings.unit_test_mode:
+ self._ReadFromFilesystem()
+ else:
+ self._ReadFromDatastore()
+
+ return self.client_configs
+
+ def _ReadFromFilesystem(self):
+ try:
+ with open(CONFIG_FILE_PATH, 'r') as f:
+ content_text = f.read()
+ logging.info('Read client configs from local file.')
+ cfg = api_clients_config_pb2.ClientCfg()
+ text_format.Merge(content_text, cfg)
+ self.client_configs = cfg
+ self.load_time = int(time.time())
+ except Exception as e:
+ logging.exception('Failed to read client configs: %s', e)
+
+ def _ReadFromDatastore(self):
+ entity = ClientConfig.get_by_key_name('api_client_configs')
+ if entity:
+ cfg = api_clients_config_pb2.ClientCfg()
+ text_format.Merge(entity.configs, cfg)
+ self.client_configs = cfg
+ self.load_time = int(time.time())
+ else:
+ logging.error('Failed to get api client configs from datastore.')
+
+ def GetClientIDEmails(self):
+ """Get client IDs and Emails."""
+ self.GetConfigs(use_cache=True)
+ client_ids = [c.client_id for c in self.client_configs.clients]
+ client_emails = [c.client_email for c in self.client_configs.clients]
+ return client_ids, client_emails
+
+ def GetDisplayNames(self):
+ """Get client display names."""
+ self.GetConfigs(use_cache=True)
+ names_dict = {}
+ for client in self.client_configs.clients:
+ if client.display_name:
+ names_dict[client.client_email] = client.display_name
+ return names_dict
+
+ def GetQPM(self):
+ """Get client qpm limit."""
+ self.GetConfigs(use_cache=True)
+ qpm_map = {}
+ for client in self.client_configs.clients:
+ if client.HasField('qpm_limit'):
+ qpm_map[client.client_email] = client.qpm_limit
+ return qpm_map
+
+ def GetAllowedOriginsSet(self):
+ """Get the set of all allowed origins."""
+ self.GetConfigs(use_cache=True)
+ origins = set()
+ for client in self.client_configs.clients:
+ origins.update(client.allowed_origins)
+ return origins
+
+
+def GetClientConfigSvc():
+ global client_config_svc
+ if client_config_svc is None:
+ client_config_svc = ClientConfigService()
+ return client_config_svc
+
+
+def GetServiceAccountMap():
+ # typ: () -> Mapping[str, str]
+ """Returns only service accounts that have specified display_names."""
+ global service_account_map
+ if service_account_map is None:
+ service_account_map = GetClientConfigSvc().GetDisplayNames()
+ return service_account_map
+
+
+def GetQPMDict():
+ global qpm_dict
+ if qpm_dict is None:
+ qpm_dict = GetClientConfigSvc().GetQPM()
+ return qpm_dict
+
+
+def GetAllowedOriginsSet():
+ global allowed_origins_set
+ if allowed_origins_set is None:
+ allowed_origins_set = GetClientConfigSvc().GetAllowedOriginsSet()
+ return allowed_origins_set
diff --git a/services/config_svc.py b/services/config_svc.py
new file mode 100644
index 0000000..27c1d3a
--- /dev/null
+++ b/services/config_svc.py
@@ -0,0 +1,1499 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Classes and functions for persistence of issue tracker configuration.
+
+This module provides functions to get, update, create, and (in some
+cases) delete each type of business object. It provides a logical
+persistence layer on top of an SQL database.
+
+Business objects are described in tracker_pb2.py and tracker_bizobj.py.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import logging
+
+from google.appengine.api import memcache
+
+import settings
+from framework import exceptions
+from framework import framework_constants
+from framework import sql
+from proto import tracker_pb2
+from services import caches
+from services import project_svc
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+
+PROJECTISSUECONFIG_TABLE_NAME = 'ProjectIssueConfig'
+LABELDEF_TABLE_NAME = 'LabelDef'
+FIELDDEF_TABLE_NAME = 'FieldDef'
+FIELDDEF2ADMIN_TABLE_NAME = 'FieldDef2Admin'
+FIELDDEF2EDITOR_TABLE_NAME = 'FieldDef2Editor'
+COMPONENTDEF_TABLE_NAME = 'ComponentDef'
+COMPONENT2ADMIN_TABLE_NAME = 'Component2Admin'
+COMPONENT2CC_TABLE_NAME = 'Component2Cc'
+COMPONENT2LABEL_TABLE_NAME = 'Component2Label'
+STATUSDEF_TABLE_NAME = 'StatusDef'
+APPROVALDEF2APPROVER_TABLE_NAME = 'ApprovalDef2Approver'
+APPROVALDEF2SURVEY_TABLE_NAME = 'ApprovalDef2Survey'
+
+PROJECTISSUECONFIG_COLS = [
+ 'project_id', 'statuses_offer_merge', 'exclusive_label_prefixes',
+ 'default_template_for_developers', 'default_template_for_users',
+ 'default_col_spec', 'default_sort_spec', 'default_x_attr',
+ 'default_y_attr', 'member_default_query', 'custom_issue_entry_url']
+STATUSDEF_COLS = [
+ 'id', 'project_id', 'rank', 'status', 'means_open', 'docstring',
+ 'deprecated']
+LABELDEF_COLS = [
+ 'id', 'project_id', 'rank', 'label', 'docstring', 'deprecated']
+FIELDDEF_COLS = [
+ 'id', 'project_id', 'rank', 'field_name', 'field_type', 'applicable_type',
+ 'applicable_predicate', 'is_required', 'is_niche', 'is_multivalued',
+ 'min_value', 'max_value', 'regex', 'needs_member', 'needs_perm',
+ 'grants_perm', 'notify_on', 'date_action', 'docstring', 'is_deleted',
+ 'approval_id', 'is_phase_field', 'is_restricted_field'
+]
+FIELDDEF2ADMIN_COLS = ['field_id', 'admin_id']
+FIELDDEF2EDITOR_COLS = ['field_id', 'editor_id']
+COMPONENTDEF_COLS = ['id', 'project_id', 'path', 'docstring', 'deprecated',
+ 'created', 'creator_id', 'modified', 'modifier_id']
+COMPONENT2ADMIN_COLS = ['component_id', 'admin_id']
+COMPONENT2CC_COLS = ['component_id', 'cc_id']
+COMPONENT2LABEL_COLS = ['component_id', 'label_id']
+APPROVALDEF2APPROVER_COLS = ['approval_id', 'approver_id', 'project_id']
+APPROVALDEF2SURVEY_COLS = ['approval_id', 'survey', 'project_id']
+
+NOTIFY_ON_ENUM = ['never', 'any_comment']
+DATE_ACTION_ENUM = ['no_action', 'ping_owner_only', 'ping_participants']
+
+# Some projects have tons of label rows, so we retrieve them in shards
+# to avoid huge DB results or exceeding the memcache size limit.
+LABEL_ROW_SHARDS = 10
+
+
+class LabelRowTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for label rows.
+
+ Label rows exist for every label used in a project, even those labels
+ that were added to issues in an ad hoc way without being defined in the
+ config ahead of time.
+
+ The set of all labels in a project can be very large, so we shard them
+ into 10 parts so that each part can be cached in memcache with < 1MB.
+ """
+
+ def __init__(self, cache_manager, config_service):
+ super(LabelRowTwoLevelCache, self).__init__(
+ cache_manager, 'project', 'label_rows:', None)
+ self.config_service = config_service
+
+ def _MakeCache(self, cache_manager, kind, max_size=None):
+ """Make the RAM cache and registier it with the cache_manager."""
+ return caches.ShardedRamCache(
+ cache_manager, kind, max_size=max_size, num_shards=LABEL_ROW_SHARDS)
+
+ def _DeserializeLabelRows(self, label_def_rows):
+ """Convert DB result rows into a dict {project_id: [row, ...]}."""
+ result_dict = collections.defaultdict(list)
+ for label_id, project_id, rank, label, docstr, deprecated in label_def_rows:
+ shard_id = label_id % LABEL_ROW_SHARDS
+ result_dict[(project_id, shard_id)].append(
+ (label_id, project_id, rank, label, docstr, deprecated))
+
+ return result_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database."""
+ # Make sure that every requested project is represented in the result
+ label_rows_dict = {}
+ for key in keys:
+ label_rows_dict.setdefault(key, [])
+
+ for project_id, shard_id in keys:
+ shard_clause = [('id %% %s = %s', [LABEL_ROW_SHARDS, shard_id])]
+
+ label_def_rows = self.config_service.labeldef_tbl.Select(
+ cnxn, cols=LABELDEF_COLS, project_id=project_id,
+ where=shard_clause)
+ label_rows_dict.update(self._DeserializeLabelRows(label_def_rows))
+
+ for rows_in_shard in label_rows_dict.values():
+ rows_in_shard.sort(key=lambda row: (row[2], row[3]), reverse=True)
+
+ return label_rows_dict
+
+ def InvalidateKeys(self, cnxn, project_ids):
+ """Drop the given keys from both RAM and memcache."""
+ self.cache.InvalidateKeys(cnxn, project_ids)
+ memcache.delete_multi(
+ [
+ self._KeyToStr((project_id, shard_id))
+ for project_id in project_ids
+ for shard_id in range(0, LABEL_ROW_SHARDS)
+ ],
+ seconds=5,
+ key_prefix=self.prefix,
+ namespace=settings.memcache_namespace)
+
+ def InvalidateAllKeys(self, cnxn, project_ids):
+ """Drop the given keys from memcache and invalidate all keys in RAM.
+
+ Useful for avoiding inserting many rows into the Invalidate table when
+ invalidating a large group of keys all at once. Only use when necessary.
+ """
+ self.cache.InvalidateAll(cnxn)
+ memcache.delete_multi(
+ [
+ self._KeyToStr((project_id, shard_id))
+ for project_id in project_ids
+ for shard_id in range(0, LABEL_ROW_SHARDS)
+ ],
+ seconds=5,
+ key_prefix=self.prefix,
+ namespace=settings.memcache_namespace)
+
+ def _KeyToStr(self, key):
+ """Convert our tuple IDs to strings for use as memcache keys."""
+ project_id, shard_id = key
+ return '%d-%d' % (project_id, shard_id)
+
+ def _StrToKey(self, key_str):
+ """Convert memcache keys back to the tuples that we use as IDs."""
+ project_id_str, shard_id_str = key_str.split('-')
+ return int(project_id_str), int(shard_id_str)
+
+
+class StatusRowTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for status rows."""
+
+ def __init__(self, cache_manager, config_service):
+ super(StatusRowTwoLevelCache, self).__init__(
+ cache_manager, 'project', 'status_rows:', None)
+ self.config_service = config_service
+
+ def _DeserializeStatusRows(self, def_rows):
+ """Convert status definition rows into {project_id: [row, ...]}."""
+ result_dict = collections.defaultdict(list)
+ for (status_id, project_id, rank, status,
+ means_open, docstr, deprecated) in def_rows:
+ result_dict[project_id].append(
+ (status_id, project_id, rank, status, means_open, docstr, deprecated))
+
+ return result_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On cache miss, get status definition rows from the DB."""
+ status_def_rows = self.config_service.statusdef_tbl.Select(
+ cnxn, cols=STATUSDEF_COLS, project_id=keys,
+ order_by=[('rank DESC', []), ('status DESC', [])])
+ status_rows_dict = self._DeserializeStatusRows(status_def_rows)
+
+ # Make sure that every requested project is represented in the result
+ for project_id in keys:
+ status_rows_dict.setdefault(project_id, [])
+
+ return status_rows_dict
+
+
+class FieldRowTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for field rows.
+
+ Field rows exist for every field used in a project, since they cannot be
+ created through ad-hoc means.
+ """
+
+ def __init__(self, cache_manager, config_service):
+ super(FieldRowTwoLevelCache, self).__init__(
+ cache_manager, 'project', 'field_rows:', None)
+ self.config_service = config_service
+
+ def _DeserializeFieldRows(self, field_def_rows):
+ """Convert DB result rows into a dict {project_id: [row, ...]}."""
+ result_dict = collections.defaultdict(list)
+ # TODO: Actually process the rest of the items.
+ for (field_id, project_id, rank, field_name, _field_type, _applicable_type,
+ _applicable_predicate, _is_required, _is_niche, _is_multivalued,
+ _min_value, _max_value, _regex, _needs_member, _needs_perm,
+ _grants_perm, _notify_on, _date_action, docstring, _is_deleted,
+ _approval_id, _is_phase_field, _is_restricted_field) in field_def_rows:
+ result_dict[project_id].append(
+ (field_id, project_id, rank, field_name, docstring))
+
+ return result_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database."""
+ field_def_rows = self.config_service.fielddef_tbl.Select(
+ cnxn, cols=FIELDDEF_COLS, project_id=keys,
+ order_by=[('rank DESC', []), ('field_name DESC', [])])
+ field_rows_dict = self._DeserializeFieldRows(field_def_rows)
+
+ # Make sure that every requested project is represented in the result
+ for project_id in keys:
+ field_rows_dict.setdefault(project_id, [])
+
+ return field_rows_dict
+
+
+class ConfigTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for IssueProjectConfig PBs."""
+
+ def __init__(self, cache_manager, config_service):
+ super(ConfigTwoLevelCache, self).__init__(
+ cache_manager, 'project', 'config:', tracker_pb2.ProjectIssueConfig)
+ self.config_service = config_service
+
+ def _UnpackProjectIssueConfig(self, config_row):
+ """Partially construct a config object using info from a DB row."""
+ (project_id, statuses_offer_merge, exclusive_label_prefixes,
+ default_template_for_developers, default_template_for_users,
+ default_col_spec, default_sort_spec, default_x_attr, default_y_attr,
+ member_default_query, custom_issue_entry_url) = config_row
+ config = tracker_pb2.ProjectIssueConfig()
+ config.project_id = project_id
+ config.statuses_offer_merge.extend(statuses_offer_merge.split())
+ config.exclusive_label_prefixes.extend(exclusive_label_prefixes.split())
+ config.default_template_for_developers = default_template_for_developers
+ config.default_template_for_users = default_template_for_users
+ config.default_col_spec = default_col_spec
+ config.default_sort_spec = default_sort_spec
+ config.default_x_attr = default_x_attr
+ config.default_y_attr = default_y_attr
+ config.member_default_query = member_default_query
+ if custom_issue_entry_url is not None:
+ config.custom_issue_entry_url = custom_issue_entry_url
+
+ return config
+
+ def _UnpackFieldDef(self, fielddef_row):
+ """Partially construct a FieldDef object using info from a DB row."""
+ (
+ field_id, project_id, _rank, field_name, field_type, applic_type,
+ applic_pred, is_required, is_niche, is_multivalued, min_value,
+ max_value, regex, needs_member, needs_perm, grants_perm, notify_on_str,
+ date_action_str, docstring, is_deleted, approval_id, is_phase_field,
+ is_restricted_field) = fielddef_row
+ if notify_on_str == 'any_comment':
+ notify_on = tracker_pb2.NotifyTriggers.ANY_COMMENT
+ else:
+ notify_on = tracker_pb2.NotifyTriggers.NEVER
+ try:
+ date_action = DATE_ACTION_ENUM.index(date_action_str)
+ except ValueError:
+ date_action = DATE_ACTION_ENUM.index('no_action')
+
+ return tracker_bizobj.MakeFieldDef(
+ field_id, project_id, field_name,
+ tracker_pb2.FieldTypes(field_type.upper()), applic_type, applic_pred,
+ is_required, is_niche, is_multivalued, min_value, max_value, regex,
+ needs_member, needs_perm, grants_perm, notify_on, date_action,
+ docstring, is_deleted, approval_id, is_phase_field, is_restricted_field)
+
+ def _UnpackComponentDef(
+ self, cd_row, component2admin_rows, component2cc_rows,
+ component2label_rows):
+ """Partially construct a FieldDef object using info from a DB row."""
+ (component_id, project_id, path, docstring, deprecated, created,
+ creator_id, modified, modifier_id) = cd_row
+ cd = tracker_bizobj.MakeComponentDef(
+ component_id, project_id, path, docstring, deprecated,
+ [admin_id for comp_id, admin_id in component2admin_rows
+ if comp_id == component_id],
+ [cc_id for comp_id, cc_id in component2cc_rows
+ if comp_id == component_id],
+ created, creator_id,
+ modified=modified, modifier_id=modifier_id,
+ label_ids=[label_id for comp_id, label_id in component2label_rows
+ if comp_id == component_id])
+
+ return cd
+
+ def _DeserializeIssueConfigs(
+ self, config_rows, statusdef_rows, labeldef_rows, fielddef_rows,
+ fielddef2admin_rows, fielddef2editor_rows, componentdef_rows,
+ component2admin_rows, component2cc_rows, component2label_rows,
+ approvaldef2approver_rows, approvaldef2survey_rows):
+ """Convert the given row tuples into a dict of ProjectIssueConfig PBs."""
+ result_dict = {}
+ fielddef_dict = {}
+ approvaldef_dict = {}
+
+ for config_row in config_rows:
+ config = self._UnpackProjectIssueConfig(config_row)
+ result_dict[config.project_id] = config
+
+ for statusdef_row in statusdef_rows:
+ (_, project_id, _rank, status,
+ means_open, docstring, deprecated) = statusdef_row
+ if project_id in result_dict:
+ wks = tracker_pb2.StatusDef(
+ status=status, means_open=bool(means_open),
+ status_docstring=docstring or '', deprecated=bool(deprecated))
+ result_dict[project_id].well_known_statuses.append(wks)
+
+ for labeldef_row in labeldef_rows:
+ _, project_id, _rank, label, docstring, deprecated = labeldef_row
+ if project_id in result_dict:
+ wkl = tracker_pb2.LabelDef(
+ label=label, label_docstring=docstring or '',
+ deprecated=bool(deprecated))
+ result_dict[project_id].well_known_labels.append(wkl)
+
+ for approver_row in approvaldef2approver_rows:
+ approval_id, approver_id, project_id = approver_row
+ if project_id in result_dict:
+ approval_def = approvaldef_dict.get(approval_id)
+ if approval_def is None:
+ approval_def = tracker_pb2.ApprovalDef(
+ approval_id=approval_id)
+ result_dict[project_id].approval_defs.append(approval_def)
+ approvaldef_dict[approval_id] = approval_def
+ approval_def.approver_ids.append(approver_id)
+
+ for survey_row in approvaldef2survey_rows:
+ approval_id, survey, project_id = survey_row
+ if project_id in result_dict:
+ approval_def = approvaldef_dict.get(approval_id)
+ if approval_def is None:
+ approval_def = tracker_pb2.ApprovalDef(
+ approval_id=approval_id)
+ result_dict[project_id].approval_defs.append(approval_def)
+ approvaldef_dict[approval_id] = approval_def
+ approval_def.survey = survey
+
+ for fd_row in fielddef_rows:
+ fd = self._UnpackFieldDef(fd_row)
+ result_dict[fd.project_id].field_defs.append(fd)
+ fielddef_dict[fd.field_id] = fd
+
+ for fd2admin_row in fielddef2admin_rows:
+ field_id, admin_id = fd2admin_row
+ fd = fielddef_dict.get(field_id)
+ if fd:
+ fd.admin_ids.append(admin_id)
+
+ for fd2editor_row in fielddef2editor_rows:
+ field_id, editor_id = fd2editor_row
+ fd = fielddef_dict.get(field_id)
+ if fd:
+ fd.editor_ids.append(editor_id)
+
+ for cd_row in componentdef_rows:
+ cd = self._UnpackComponentDef(
+ cd_row, component2admin_rows, component2cc_rows, component2label_rows)
+ result_dict[cd.project_id].component_defs.append(cd)
+
+ return result_dict
+
+ def _FetchConfigs(self, cnxn, project_ids):
+ """On RAM and memcache miss, hit the database."""
+ config_rows = self.config_service.projectissueconfig_tbl.Select(
+ cnxn, cols=PROJECTISSUECONFIG_COLS, project_id=project_ids)
+ statusdef_rows = self.config_service.statusdef_tbl.Select(
+ cnxn, cols=STATUSDEF_COLS, project_id=project_ids,
+ where=[('rank IS NOT NULL', [])], order_by=[('rank', [])])
+
+ labeldef_rows = self.config_service.labeldef_tbl.Select(
+ cnxn, cols=LABELDEF_COLS, project_id=project_ids,
+ where=[('rank IS NOT NULL', [])], order_by=[('rank', [])])
+
+ approver_rows = self.config_service.approvaldef2approver_tbl.Select(
+ cnxn, cols=APPROVALDEF2APPROVER_COLS, project_id=project_ids)
+ survey_rows = self.config_service.approvaldef2survey_tbl.Select(
+ cnxn, cols=APPROVALDEF2SURVEY_COLS, project_id=project_ids)
+
+ # TODO(jrobbins): For now, sort by field name, but someday allow admins
+ # to adjust the rank to group and order field definitions logically.
+ fielddef_rows = self.config_service.fielddef_tbl.Select(
+ cnxn, cols=FIELDDEF_COLS, project_id=project_ids,
+ order_by=[('field_name', [])])
+ field_ids = [row[0] for row in fielddef_rows]
+ fielddef2admin_rows = []
+ fielddef2editor_rows = []
+ if field_ids:
+ fielddef2admin_rows = self.config_service.fielddef2admin_tbl.Select(
+ cnxn, cols=FIELDDEF2ADMIN_COLS, field_id=field_ids)
+ fielddef2editor_rows = self.config_service.fielddef2editor_tbl.Select(
+ cnxn, cols=FIELDDEF2EDITOR_COLS, field_id=field_ids)
+
+ componentdef_rows = self.config_service.componentdef_tbl.Select(
+ cnxn, cols=COMPONENTDEF_COLS, project_id=project_ids,
+ is_deleted=False, order_by=[('path', [])])
+ component_ids = [cd_row[0] for cd_row in componentdef_rows]
+ component2admin_rows = []
+ component2cc_rows = []
+ component2label_rows = []
+ if component_ids:
+ component2admin_rows = self.config_service.component2admin_tbl.Select(
+ cnxn, cols=COMPONENT2ADMIN_COLS, component_id=component_ids)
+ component2cc_rows = self.config_service.component2cc_tbl.Select(
+ cnxn, cols=COMPONENT2CC_COLS, component_id=component_ids)
+ component2label_rows = self.config_service.component2label_tbl.Select(
+ cnxn, cols=COMPONENT2LABEL_COLS, component_id=component_ids)
+
+ retrieved_dict = self._DeserializeIssueConfigs(
+ config_rows, statusdef_rows, labeldef_rows, fielddef_rows,
+ fielddef2admin_rows, fielddef2editor_rows, componentdef_rows,
+ component2admin_rows, component2cc_rows, component2label_rows,
+ approver_rows, survey_rows)
+ return retrieved_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database."""
+ retrieved_dict = self._FetchConfigs(cnxn, keys)
+
+ # Any projects which don't have stored configs should use a default
+ # config instead.
+ for project_id in keys:
+ if project_id not in retrieved_dict:
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(project_id)
+ retrieved_dict[project_id] = config
+
+ return retrieved_dict
+
+
+class ConfigService(object):
+ """The persistence layer for Monorail's issue tracker configuration data."""
+
+ def __init__(self, cache_manager):
+ """Initialize this object so that it is ready to use.
+
+ Args:
+ cache_manager: manages local caches with distributed invalidation.
+ """
+ self.projectissueconfig_tbl = sql.SQLTableManager(
+ PROJECTISSUECONFIG_TABLE_NAME)
+ self.statusdef_tbl = sql.SQLTableManager(STATUSDEF_TABLE_NAME)
+ self.labeldef_tbl = sql.SQLTableManager(LABELDEF_TABLE_NAME)
+ self.fielddef_tbl = sql.SQLTableManager(FIELDDEF_TABLE_NAME)
+ self.fielddef2admin_tbl = sql.SQLTableManager(FIELDDEF2ADMIN_TABLE_NAME)
+ self.fielddef2editor_tbl = sql.SQLTableManager(FIELDDEF2EDITOR_TABLE_NAME)
+ self.componentdef_tbl = sql.SQLTableManager(COMPONENTDEF_TABLE_NAME)
+ self.component2admin_tbl = sql.SQLTableManager(COMPONENT2ADMIN_TABLE_NAME)
+ self.component2cc_tbl = sql.SQLTableManager(COMPONENT2CC_TABLE_NAME)
+ self.component2label_tbl = sql.SQLTableManager(COMPONENT2LABEL_TABLE_NAME)
+ self.approvaldef2approver_tbl = sql.SQLTableManager(
+ APPROVALDEF2APPROVER_TABLE_NAME)
+ self.approvaldef2survey_tbl = sql.SQLTableManager(
+ APPROVALDEF2SURVEY_TABLE_NAME)
+
+ self.config_2lc = ConfigTwoLevelCache(cache_manager, self)
+ self.label_row_2lc = LabelRowTwoLevelCache(cache_manager, self)
+ self.label_cache = caches.RamCache(cache_manager, 'project')
+ self.status_row_2lc = StatusRowTwoLevelCache(cache_manager, self)
+ self.status_cache = caches.RamCache(cache_manager, 'project')
+ self.field_row_2lc = FieldRowTwoLevelCache(cache_manager, self)
+ self.field_cache = caches.RamCache(cache_manager, 'project')
+
+ ### Label lookups
+
+ def GetLabelDefRows(self, cnxn, project_id, use_cache=True):
+ """Get SQL result rows for all labels used in the specified project."""
+ result = []
+ for shard_id in range(0, LABEL_ROW_SHARDS):
+ key = (project_id, shard_id)
+ pids_to_label_rows_shard, _misses = self.label_row_2lc.GetAll(
+ cnxn, [key], use_cache=use_cache)
+ result.extend(pids_to_label_rows_shard[key])
+ # Sort in python to reduce DB load and integrate results from shards.
+ # row[2] is rank, row[3] is label name.
+ result.sort(key=lambda row: (row[2], row[3]), reverse=True)
+ return result
+
+ def GetLabelDefRowsAnyProject(self, cnxn, where=None):
+ """Get all LabelDef rows for the whole site. Used in whole-site search."""
+ # TODO(jrobbins): maybe add caching for these too.
+ label_def_rows = self.labeldef_tbl.Select(
+ cnxn, cols=LABELDEF_COLS, where=where,
+ order_by=[('rank DESC', []), ('label DESC', [])])
+ return label_def_rows
+
+ def _DeserializeLabels(self, def_rows):
+ """Convert label defs into bi-directional mappings of names and IDs."""
+ label_id_to_name = {
+ label_id: label for
+ label_id, _pid, _rank, label, _doc, _deprecated
+ in def_rows}
+ label_name_to_id = {
+ label.lower(): label_id
+ for label_id, label in label_id_to_name.items()}
+
+ return label_id_to_name, label_name_to_id
+
+ def _EnsureLabelCacheEntry(self, cnxn, project_id, use_cache=True):
+ """Make sure that self.label_cache has an entry for project_id."""
+ if not use_cache or not self.label_cache.HasItem(project_id):
+ def_rows = self.GetLabelDefRows(cnxn, project_id, use_cache=use_cache)
+ self.label_cache.CacheItem(project_id, self._DeserializeLabels(def_rows))
+
+ def LookupLabel(self, cnxn, project_id, label_id):
+ """Lookup a label string given the label_id.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project where the label is defined or used.
+ label_id: int label ID.
+
+ Returns:
+ Label name string for the given label_id, or None.
+ """
+ self._EnsureLabelCacheEntry(cnxn, project_id)
+ label_id_to_name, _label_name_to_id = self.label_cache.GetItem(
+ project_id)
+ if label_id in label_id_to_name:
+ return label_id_to_name[label_id]
+
+ logging.info('Label %r not found. Getting fresh from DB.', label_id)
+ self._EnsureLabelCacheEntry(cnxn, project_id, use_cache=False)
+ label_id_to_name, _label_name_to_id = self.label_cache.GetItem(
+ project_id)
+ return label_id_to_name.get(label_id)
+
+ def LookupLabelID(self, cnxn, project_id, label, autocreate=True):
+ """Look up a label ID, optionally interning it.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project where the statuses are defined.
+ label: label string.
+ autocreate: if not already in the DB, store it and generate a new ID.
+
+ Returns:
+ The label ID for the given label string.
+ """
+ self._EnsureLabelCacheEntry(cnxn, project_id)
+ _label_id_to_name, label_name_to_id = self.label_cache.GetItem(
+ project_id)
+ if label.lower() in label_name_to_id:
+ return label_name_to_id[label.lower()]
+
+ # Double check that the label does not already exist in the DB.
+ rows = self.labeldef_tbl.Select(
+ cnxn, cols=['id'], project_id=project_id,
+ where=[('LOWER(label) = %s', [label.lower()])],
+ limit=1)
+ logging.info('Double checking for %r gave %r', label, rows)
+ if rows:
+ self.label_row_2lc.cache.LocalInvalidate(project_id)
+ self.label_cache.LocalInvalidate(project_id)
+ return rows[0][0]
+
+ if autocreate:
+ logging.info('No label %r is known in project %d, so intern it.',
+ label, project_id)
+ label_id = self.labeldef_tbl.InsertRow(
+ cnxn, project_id=project_id, label=label)
+ self.label_row_2lc.InvalidateKeys(cnxn, [project_id])
+ self.label_cache.Invalidate(cnxn, project_id)
+ return label_id
+
+ return None # It was not found and we don't want to create it.
+
+ def LookupLabelIDs(self, cnxn, project_id, labels, autocreate=False):
+ """Look up several label IDs.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project where the statuses are defined.
+ labels: list of label strings.
+ autocreate: if not already in the DB, store it and generate a new ID.
+
+ Returns:
+ Returns a list of int label IDs for the given label strings.
+ """
+ result = []
+ for lab in labels:
+ label_id = self.LookupLabelID(
+ cnxn, project_id, lab, autocreate=autocreate)
+ if label_id is not None:
+ result.append(label_id)
+
+ return result
+
+ def LookupIDsOfLabelsMatching(self, cnxn, project_id, regex):
+ """Look up the IDs of all labels in a project that match the regex.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project where the statuses are defined.
+ regex: regular expression object to match against the label strings.
+
+ Returns:
+ List of label IDs for labels that match the regex.
+ """
+ self._EnsureLabelCacheEntry(cnxn, project_id)
+ label_id_to_name, _label_name_to_id = self.label_cache.GetItem(
+ project_id)
+ result = [label_id for label_id, label in label_id_to_name.items()
+ if regex.match(label)]
+
+ return result
+
+ def LookupLabelIDsAnyProject(self, cnxn, label):
+ """Return the IDs of labels with the given name in any project.
+
+ Args:
+ cnxn: connection to SQL database.
+ label: string label to look up. Case sensitive.
+
+ Returns:
+ A list of int label IDs of all labels matching the given string.
+ """
+ # TODO(jrobbins): maybe add caching for these too.
+ label_id_rows = self.labeldef_tbl.Select(
+ cnxn, cols=['id'], label=label)
+ label_ids = [row[0] for row in label_id_rows]
+ return label_ids
+
+ def LookupIDsOfLabelsMatchingAnyProject(self, cnxn, regex):
+ """Return the IDs of matching labels in any project."""
+ label_rows = self.labeldef_tbl.Select(
+ cnxn, cols=['id', 'label'])
+ matching_ids = [
+ label_id for label_id, label in label_rows if regex.match(label)]
+ return matching_ids
+
+ ### Status lookups
+
+ def GetStatusDefRows(self, cnxn, project_id):
+ """Return a list of status definition rows for the specified project."""
+ pids_to_status_rows, misses = self.status_row_2lc.GetAll(
+ cnxn, [project_id])
+ assert not misses
+ return pids_to_status_rows[project_id]
+
+ def GetStatusDefRowsAnyProject(self, cnxn):
+ """Return all status definition rows on the whole site."""
+ # TODO(jrobbins): maybe add caching for these too.
+ status_def_rows = self.statusdef_tbl.Select(
+ cnxn, cols=STATUSDEF_COLS,
+ order_by=[('rank DESC', []), ('status DESC', [])])
+ return status_def_rows
+
+ def _DeserializeStatuses(self, def_rows):
+ """Convert status defs into bi-directional mappings of names and IDs."""
+ status_id_to_name = {
+ status_id: status
+ for (status_id, _pid, _rank, status, _means_open,
+ _doc, _deprecated) in def_rows}
+ status_name_to_id = {
+ status.lower(): status_id
+ for status_id, status in status_id_to_name.items()}
+ closed_status_ids = [
+ status_id
+ for (status_id, _pid, _rank, _status, means_open,
+ _doc, _deprecated) in def_rows
+ if means_open == 0] # Only 0 means closed. NULL/None means open.
+
+ return status_id_to_name, status_name_to_id, closed_status_ids
+
+ def _EnsureStatusCacheEntry(self, cnxn, project_id):
+ """Make sure that self.status_cache has an entry for project_id."""
+ if not self.status_cache.HasItem(project_id):
+ def_rows = self.GetStatusDefRows(cnxn, project_id)
+ self.status_cache.CacheItem(
+ project_id, self._DeserializeStatuses(def_rows))
+
+ def LookupStatus(self, cnxn, project_id, status_id):
+ """Look up a status string for the given status ID.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project where the statuses are defined.
+ status_id: int ID of the status value.
+
+ Returns:
+ A status string, or None.
+ """
+ if status_id == 0:
+ return ''
+
+ self._EnsureStatusCacheEntry(cnxn, project_id)
+ (status_id_to_name, _status_name_to_id,
+ _closed_status_ids) = self.status_cache.GetItem(project_id)
+
+ return status_id_to_name.get(status_id)
+
+ def LookupStatusID(self, cnxn, project_id, status, autocreate=True):
+ """Look up a status ID for the given status string.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project where the statuses are defined.
+ status: status string.
+ autocreate: if not already in the DB, store it and generate a new ID.
+
+ Returns:
+ The status ID for the given status string, or None.
+ """
+ if not status:
+ return None
+
+ self._EnsureStatusCacheEntry(cnxn, project_id)
+ (_status_id_to_name, status_name_to_id,
+ _closed_status_ids) = self.status_cache.GetItem(project_id)
+ if status.lower() in status_name_to_id:
+ return status_name_to_id[status.lower()]
+
+ if autocreate:
+ logging.info('No status %r is known in project %d, so intern it.',
+ status, project_id)
+ status_id = self.statusdef_tbl.InsertRow(
+ cnxn, project_id=project_id, status=status)
+ self.status_row_2lc.InvalidateKeys(cnxn, [project_id])
+ self.status_cache.Invalidate(cnxn, project_id)
+ return status_id
+
+ return None # It was not found and we don't want to create it.
+
+ def LookupStatusIDs(self, cnxn, project_id, statuses):
+ """Look up several status IDs for the given status strings.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project where the statuses are defined.
+ statuses: list of status strings.
+
+ Returns:
+ A list of int status IDs.
+ """
+ result = []
+ for stat in statuses:
+ status_id = self.LookupStatusID(cnxn, project_id, stat, autocreate=False)
+ if status_id:
+ result.append(status_id)
+
+ return result
+
+ def LookupClosedStatusIDs(self, cnxn, project_id):
+ """Return the IDs of closed statuses defined in the given project."""
+ self._EnsureStatusCacheEntry(cnxn, project_id)
+ (_status_id_to_name, _status_name_to_id,
+ closed_status_ids) = self.status_cache.GetItem(project_id)
+
+ return closed_status_ids
+
+ def LookupClosedStatusIDsAnyProject(self, cnxn):
+ """Return the IDs of closed statuses defined in any project."""
+ status_id_rows = self.statusdef_tbl.Select(
+ cnxn, cols=['id'], means_open=False)
+ status_ids = [row[0] for row in status_id_rows]
+ return status_ids
+
+ def LookupStatusIDsAnyProject(self, cnxn, status):
+ """Return the IDs of statues with the given name in any project."""
+ status_id_rows = self.statusdef_tbl.Select(
+ cnxn, cols=['id'], status=status)
+ status_ids = [row[0] for row in status_id_rows]
+ return status_ids
+
+ # TODO(jrobbins): regex matching for status values.
+
+ ### Issue tracker configuration objects
+
+ def GetProjectConfigs(self, cnxn, project_ids, use_cache=True):
+ # type: (MonorailConnection, Collection[int], Optional[bool])
+ # -> Mapping[int, ProjectConfig]
+ """Get several project issue config objects."""
+ config_dict, missed_ids = self.config_2lc.GetAll(
+ cnxn, project_ids, use_cache=use_cache)
+ if missed_ids:
+ raise exceptions.NoSuchProjectException()
+ return config_dict
+
+ def GetProjectConfig(self, cnxn, project_id, use_cache=True):
+ """Load a ProjectIssueConfig for the specified project from the database.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+ use_cache: if False, always hit the database.
+
+ Returns:
+ A ProjectIssueConfig describing how the issue tracker in the specified
+ project is configured. Projects only have a stored ProjectIssueConfig if
+ a project owner has edited the configuration. Other projects use a
+ default configuration.
+ """
+ config_dict = self.GetProjectConfigs(
+ cnxn, [project_id], use_cache=use_cache)
+ return config_dict[project_id]
+
+ def StoreConfig(self, cnxn, config):
+ """Update an issue config in the database.
+
+ Args:
+ cnxn: connection to SQL database.
+ config: ProjectIssueConfig PB to update.
+ """
+ # TODO(jrobbins): Convert default template index values into foreign
+ # key references. Updating an entire config might require (1) adding
+ # new templates, (2) updating the config with new foreign key values,
+ # and finally (3) deleting only the specific templates that should be
+ # deleted.
+ self.projectissueconfig_tbl.InsertRow(
+ cnxn, replace=True,
+ project_id=config.project_id,
+ statuses_offer_merge=' '.join(config.statuses_offer_merge),
+ exclusive_label_prefixes=' '.join(config.exclusive_label_prefixes),
+ default_template_for_developers=config.default_template_for_developers,
+ default_template_for_users=config.default_template_for_users,
+ default_col_spec=config.default_col_spec,
+ default_sort_spec=config.default_sort_spec,
+ default_x_attr=config.default_x_attr,
+ default_y_attr=config.default_y_attr,
+ member_default_query=config.member_default_query,
+ custom_issue_entry_url=config.custom_issue_entry_url,
+ commit=False)
+
+ self._UpdateWellKnownLabels(cnxn, config)
+ self._UpdateWellKnownStatuses(cnxn, config)
+ self._UpdateApprovals(cnxn, config)
+ cnxn.Commit()
+
+ def _UpdateWellKnownLabels(self, cnxn, config):
+ """Update the labels part of a project's issue configuration.
+
+ Args:
+ cnxn: connection to SQL database.
+ config: ProjectIssueConfig PB to update in the DB.
+ """
+ update_labeldef_rows = []
+ new_labeldef_rows = []
+ labels_seen = set()
+ for rank, wkl in enumerate(config.well_known_labels):
+ # Prevent duplicate key errors
+ if wkl.label in labels_seen:
+ raise exceptions.InputException('Defined label "%s" twice' % wkl.label)
+ labels_seen.add(wkl.label)
+ # We must specify label ID when replacing, otherwise a new ID is made.
+ label_id = self.LookupLabelID(
+ cnxn, config.project_id, wkl.label, autocreate=False)
+ if label_id:
+ row = (label_id, config.project_id, rank, wkl.label,
+ wkl.label_docstring, wkl.deprecated)
+ update_labeldef_rows.append(row)
+ else:
+ row = (
+ config.project_id, rank, wkl.label, wkl.label_docstring,
+ wkl.deprecated)
+ new_labeldef_rows.append(row)
+
+ self.labeldef_tbl.Update(
+ cnxn, {'rank': None}, project_id=config.project_id, commit=False)
+ self.labeldef_tbl.InsertRows(
+ cnxn, LABELDEF_COLS, update_labeldef_rows, replace=True, commit=False)
+ self.labeldef_tbl.InsertRows(
+ cnxn, LABELDEF_COLS[1:], new_labeldef_rows, commit=False)
+ self.label_row_2lc.InvalidateKeys(cnxn, [config.project_id])
+ self.label_cache.Invalidate(cnxn, config.project_id)
+
+ def _UpdateWellKnownStatuses(self, cnxn, config):
+ """Update the status part of a project's issue configuration.
+
+ Args:
+ cnxn: connection to SQL database.
+ config: ProjectIssueConfig PB to update in the DB.
+ """
+ update_statusdef_rows = []
+ new_statusdef_rows = []
+ for rank, wks in enumerate(config.well_known_statuses):
+ # We must specify label ID when replacing, otherwise a new ID is made.
+ status_id = self.LookupStatusID(cnxn, config.project_id, wks.status,
+ autocreate=False)
+ if status_id is not None:
+ row = (status_id, config.project_id, rank, wks.status,
+ bool(wks.means_open), wks.status_docstring, wks.deprecated)
+ update_statusdef_rows.append(row)
+ else:
+ row = (config.project_id, rank, wks.status,
+ bool(wks.means_open), wks.status_docstring, wks.deprecated)
+ new_statusdef_rows.append(row)
+
+ self.statusdef_tbl.Update(
+ cnxn, {'rank': None}, project_id=config.project_id, commit=False)
+ self.statusdef_tbl.InsertRows(
+ cnxn, STATUSDEF_COLS, update_statusdef_rows, replace=True,
+ commit=False)
+ self.statusdef_tbl.InsertRows(
+ cnxn, STATUSDEF_COLS[1:], new_statusdef_rows, commit=False)
+ self.status_row_2lc.InvalidateKeys(cnxn, [config.project_id])
+ self.status_cache.Invalidate(cnxn, config.project_id)
+
+ def _UpdateApprovals(self, cnxn, config):
+ """Update the approvals part of a project's issue configuration.
+
+ Args:
+ cnxn: connection to SQL database.
+ config: ProjectIssueConfig PB to update in the DB.
+ """
+ ids_to_field_def = {fd.field_id: fd for fd in config.field_defs}
+ for approval_def in config.approval_defs:
+ try:
+ approval_fd = ids_to_field_def[approval_def.approval_id]
+ if approval_fd.field_type != tracker_pb2.FieldTypes.APPROVAL_TYPE:
+ raise exceptions.InvalidFieldTypeException()
+ except KeyError:
+ raise exceptions.NoSuchFieldDefException()
+
+ self.approvaldef2approver_tbl.Delete(
+ cnxn, approval_id=approval_def.approval_id, commit=False)
+
+ self.approvaldef2approver_tbl.InsertRows(
+ cnxn, APPROVALDEF2APPROVER_COLS,
+ [(approval_def.approval_id, approver_id, config.project_id) for
+ approver_id in approval_def.approver_ids],
+ commit=False)
+
+ self.approvaldef2survey_tbl.Delete(
+ cnxn, approval_id=approval_def.approval_id, commit=False)
+ self.approvaldef2survey_tbl.InsertRow(
+ cnxn, approval_id=approval_def.approval_id,
+ survey=approval_def.survey, project_id=config.project_id,
+ commit=False)
+
+ def UpdateConfig(
+ self, cnxn, project, well_known_statuses=None,
+ statuses_offer_merge=None, well_known_labels=None,
+ excl_label_prefixes=None, default_template_for_developers=None,
+ default_template_for_users=None, list_prefs=None, restrict_to_known=None,
+ approval_defs=None):
+ """Update project's issue tracker configuration with the given info.
+
+ Args:
+ cnxn: connection to SQL database.
+ project: the project in which to update the issue tracker config.
+ well_known_statuses: [(status_name, docstring, means_open, deprecated),..]
+ statuses_offer_merge: list of status values that trigger UI to merge.
+ well_known_labels: [(label_name, docstring, deprecated),...]
+ excl_label_prefixes: list of prefix strings. Each issue should
+ have only one label with each of these prefixed.
+ default_template_for_developers: int ID of template to use for devs.
+ default_template_for_users: int ID of template to use for non-members.
+ list_prefs: defaults for columns and sorting.
+ restrict_to_known: optional bool to allow project owners
+ to limit issue status and label values to only the well-known ones.
+ approval_defs: [(approval_id, approver_ids, survey), ..]
+
+ Returns:
+ The updated ProjectIssueConfig PB.
+ """
+ project_id = project.project_id
+ project_config = self.GetProjectConfig(cnxn, project_id, use_cache=False)
+
+ if well_known_statuses is not None:
+ tracker_bizobj.SetConfigStatuses(project_config, well_known_statuses)
+
+ if statuses_offer_merge is not None:
+ project_config.statuses_offer_merge = statuses_offer_merge
+
+ if well_known_labels is not None:
+ tracker_bizobj.SetConfigLabels(project_config, well_known_labels)
+
+ if excl_label_prefixes is not None:
+ project_config.exclusive_label_prefixes = excl_label_prefixes
+
+ if approval_defs is not None:
+ tracker_bizobj.SetConfigApprovals(project_config, approval_defs)
+
+ if default_template_for_developers is not None:
+ project_config.default_template_for_developers = (
+ default_template_for_developers)
+ if default_template_for_users is not None:
+ project_config.default_template_for_users = default_template_for_users
+
+ if list_prefs:
+ (default_col_spec, default_sort_spec, default_x_attr, default_y_attr,
+ member_default_query) = list_prefs
+ project_config.default_col_spec = default_col_spec
+ project_config.default_col_spec = default_col_spec
+ project_config.default_sort_spec = default_sort_spec
+ project_config.default_x_attr = default_x_attr
+ project_config.default_y_attr = default_y_attr
+ project_config.member_default_query = member_default_query
+
+ if restrict_to_known is not None:
+ project_config.restrict_to_known = restrict_to_known
+
+ self.StoreConfig(cnxn, project_config)
+ self.config_2lc.InvalidateKeys(cnxn, [project_id])
+ self.InvalidateMemcacheForEntireProject(project_id)
+ # Invalidate all issue caches in all frontends to clear out
+ # sorting.art_values_cache which now has wrong sort orders.
+ cache_manager = self.config_2lc.cache.cache_manager
+ cache_manager.StoreInvalidateAll(cnxn, 'issue')
+
+ return project_config
+
+ def ExpungeConfig(self, cnxn, project_id):
+ """Completely delete the specified project config from the database."""
+ logging.info('expunging the config for %r', project_id)
+ self.statusdef_tbl.Delete(cnxn, project_id=project_id)
+ self.labeldef_tbl.Delete(cnxn, project_id=project_id)
+ self.projectissueconfig_tbl.Delete(cnxn, project_id=project_id)
+
+ self.config_2lc.InvalidateKeys(cnxn, [project_id])
+
+ def ExpungeUsersInConfigs(self, cnxn, user_ids, limit=None):
+ """Wipes specified users from the configs system.
+
+ This method will not commit the operation. This method will
+ not make changes to in-memory data.
+ """
+ self.component2admin_tbl.Delete(
+ cnxn, admin_id=user_ids, commit=False, limit=limit)
+ self.component2cc_tbl.Delete(
+ cnxn, cc_id=user_ids, commit=False, limit=limit)
+ self.componentdef_tbl.Update(
+ cnxn, {'creator_id': framework_constants.DELETED_USER_ID},
+ creator_id=user_ids, commit=False, limit=limit)
+ self.componentdef_tbl.Update(
+ cnxn, {'modifier_id': framework_constants.DELETED_USER_ID},
+ modifier_id=user_ids, commit=False, limit=limit)
+ self.fielddef2admin_tbl.Delete(
+ cnxn, admin_id=user_ids, commit=False, limit=limit)
+ self.fielddef2editor_tbl.Delete(
+ cnxn, editor_id=user_ids, commit=False, limit=limit)
+ self.approvaldef2approver_tbl.Delete(
+ cnxn, approver_id=user_ids, commit=False, limit=limit)
+
+ ### Custom field definitions
+
+ def CreateFieldDef(
+ self,
+ cnxn,
+ project_id,
+ field_name,
+ field_type_str,
+ applic_type,
+ applic_pred,
+ is_required,
+ is_niche,
+ is_multivalued,
+ min_value,
+ max_value,
+ regex,
+ needs_member,
+ needs_perm,
+ grants_perm,
+ notify_on,
+ date_action_str,
+ docstring,
+ admin_ids,
+ editor_ids,
+ approval_id=None,
+ is_phase_field=False,
+ is_restricted_field=False):
+ """Create a new field definition with the given info.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+ field_name: name of the new custom field.
+ field_type_str: string identifying the type of the custom field.
+ applic_type: string specifying issue type the field is applicable to.
+ applic_pred: string condition to test if the field is applicable.
+ is_required: True if the field should be required on issues.
+ is_niche: True if the field is not initially offered for editing, so users
+ must click to reveal such special-purpose or experimental fields.
+ is_multivalued: True if the field can occur multiple times on one issue.
+ min_value: optional validation for int_type fields.
+ max_value: optional validation for int_type fields.
+ regex: optional validation for str_type fields.
+ needs_member: optional validation for user_type fields.
+ needs_perm: optional validation for user_type fields.
+ grants_perm: optional string for perm to grant any user named in field.
+ notify_on: int enum of when to notify users named in field.
+ date_action_str: string saying who to notify when a date arrives.
+ docstring: string describing this field.
+ admin_ids: list of additional user IDs who can edit this field def.
+ editor_ids: list of additional user IDs
+ who can edit a restricted field value.
+ approval_id: field_id of approval field this field belongs to.
+ is_phase_field: True if field should only be associated with issue phases.
+ is_restricted_field: True if field has its edition restricted.
+
+ Returns:
+ Integer field_id of the new field definition.
+ """
+ field_id = self.fielddef_tbl.InsertRow(
+ cnxn,
+ project_id=project_id,
+ field_name=field_name,
+ field_type=field_type_str,
+ applicable_type=applic_type,
+ applicable_predicate=applic_pred,
+ is_required=is_required,
+ is_niche=is_niche,
+ is_multivalued=is_multivalued,
+ min_value=min_value,
+ max_value=max_value,
+ regex=regex,
+ needs_member=needs_member,
+ needs_perm=needs_perm,
+ grants_perm=grants_perm,
+ notify_on=NOTIFY_ON_ENUM[notify_on],
+ date_action=date_action_str,
+ docstring=docstring,
+ approval_id=approval_id,
+ is_phase_field=is_phase_field,
+ is_restricted_field=is_restricted_field,
+ commit=False)
+ self.fielddef2admin_tbl.InsertRows(
+ cnxn, FIELDDEF2ADMIN_COLS,
+ [(field_id, admin_id) for admin_id in admin_ids],
+ commit=False)
+ self.fielddef2editor_tbl.InsertRows(
+ cnxn,
+ FIELDDEF2EDITOR_COLS,
+ [(field_id, editor_id) for editor_id in editor_ids],
+ commit=False)
+ cnxn.Commit()
+ self.config_2lc.InvalidateKeys(cnxn, [project_id])
+ self.field_row_2lc.InvalidateKeys(cnxn, [project_id])
+ self.InvalidateMemcacheForEntireProject(project_id)
+ return field_id
+
+ def _DeserializeFields(self, def_rows):
+ """Convert field defs into bi-directional mappings of names and IDs."""
+ field_id_to_name = {
+ field_id: field
+ for field_id, _pid, _rank, field, _doc in def_rows}
+ field_name_to_id = {
+ field.lower(): field_id
+ for field_id, field in field_id_to_name.items()}
+
+ return field_id_to_name, field_name_to_id
+
+ def GetFieldDefRows(self, cnxn, project_id):
+ """Get SQL result rows for all fields used in the specified project."""
+ pids_to_field_rows, misses = self.field_row_2lc.GetAll(cnxn, [project_id])
+ assert not misses
+ return pids_to_field_rows[project_id]
+
+ def _EnsureFieldCacheEntry(self, cnxn, project_id):
+ """Make sure that self.field_cache has an entry for project_id."""
+ if not self.field_cache.HasItem(project_id):
+ def_rows = self.GetFieldDefRows(cnxn, project_id)
+ self.field_cache.CacheItem(
+ project_id, self._DeserializeFields(def_rows))
+
+ def LookupField(self, cnxn, project_id, field_id):
+ """Lookup a field string given the field_id.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project where the label is defined or used.
+ field_id: int field ID.
+
+ Returns:
+ Field name string for the given field_id, or None.
+ """
+ self._EnsureFieldCacheEntry(cnxn, project_id)
+ field_id_to_name, _field_name_to_id = self.field_cache.GetItem(
+ project_id)
+ return field_id_to_name.get(field_id)
+
+ def LookupFieldID(self, cnxn, project_id, field):
+ """Look up a field ID.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project where the fields are defined.
+ field: field string.
+
+ Returns:
+ The field ID for the given field string.
+ """
+ self._EnsureFieldCacheEntry(cnxn, project_id)
+ _field_id_to_name, field_name_to_id = self.field_cache.GetItem(
+ project_id)
+ return field_name_to_id.get(field.lower())
+
+ def SoftDeleteFieldDefs(self, cnxn, project_id, field_ids):
+ """Mark the specified field as deleted, it will be reaped later."""
+ self.fielddef_tbl.Update(cnxn, {'is_deleted': True}, id=field_ids)
+ self.config_2lc.InvalidateKeys(cnxn, [project_id])
+ self.InvalidateMemcacheForEntireProject(project_id)
+
+ # TODO(jrobbins): GC deleted field defs after field values are gone.
+
+ def UpdateFieldDef(
+ self,
+ cnxn,
+ project_id,
+ field_id,
+ field_name=None,
+ applicable_type=None,
+ applicable_predicate=None,
+ is_required=None,
+ is_niche=None,
+ is_multivalued=None,
+ min_value=None,
+ max_value=None,
+ regex=None,
+ needs_member=None,
+ needs_perm=None,
+ grants_perm=None,
+ notify_on=None,
+ date_action=None,
+ docstring=None,
+ admin_ids=None,
+ editor_ids=None,
+ is_restricted_field=None):
+ """Update the specified field definition."""
+ new_values = {}
+ if field_name is not None:
+ new_values['field_name'] = field_name
+ if applicable_type is not None:
+ new_values['applicable_type'] = applicable_type
+ if applicable_predicate is not None:
+ new_values['applicable_predicate'] = applicable_predicate
+ if is_required is not None:
+ new_values['is_required'] = bool(is_required)
+ if is_niche is not None:
+ new_values['is_niche'] = bool(is_niche)
+ if is_multivalued is not None:
+ new_values['is_multivalued'] = bool(is_multivalued)
+ if min_value is not None:
+ new_values['min_value'] = min_value
+ if max_value is not None:
+ new_values['max_value'] = max_value
+ if regex is not None:
+ new_values['regex'] = regex
+ if needs_member is not None:
+ new_values['needs_member'] = needs_member
+ if needs_perm is not None:
+ new_values['needs_perm'] = needs_perm
+ if grants_perm is not None:
+ new_values['grants_perm'] = grants_perm
+ if notify_on is not None:
+ new_values['notify_on'] = NOTIFY_ON_ENUM[notify_on]
+ if date_action is not None:
+ new_values['date_action'] = date_action
+ if docstring is not None:
+ new_values['docstring'] = docstring
+ if is_restricted_field is not None:
+ new_values['is_restricted_field'] = is_restricted_field
+
+ self.fielddef_tbl.Update(cnxn, new_values, id=field_id, commit=False)
+ if admin_ids is not None:
+ self.fielddef2admin_tbl.Delete(cnxn, field_id=field_id, commit=False)
+ self.fielddef2admin_tbl.InsertRows(
+ cnxn,
+ FIELDDEF2ADMIN_COLS, [(field_id, admin_id) for admin_id in admin_ids],
+ commit=False)
+ if editor_ids is not None:
+ self.fielddef2editor_tbl.Delete(cnxn, field_id=field_id, commit=False)
+ self.fielddef2editor_tbl.InsertRows(
+ cnxn,
+ FIELDDEF2EDITOR_COLS,
+ [(field_id, editor_id) for editor_id in editor_ids],
+ commit=False)
+ cnxn.Commit()
+ self.config_2lc.InvalidateKeys(cnxn, [project_id])
+ self.InvalidateMemcacheForEntireProject(project_id)
+
+ ### Component definitions
+
+ def FindMatchingComponentIDsAnyProject(self, cnxn, path_list, exact=True):
+ """Look up component IDs across projects.
+
+ Args:
+ cnxn: connection to SQL database.
+ path_list: list of component path prefixes.
+ exact: set to False to include all components which have one of the
+ given paths as their ancestor, instead of exact matches.
+
+ Returns:
+ A list of component IDs of component's whose paths match path_list.
+ """
+ or_terms = []
+ args = []
+ for path in path_list:
+ or_terms.append('path = %s')
+ args.append(path)
+
+ if not exact:
+ for path in path_list:
+ or_terms.append('path LIKE %s')
+ args.append(path + '>%')
+
+ cond_str = '(' + ' OR '.join(or_terms) + ')'
+ rows = self.componentdef_tbl.Select(
+ cnxn, cols=['id'], where=[(cond_str, args)])
+ return [row[0] for row in rows]
+
+ def CreateComponentDef(
+ self, cnxn, project_id, path, docstring, deprecated, admin_ids, cc_ids,
+ created, creator_id, label_ids):
+ """Create a new component definition with the given info.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+ path: string pathname of the new component.
+ docstring: string describing this field.
+ deprecated: whether or not this should be autocompleted
+ admin_ids: list of int IDs of users who can administer.
+ cc_ids: list of int IDs of users to notify when an issue in
+ this component is updated.
+ created: timestamp this component was created at.
+ creator_id: int ID of user who created this component.
+ label_ids: list of int IDs of labels to add when an issue is
+ in this component.
+
+ Returns:
+ Integer component_id of the new component definition.
+ """
+ component_id = self.componentdef_tbl.InsertRow(
+ cnxn, project_id=project_id, path=path, docstring=docstring,
+ deprecated=deprecated, created=created, creator_id=creator_id,
+ commit=False)
+ self.component2admin_tbl.InsertRows(
+ cnxn, COMPONENT2ADMIN_COLS,
+ [(component_id, admin_id) for admin_id in admin_ids],
+ commit=False)
+ self.component2cc_tbl.InsertRows(
+ cnxn, COMPONENT2CC_COLS,
+ [(component_id, cc_id) for cc_id in cc_ids],
+ commit=False)
+ self.component2label_tbl.InsertRows(
+ cnxn, COMPONENT2LABEL_COLS,
+ [(component_id, label_id) for label_id in label_ids],
+ commit=False)
+ cnxn.Commit()
+ self.config_2lc.InvalidateKeys(cnxn, [project_id])
+ self.InvalidateMemcacheForEntireProject(project_id)
+ return component_id
+
+ def UpdateComponentDef(
+ self, cnxn, project_id, component_id, path=None, docstring=None,
+ deprecated=None, admin_ids=None, cc_ids=None, created=None,
+ creator_id=None, modified=None, modifier_id=None,
+ label_ids=None):
+ """Update the specified component definition."""
+ new_values = {}
+ if path is not None:
+ assert path
+ new_values['path'] = path
+ if docstring is not None:
+ new_values['docstring'] = docstring
+ if deprecated is not None:
+ new_values['deprecated'] = deprecated
+ if created is not None:
+ new_values['created'] = created
+ if creator_id is not None:
+ new_values['creator_id'] = creator_id
+ if modified is not None:
+ new_values['modified'] = modified
+ if modifier_id is not None:
+ new_values['modifier_id'] = modifier_id
+
+ if admin_ids is not None:
+ self.component2admin_tbl.Delete(
+ cnxn, component_id=component_id, commit=False)
+ self.component2admin_tbl.InsertRows(
+ cnxn, COMPONENT2ADMIN_COLS,
+ [(component_id, admin_id) for admin_id in admin_ids],
+ commit=False)
+
+ if cc_ids is not None:
+ self.component2cc_tbl.Delete(
+ cnxn, component_id=component_id, commit=False)
+ self.component2cc_tbl.InsertRows(
+ cnxn, COMPONENT2CC_COLS,
+ [(component_id, cc_id) for cc_id in cc_ids],
+ commit=False)
+
+ if label_ids is not None:
+ self.component2label_tbl.Delete(
+ cnxn, component_id=component_id, commit=False)
+ self.component2label_tbl.InsertRows(
+ cnxn, COMPONENT2LABEL_COLS,
+ [(component_id, label_id) for label_id in label_ids],
+ commit=False)
+
+ self.componentdef_tbl.Update(
+ cnxn, new_values, id=component_id, commit=False)
+ cnxn.Commit()
+ self.config_2lc.InvalidateKeys(cnxn, [project_id])
+ self.InvalidateMemcacheForEntireProject(project_id)
+
+ def DeleteComponentDef(self, cnxn, project_id, component_id):
+ """Delete the specified component definition."""
+ self.componentdef_tbl.Update(
+ cnxn, {'is_deleted': True}, id=component_id, commit=False)
+
+ cnxn.Commit()
+ self.config_2lc.InvalidateKeys(cnxn, [project_id])
+ self.InvalidateMemcacheForEntireProject(project_id)
+
+ ### Memcache management
+
+ def InvalidateMemcache(self, issues, key_prefix=''):
+ """Delete the memcache entries for issues and their project-shard pairs."""
+ memcache.delete_multi(
+ [str(issue.issue_id) for issue in issues], key_prefix='issue:',
+ seconds=5, namespace=settings.memcache_namespace)
+ project_shards = set(
+ (issue.project_id, issue.issue_id % settings.num_logical_shards)
+ for issue in issues)
+ self._InvalidateMemcacheShards(project_shards, key_prefix=key_prefix)
+
+ def _InvalidateMemcacheShards(self, project_shards, key_prefix=''):
+ """Delete the memcache entries for the given project-shard pairs.
+
+ Deleting these rows does not delete the actual cached search results
+ but it does mean that they will be considered stale and thus not used.
+
+ Args:
+ project_shards: list of (pid, sid) pairs.
+ key_prefix: string to pass as memcache key prefix.
+ """
+ cache_entries = ['%d;%d' % ps for ps in project_shards]
+ # Whenever any project is invalidated, also invalidate the 'all'
+ # entry that is used in site-wide searches.
+ shard_id_set = {sid for _pid, sid in project_shards}
+ cache_entries.extend(('all;%d' % sid) for sid in shard_id_set)
+
+ memcache.delete_multi(
+ cache_entries, key_prefix=key_prefix,
+ namespace=settings.memcache_namespace)
+
+ def InvalidateMemcacheForEntireProject(self, project_id):
+ """Delete the memcache entries for all searches in a project."""
+ project_shards = set((project_id, shard_id)
+ for shard_id in range(settings.num_logical_shards))
+ self._InvalidateMemcacheShards(project_shards)
+ memcache.delete_multi(
+ [str(project_id)], key_prefix='config:',
+ namespace=settings.memcache_namespace)
+ memcache.delete_multi(
+ [str(project_id)], key_prefix='label_rows:',
+ namespace=settings.memcache_namespace)
+ memcache.delete_multi(
+ [str(project_id)], key_prefix='status_rows:',
+ namespace=settings.memcache_namespace)
+ memcache.delete_multi(
+ [str(project_id)], key_prefix='field_rows:',
+ namespace=settings.memcache_namespace)
+
+ def UsersInvolvedInConfig(self, config, project_templates):
+ """Return a set of all user IDs referenced in the ProjectIssueConfig."""
+ result = set()
+ for template in project_templates:
+ result.update(tracker_bizobj.UsersInvolvedInTemplate(template))
+ for field in config.field_defs:
+ result.update(field.admin_ids)
+ result.update(field.editor_ids)
+ # TODO(jrobbins): add component owners, auto-cc, and admins.
+ return result
diff --git a/services/features_svc.py b/services/features_svc.py
new file mode 100644
index 0000000..471a513
--- /dev/null
+++ b/services/features_svc.py
@@ -0,0 +1,1381 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A class that provides persistence for Monorail's additional features.
+
+Business objects are described in tracker_pb2.py, features_pb2.py, and
+tracker_bizobj.py.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import logging
+import re
+import time
+
+import settings
+
+from features import features_constants
+from features import filterrules_helpers
+from framework import exceptions
+from framework import framework_bizobj
+from framework import framework_constants
+from framework import sql
+from proto import features_pb2
+from services import caches
+from services import config_svc
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+QUICKEDITHISTORY_TABLE_NAME = 'QuickEditHistory'
+QUICKEDITMOSTRECENT_TABLE_NAME = 'QuickEditMostRecent'
+SAVEDQUERY_TABLE_NAME = 'SavedQuery'
+PROJECT2SAVEDQUERY_TABLE_NAME = 'Project2SavedQuery'
+SAVEDQUERYEXECUTESINPROJECT_TABLE_NAME = 'SavedQueryExecutesInProject'
+USER2SAVEDQUERY_TABLE_NAME = 'User2SavedQuery'
+FILTERRULE_TABLE_NAME = 'FilterRule'
+HOTLIST_TABLE_NAME = 'Hotlist'
+HOTLIST2ISSUE_TABLE_NAME = 'Hotlist2Issue'
+HOTLIST2USER_TABLE_NAME = 'Hotlist2User'
+
+
+QUICKEDITHISTORY_COLS = [
+ 'user_id', 'project_id', 'slot_num', 'command', 'comment']
+QUICKEDITMOSTRECENT_COLS = ['user_id', 'project_id', 'slot_num']
+SAVEDQUERY_COLS = ['id', 'name', 'base_query_id', 'query']
+PROJECT2SAVEDQUERY_COLS = ['project_id', 'rank', 'query_id']
+SAVEDQUERYEXECUTESINPROJECT_COLS = ['query_id', 'project_id']
+USER2SAVEDQUERY_COLS = ['user_id', 'rank', 'query_id', 'subscription_mode']
+FILTERRULE_COLS = ['project_id', 'rank', 'predicate', 'consequence']
+HOTLIST_COLS = [
+ 'id', 'name', 'summary', 'description', 'is_private', 'default_col_spec']
+HOTLIST_ABBR_COLS = ['id', 'name', 'summary', 'is_private']
+HOTLIST2ISSUE_COLS = [
+ 'hotlist_id', 'issue_id', 'rank', 'adder_id', 'added', 'note']
+HOTLIST2USER_COLS = ['hotlist_id', 'user_id', 'role_name']
+
+
+# Regex for parsing one action in the filter rule consequence storage syntax.
+CONSEQUENCE_RE = re.compile(
+ r'(default_status:(?P<default_status>[-.\w]+))|'
+ r'(default_owner_id:(?P<default_owner_id>\d+))|'
+ r'(add_cc_id:(?P<add_cc_id>\d+))|'
+ r'(add_label:(?P<add_label>[-.\w]+))|'
+ r'(add_notify:(?P<add_notify>[-.@\w]+))|'
+ r'(warning:(?P<warning>.+))|' # Warnings consume the rest of the string.
+ r'(error:(?P<error>.+))' # Errors consume the rest of the string.
+ )
+
+class HotlistTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage both RAM and memcache for Hotlist PBs."""
+
+ def __init__(self, cachemanager, features_service):
+ super(HotlistTwoLevelCache, self).__init__(
+ cachemanager, 'hotlist', 'hotlist:', features_pb2.Hotlist)
+ self.features_service = features_service
+
+ def _DeserializeHotlists(
+ self, hotlist_rows, issue_rows, role_rows):
+ """Convert database rows into a dictionary of Hotlist PB keyed by ID.
+
+ Args:
+ hotlist_rows: a list of hotlist rows from HOTLIST_TABLE_NAME.
+ issue_rows: a list of issue rows from HOTLIST2ISSUE_TABLE_NAME,
+ ordered by rank DESC, issue_id.
+ role_rows: a list of role rows from HOTLIST2USER_TABLE_NAME.
+
+ Returns:
+ a dict mapping hotlist_id to hotlist PB"""
+ hotlist_dict = {}
+
+ for hotlist_row in hotlist_rows:
+ (hotlist_id, hotlist_name, summary, description, is_private,
+ default_col_spec) = hotlist_row
+ hotlist = features_pb2.MakeHotlist(
+ hotlist_name, hotlist_id=hotlist_id, summary=summary,
+ description=description, is_private=bool(is_private),
+ default_col_spec=default_col_spec)
+ hotlist_dict[hotlist_id] = hotlist
+
+ for (hotlist_id, issue_id, rank, adder_id, added, note) in issue_rows:
+ hotlist = hotlist_dict.get(hotlist_id)
+ if hotlist:
+ hotlist.items.append(
+ features_pb2.MakeHotlistItem(issue_id=issue_id, rank=rank,
+ adder_id=adder_id , date_added=added,
+ note=note))
+ else:
+ logging.warn('hotlist %d not found', hotlist_id)
+
+ for (hotlist_id, user_id, role_name) in role_rows:
+ hotlist = hotlist_dict.get(hotlist_id)
+ if not hotlist:
+ logging.warn('hotlist %d not found', hotlist_id)
+ elif role_name == 'owner':
+ hotlist.owner_ids.append(user_id)
+ elif role_name == 'editor':
+ hotlist.editor_ids.append(user_id)
+ elif role_name == 'follower':
+ hotlist.follower_ids.append(user_id)
+ else:
+ logging.info('unknown role name %s', role_name)
+
+ return hotlist_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database to get missing hotlists."""
+ hotlist_rows = self.features_service.hotlist_tbl.Select(
+ cnxn, cols=HOTLIST_COLS, is_deleted=False, id=keys)
+ issue_rows = self.features_service.hotlist2issue_tbl.Select(
+ cnxn, cols=HOTLIST2ISSUE_COLS, hotlist_id=keys,
+ order_by=[('rank DESC', []), ('issue_id', [])])
+ role_rows = self.features_service.hotlist2user_tbl.Select(
+ cnxn, cols=HOTLIST2USER_COLS, hotlist_id=keys)
+ retrieved_dict = self._DeserializeHotlists(
+ hotlist_rows, issue_rows, role_rows)
+ return retrieved_dict
+
+
+class HotlistIDTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage both RAM and memcache for hotlist_ids.
+
+ Keys for this cache are tuples (hotlist_name.lower(), owner_id).
+ This cache should be used to fetch hotlist_ids owned by users or
+ to check if a user owns a hotlist with a certain name, so the
+ hotlist_names in keys will always be in lowercase.
+ """
+
+ def __init__(self, cachemanager, features_service):
+ super(HotlistIDTwoLevelCache, self).__init__(
+ cachemanager, 'hotlist_id', 'hotlist_id:', int,
+ max_size=settings.issue_cache_max_size)
+ self.features_service = features_service
+
+ def _MakeCache(self, cache_manager, kind, max_size=None):
+ """Override normal RamCache creation with ValueCentricRamCache."""
+ return caches.ValueCentricRamCache(cache_manager, kind, max_size=max_size)
+
+ def _KeyToStr(self, key):
+ """This cache uses pairs of (str, int) as keys. Convert them to strings."""
+ return '%s,%d' % key
+
+ def _StrToKey(self, key_str):
+ """This cache uses pairs of (str, int) as keys.
+ Convert them from strings.
+ """
+ hotlist_name_str, owner_id_str = key_str.split(',')
+ return (hotlist_name_str, int(owner_id_str))
+
+ def _DeserializeHotlistIDs(
+ self, hotlist_rows, owner_rows, wanted_names_for_owners):
+ """Convert database rows into a dictionary of hotlist_ids keyed by (
+ hotlist_name, owner_id).
+
+ Args:
+ hotlist_rows: a list of hotlist rows [id, name] from HOTLIST for
+ with names we are interested in.
+ owner_rows: a list of role rows [hotlist_id, uwer_id] from HOTLIST2USER
+ for owners that we are interested in that own hotlists with names that
+ we are interested in.
+ wanted_names_for_owners: a dict of
+ {owner_id: [hotlist_name.lower(), ...], ...}
+ so we know which (hotlist_name, owner_id) keys to return.
+
+ Returns:
+ A dict mapping (hotlist_name.lower(), owner_id) keys to hotlist_id values.
+ """
+ hotlist_ids_dict = {}
+ if not hotlist_rows or not owner_rows:
+ return hotlist_ids_dict
+
+ hotlist_to_owner_id = {}
+
+ # Note: owner_rows contains hotlist owners that we are interested in, but
+ # may not own hotlists with names we are interested in.
+ for (hotlist_id, user_id) in owner_rows:
+ found_owner_id = hotlist_to_owner_id.get(hotlist_id)
+ if found_owner_id:
+ logging.warn(
+ 'hotlist %d has more than one owner: %d, %d',
+ hotlist_id, user_id, found_owner_id)
+ hotlist_to_owner_id[hotlist_id] = user_id
+
+ # Note: hotlist_rows hotlists found in the owner_rows that have names
+ # we're interested in.
+ # We use wanted_names_for_owners to filter out hotlists in hotlist_rows
+ # that have a (hotlist_name, owner_id) pair we are not interested in.
+ for (hotlist_id, hotlist_name) in hotlist_rows:
+ owner_id = hotlist_to_owner_id.get(hotlist_id)
+ if owner_id:
+ if hotlist_name.lower() in wanted_names_for_owners.get(owner_id, []):
+ hotlist_ids_dict[(hotlist_name.lower(), owner_id)] = hotlist_id
+
+ return hotlist_ids_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database."""
+ hotlist_names, _owner_ids = zip(*keys)
+ # Keys may contain [(name1, user1), (name1, user2)] so we cast this to
+ # a set to make sure 'name1' is not repeated.
+ hotlist_names_set = set(hotlist_names)
+ # Pass this dict to _DeserializeHotlistIDs so it knows what hotlist names
+ # we're interested in for each owner.
+ wanted_names_for_owner = collections.defaultdict(list)
+ for hotlist_name, owner_id in keys:
+ wanted_names_for_owner[owner_id].append(hotlist_name.lower())
+
+ role_rows = self.features_service.hotlist2user_tbl.Select(
+ cnxn, cols=['hotlist_id', 'user_id'],
+ user_id=wanted_names_for_owner.keys(), role_name='owner')
+
+ hotlist_ids = [row[0] for row in role_rows]
+ hotlist_rows = self.features_service.hotlist_tbl.Select(
+ cnxn, cols=['id', 'name'], id=hotlist_ids, is_deleted=False,
+ where=[('LOWER(name) IN (%s)' % sql.PlaceHolders(hotlist_names_set),
+ [name.lower() for name in hotlist_names_set])])
+
+ return self._DeserializeHotlistIDs(
+ hotlist_rows, role_rows, wanted_names_for_owner)
+
+
+class FeaturesService(object):
+ """The persistence layer for servlets in the features directory."""
+
+ def __init__(self, cache_manager, config_service):
+ """Initialize this object so that it is ready to use.
+
+ Args:
+ cache_manager: local cache with distributed invalidation.
+ config_service: an instance of ConfigService.
+ """
+ self.quickedithistory_tbl = sql.SQLTableManager(QUICKEDITHISTORY_TABLE_NAME)
+ self.quickeditmostrecent_tbl = sql.SQLTableManager(
+ QUICKEDITMOSTRECENT_TABLE_NAME)
+
+ self.savedquery_tbl = sql.SQLTableManager(SAVEDQUERY_TABLE_NAME)
+ self.project2savedquery_tbl = sql.SQLTableManager(
+ PROJECT2SAVEDQUERY_TABLE_NAME)
+ self.savedqueryexecutesinproject_tbl = sql.SQLTableManager(
+ SAVEDQUERYEXECUTESINPROJECT_TABLE_NAME)
+ self.user2savedquery_tbl = sql.SQLTableManager(USER2SAVEDQUERY_TABLE_NAME)
+
+ self.filterrule_tbl = sql.SQLTableManager(FILTERRULE_TABLE_NAME)
+
+ self.hotlist_tbl = sql.SQLTableManager(HOTLIST_TABLE_NAME)
+ self.hotlist2issue_tbl = sql.SQLTableManager(HOTLIST2ISSUE_TABLE_NAME)
+ self.hotlist2user_tbl = sql.SQLTableManager(HOTLIST2USER_TABLE_NAME)
+
+ self.saved_query_cache = caches.RamCache(
+ cache_manager, 'user', max_size=1000)
+ self.canned_query_cache = caches.RamCache(
+ cache_manager, 'project', max_size=1000)
+
+ self.hotlist_2lc = HotlistTwoLevelCache(cache_manager, self)
+ self.hotlist_id_2lc = HotlistIDTwoLevelCache(cache_manager, self)
+ self.hotlist_user_to_ids = caches.RamCache(cache_manager, 'hotlist')
+
+ self.config_service = config_service
+
+ ### QuickEdit command history
+
+ def GetRecentCommands(self, cnxn, user_id, project_id):
+ """Return recent command items for the "Redo" menu.
+
+ Args:
+ cnxn: Connection to SQL database.
+ user_id: int ID of the current user.
+ project_id: int ID of the current project.
+
+ Returns:
+ A pair (cmd_slots, recent_slot_num). cmd_slots is a list of
+ 3-tuples that can be used to populate the "Redo" menu of the
+ quick-edit dialog. recent_slot_num indicates which of those
+ slots should initially populate the command and comment fields.
+ """
+ # Always start with the standard 5 commands.
+ history = tracker_constants.DEFAULT_RECENT_COMMANDS[:]
+ # If the user has modified any, then overwrite some standard ones.
+ history_rows = self.quickedithistory_tbl.Select(
+ cnxn, cols=['slot_num', 'command', 'comment'],
+ user_id=user_id, project_id=project_id)
+ for slot_num, command, comment in history_rows:
+ if slot_num < len(history):
+ history[slot_num - 1] = (command, comment)
+
+ slots = []
+ for idx, (command, comment) in enumerate(history):
+ slots.append((idx + 1, command, comment))
+
+ recent_slot_num = self.quickeditmostrecent_tbl.SelectValue(
+ cnxn, 'slot_num', default=1, user_id=user_id, project_id=project_id)
+
+ return slots, recent_slot_num
+
+ def StoreRecentCommand(
+ self, cnxn, user_id, project_id, slot_num, command, comment):
+ """Store the given command and comment in the user's command history."""
+ self.quickedithistory_tbl.InsertRow(
+ cnxn, replace=True, user_id=user_id, project_id=project_id,
+ slot_num=slot_num, command=command, comment=comment)
+ self.quickeditmostrecent_tbl.InsertRow(
+ cnxn, replace=True, user_id=user_id, project_id=project_id,
+ slot_num=slot_num)
+
+ def ExpungeQuickEditHistory(self, cnxn, project_id):
+ """Completely delete every users' quick edit history for this project."""
+ self.quickeditmostrecent_tbl.Delete(cnxn, project_id=project_id)
+ self.quickedithistory_tbl.Delete(cnxn, project_id=project_id)
+
+ def ExpungeQuickEditsByUsers(self, cnxn, user_ids, limit=None):
+ """Completely delete every given users' quick edits.
+
+ This method will not commit the operations. This method will
+ not make changes to in-memory data.
+ """
+ commit = False
+ self.quickeditmostrecent_tbl.Delete(
+ cnxn, user_id=user_ids, commit=commit, limit=limit)
+ self.quickedithistory_tbl.Delete(
+ cnxn, user_id=user_ids, commit=commit, limit=limit)
+
+ ### Saved User and Project Queries
+
+ def GetSavedQueries(self, cnxn, query_ids):
+ """Retrieve the specified SaveQuery PBs."""
+ # TODO(jrobbins): RAM cache
+ if not query_ids:
+ return {}
+ saved_queries = {}
+ savedquery_rows = self.savedquery_tbl.Select(
+ cnxn, cols=SAVEDQUERY_COLS, id=query_ids)
+ for saved_query_tuple in savedquery_rows:
+ qid, name, base_id, query = saved_query_tuple
+ saved_queries[qid] = tracker_bizobj.MakeSavedQuery(
+ qid, name, base_id, query)
+
+ sqeip_rows = self.savedqueryexecutesinproject_tbl.Select(
+ cnxn, cols=SAVEDQUERYEXECUTESINPROJECT_COLS, query_id=query_ids)
+ for query_id, project_id in sqeip_rows:
+ saved_queries[query_id].executes_in_project_ids.append(project_id)
+
+ return saved_queries
+
+ def GetSavedQuery(self, cnxn, query_id):
+ """Retrieve the specified SaveQuery PB."""
+ saved_queries = self.GetSavedQueries(cnxn, [query_id])
+ return saved_queries.get(query_id)
+
+ def _GetUsersSavedQueriesDict(self, cnxn, user_ids):
+ """Return a dict of all SavedQuery PBs for the specified users."""
+ results_dict, missed_uids = self.saved_query_cache.GetAll(user_ids)
+
+ if missed_uids:
+ savedquery_rows = self.user2savedquery_tbl.Select(
+ cnxn, cols=SAVEDQUERY_COLS + ['user_id', 'subscription_mode'],
+ left_joins=[('SavedQuery ON query_id = id', [])],
+ order_by=[('rank', [])], user_id=missed_uids)
+ sqeip_dict = {}
+ if savedquery_rows:
+ query_ids = {row[0] for row in savedquery_rows}
+ sqeip_rows = self.savedqueryexecutesinproject_tbl.Select(
+ cnxn, cols=SAVEDQUERYEXECUTESINPROJECT_COLS, query_id=query_ids)
+ for qid, pid in sqeip_rows:
+ sqeip_dict.setdefault(qid, []).append(pid)
+
+ for saved_query_tuple in savedquery_rows:
+ query_id, name, base_id, query, uid, sub_mode = saved_query_tuple
+ sq = tracker_bizobj.MakeSavedQuery(
+ query_id, name, base_id, query, subscription_mode=sub_mode,
+ executes_in_project_ids=sqeip_dict.get(query_id, []))
+ results_dict.setdefault(uid, []).append(sq)
+
+ self.saved_query_cache.CacheAll(results_dict)
+ return results_dict
+
+ # TODO(jrobbins): change this termonology to "canned query" rather than
+ # "saved" throughout the application.
+ def GetSavedQueriesByUserID(self, cnxn, user_id):
+ """Return a list of SavedQuery PBs for the specified user."""
+ saved_queries_dict = self._GetUsersSavedQueriesDict(cnxn, [user_id])
+ saved_queries = saved_queries_dict.get(user_id, [])
+ return saved_queries[:]
+
+ def GetCannedQueriesForProjects(self, cnxn, project_ids):
+ """Return a dict {project_id: [saved_query]} for the specified projects."""
+ results_dict, missed_pids = self.canned_query_cache.GetAll(project_ids)
+
+ if missed_pids:
+ cannedquery_rows = self.project2savedquery_tbl.Select(
+ cnxn, cols=['project_id'] + SAVEDQUERY_COLS,
+ left_joins=[('SavedQuery ON query_id = id', [])],
+ order_by=[('rank', [])], project_id=project_ids)
+
+ for cq_row in cannedquery_rows:
+ project_id = cq_row[0]
+ canned_query_tuple = cq_row[1:]
+ results_dict.setdefault(project_id ,[]).append(
+ tracker_bizobj.MakeSavedQuery(*canned_query_tuple))
+
+ self.canned_query_cache.CacheAll(results_dict)
+ return results_dict
+
+ def GetCannedQueriesByProjectID(self, cnxn, project_id):
+ """Return the list of SavedQueries for the specified project."""
+ project_ids_to_canned_queries = self.GetCannedQueriesForProjects(
+ cnxn, [project_id])
+ return project_ids_to_canned_queries.get(project_id, [])
+
+ def _UpdateSavedQueries(self, cnxn, saved_queries, commit=True):
+ """Store the given SavedQueries to the DB."""
+ savedquery_rows = [
+ (sq.query_id or None, sq.name, sq.base_query_id, sq.query)
+ for sq in saved_queries]
+ existing_query_ids = [sq.query_id for sq in saved_queries if sq.query_id]
+ if existing_query_ids:
+ self.savedquery_tbl.Delete(cnxn, id=existing_query_ids, commit=commit)
+
+ generated_ids = self.savedquery_tbl.InsertRows(
+ cnxn, SAVEDQUERY_COLS, savedquery_rows, commit=commit,
+ return_generated_ids=True)
+ if generated_ids:
+ logging.info('generated_ids are %r', generated_ids)
+ for sq in saved_queries:
+ generated_id = generated_ids.pop(0)
+ if not sq.query_id:
+ sq.query_id = generated_id
+
+ def UpdateCannedQueries(self, cnxn, project_id, canned_queries):
+ """Update the canned queries for a project.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int project ID of the project that contains these queries.
+ canned_queries: list of SavedQuery PBs to update.
+ """
+ self.project2savedquery_tbl.Delete(
+ cnxn, project_id=project_id, commit=False)
+ self._UpdateSavedQueries(cnxn, canned_queries, commit=False)
+ project2savedquery_rows = [
+ (project_id, rank, sq.query_id)
+ for rank, sq in enumerate(canned_queries)]
+ self.project2savedquery_tbl.InsertRows(
+ cnxn, PROJECT2SAVEDQUERY_COLS, project2savedquery_rows,
+ commit=False)
+ cnxn.Commit()
+
+ self.canned_query_cache.Invalidate(cnxn, project_id)
+
+ def UpdateUserSavedQueries(self, cnxn, user_id, saved_queries):
+ """Store the given saved_queries for the given user."""
+ saved_query_ids = [sq.query_id for sq in saved_queries if sq.query_id]
+ self.savedqueryexecutesinproject_tbl.Delete(
+ cnxn, query_id=saved_query_ids, commit=False)
+ self.user2savedquery_tbl.Delete(cnxn, user_id=user_id, commit=False)
+
+ self._UpdateSavedQueries(cnxn, saved_queries, commit=False)
+ user2savedquery_rows = []
+ for rank, sq in enumerate(saved_queries):
+ user2savedquery_rows.append(
+ (user_id, rank, sq.query_id, sq.subscription_mode or 'noemail'))
+
+ self.user2savedquery_tbl.InsertRows(
+ cnxn, USER2SAVEDQUERY_COLS, user2savedquery_rows, commit=False)
+
+ sqeip_rows = []
+ for sq in saved_queries:
+ for pid in sq.executes_in_project_ids:
+ sqeip_rows.append((sq.query_id, pid))
+
+ self.savedqueryexecutesinproject_tbl.InsertRows(
+ cnxn, SAVEDQUERYEXECUTESINPROJECT_COLS, sqeip_rows, commit=False)
+ cnxn.Commit()
+
+ self.saved_query_cache.Invalidate(cnxn, user_id)
+
+ ### Subscriptions
+
+ def GetSubscriptionsInProjects(self, cnxn, project_ids):
+ """Return all saved queries for users that have any subscription there.
+
+ Args:
+ cnxn: Connection to SQL database.
+ project_ids: list of int project IDs that contain the modified issues.
+
+ Returns:
+ A dict {user_id: all_saved_queries, ...} for all users that have any
+ subscription in any of the specified projects.
+ """
+ sqeip_join_str = (
+ 'SavedQueryExecutesInProject ON '
+ 'SavedQueryExecutesInProject.query_id = User2SavedQuery.query_id')
+ user_join_str = (
+ 'User ON '
+ 'User.user_id = User2SavedQuery.user_id')
+ now = int(time.time())
+ absence_threshold = now - settings.subscription_timeout_secs
+ where = [
+ ('(User.banned IS NULL OR User.banned = %s)', ['']),
+ ('User.last_visit_timestamp >= %s', [absence_threshold]),
+ ('(User.email_bounce_timestamp IS NULL OR '
+ 'User.email_bounce_timestamp = %s)', [0]),
+ ]
+ # TODO(jrobbins): cache this since it rarely changes.
+ subscriber_rows = self.user2savedquery_tbl.Select(
+ cnxn, cols=['User2SavedQuery.user_id'], distinct=True,
+ joins=[(sqeip_join_str, []), (user_join_str, [])],
+ subscription_mode='immediate', project_id=project_ids,
+ where=where)
+ subscriber_ids = [row[0] for row in subscriber_rows]
+ logging.info('subscribers relevant to projects %r are %r',
+ project_ids, subscriber_ids)
+ user_ids_to_saved_queries = self._GetUsersSavedQueriesDict(
+ cnxn, subscriber_ids)
+ return user_ids_to_saved_queries
+
+ def ExpungeSavedQueriesExecuteInProject(self, cnxn, project_id):
+ """Remove any references from saved queries to projects in the database."""
+ self.savedqueryexecutesinproject_tbl.Delete(cnxn, project_id=project_id)
+
+ savedquery_rows = self.project2savedquery_tbl.Select(
+ cnxn, cols=['query_id'], project_id=project_id)
+ savedquery_ids = [row[0] for row in savedquery_rows]
+ self.project2savedquery_tbl.Delete(cnxn, project_id=project_id)
+ self.savedquery_tbl.Delete(cnxn, id=savedquery_ids)
+
+ def ExpungeSavedQueriesByUsers(self, cnxn, user_ids, limit=None):
+ """Completely delete every given users' saved queries.
+
+ This method will not commit the operations. This method will
+ not make changes to in-memory data.
+ """
+ commit = False
+ savedquery_rows = self.user2savedquery_tbl.Select(
+ cnxn, cols=['query_id'], user_id=user_ids, limit=limit)
+ savedquery_ids = [row[0] for row in savedquery_rows]
+ self.user2savedquery_tbl.Delete(
+ cnxn, query_id=savedquery_ids, commit=commit)
+ self.savedqueryexecutesinproject_tbl.Delete(
+ cnxn, query_id=savedquery_ids, commit=commit)
+ self.savedquery_tbl.Delete(cnxn, id=savedquery_ids, commit=commit)
+
+
+ ### Filter rules
+
+ def _DeserializeFilterRules(self, filterrule_rows):
+ """Convert the given DB row tuples into PBs."""
+ result_dict = collections.defaultdict(list)
+
+ for filterrule_row in sorted(filterrule_rows):
+ project_id, _rank, predicate, consequence = filterrule_row
+ (default_status, default_owner_id, add_cc_ids, add_labels,
+ add_notify, warning, error) = self._DeserializeRuleConsequence(
+ consequence)
+ rule = filterrules_helpers.MakeRule(
+ predicate, default_status=default_status,
+ default_owner_id=default_owner_id, add_cc_ids=add_cc_ids,
+ add_labels=add_labels, add_notify=add_notify, warning=warning,
+ error=error)
+ result_dict[project_id].append(rule)
+
+ return result_dict
+
+ def _DeserializeRuleConsequence(self, consequence):
+ """Decode the THEN-part of a filter rule."""
+ (default_status, default_owner_id, add_cc_ids, add_labels,
+ add_notify, warning, error) = None, None, [], [], [], None, None
+ for match in CONSEQUENCE_RE.finditer(consequence):
+ if match.group('default_status'):
+ default_status = match.group('default_status')
+ elif match.group('default_owner_id'):
+ default_owner_id = int(match.group('default_owner_id'))
+ elif match.group('add_cc_id'):
+ add_cc_ids.append(int(match.group('add_cc_id')))
+ elif match.group('add_label'):
+ add_labels.append(match.group('add_label'))
+ elif match.group('add_notify'):
+ add_notify.append(match.group('add_notify'))
+ elif match.group('warning'):
+ warning = match.group('warning')
+ elif match.group('error'):
+ error = match.group('error')
+
+ return (default_status, default_owner_id, add_cc_ids, add_labels,
+ add_notify, warning, error)
+
+ def _GetFilterRulesByProjectIDs(self, cnxn, project_ids):
+ """Return {project_id: [FilterRule, ...]} for the specified projects."""
+ # TODO(jrobbins): caching
+ filterrule_rows = self.filterrule_tbl.Select(
+ cnxn, cols=FILTERRULE_COLS, project_id=project_ids)
+ return self._DeserializeFilterRules(filterrule_rows)
+
+ def GetFilterRules(self, cnxn, project_id):
+ """Return a list of FilterRule PBs for the specified project."""
+ rules_by_project_id = self._GetFilterRulesByProjectIDs(cnxn, [project_id])
+ return rules_by_project_id[project_id]
+
+ def _SerializeRuleConsequence(self, rule):
+ """Put all actions of a filter rule into one string."""
+ assignments = []
+ for add_lab in rule.add_labels:
+ assignments.append('add_label:%s' % add_lab)
+ if rule.default_status:
+ assignments.append('default_status:%s' % rule.default_status)
+ if rule.default_owner_id:
+ assignments.append('default_owner_id:%d' % rule.default_owner_id)
+ for add_cc_id in rule.add_cc_ids:
+ assignments.append('add_cc_id:%d' % add_cc_id)
+ for add_notify in rule.add_notify_addrs:
+ assignments.append('add_notify:%s' % add_notify)
+ if rule.warning:
+ assignments.append('warning:%s' % rule.warning)
+ if rule.error:
+ assignments.append('error:%s' % rule.error)
+
+ return ' '.join(assignments)
+
+ def UpdateFilterRules(self, cnxn, project_id, rules):
+ """Update the filter rules part of a project's issue configuration.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+ rules: a list of FilterRule PBs.
+ """
+ rows = []
+ for rank, rule in enumerate(rules):
+ predicate = rule.predicate
+ consequence = self._SerializeRuleConsequence(rule)
+ if predicate and consequence:
+ rows.append((project_id, rank, predicate, consequence))
+
+ self.filterrule_tbl.Delete(cnxn, project_id=project_id)
+ self.filterrule_tbl.InsertRows(cnxn, FILTERRULE_COLS, rows)
+
+ def ExpungeFilterRules(self, cnxn, project_id):
+ """Completely destroy filter rule info for the specified project."""
+ self.filterrule_tbl.Delete(cnxn, project_id=project_id)
+
+ def ExpungeFilterRulesByUser(self, cnxn, user_ids_by_email):
+ """Wipes any Filter Rules containing the given users.
+
+ This method will not commit the operation. This method will not make
+ changes to in-memory data.
+ Args:
+ cnxn: connection to SQL database.
+ user_ids_by_email: dict of {email: user_id ..} of all users we want to
+ expunge
+
+ Returns:
+ Dictionary of {project_id: [(predicate, consequence), ..]} for Filter
+ Rules that will be deleted for containing the given emails.
+ """
+ deleted_project_rules_dict = collections.defaultdict(list)
+ if user_ids_by_email:
+ deleted_rows = []
+ emails = user_ids_by_email.keys()
+ all_rules_rows = self.filterrule_tbl.Select(cnxn, FILTERRULE_COLS)
+ logging.info('Fetched all filter rules: %s' % (all_rules_rows,))
+ for rule_row in all_rules_rows:
+ project_id, _rank, predicate, consequence = rule_row
+ if any(email in predicate for email in emails):
+ deleted_rows.append(rule_row)
+ continue
+ if any(
+ (('add_notify:%s' % email) in consequence or
+ ('add_cc_id:%s' % user_id) in consequence or
+ ('default_owner_id:%s' % user_id) in consequence)
+ for email, user_id in user_ids_by_email.iteritems()):
+ deleted_rows.append(rule_row)
+ continue
+
+ for deleted_row in deleted_rows:
+ project_id, rank, predicate, consequence = deleted_row
+ self.filterrule_tbl.Delete(
+ cnxn, project_id=project_id, rank=rank, predicate=predicate,
+ consequence=consequence, commit=False)
+ deleted_project_rules_dict = self._DeserializeFilterRules(deleted_rows)
+
+ return deleted_project_rules_dict
+
+ ### Creating hotlists
+
+ def CreateHotlist(
+ self, cnxn, name, summary, description, owner_ids, editor_ids,
+ issue_ids=None, is_private=None, default_col_spec=None, ts=None):
+ # type: (MonorailConnection, string, string, string, Collection[int],
+ # Optional[Collection[int]], Optional[Boolean], Optional[string],
+ # Optional[int] -> int
+ """Create and store a Hotlist with the given attributes.
+
+ Args:
+ cnxn: connection to SQL database.
+ name: a valid hotlist name.
+ summary: one-line explanation of the hotlist.
+ description: one-page explanation of the hotlist.
+ owner_ids: a list of user IDs for the hotlist owners.
+ editor_ids: a list of user IDs for the hotlist editors.
+ issue_ids: a list of issue IDs for the hotlist issues.
+ is_private: True if the hotlist can only be viewed by owners and editors.
+ default_col_spec: the default columns that show in list view.
+ ts: a timestamp for when this hotlist was created.
+
+ Returns:
+ The int id of the new hotlist.
+
+ Raises:
+ InputException: if the hotlist name is invalid.
+ HotlistAlreadyExists: if any of the owners already own a hotlist with
+ the same name.
+ UnownedHotlistException: if owner_ids is empty.
+ """
+ # TODO(crbug.com/monorail/7677): These checks should be done in the
+ # the business layer.
+ # Remove when calls from non-business layer code are removed.
+ if not owner_ids: # Should never happen.
+ logging.error('Attempt to create unowned Hotlist: name:%r', name)
+ raise UnownedHotlistException()
+ if not framework_bizobj.IsValidHotlistName(name):
+ raise exceptions.InputException(
+ '%s is not a valid name for a Hotlist' % name)
+ if self.LookupHotlistIDs(cnxn, [name], owner_ids):
+ raise HotlistAlreadyExists()
+ # TODO(crbug.com/monorail/7677): We are not setting a
+ # default default_col_spec in v3.
+ if default_col_spec is None:
+ default_col_spec = features_constants.DEFAULT_COL_SPEC
+
+ hotlist_item_fields = [
+ (issue_id, rank*100, owner_ids[0], ts, '') for
+ rank, issue_id in enumerate(issue_ids or [])]
+ hotlist = features_pb2.MakeHotlist(
+ name, hotlist_item_fields=hotlist_item_fields, summary=summary,
+ description=description, is_private=is_private, owner_ids=owner_ids,
+ editor_ids=editor_ids, default_col_spec=default_col_spec)
+ hotlist.hotlist_id = self._InsertHotlist(cnxn, hotlist)
+ return hotlist
+
+ def UpdateHotlist(
+ self, cnxn, hotlist_id, name=None, summary=None, description=None,
+ is_private=None, default_col_spec=None, owner_id=None,
+ add_editor_ids=None):
+ """Update the DB with the given hotlist information."""
+ # Note: If something is None, it does not get changed to None,
+ # it just does not get updated.
+ hotlist = self.GetHotlist(cnxn, hotlist_id, use_cache=False)
+ if not hotlist:
+ raise NoSuchHotlistException()
+
+ delta = {}
+ if name is not None:
+ delta['name'] = name
+ if summary is not None:
+ delta['summary'] = summary
+ if description is not None:
+ delta['description'] = description
+ if is_private is not None:
+ delta['is_private'] = is_private
+ if default_col_spec is not None:
+ delta['default_col_spec'] = default_col_spec
+
+ self.hotlist_tbl.Update(cnxn, delta, id=hotlist_id, commit=False)
+ insert_rows = []
+ if owner_id is not None:
+ insert_rows.append((hotlist_id, owner_id, 'owner'))
+ self.hotlist2user_tbl.Delete(
+ cnxn, hotlist_id=hotlist_id, role='owner', commit=False)
+ if add_editor_ids:
+ insert_rows.extend(
+ [(hotlist_id, user_id, 'editor') for user_id in add_editor_ids])
+ if insert_rows:
+ self.hotlist2user_tbl.InsertRows(
+ cnxn, HOTLIST2USER_COLS, insert_rows, commit=False)
+
+ cnxn.Commit()
+
+ self.hotlist_2lc.InvalidateKeys(cnxn, [hotlist_id])
+ if not hotlist.owner_ids: # Should never happen.
+ logging.warn('Modifying unowned Hotlist: id:%r, name:%r',
+ hotlist_id, hotlist.name)
+ elif hotlist.name:
+ self.hotlist_id_2lc.InvalidateKeys(
+ cnxn, [(hotlist.name.lower(), owner_id) for
+ owner_id in hotlist.owner_ids])
+
+ # Update the hotlist PB in RAM
+ if name is not None:
+ hotlist.name = name
+ if summary is not None:
+ hotlist.summary = summary
+ if description is not None:
+ hotlist.description = description
+ if is_private is not None:
+ hotlist.is_private = is_private
+ if default_col_spec is not None:
+ hotlist.default_col_spec = default_col_spec
+ if owner_id is not None:
+ hotlist.owner_ids = [owner_id]
+ if add_editor_ids:
+ hotlist.editor_ids.extend(add_editor_ids)
+
+ def RemoveHotlistEditors(self, cnxn, hotlist_id, remove_editor_ids):
+ # type: MonorailConnection, int, Collection[int]
+ """Remove given editors from the specified hotlist.
+
+ Args:
+ cnxn: MonorailConnection object.
+ hotlist_id: int ID of the Hotlist we want to update.
+ remove_editor_ids: collection of existing hotlist editor User IDs
+ that we want to remove from the hotlist.
+
+ Raises:
+ NoSuchHotlistException: if the hotlist is not found.
+ InputException: if there are not editors to remove.
+ """
+ if not remove_editor_ids:
+ raise exceptions.InputException
+ hotlist = self.GetHotlist(cnxn, hotlist_id, use_cache=False)
+ if not hotlist:
+ raise NoSuchHotlistException()
+
+ self.hotlist2user_tbl.Delete(
+ cnxn, hotlist_id=hotlist_id, user_id=remove_editor_ids)
+ self.hotlist_2lc.InvalidateKeys(cnxn, [hotlist_id])
+
+ # Update in-memory data
+ for remove_id in remove_editor_ids:
+ hotlist.editor_ids.remove(remove_id)
+
+ def UpdateHotlistIssues(
+ self,
+ cnxn, # type: sql.MonorailConnection
+ hotlist_id, # type: int
+ updated_items, # type: Collection[features_pb2.HotlistItem]
+ remove_issue_ids, # type: Collection[int]
+ issue_svc, # type: issue_svc.IssueService
+ chart_svc, # type: chart_svc.ChartService
+ commit=True # type: Optional[bool]
+ ):
+ # type: (...) -> None
+ """Update the Issues in a Hotlist.
+ This method removes the given remove_issue_ids from a Hotlist then
+ updates or adds the HotlistItems found in updated_items. HotlistItems
+ in updated_items may exist in the hotlist and just need to be updated
+ or they may be new items that should be added to the Hotlist.
+
+ Args:
+ cnxn: MonorailConnection object.
+ hotlist_id: int ID of the Hotlist to update.
+ updated_items: Collection of HotlistItems that either already exist in
+ the hotlist and need to be updated or needed to be added to the hotlist.
+ remove_issue_ids: Collection of Issue IDs that should be removed from the
+ hotlist.
+ issue_svc: IssueService object.
+ chart_svc: ChartService object.
+
+ Raises:
+ NoSuchHotlistException if a hotlist with the given ID is not found.
+ InputException if no changes were given.
+ """
+ if not updated_items and not remove_issue_ids:
+ raise exceptions.InputException('No changes to make')
+
+ hotlist = self.GetHotlist(cnxn, hotlist_id, use_cache=False)
+ if not hotlist:
+ raise NoSuchHotlistException()
+
+ # Used to hold the updated Hotlist.items to use when updating
+ # the in-memory hotlist.
+ all_hotlist_items = list(hotlist.items)
+
+ # Used to hold ids of issues affected by this change for storing
+ # Issue Snapshots.
+ affected_issue_ids = set()
+
+ if remove_issue_ids:
+ affected_issue_ids.update(remove_issue_ids)
+ self.hotlist2issue_tbl.Delete(
+ cnxn, hotlist_id=hotlist_id, issue_id=remove_issue_ids, commit=False)
+ all_hotlist_items = filter(
+ lambda item: item.issue_id not in remove_issue_ids, all_hotlist_items)
+
+ if updated_items:
+ updated_issue_ids = [item.issue_id for item in updated_items]
+ affected_issue_ids.update(updated_issue_ids)
+ self.hotlist2issue_tbl.Delete(
+ cnxn, hotlist_id=hotlist_id, issue_id=updated_issue_ids, commit=False)
+ insert_rows = []
+ for item in updated_items:
+ insert_rows.append(
+ (
+ hotlist_id, item.issue_id, item.rank, item.adder_id,
+ item.date_added, item.note))
+ self.hotlist2issue_tbl.InsertRows(
+ cnxn, cols=HOTLIST2ISSUE_COLS, row_values=insert_rows, commit=False)
+ all_hotlist_items = filter(
+ lambda item: item.issue_id not in updated_issue_ids,
+ all_hotlist_items)
+ all_hotlist_items.extend(updated_items)
+
+ if commit:
+ cnxn.Commit()
+ self.hotlist_2lc.InvalidateKeys(cnxn, [hotlist_id])
+
+ # Update in-memory hotlist items.
+ hotlist.items = sorted(all_hotlist_items, key=lambda item: item.rank)
+
+ issues = issue_svc.GetIssues(cnxn, list(affected_issue_ids))
+ chart_svc.StoreIssueSnapshots(cnxn, issues, commit=commit)
+
+ # TODO(crbug/monorail/7104): {Add|Remove}IssuesToHotlists both call
+ # UpdateHotlistItems to add/remove issues from a hotlist.
+ # UpdateHotlistItemsFields is called by methods for reranking existing issues
+ # and updating HotlistItem notes.
+ # (1) We are removing notes from HotlistItems. crbug/monorail/####
+ # (2) our v3 AddHotlistItems will allow for inserting new issues to
+ # non-last ranks of a hotlist. So there could be some shared code
+ # for the reranking path and the adding issues path.
+ # UpdateHotlistIssues will be handling adding, removing, and reranking issues.
+ # {Add|Remove}IssueToHotlists, UpdateHotlistItems, UpdateHotlistItemFields
+ # should be removed, once all methods are updated to call UpdateHotlistIssues.
+
+ def AddIssueToHotlists(self, cnxn, hotlist_ids, issue_tuple, issue_svc,
+ chart_svc, commit=True):
+ """Add a single issue, specified in the issue_tuple, to the given hotlists.
+
+ Args:
+ cnxn: connection to SQL database.
+ hotlist_ids: a list of hotlist_ids to add the issues to.
+ issue_tuple: (issue_id, user_id, ts, note) of the issue to be added.
+ issue_svc: an instance of IssueService.
+ chart_svc: an instance of ChartService.
+ """
+ self.AddIssuesToHotlists(cnxn, hotlist_ids, [issue_tuple], issue_svc,
+ chart_svc, commit=commit)
+
+ def AddIssuesToHotlists(self, cnxn, hotlist_ids, added_tuples, issue_svc,
+ chart_svc, commit=True):
+ """Add the issues given in the added_tuples list to the given hotlists.
+
+ Args:
+ cnxn: connection to SQL database.
+ hotlist_ids: a list of hotlist_ids to add the issues to.
+ added_tuples: a list of (issue_id, user_id, ts, note)
+ for issues to be added.
+ issue_svc: an instance of IssueService.
+ chart_svc: an instance of ChartService.
+ """
+ for hotlist_id in hotlist_ids:
+ self.UpdateHotlistItems(cnxn, hotlist_id, [], added_tuples, commit=commit)
+
+ issues = issue_svc.GetIssues(cnxn,
+ [added_tuple[0] for added_tuple in added_tuples])
+ chart_svc.StoreIssueSnapshots(cnxn, issues, commit=commit)
+
+ def RemoveIssuesFromHotlists(self, cnxn, hotlist_ids, issue_ids, issue_svc,
+ chart_svc, commit=True):
+ """Remove the issues given in issue_ids from the given hotlists.
+
+ Args:
+ cnxn: connection to SQL database.
+ hotlist_ids: a list of hotlist ids to remove the issues from.
+ issue_ids: a list of issue_ids to be removed.
+ issue_svc: an instance of IssueService.
+ chart_svc: an instance of ChartService.
+ """
+ for hotlist_id in hotlist_ids:
+ self.UpdateHotlistItems(cnxn, hotlist_id, issue_ids, [], commit=commit)
+
+ issues = issue_svc.GetIssues(cnxn, issue_ids)
+ chart_svc.StoreIssueSnapshots(cnxn, issues, commit=commit)
+
+ def UpdateHotlistItems(
+ self, cnxn, hotlist_id, remove, added_tuples, commit=True):
+ """Updates a hotlist's list of hotlistissues.
+
+ Args:
+ cnxn: connection to SQL database.
+ hotlist_id: the ID of the hotlist to update.
+ remove: a list of issue_ids for be removed.
+ added_tuples: a list of (issue_id, user_id, ts, note)
+ for issues to be added.
+ """
+ hotlist = self.GetHotlist(cnxn, hotlist_id, use_cache=False)
+ if not hotlist:
+ raise NoSuchHotlistException()
+
+ # adding new Hotlistissues, ignoring pairs where issue_id is already in
+ # hotlist's iid_rank_pairs
+ current_issues_ids = {
+ item.issue_id for item in hotlist.items}
+
+ self.hotlist2issue_tbl.Delete(
+ cnxn, hotlist_id=hotlist_id,
+ issue_id=[remove_id for remove_id in remove
+ if remove_id in current_issues_ids],
+ commit=False)
+ if hotlist.items:
+ items_sorted = sorted(hotlist.items, key=lambda item: item.rank)
+ rank_base = items_sorted[-1].rank + 10
+ else:
+ rank_base = 1
+ insert_rows = [
+ (hotlist_id, issue_id, rank*10 + rank_base, user_id, ts, note)
+ for (rank, (issue_id, user_id, ts, note)) in enumerate(added_tuples)
+ if issue_id not in current_issues_ids]
+ self.hotlist2issue_tbl.InsertRows(
+ cnxn, cols=HOTLIST2ISSUE_COLS, row_values=insert_rows, commit=commit)
+ self.hotlist_2lc.InvalidateKeys(cnxn, [hotlist_id])
+
+ # removing an issue that was never in the hotlist would not cause any
+ # problems.
+ items = [
+ item for item in hotlist.items if
+ item.issue_id not in remove]
+
+ new_hotlist_items = [
+ features_pb2.MakeHotlistItem(issue_id, rank, user_id, ts, note)
+ for (_hid, issue_id, rank, user_id, ts, note) in insert_rows]
+ items.extend(new_hotlist_items)
+ hotlist.items = items
+
+ def UpdateHotlistItemsFields(
+ self, cnxn, hotlist_id, new_ranks=None, new_notes=None, commit=True):
+ """Updates rankings or notes of hotlistissues.
+
+ Args:
+ cnxn: connection to SQL database.
+ hotlist_id: the ID of the hotlist to update.
+ new_ranks : This should be a dictionary of {issue_id: rank}.
+ new_notes: This should be a diciontary of {issue_id: note}.
+ commit: set to False to skip the DB commit and do it in the caller.
+ """
+ hotlist = self.GetHotlist(cnxn, hotlist_id, use_cache=False)
+ if not hotlist:
+ raise NoSuchHotlistException()
+ if new_ranks is None:
+ new_ranks = {}
+ if new_notes is None:
+ new_notes = {}
+ issue_ids = []
+ insert_rows = []
+
+ # Update the hotlist PB in RAM
+ for hotlist_item in hotlist.items:
+ item_updated = False
+ if hotlist_item.issue_id in new_ranks:
+ # Update rank before adding it to insert_rows
+ hotlist_item.rank = new_ranks[hotlist_item.issue_id]
+ item_updated = True
+ if hotlist_item.issue_id in new_notes:
+ # Update note before adding it to insert_rows
+ hotlist_item.note = new_notes[hotlist_item.issue_id]
+ item_updated = True
+ if item_updated:
+ issue_ids.append(hotlist_item.issue_id)
+ insert_rows.append((
+ hotlist_id, hotlist_item.issue_id, hotlist_item.rank,
+ hotlist_item.adder_id, hotlist_item.date_added, hotlist_item.note))
+ hotlist.items = sorted(hotlist.items, key=lambda item: item.rank)
+ self.hotlist2issue_tbl.Delete(
+ cnxn, hotlist_id=hotlist_id, issue_id=issue_ids, commit=False)
+
+ self.hotlist2issue_tbl.InsertRows(
+ cnxn, cols=HOTLIST2ISSUE_COLS , row_values=insert_rows, commit=commit)
+ self.hotlist_2lc.InvalidateKeys(cnxn, [hotlist_id])
+
+ def _InsertHotlist(self, cnxn, hotlist):
+ """Insert the given hotlist into the database."""
+ hotlist_id = self.hotlist_tbl.InsertRow(
+ cnxn, name=hotlist.name, summary=hotlist.summary,
+ description=hotlist.description, is_private=hotlist.is_private,
+ default_col_spec=hotlist.default_col_spec)
+ logging.info('stored hotlist was given id %d', hotlist_id)
+
+ self.hotlist2issue_tbl.InsertRows(
+ cnxn, HOTLIST2ISSUE_COLS,
+ [(hotlist_id, issue.issue_id, issue.rank,
+ issue.adder_id, issue.date_added, issue.note)
+ for issue in hotlist.items],
+ commit=False)
+ self.hotlist2user_tbl.InsertRows(
+ cnxn, HOTLIST2USER_COLS,
+ [(hotlist_id, user_id, 'owner')
+ for user_id in hotlist.owner_ids] +
+ [(hotlist_id, user_id, 'editor')
+ for user_id in hotlist.editor_ids] +
+ [(hotlist_id, user_id, 'follower')
+ for user_id in hotlist.follower_ids])
+
+ self.hotlist_user_to_ids.InvalidateKeys(cnxn, hotlist.owner_ids)
+
+ return hotlist_id
+
+ def TransferHotlistOwnership(
+ self, cnxn, hotlist, new_owner_id, remain_editor, commit=True):
+ """Transfers ownership of a hotlist to a new owner."""
+ new_editor_ids = hotlist.editor_ids
+ if remain_editor:
+ new_editor_ids.extend(hotlist.owner_ids)
+ if new_owner_id in new_editor_ids:
+ new_editor_ids.remove(new_owner_id)
+ new_follower_ids = hotlist.follower_ids
+ if new_owner_id in new_follower_ids:
+ new_follower_ids.remove(new_owner_id)
+ self.UpdateHotlistRoles(
+ cnxn, hotlist.hotlist_id, [new_owner_id], new_editor_ids,
+ new_follower_ids, commit=commit)
+
+ ### Lookup hotlist IDs
+
+ def LookupHotlistIDs(self, cnxn, hotlist_names, owner_ids):
+ """Return a dict of (name, owner_id) mapped to hotlist_id for all hotlists
+ with one of the given names and any of the given owners. Hotlists that
+ match multiple owners will be in the dict multiple times."""
+ id_dict, _missed_keys = self.hotlist_id_2lc.GetAll(
+ cnxn, [(name.lower(), owner_id)
+ for name in hotlist_names for owner_id in owner_ids])
+ return id_dict
+
+ def LookupUserHotlists(self, cnxn, user_ids):
+ """Return a dict of {user_id: [hotlist_id,...]} for all user_ids."""
+ id_dict, missed_ids = self.hotlist_user_to_ids.GetAll(user_ids)
+ if missed_ids:
+ retrieved_dict = {user_id: [] for user_id in missed_ids}
+ id_rows = self.hotlist2user_tbl.Select(
+ cnxn, cols=['user_id', 'hotlist_id'], user_id=user_ids,
+ left_joins=[('Hotlist ON hotlist_id = id', [])],
+ where=[('Hotlist.is_deleted = %s', [False])])
+ for (user_id, hotlist_id) in id_rows:
+ retrieved_dict[user_id].append(hotlist_id)
+ self.hotlist_user_to_ids.CacheAll(retrieved_dict)
+ id_dict.update(retrieved_dict)
+
+ return id_dict
+
+ def LookupIssueHotlists(self, cnxn, issue_ids):
+ """Return a dict of {issue_id: [hotlist_id,...]} for all issue_ids."""
+ # TODO(jojwang): create hotlist_issue_to_ids cache
+ retrieved_dict = {issue_id: [] for issue_id in issue_ids}
+ id_rows = self.hotlist2issue_tbl.Select(
+ cnxn, cols=['hotlist_id', 'issue_id'], issue_id=issue_ids,
+ left_joins=[('Hotlist ON hotlist_id = id', [])],
+ where=[('Hotlist.is_deleted = %s', [False])])
+ for hotlist_id, issue_id in id_rows:
+ retrieved_dict[issue_id].append(hotlist_id)
+ return retrieved_dict
+
+ def GetProjectIDsFromHotlist(self, cnxn, hotlist_id):
+ project_id_rows = self.hotlist2issue_tbl.Select(cnxn,
+ cols=['Issue.project_id'], hotlist_id=hotlist_id, distinct=True,
+ left_joins=[('Issue ON issue_id = id', [])])
+ return [row[0] for row in project_id_rows]
+
+ ### Get hotlists
+ def GetHotlists(self, cnxn, hotlist_ids, use_cache=True):
+ """Returns dict of {hotlist_id: hotlist PB}."""
+ hotlists_dict, missed_ids = self.hotlist_2lc.GetAll(
+ cnxn, hotlist_ids, use_cache=use_cache)
+
+ if missed_ids:
+ raise NoSuchHotlistException()
+
+ return hotlists_dict
+
+ def GetHotlistsByUserID(self, cnxn, user_id, use_cache=True):
+ """Get a list of hotlist PBs for a given user."""
+ hotlist_id_dict = self.LookupUserHotlists(cnxn, [user_id])
+ hotlists = self.GetHotlists(
+ cnxn, hotlist_id_dict.get(user_id, []), use_cache=use_cache)
+ return list(hotlists.values())
+
+ def GetHotlistsByIssueID(self, cnxn, issue_id, use_cache=True):
+ """Get a list of hotlist PBs for a given issue."""
+ hotlist_id_dict = self.LookupIssueHotlists(cnxn, [issue_id])
+ hotlists = self.GetHotlists(
+ cnxn, hotlist_id_dict.get(issue_id, []), use_cache=use_cache)
+ return list(hotlists.values())
+
+ def GetHotlist(self, cnxn, hotlist_id, use_cache=True):
+ """Returns hotlist PB."""
+ hotlist_dict = self.GetHotlists(cnxn, [hotlist_id], use_cache=use_cache)
+ return hotlist_dict[hotlist_id]
+
+ def GetHotlistsByID(self, cnxn, hotlist_ids, use_cache=True):
+ """Load all the Hotlist PBs for the given hotlists.
+
+ Args:
+ cnxn: connection to SQL database.
+ hotlist_ids: list of hotlist ids.
+ use_cache: specifiy False to force database query.
+
+ Returns:
+ A dict mapping ids to the corresponding Hotlist protocol buffers and
+ a list of any hotlist_ids that were not found.
+ """
+ hotlists_dict, missed_ids = self.hotlist_2lc.GetAll(
+ cnxn, hotlist_ids, use_cache=use_cache)
+ return hotlists_dict, missed_ids
+
+ def GetHotlistByID(self, cnxn, hotlist_id, use_cache=True):
+ """Load the specified hotlist from the database, None if does not exist."""
+ hotlist_dict, _ = self.GetHotlistsByID(
+ cnxn, [hotlist_id], use_cache=use_cache)
+ return hotlist_dict.get(hotlist_id)
+
+ def UpdateHotlistRoles(
+ self, cnxn, hotlist_id, owner_ids, editor_ids, follower_ids, commit=True):
+ """"Store the hotlist's roles in the DB."""
+ # This will be a newly contructed object, not from the cache and not
+ # shared with any other thread.
+ hotlist = self.GetHotlist(cnxn, hotlist_id, use_cache=False)
+ if not hotlist:
+ raise NoSuchHotlistException()
+
+ self.hotlist2user_tbl.Delete(
+ cnxn, hotlist_id=hotlist_id, commit=False)
+
+ insert_rows = [(hotlist_id, user_id, 'owner') for user_id in owner_ids]
+ insert_rows.extend(
+ [(hotlist_id, user_id, 'editor') for user_id in editor_ids])
+ insert_rows.extend(
+ [(hotlist_id, user_id, 'follower') for user_id in follower_ids])
+ self.hotlist2user_tbl.InsertRows(
+ cnxn, HOTLIST2USER_COLS, insert_rows, commit=False)
+
+ if commit:
+ cnxn.Commit()
+ self.hotlist_2lc.InvalidateKeys(cnxn, [hotlist_id])
+ self.hotlist_user_to_ids.InvalidateKeys(cnxn, hotlist.owner_ids)
+ hotlist.owner_ids = owner_ids
+ hotlist.editor_ids = editor_ids
+ hotlist.follower_ids = follower_ids
+
+ def DeleteHotlist(self, cnxn, hotlist_id, commit=True):
+ hotlist = self.GetHotlist(cnxn, hotlist_id, use_cache=False)
+ if not hotlist:
+ raise NoSuchHotlistException()
+
+ # Fetch all associated project IDs in order to invalidate their cache.
+ project_ids = self.GetProjectIDsFromHotlist(cnxn, hotlist_id)
+
+ delta = {'is_deleted': True}
+ self.hotlist_tbl.Update(cnxn, delta, id=hotlist_id, commit=commit)
+
+ self.hotlist_2lc.InvalidateKeys(cnxn, [hotlist_id])
+ self.hotlist_user_to_ids.InvalidateKeys(cnxn, hotlist.owner_ids)
+ self.hotlist_user_to_ids.InvalidateKeys(cnxn, hotlist.editor_ids)
+ if not hotlist.owner_ids: # Should never happen.
+ logging.warn('Soft-deleting unowned Hotlist: id:%r, name:%r',
+ hotlist_id, hotlist.name)
+ elif hotlist.name:
+ self.hotlist_id_2lc.InvalidateKeys(
+ cnxn, [(hotlist.name.lower(), owner_id) for
+ owner_id in hotlist.owner_ids])
+
+ for project_id in project_ids:
+ self.config_service.InvalidateMemcacheForEntireProject(project_id)
+
+ def ExpungeHotlists(
+ self, cnxn, hotlist_ids, star_svc, user_svc, chart_svc, commit=True):
+ """Wipes the given hotlists from the DB tables.
+
+ This method will only do cache invalidation if commit is set to True.
+
+ Args:
+ cnxn: connection to SQL database.
+ hotlist_ids: the ID of the hotlists to Expunge.
+ star_svc: an instance of a HotlistStarService.
+ user_svc: an instance of a UserService.
+ chart_svc: an instance of a ChartService.
+ commit: set to False to skip the DB commit and do it in the caller.
+ """
+
+ hotlists_by_id = self.GetHotlists(cnxn, hotlist_ids)
+
+ for hotlist_id in hotlist_ids:
+ star_svc.ExpungeStars(cnxn, hotlist_id, commit=commit)
+ chart_svc.ExpungeHotlistsFromIssueSnapshots(
+ cnxn, hotlist_ids, commit=commit)
+ user_svc.ExpungeHotlistsFromHistory(cnxn, hotlist_ids, commit=commit)
+ self.hotlist2user_tbl.Delete(cnxn, hotlist_id=hotlist_ids, commit=commit)
+ self.hotlist2issue_tbl.Delete(cnxn, hotlist_id=hotlist_ids, commit=commit)
+ self.hotlist_tbl.Delete(cnxn, id=hotlist_ids, commit=commit)
+
+ # Invalidate cache for deleted hotlists.
+ self.hotlist_2lc.InvalidateKeys(cnxn, hotlist_ids)
+ users_to_invalidate = set()
+ for hotlist in hotlists_by_id.values():
+ users_to_invalidate.update(
+ hotlist.owner_ids + hotlist.editor_ids + hotlist.follower_ids)
+ self.hotlist_id_2lc.InvalidateKeys(
+ cnxn, [(hotlist.name, owner_id) for owner_id in hotlist.owner_ids])
+ self.hotlist_user_to_ids.InvalidateKeys(cnxn, list(users_to_invalidate))
+ hotlist_project_ids = set()
+ for hotlist_id in hotlist_ids:
+ hotlist_project_ids.update(self.GetProjectIDsFromHotlist(
+ cnxn, hotlist_id))
+ for project_id in hotlist_project_ids:
+ self.config_service.InvalidateMemcacheForEntireProject(project_id)
+
+ def ExpungeUsersInHotlists(
+ self, cnxn, user_ids, star_svc, user_svc, chart_svc):
+ """Wipes the given users and any hotlists they owned from the
+ hotlists system.
+
+ This method will not commit the operation. This method will not make
+ changes to in-memory data.
+ """
+ # Transfer hotlist ownership to editors, if possible.
+ hotlist_ids_by_user_id = self.LookupUserHotlists(cnxn, user_ids)
+ hotlist_ids = [hotlist_id for hotlist_ids in hotlist_ids_by_user_id.values()
+ for hotlist_id in hotlist_ids]
+ hotlists_by_id, missed = self.GetHotlistsByID(
+ cnxn, list(set(hotlist_ids)), use_cache=False)
+ logging.info('Missed hotlists: %s', missed)
+
+ hotlists_to_delete = []
+ for hotlist_id, hotlist in hotlists_by_id.items():
+ # One of the users to be deleted is an owner of hotlist.
+ if not set(hotlist.owner_ids).isdisjoint(user_ids):
+ hotlists_to_delete.append(hotlist_id)
+ candidate_new_owners = [user_id for user_id in hotlist.editor_ids
+ if user_id not in user_ids]
+ for candidate_id in candidate_new_owners:
+ if not self.LookupHotlistIDs(cnxn, [hotlist.name], [candidate_id]):
+ self.TransferHotlistOwnership(
+ cnxn, hotlist, candidate_id, False, commit=False)
+ # Hotlist transferred successfully. No need to delete it.
+ hotlists_to_delete.remove(hotlist_id)
+ break
+
+ # Delete users
+ self.hotlist2user_tbl.Delete(cnxn, user_id=user_ids, commit=False)
+ self.hotlist2issue_tbl.Update(
+ cnxn, {'adder_id': framework_constants.DELETED_USER_ID},
+ adder_id=user_ids, commit=False)
+ user_svc.ExpungeUsersHotlistsHistory(cnxn, user_ids, commit=False)
+ # Delete hotlists
+ if hotlists_to_delete:
+ self.ExpungeHotlists(
+ cnxn, hotlists_to_delete, star_svc, user_svc, chart_svc, commit=False)
+
+
+class HotlistAlreadyExists(Exception):
+ """Tried to create a hotlist with the same name as another hotlist
+ with the same owner."""
+ pass
+
+
+class NoSuchHotlistException(Exception):
+ """The requested hotlist was not found."""
+ pass
+
+
+class UnownedHotlistException(Exception):
+ """Tried to create a hotlist with no owner."""
+ pass
diff --git a/services/fulltext_helpers.py b/services/fulltext_helpers.py
new file mode 100644
index 0000000..80d4264
--- /dev/null
+++ b/services/fulltext_helpers.py
@@ -0,0 +1,126 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of helpers functions for fulltext search."""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import logging
+
+from google.appengine.api import search
+
+import settings
+from proto import ast_pb2
+from proto import tracker_pb2
+from search import query2ast
+
+# GAE search API can only respond with 500 results per call.
+_SEARCH_RESULT_CHUNK_SIZE = 500
+
+
+def BuildFTSQuery(query_ast_conj, fulltext_fields):
+ """Convert a Monorail query AST into a GAE search query string.
+
+ Args:
+ query_ast_conj: a Conjunction PB with a list of Comparison PBs that each
+ have operator, field definitions, string values, and int values.
+ All Conditions should be AND'd together.
+ fulltext_fields: a list of string names of fields that may exist in the
+ fulltext documents. E.g., issue fulltext documents have a "summary"
+ field.
+
+ Returns:
+ A string that can be passed to AppEngine's search API. Or, None if there
+ were no fulltext conditions, so no fulltext search should be done.
+ """
+ fulltext_parts = [
+ _BuildFTSCondition(cond, fulltext_fields)
+ for cond in query_ast_conj.conds]
+ if any(fulltext_parts):
+ return ' '.join(fulltext_parts)
+ else:
+ return None
+
+
+def _BuildFTSCondition(cond, fulltext_fields):
+ """Convert one query AST condition into a GAE search query string."""
+ if cond.op == ast_pb2.QueryOp.NOT_TEXT_HAS:
+ neg = 'NOT '
+ elif cond.op == ast_pb2.QueryOp.TEXT_HAS:
+ neg = ''
+ else:
+ return '' # FTS only looks at TEXT_HAS and NOT_TEXT_HAS
+
+ parts = []
+
+ for fd in cond.field_defs:
+ if fd.field_name in fulltext_fields:
+ pattern = fd.field_name + ':"%s"'
+ elif fd.field_name == ast_pb2.ANY_FIELD:
+ pattern = '"%s"'
+ elif fd.field_id and fd.field_type == tracker_pb2.FieldTypes.STR_TYPE:
+ pattern = 'custom_' + str(fd.field_id) + ':"%s"'
+ else:
+ pattern = 'pylint does not handle else-continue'
+ continue # This issue field is searched via SQL.
+
+ for value in cond.str_values:
+ # Strip out quotes around the value.
+ value = value.strip('"')
+ special_prefixes_match = any(
+ value.startswith(p) for p in query2ast.NON_OP_PREFIXES)
+ if not special_prefixes_match:
+ value = value.replace(':', ' ')
+ assert ('"' not in value), 'Value %r has a quote in it' % value
+ parts.append(pattern % value)
+
+ if parts:
+ return neg + '(%s)' % ' OR '.join(parts)
+ else:
+ return '' # None of the fields were fulltext fields.
+
+
+def ComprehensiveSearch(fulltext_query, index_name):
+ """Call the GAE search API, and keep calling it to get all results.
+
+ Args:
+ fulltext_query: string in the GAE search API query language.
+ index_name: string name of the GAE fulltext index to hit.
+
+ Returns:
+ A list of integer issue IIDs or project IDs.
+ """
+ search_index = search.Index(name=index_name)
+
+ try:
+ response = search_index.search(search.Query(
+ fulltext_query,
+ options=search.QueryOptions(
+ limit=_SEARCH_RESULT_CHUNK_SIZE, returned_fields=[], ids_only=True,
+ cursor=search.Cursor())))
+ except ValueError as e:
+ raise query2ast.InvalidQueryError(e.message)
+
+ logging.info('got %d initial results', len(response.results))
+ ids = [int(result.doc_id) for result in response]
+
+ remaining_iterations = int(
+ (settings.fulltext_limit_per_shard - 1) // _SEARCH_RESULT_CHUNK_SIZE)
+ for _ in range(remaining_iterations):
+ if not response.cursor:
+ break
+ response = search_index.search(search.Query(
+ fulltext_query,
+ options=search.QueryOptions(
+ limit=_SEARCH_RESULT_CHUNK_SIZE, returned_fields=[], ids_only=True,
+ cursor=response.cursor)))
+ logging.info(
+ 'got %d more results: %r', len(response.results), response.results)
+ ids.extend(int(result.doc_id) for result in response)
+
+ logging.info('FTS result ids %d', len(ids))
+ return ids
diff --git a/services/issue_svc.py b/services/issue_svc.py
new file mode 100644
index 0000000..eab85ab
--- /dev/null
+++ b/services/issue_svc.py
@@ -0,0 +1,2901 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of functions that provide persistence for Monorail issue tracking.
+
+This module provides functions to get, update, create, and (in some
+cases) delete each type of business object. It provides a logical
+persistence layer on top of an SQL database.
+
+Business objects are described in tracker_pb2.py and tracker_bizobj.py.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import json
+import logging
+import os
+import time
+import uuid
+
+from google.appengine.api import app_identity
+from google.appengine.api import images
+from third_party import cloudstorage
+
+import settings
+from features import filterrules_helpers
+from framework import authdata
+from framework import exceptions
+from framework import framework_bizobj
+from framework import framework_constants
+from framework import framework_helpers
+from framework import gcs_helpers
+from framework import permissions
+from framework import sql
+from infra_libs import ts_mon
+from proto import project_pb2
+from proto import tracker_pb2
+from services import caches
+from services import tracker_fulltext
+from tracker import tracker_bizobj
+from tracker import tracker_helpers
+
+# TODO(jojwang): monorail:4693, remove this after all 'stable-full'
+# gates have been renamed to 'stable'.
+FLT_EQUIVALENT_GATES = {'stable-full': 'stable',
+ 'stable': 'stable-full'}
+
+ISSUE_TABLE_NAME = 'Issue'
+ISSUESUMMARY_TABLE_NAME = 'IssueSummary'
+ISSUE2LABEL_TABLE_NAME = 'Issue2Label'
+ISSUE2COMPONENT_TABLE_NAME = 'Issue2Component'
+ISSUE2CC_TABLE_NAME = 'Issue2Cc'
+ISSUE2NOTIFY_TABLE_NAME = 'Issue2Notify'
+ISSUE2FIELDVALUE_TABLE_NAME = 'Issue2FieldValue'
+COMMENT_TABLE_NAME = 'Comment'
+COMMENTCONTENT_TABLE_NAME = 'CommentContent'
+COMMENTIMPORTER_TABLE_NAME = 'CommentImporter'
+ATTACHMENT_TABLE_NAME = 'Attachment'
+ISSUERELATION_TABLE_NAME = 'IssueRelation'
+DANGLINGRELATION_TABLE_NAME = 'DanglingIssueRelation'
+ISSUEUPDATE_TABLE_NAME = 'IssueUpdate'
+ISSUEFORMERLOCATIONS_TABLE_NAME = 'IssueFormerLocations'
+REINDEXQUEUE_TABLE_NAME = 'ReindexQueue'
+LOCALIDCOUNTER_TABLE_NAME = 'LocalIDCounter'
+ISSUESNAPSHOT_TABLE_NAME = 'IssueSnapshot'
+ISSUESNAPSHOT2CC_TABLE_NAME = 'IssueSnapshot2Cc'
+ISSUESNAPSHOT2COMPONENT_TABLE_NAME = 'IssueSnapshot2Component'
+ISSUESNAPSHOT2LABEL_TABLE_NAME = 'IssueSnapshot2Label'
+ISSUEPHASEDEF_TABLE_NAME = 'IssuePhaseDef'
+ISSUE2APPROVALVALUE_TABLE_NAME = 'Issue2ApprovalValue'
+ISSUEAPPROVAL2APPROVER_TABLE_NAME = 'IssueApproval2Approver'
+ISSUEAPPROVAL2COMMENT_TABLE_NAME = 'IssueApproval2Comment'
+
+
+ISSUE_COLS = [
+ 'id', 'project_id', 'local_id', 'status_id', 'owner_id', 'reporter_id',
+ 'opened', 'closed', 'modified',
+ 'owner_modified', 'status_modified', 'component_modified',
+ 'derived_owner_id', 'derived_status_id',
+ 'deleted', 'star_count', 'attachment_count', 'is_spam']
+ISSUESUMMARY_COLS = ['issue_id', 'summary']
+ISSUE2LABEL_COLS = ['issue_id', 'label_id', 'derived']
+ISSUE2COMPONENT_COLS = ['issue_id', 'component_id', 'derived']
+ISSUE2CC_COLS = ['issue_id', 'cc_id', 'derived']
+ISSUE2NOTIFY_COLS = ['issue_id', 'email']
+ISSUE2FIELDVALUE_COLS = [
+ 'issue_id', 'field_id', 'int_value', 'str_value', 'user_id', 'date_value',
+ 'url_value', 'derived', 'phase_id']
+# Explicitly specify column 'Comment.id' to allow joins on other tables that
+# have an 'id' column.
+COMMENT_COLS = [
+ 'Comment.id', 'issue_id', 'created', 'Comment.project_id', 'commenter_id',
+ 'deleted_by', 'Comment.is_spam', 'is_description',
+ 'commentcontent_id'] # Note: commentcontent_id must be last.
+COMMENTCONTENT_COLS = [
+ 'CommentContent.id', 'content', 'inbound_message']
+COMMENTIMPORTER_COLS = ['comment_id', 'importer_id']
+ABBR_COMMENT_COLS = ['Comment.id', 'commenter_id', 'deleted_by',
+ 'is_description']
+ATTACHMENT_COLS = [
+ 'id', 'issue_id', 'comment_id', 'filename', 'filesize', 'mimetype',
+ 'deleted', 'gcs_object_id']
+ISSUERELATION_COLS = ['issue_id', 'dst_issue_id', 'kind', 'rank']
+ABBR_ISSUERELATION_COLS = ['dst_issue_id', 'rank']
+DANGLINGRELATION_COLS = [
+ 'issue_id', 'dst_issue_project', 'dst_issue_local_id',
+ 'ext_issue_identifier', 'kind']
+ISSUEUPDATE_COLS = [
+ 'id', 'issue_id', 'comment_id', 'field', 'old_value', 'new_value',
+ 'added_user_id', 'removed_user_id', 'custom_field_name']
+ISSUEFORMERLOCATIONS_COLS = ['issue_id', 'project_id', 'local_id']
+REINDEXQUEUE_COLS = ['issue_id', 'created']
+ISSUESNAPSHOT_COLS = ['id', 'issue_id', 'shard', 'project_id', 'local_id',
+ 'reporter_id', 'owner_id', 'status_id', 'period_start', 'period_end',
+ 'is_open']
+ISSUESNAPSHOT2CC_COLS = ['issuesnapshot_id', 'cc_id']
+ISSUESNAPSHOT2COMPONENT_COLS = ['issuesnapshot_id', 'component_id']
+ISSUESNAPSHOT2LABEL_COLS = ['issuesnapshot_id', 'label_id']
+ISSUEPHASEDEF_COLS = ['id', 'name', 'rank']
+ISSUE2APPROVALVALUE_COLS = ['approval_id', 'issue_id', 'phase_id',
+ 'status', 'setter_id', 'set_on']
+ISSUEAPPROVAL2APPROVER_COLS = ['approval_id', 'approver_id', 'issue_id']
+ISSUEAPPROVAL2COMMENT_COLS = ['approval_id', 'comment_id']
+
+CHUNK_SIZE = 1000
+
+
+class IssueIDTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for Issue IDs."""
+
+ def __init__(self, cache_manager, issue_service):
+ super(IssueIDTwoLevelCache, self).__init__(
+ cache_manager, 'issue_id', 'issue_id:', int,
+ max_size=settings.issue_cache_max_size)
+ self.issue_service = issue_service
+
+ def _MakeCache(self, cache_manager, kind, max_size=None):
+ """Override normal RamCache creation with ValueCentricRamCache."""
+ return caches.ValueCentricRamCache(cache_manager, kind, max_size=max_size)
+
+ def _DeserializeIssueIDs(self, project_local_issue_ids):
+ """Convert database rows into a dict {(project_id, local_id): issue_id}."""
+ return {(project_id, local_id): issue_id
+ for (project_id, local_id, issue_id) in project_local_issue_ids}
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database."""
+ local_ids_by_pid = collections.defaultdict(list)
+ for project_id, local_id in keys:
+ local_ids_by_pid[project_id].append(local_id)
+
+ where = [] # We OR per-project pairs of conditions together.
+ for project_id, local_ids_in_project in local_ids_by_pid.items():
+ term_str = ('(Issue.project_id = %%s AND Issue.local_id IN (%s))' %
+ sql.PlaceHolders(local_ids_in_project))
+ where.append((term_str, [project_id] + local_ids_in_project))
+
+ rows = self.issue_service.issue_tbl.Select(
+ cnxn, cols=['project_id', 'local_id', 'id'],
+ where=where, or_where_conds=True)
+ return self._DeserializeIssueIDs(rows)
+
+ def _KeyToStr(self, key):
+ """This cache uses pairs of ints as keys. Convert them to strings."""
+ return '%d,%d' % key
+
+ def _StrToKey(self, key_str):
+ """This cache uses pairs of ints as keys. Convert them from strings."""
+ project_id_str, local_id_str = key_str.split(',')
+ return int(project_id_str), int(local_id_str)
+
+
+class IssueTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for Issue PBs."""
+
+ def __init__(
+ self, cache_manager, issue_service, project_service, config_service):
+ super(IssueTwoLevelCache, self).__init__(
+ cache_manager, 'issue', 'issue:', tracker_pb2.Issue,
+ max_size=settings.issue_cache_max_size)
+ self.issue_service = issue_service
+ self.project_service = project_service
+ self.config_service = config_service
+
+ def _UnpackIssue(self, cnxn, issue_row):
+ """Partially construct an issue object using info from a DB row."""
+ (issue_id, project_id, local_id, status_id, owner_id, reporter_id,
+ opened, closed, modified, owner_modified, status_modified,
+ component_modified, derived_owner_id, derived_status_id,
+ deleted, star_count, attachment_count, is_spam) = issue_row
+
+ issue = tracker_pb2.Issue()
+ project = self.project_service.GetProject(cnxn, project_id)
+ issue.project_name = project.project_name
+ issue.issue_id = issue_id
+ issue.project_id = project_id
+ issue.local_id = local_id
+ if status_id is not None:
+ status = self.config_service.LookupStatus(cnxn, project_id, status_id)
+ issue.status = status
+ issue.owner_id = owner_id or 0
+ issue.reporter_id = reporter_id or 0
+ issue.derived_owner_id = derived_owner_id or 0
+ if derived_status_id is not None:
+ derived_status = self.config_service.LookupStatus(
+ cnxn, project_id, derived_status_id)
+ issue.derived_status = derived_status
+ issue.deleted = bool(deleted)
+ if opened:
+ issue.opened_timestamp = opened
+ if closed:
+ issue.closed_timestamp = closed
+ if modified:
+ issue.modified_timestamp = modified
+ if owner_modified:
+ issue.owner_modified_timestamp = owner_modified
+ if status_modified:
+ issue.status_modified_timestamp = status_modified
+ if component_modified:
+ issue.component_modified_timestamp = component_modified
+ issue.star_count = star_count
+ issue.attachment_count = attachment_count
+ issue.is_spam = bool(is_spam)
+ return issue
+
+ def _UnpackFieldValue(self, fv_row):
+ """Construct a field value object from a DB row."""
+ (issue_id, field_id, int_value, str_value, user_id, date_value, url_value,
+ derived, phase_id) = fv_row
+ fv = tracker_bizobj.MakeFieldValue(
+ field_id, int_value, str_value, user_id, date_value, url_value,
+ bool(derived), phase_id=phase_id)
+ return fv, issue_id
+
+ def _UnpackApprovalValue(self, av_row):
+ """Contruct an ApprovalValue PB from a DB row."""
+ (approval_id, issue_id, phase_id, status, setter_id, set_on) = av_row
+ if status:
+ status_enum = tracker_pb2.ApprovalStatus(status.upper())
+ else:
+ status_enum = tracker_pb2.ApprovalStatus.NOT_SET
+ av = tracker_pb2.ApprovalValue(
+ approval_id=approval_id, setter_id=setter_id, set_on=set_on,
+ status=status_enum, phase_id=phase_id)
+ return av, issue_id
+
+ def _UnpackPhase(self, phase_row):
+ """Construct a Phase PB from a DB row."""
+ (phase_id, name, rank) = phase_row
+ phase = tracker_pb2.Phase(
+ phase_id=phase_id, name=name, rank=rank)
+ return phase
+
+ def _DeserializeIssues(
+ self, cnxn, issue_rows, summary_rows, label_rows, component_rows,
+ cc_rows, notify_rows, fieldvalue_rows, relation_rows,
+ dangling_relation_rows, phase_rows, approvalvalue_rows,
+ av_approver_rows):
+ """Convert the given DB rows into a dict of Issue PBs."""
+ results_dict = {}
+ for issue_row in issue_rows:
+ issue = self._UnpackIssue(cnxn, issue_row)
+ results_dict[issue.issue_id] = issue
+
+ for issue_id, summary in summary_rows:
+ results_dict[issue_id].summary = summary
+
+ # TODO(jrobbins): it would be nice to order labels by rank and name.
+ for issue_id, label_id, derived in label_rows:
+ issue = results_dict.get(issue_id)
+ if not issue:
+ logging.info('Got label for an unknown issue: %r %r',
+ label_rows, issue_rows)
+ continue
+ label = self.config_service.LookupLabel(cnxn, issue.project_id, label_id)
+ assert label, ('Label ID %r on IID %r not found in project %r' %
+ (label_id, issue_id, issue.project_id))
+ if derived:
+ results_dict[issue_id].derived_labels.append(label)
+ else:
+ results_dict[issue_id].labels.append(label)
+
+ for issue_id, component_id, derived in component_rows:
+ if derived:
+ results_dict[issue_id].derived_component_ids.append(component_id)
+ else:
+ results_dict[issue_id].component_ids.append(component_id)
+
+ for issue_id, user_id, derived in cc_rows:
+ if derived:
+ results_dict[issue_id].derived_cc_ids.append(user_id)
+ else:
+ results_dict[issue_id].cc_ids.append(user_id)
+
+ for issue_id, email in notify_rows:
+ results_dict[issue_id].derived_notify_addrs.append(email)
+
+ for fv_row in fieldvalue_rows:
+ fv, issue_id = self._UnpackFieldValue(fv_row)
+ results_dict[issue_id].field_values.append(fv)
+
+ phases_by_id = {}
+ for phase_row in phase_rows:
+ phase = self._UnpackPhase(phase_row)
+ phases_by_id[phase.phase_id] = phase
+
+ approvers_dict = collections.defaultdict(list)
+ for approver_row in av_approver_rows:
+ approval_id, approver_id, issue_id = approver_row
+ approvers_dict[approval_id, issue_id].append(approver_id)
+
+ for av_row in approvalvalue_rows:
+ av, issue_id = self._UnpackApprovalValue(av_row)
+ av.approver_ids = approvers_dict[av.approval_id, issue_id]
+ results_dict[issue_id].approval_values.append(av)
+ if av.phase_id:
+ phase = phases_by_id[av.phase_id]
+ issue_phases = results_dict[issue_id].phases
+ if phase not in issue_phases:
+ issue_phases.append(phase)
+ # Order issue phases
+ for issue in results_dict.values():
+ if issue.phases:
+ issue.phases.sort(key=lambda phase: phase.rank)
+
+ for issue_id, dst_issue_id, kind, rank in relation_rows:
+ src_issue = results_dict.get(issue_id)
+ dst_issue = results_dict.get(dst_issue_id)
+ assert src_issue or dst_issue, (
+ 'Neither source issue %r nor dest issue %r was found' %
+ (issue_id, dst_issue_id))
+ if src_issue:
+ if kind == 'blockedon':
+ src_issue.blocked_on_iids.append(dst_issue_id)
+ src_issue.blocked_on_ranks.append(rank)
+ elif kind == 'mergedinto':
+ src_issue.merged_into = dst_issue_id
+ else:
+ logging.info('unknown relation kind %r', kind)
+ continue
+
+ if dst_issue:
+ if kind == 'blockedon':
+ dst_issue.blocking_iids.append(issue_id)
+
+ for row in dangling_relation_rows:
+ issue_id, dst_issue_proj, dst_issue_id, ext_id, kind = row
+ src_issue = results_dict.get(issue_id)
+ if kind == 'blockedon':
+ src_issue.dangling_blocked_on_refs.append(
+ tracker_bizobj.MakeDanglingIssueRef(dst_issue_proj,
+ dst_issue_id, ext_id))
+ elif kind == 'blocking':
+ src_issue.dangling_blocking_refs.append(
+ tracker_bizobj.MakeDanglingIssueRef(dst_issue_proj, dst_issue_id,
+ ext_id))
+ elif kind == 'mergedinto':
+ src_issue.merged_into_external = ext_id
+ else:
+ logging.warn('unhandled danging relation kind %r', kind)
+ continue
+
+ return results_dict
+
+ # Note: sharding is used to here to allow us to load issues from the replicas
+ # without placing load on the primary DB. Writes are not sharded.
+ # pylint: disable=arguments-differ
+ def FetchItems(self, cnxn, issue_ids, shard_id=None):
+ """Retrieve and deserialize issues."""
+ issue_rows = self.issue_service.issue_tbl.Select(
+ cnxn, cols=ISSUE_COLS, id=issue_ids, shard_id=shard_id)
+
+ summary_rows = self.issue_service.issuesummary_tbl.Select(
+ cnxn, cols=ISSUESUMMARY_COLS, shard_id=shard_id, issue_id=issue_ids)
+ label_rows = self.issue_service.issue2label_tbl.Select(
+ cnxn, cols=ISSUE2LABEL_COLS, shard_id=shard_id, issue_id=issue_ids)
+ component_rows = self.issue_service.issue2component_tbl.Select(
+ cnxn, cols=ISSUE2COMPONENT_COLS, shard_id=shard_id, issue_id=issue_ids)
+ cc_rows = self.issue_service.issue2cc_tbl.Select(
+ cnxn, cols=ISSUE2CC_COLS, shard_id=shard_id, issue_id=issue_ids)
+ notify_rows = self.issue_service.issue2notify_tbl.Select(
+ cnxn, cols=ISSUE2NOTIFY_COLS, shard_id=shard_id, issue_id=issue_ids)
+ fieldvalue_rows = self.issue_service.issue2fieldvalue_tbl.Select(
+ cnxn, cols=ISSUE2FIELDVALUE_COLS, shard_id=shard_id,
+ issue_id=issue_ids)
+ approvalvalue_rows = self.issue_service.issue2approvalvalue_tbl.Select(
+ cnxn, cols=ISSUE2APPROVALVALUE_COLS, issue_id=issue_ids)
+ phase_ids = [av_row[2] for av_row in approvalvalue_rows]
+ phase_rows = []
+ if phase_ids:
+ phase_rows = self.issue_service.issuephasedef_tbl.Select(
+ cnxn, cols=ISSUEPHASEDEF_COLS, id=list(set(phase_ids)))
+ av_approver_rows = self.issue_service.issueapproval2approver_tbl.Select(
+ cnxn, cols=ISSUEAPPROVAL2APPROVER_COLS, issue_id=issue_ids)
+ if issue_ids:
+ ph = sql.PlaceHolders(issue_ids)
+ blocked_on_rows = self.issue_service.issuerelation_tbl.Select(
+ cnxn, cols=ISSUERELATION_COLS, issue_id=issue_ids, kind='blockedon',
+ order_by=[('issue_id', []), ('rank DESC', []), ('dst_issue_id', [])])
+ blocking_rows = self.issue_service.issuerelation_tbl.Select(
+ cnxn, cols=ISSUERELATION_COLS, dst_issue_id=issue_ids,
+ kind='blockedon', order_by=[('issue_id', []), ('dst_issue_id', [])])
+ unique_blocking = tuple(
+ row for row in blocking_rows if row not in blocked_on_rows)
+ merge_rows = self.issue_service.issuerelation_tbl.Select(
+ cnxn, cols=ISSUERELATION_COLS,
+ where=[('(issue_id IN (%s) OR dst_issue_id IN (%s))' % (ph, ph),
+ issue_ids + issue_ids),
+ ('kind != %s', ['blockedon'])])
+ relation_rows = blocked_on_rows + unique_blocking + merge_rows
+ dangling_relation_rows = self.issue_service.danglingrelation_tbl.Select(
+ cnxn, cols=DANGLINGRELATION_COLS, issue_id=issue_ids)
+ else:
+ relation_rows = []
+ dangling_relation_rows = []
+
+ issue_dict = self._DeserializeIssues(
+ cnxn, issue_rows, summary_rows, label_rows, component_rows, cc_rows,
+ notify_rows, fieldvalue_rows, relation_rows, dangling_relation_rows,
+ phase_rows, approvalvalue_rows, av_approver_rows)
+ logging.info('IssueTwoLevelCache.FetchItems returning: %r', issue_dict)
+ return issue_dict
+
+
+class CommentTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for IssueComment PBs."""
+
+ def __init__(self, cache_manager, issue_svc):
+ super(CommentTwoLevelCache, self).__init__(
+ cache_manager, 'comment', 'comment:', tracker_pb2.IssueComment,
+ max_size=settings.comment_cache_max_size)
+ self.issue_svc = issue_svc
+
+ # pylint: disable=arguments-differ
+ def FetchItems(self, cnxn, keys, shard_id=None):
+ comment_rows = self.issue_svc.comment_tbl.Select(cnxn,
+ cols=COMMENT_COLS, id=keys, shard_id=shard_id)
+
+ if len(comment_rows) < len(keys):
+ self.issue_svc.replication_lag_retries.increment()
+ logging.info('issue3755: expected %d, but got %d rows from shard %d',
+ len(keys), len(comment_rows), shard_id)
+ shard_id = None # Will use Primary DB.
+ comment_rows = self.issue_svc.comment_tbl.Select(
+ cnxn, cols=COMMENT_COLS, id=keys, shard_id=None)
+ logging.info(
+ 'Retry got %d comment rows from the primary DB', len(comment_rows))
+
+ cids = [row[0] for row in comment_rows]
+ commentcontent_ids = [row[-1] for row in comment_rows]
+ content_rows = self.issue_svc.commentcontent_tbl.Select(
+ cnxn, cols=COMMENTCONTENT_COLS, id=commentcontent_ids,
+ shard_id=shard_id)
+ approval_rows = self.issue_svc.issueapproval2comment_tbl.Select(
+ cnxn, cols=ISSUEAPPROVAL2COMMENT_COLS, comment_id=cids)
+ amendment_rows = self.issue_svc.issueupdate_tbl.Select(
+ cnxn, cols=ISSUEUPDATE_COLS, comment_id=cids, shard_id=shard_id)
+ attachment_rows = self.issue_svc.attachment_tbl.Select(
+ cnxn, cols=ATTACHMENT_COLS, comment_id=cids, shard_id=shard_id)
+ importer_rows = self.issue_svc.commentimporter_tbl.Select(
+ cnxn, cols=COMMENTIMPORTER_COLS, comment_id=cids, shard_id=shard_id)
+
+ comments = self.issue_svc._DeserializeComments(
+ comment_rows, content_rows, amendment_rows, attachment_rows,
+ approval_rows, importer_rows)
+
+ comments_dict = {}
+ for comment in comments:
+ comments_dict[comment.id] = comment
+
+ return comments_dict
+
+
+class IssueService(object):
+ """The persistence layer for Monorail's issues, comments, and attachments."""
+ spam_labels = ts_mon.CounterMetric(
+ 'monorail/issue_svc/spam_label',
+ 'Issues created, broken down by spam label.',
+ [ts_mon.StringField('type')])
+ replication_lag_retries = ts_mon.CounterMetric(
+ 'monorail/issue_svc/replication_lag_retries',
+ 'Counts times that loading comments from a replica failed',
+ [])
+ issue_creations = ts_mon.CounterMetric(
+ 'monorail/issue_svc/issue_creations',
+ 'Counts times that issues were created',
+ [])
+ comment_creations = ts_mon.CounterMetric(
+ 'monorail/issue_svc/comment_creations',
+ 'Counts times that comments were created',
+ [])
+
+ def __init__(self, project_service, config_service, cache_manager,
+ chart_service):
+ """Initialize this object so that it is ready to use.
+
+ Args:
+ project_service: services object for project info.
+ config_service: services object for tracker configuration info.
+ cache_manager: local cache with distributed invalidation.
+ chart_service (ChartService): An instance of ChartService.
+ """
+ # Tables that represent issue data.
+ self.issue_tbl = sql.SQLTableManager(ISSUE_TABLE_NAME)
+ self.issuesummary_tbl = sql.SQLTableManager(ISSUESUMMARY_TABLE_NAME)
+ self.issue2label_tbl = sql.SQLTableManager(ISSUE2LABEL_TABLE_NAME)
+ self.issue2component_tbl = sql.SQLTableManager(ISSUE2COMPONENT_TABLE_NAME)
+ self.issue2cc_tbl = sql.SQLTableManager(ISSUE2CC_TABLE_NAME)
+ self.issue2notify_tbl = sql.SQLTableManager(ISSUE2NOTIFY_TABLE_NAME)
+ self.issue2fieldvalue_tbl = sql.SQLTableManager(ISSUE2FIELDVALUE_TABLE_NAME)
+ self.issuerelation_tbl = sql.SQLTableManager(ISSUERELATION_TABLE_NAME)
+ self.danglingrelation_tbl = sql.SQLTableManager(DANGLINGRELATION_TABLE_NAME)
+ self.issueformerlocations_tbl = sql.SQLTableManager(
+ ISSUEFORMERLOCATIONS_TABLE_NAME)
+ self.issuesnapshot_tbl = sql.SQLTableManager(ISSUESNAPSHOT_TABLE_NAME)
+ self.issuesnapshot2cc_tbl = sql.SQLTableManager(
+ ISSUESNAPSHOT2CC_TABLE_NAME)
+ self.issuesnapshot2component_tbl = sql.SQLTableManager(
+ ISSUESNAPSHOT2COMPONENT_TABLE_NAME)
+ self.issuesnapshot2label_tbl = sql.SQLTableManager(
+ ISSUESNAPSHOT2LABEL_TABLE_NAME)
+ self.issuephasedef_tbl = sql.SQLTableManager(ISSUEPHASEDEF_TABLE_NAME)
+ self.issue2approvalvalue_tbl = sql.SQLTableManager(
+ ISSUE2APPROVALVALUE_TABLE_NAME)
+ self.issueapproval2approver_tbl = sql.SQLTableManager(
+ ISSUEAPPROVAL2APPROVER_TABLE_NAME)
+ self.issueapproval2comment_tbl = sql.SQLTableManager(
+ ISSUEAPPROVAL2COMMENT_TABLE_NAME)
+
+ # Tables that represent comments.
+ self.comment_tbl = sql.SQLTableManager(COMMENT_TABLE_NAME)
+ self.commentcontent_tbl = sql.SQLTableManager(COMMENTCONTENT_TABLE_NAME)
+ self.commentimporter_tbl = sql.SQLTableManager(COMMENTIMPORTER_TABLE_NAME)
+ self.issueupdate_tbl = sql.SQLTableManager(ISSUEUPDATE_TABLE_NAME)
+ self.attachment_tbl = sql.SQLTableManager(ATTACHMENT_TABLE_NAME)
+
+ # Tables for cron tasks.
+ self.reindexqueue_tbl = sql.SQLTableManager(REINDEXQUEUE_TABLE_NAME)
+
+ # Tables for generating sequences of local IDs.
+ self.localidcounter_tbl = sql.SQLTableManager(LOCALIDCOUNTER_TABLE_NAME)
+
+ # Like a dictionary {(project_id, local_id): issue_id}
+ # Use value centric cache here because we cannot store a tuple in the
+ # Invalidate table.
+ self.issue_id_2lc = IssueIDTwoLevelCache(cache_manager, self)
+ # Like a dictionary {issue_id: issue}
+ self.issue_2lc = IssueTwoLevelCache(
+ cache_manager, self, project_service, config_service)
+
+ # Like a dictionary {comment_id: comment)
+ self.comment_2lc = CommentTwoLevelCache(
+ cache_manager, self)
+
+ self._config_service = config_service
+ self.chart_service = chart_service
+
+ ### Issue ID lookups
+
+ def LookupIssueIDsFollowMoves(self, cnxn, project_local_id_pairs):
+ # type: (MonorailConnection, Sequence[Tuple(int, int)]) ->
+ # (Sequence[int], Sequence[Tuple(int, int)])
+ """Find the global issue IDs given the project ID and local ID of each.
+
+ If any (project_id, local_id) pairs refer to an issue that has been moved,
+ the issue ID will still be returned.
+
+ Args:
+ cnxn: Monorail connection.
+ project_local_id_pairs: (project_id, local_id) pairs to look up.
+
+ Returns:
+ A tuple of two items.
+ 1. A sequence of global issue IDs in the `project_local_id_pairs` order.
+ 2. A sequence of (project_id, local_id) containing each pair provided
+ for which no matching issue is found.
+ """
+
+ issue_id_dict, misses = self.issue_id_2lc.GetAll(
+ cnxn, project_local_id_pairs)
+ for miss in misses:
+ project_id, local_id = miss
+ issue_id = int(
+ self.issueformerlocations_tbl.SelectValue(
+ cnxn,
+ 'issue_id',
+ default=0,
+ project_id=project_id,
+ local_id=local_id))
+ if issue_id:
+ misses.remove(miss)
+ issue_id_dict[miss] = issue_id
+ # Put the Issue IDs in the order specified by project_local_id_pairs
+ issue_ids = [
+ issue_id_dict[pair]
+ for pair in project_local_id_pairs
+ if pair in issue_id_dict
+ ]
+
+ return issue_ids, misses
+
+ def LookupIssueIDs(self, cnxn, project_local_id_pairs):
+ """Find the global issue IDs given the project ID and local ID of each."""
+ issue_id_dict, misses = self.issue_id_2lc.GetAll(
+ cnxn, project_local_id_pairs)
+
+ # Put the Issue IDs in the order specified by project_local_id_pairs
+ issue_ids = [issue_id_dict[pair] for pair in project_local_id_pairs
+ if pair in issue_id_dict]
+
+ return issue_ids, misses
+
+ def LookupIssueID(self, cnxn, project_id, local_id):
+ """Find the global issue ID given the project ID and local ID."""
+ issue_ids, _misses = self.LookupIssueIDs(cnxn, [(project_id, local_id)])
+ try:
+ return issue_ids[0]
+ except IndexError:
+ raise exceptions.NoSuchIssueException()
+
+ def ResolveIssueRefs(
+ self, cnxn, ref_projects, default_project_name, refs):
+ """Look up all the referenced issues and return their issue_ids.
+
+ Args:
+ cnxn: connection to SQL database.
+ ref_projects: pre-fetched dict {project_name: project} of all projects
+ mentioned in the refs as well as the default project.
+ default_project_name: string name of the current project, this is used
+ when the project_name in a ref is None.
+ refs: list of (project_name, local_id) pairs. These are parsed from
+ textual references in issue descriptions, comments, and the input
+ in the blocked-on field.
+
+ Returns:
+ A list of issue_ids for all the referenced issues. References to issues
+ in deleted projects and any issues not found are simply ignored.
+ """
+ if not refs:
+ return [], []
+
+ project_local_id_pairs = []
+ for project_name, local_id in refs:
+ project = ref_projects.get(project_name or default_project_name)
+ if not project or project.state == project_pb2.ProjectState.DELETABLE:
+ continue # ignore any refs to issues in deleted projects
+ project_local_id_pairs.append((project.project_id, local_id))
+
+ return self.LookupIssueIDs(cnxn, project_local_id_pairs) # tuple
+
+ def LookupIssueRefs(self, cnxn, issue_ids):
+ """Return {issue_id: (project_name, local_id)} for each issue_id."""
+ issue_dict, _misses = self.GetIssuesDict(cnxn, issue_ids)
+ return {
+ issue_id: (issue.project_name, issue.local_id)
+ for issue_id, issue in issue_dict.items()}
+
+ ### Issue objects
+
+ def CreateIssue(
+ self,
+ cnxn,
+ services,
+ issue,
+ marked_description,
+ attachments=None,
+ index_now=False,
+ importer_id=None):
+ """Create and store a new issue with all the given information.
+
+ Args:
+ cnxn: connection to SQL database.
+ services: persistence layer for users, issues, and projects.
+ issue: Issue PB to create.
+ marked_description: issue description with initial HTML markup.
+ attachments: [(filename, contents, mimetype),...] attachments uploaded at
+ the time the comment was made.
+ index_now: True if the issue should be updated in the full text index.
+ importer_id: optional user ID of API client importing issues for users.
+
+ Returns:
+ A tuple (the newly created Issue PB and Comment PB for the
+ issue description).
+ """
+ project_id = issue.project_id
+ reporter_id = issue.reporter_id
+ timestamp = issue.opened_timestamp
+ config = self._config_service.GetProjectConfig(cnxn, project_id)
+
+ iids_to_invalidate = set()
+ if len(issue.blocked_on_iids) != 0:
+ iids_to_invalidate.update(issue.blocked_on_iids)
+ if len(issue.blocking_iids) != 0:
+ iids_to_invalidate.update(issue.blocking_iids)
+
+ comment = self._MakeIssueComment(
+ project_id, reporter_id, marked_description,
+ attachments=attachments, timestamp=timestamp,
+ is_description=True, importer_id=importer_id)
+
+ reporter = services.user.GetUser(cnxn, reporter_id)
+ project = services.project.GetProject(cnxn, project_id)
+ reporter_auth = authdata.AuthData.FromUserID(cnxn, reporter_id, services)
+ is_project_member = framework_bizobj.UserIsInProject(
+ project, reporter_auth.effective_ids)
+ classification = services.spam.ClassifyIssue(
+ issue, comment, reporter, is_project_member)
+
+ if classification['confidence_is_spam'] > settings.classifier_spam_thresh:
+ issue.is_spam = True
+ predicted_label = 'spam'
+ else:
+ predicted_label = 'ham'
+
+ logging.info('classified new issue as %s' % predicted_label)
+ self.spam_labels.increment({'type': predicted_label})
+
+ # Create approval surveys
+ approval_comments = []
+ if len(issue.approval_values) != 0:
+ approval_defs_by_id = {ad.approval_id: ad for ad in config.approval_defs}
+ for av in issue.approval_values:
+ ad = approval_defs_by_id.get(av.approval_id)
+ if ad:
+ survey = ''
+ if ad.survey:
+ questions = ad.survey.split('\n')
+ survey = '\n'.join(['<b>' + q + '</b>' for q in questions])
+ approval_comments.append(self._MakeIssueComment(
+ project_id, reporter_id, survey, timestamp=timestamp,
+ is_description=True, approval_id=ad.approval_id))
+ else:
+ logging.info('Could not find ApprovalDef with approval_id %r',
+ av.approval_id)
+
+ issue.local_id = self.AllocateNextLocalID(cnxn, project_id)
+ self.issue_creations.increment()
+ issue_id = self.InsertIssue(cnxn, issue)
+ comment.issue_id = issue_id
+ self.InsertComment(cnxn, comment)
+ for approval_comment in approval_comments:
+ approval_comment.issue_id = issue_id
+ self.InsertComment(cnxn, approval_comment)
+
+ issue.issue_id = issue_id
+
+ # ClassifyIssue only returns confidence_is_spam, but
+ # RecordClassifierIssueVerdict records confidence of
+ # ham or spam. Therefore if ham, invert score.
+ confidence = classification['confidence_is_spam']
+ if not issue.is_spam:
+ confidence = 1.0 - confidence
+
+ services.spam.RecordClassifierIssueVerdict(
+ cnxn, issue, predicted_label=='spam',
+ confidence, classification['failed_open'])
+
+ if permissions.HasRestrictions(issue, 'view'):
+ self._config_service.InvalidateMemcache(
+ [issue], key_prefix='nonviewable:')
+
+ # Add a comment to existing issues saying they are now blocking or
+ # blocked on this issue.
+ blocked_add_issues = self.GetIssues(cnxn, issue.blocked_on_iids)
+ for add_issue in blocked_add_issues:
+ self.CreateIssueComment(
+ cnxn, add_issue, reporter_id, content='',
+ amendments=[tracker_bizobj.MakeBlockingAmendment(
+ [(issue.project_name, issue.local_id)], [],
+ default_project_name=add_issue.project_name)])
+ blocking_add_issues = self.GetIssues(cnxn, issue.blocking_iids)
+ for add_issue in blocking_add_issues:
+ self.CreateIssueComment(
+ cnxn, add_issue, reporter_id, content='',
+ amendments=[tracker_bizobj.MakeBlockedOnAmendment(
+ [(issue.project_name, issue.local_id)], [],
+ default_project_name=add_issue.project_name)])
+
+ self._UpdateIssuesModified(
+ cnxn, iids_to_invalidate, modified_timestamp=timestamp)
+
+ if index_now:
+ tracker_fulltext.IndexIssues(
+ cnxn, [issue], services.user, self, self._config_service)
+ else:
+ self.EnqueueIssuesForIndexing(cnxn, [issue.issue_id])
+
+ return issue, comment
+
+ def AllocateNewLocalIDs(self, cnxn, issues):
+ # Filter to just the issues that need new local IDs.
+ issues = [issue for issue in issues if issue.local_id < 0]
+
+ for issue in issues:
+ if issue.local_id < 0:
+ issue.local_id = self.AllocateNextLocalID(cnxn, issue.project_id)
+
+ self.UpdateIssues(cnxn, issues)
+
+ logging.info("AllocateNewLocalIDs")
+
+ def GetAllIssuesInProject(
+ self, cnxn, project_id, min_local_id=None, use_cache=True):
+ """Special query to efficiently get ALL issues in a project.
+
+ This is not done while the user is waiting, only by backround tasks.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: the ID of the project.
+ min_local_id: optional int to start at.
+ use_cache: optional boolean to turn off using the cache.
+
+ Returns:
+ A list of Issue protocol buffers for all issues.
+ """
+ all_local_ids = self.GetAllLocalIDsInProject(
+ cnxn, project_id, min_local_id=min_local_id)
+ return self.GetIssuesByLocalIDs(
+ cnxn, project_id, all_local_ids, use_cache=use_cache)
+
+ def GetAnyOnHandIssue(self, issue_ids, start=None, end=None):
+ """Get any one issue from RAM or memcache, otherwise return None."""
+ return self.issue_2lc.GetAnyOnHandItem(issue_ids, start=start, end=end)
+
+ def GetIssuesDict(self, cnxn, issue_ids, use_cache=True, shard_id=None):
+ # type: (MonorailConnection, Collection[int], Optional[Boolean],
+ # Optional[int]) -> (Dict[int, Issue], Sequence[int])
+ """Get a dict {iid: issue} from the DB or cache.
+
+ Returns:
+ A dict {iid: issue} from the DB or cache.
+ A sequence of iid that could not be found.
+ """
+ issue_dict, missed_iids = self.issue_2lc.GetAll(
+ cnxn, issue_ids, use_cache=use_cache, shard_id=shard_id)
+ if not use_cache:
+ for issue in issue_dict.values():
+ issue.assume_stale = False
+ return issue_dict, missed_iids
+
+ def GetIssues(self, cnxn, issue_ids, use_cache=True, shard_id=None):
+ # type: (MonorailConnection, Sequence[int], Optional[Boolean],
+ # Optional[int]) -> (Sequence[int])
+ """Get a list of Issue PBs from the DB or cache.
+
+ Args:
+ cnxn: connection to SQL database.
+ issue_ids: integer global issue IDs of the issues.
+ use_cache: optional boolean to turn off using the cache.
+ shard_id: optional int shard_id to limit retrieval.
+
+ Returns:
+ A list of Issue PBs in the same order as the given issue_ids.
+ """
+ issue_dict, _misses = self.GetIssuesDict(
+ cnxn, issue_ids, use_cache=use_cache, shard_id=shard_id)
+
+ # Return a list that is ordered the same as the given issue_ids.
+ issue_list = [issue_dict[issue_id] for issue_id in issue_ids
+ if issue_id in issue_dict]
+
+ return issue_list
+
+ def GetIssue(self, cnxn, issue_id, use_cache=True):
+ """Get one Issue PB from the DB.
+
+ Args:
+ cnxn: connection to SQL database.
+ issue_id: integer global issue ID of the issue.
+ use_cache: optional boolean to turn off using the cache.
+
+ Returns:
+ The requested Issue protocol buffer.
+
+ Raises:
+ NoSuchIssueException: the issue was not found.
+ """
+ issues = self.GetIssues(cnxn, [issue_id], use_cache=use_cache)
+ try:
+ return issues[0]
+ except IndexError:
+ raise exceptions.NoSuchIssueException()
+
+ def GetIssuesByLocalIDs(
+ self, cnxn, project_id, local_id_list, use_cache=True, shard_id=None):
+ """Get all the requested issues.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project to which the issues belong.
+ local_id_list: list of integer local IDs for the requested issues.
+ use_cache: optional boolean to turn off using the cache.
+ shard_id: optional int shard_id to choose a replica.
+
+ Returns:
+ List of Issue PBs for the requested issues. The result Issues
+ will be ordered in the same order as local_id_list.
+ """
+ issue_ids_to_fetch, _misses = self.LookupIssueIDs(
+ cnxn, [(project_id, local_id) for local_id in local_id_list])
+ issues = self.GetIssues(
+ cnxn, issue_ids_to_fetch, use_cache=use_cache, shard_id=shard_id)
+ return issues
+
+ def GetIssueByLocalID(self, cnxn, project_id, local_id, use_cache=True):
+ """Get one Issue PB from the DB.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: the ID of the project to which the issue belongs.
+ local_id: integer local ID of the issue.
+ use_cache: optional boolean to turn off using the cache.
+
+ Returns:
+ The requested Issue protocol buffer.
+ """
+ issues = self.GetIssuesByLocalIDs(
+ cnxn, project_id, [local_id], use_cache=use_cache)
+ try:
+ return issues[0]
+ except IndexError:
+ raise exceptions.NoSuchIssueException(
+ 'The issue %s:%d does not exist.' % (project_id, local_id))
+
+ def GetOpenAndClosedIssues(self, cnxn, issue_ids):
+ """Return the requested issues in separate open and closed lists.
+
+ Args:
+ cnxn: connection to SQL database.
+ issue_ids: list of int issue issue_ids.
+
+ Returns:
+ A pair of lists, the first with open issues, second with closed issues.
+ """
+ if not issue_ids:
+ return [], [] # make one common case efficient
+
+ issues = self.GetIssues(cnxn, issue_ids)
+ project_ids = {issue.project_id for issue in issues}
+ configs = self._config_service.GetProjectConfigs(cnxn, project_ids)
+ open_issues = []
+ closed_issues = []
+ for issue in issues:
+ config = configs[issue.project_id]
+ if tracker_helpers.MeansOpenInProject(
+ tracker_bizobj.GetStatus(issue), config):
+ open_issues.append(issue)
+ else:
+ closed_issues.append(issue)
+
+ return open_issues, closed_issues
+
+ # TODO(crbug.com/monorail/7822): Delete this method when V0 API retired.
+ def GetCurrentLocationOfMovedIssue(self, cnxn, project_id, local_id):
+ """Return the current location of a moved issue based on old location."""
+ issue_id = int(self.issueformerlocations_tbl.SelectValue(
+ cnxn, 'issue_id', default=0, project_id=project_id, local_id=local_id))
+ if not issue_id:
+ return None, None
+ project_id, local_id = self.issue_tbl.SelectRow(
+ cnxn, cols=['project_id', 'local_id'], id=issue_id)
+ return project_id, local_id
+
+ def GetPreviousLocations(self, cnxn, issue):
+ """Get all the previous locations of an issue."""
+ location_rows = self.issueformerlocations_tbl.Select(
+ cnxn, cols=['project_id', 'local_id'], issue_id=issue.issue_id)
+ locations = [(pid, local_id) for (pid, local_id) in location_rows
+ if pid != issue.project_id or local_id != issue.local_id]
+ return locations
+
+ def GetCommentsByUser(self, cnxn, user_id):
+ """Get all comments created by a user"""
+ comments = self.GetComments(cnxn, commenter_id=user_id,
+ is_description=False, limit=10000)
+ return comments
+
+ def GetIssueActivity(self, cnxn, num=50, before=None, after=None,
+ project_ids=None, user_ids=None, ascending=False):
+
+ if project_ids:
+ use_clause = (
+ 'USE INDEX (project_id) USE INDEX FOR ORDER BY (project_id)')
+ elif user_ids:
+ use_clause = (
+ 'USE INDEX (commenter_id) USE INDEX FOR ORDER BY (commenter_id)')
+ else:
+ use_clause = ''
+
+ # TODO(jrobbins): make this into a persist method.
+ # TODO(jrobbins): this really needs permission checking in SQL, which
+ # will be slow.
+ where_conds = [('Issue.id = Comment.issue_id', [])]
+ if project_ids is not None:
+ cond_str = 'Comment.project_id IN (%s)' % sql.PlaceHolders(project_ids)
+ where_conds.append((cond_str, project_ids))
+ if user_ids is not None:
+ cond_str = 'Comment.commenter_id IN (%s)' % sql.PlaceHolders(user_ids)
+ where_conds.append((cond_str, user_ids))
+
+ if before:
+ where_conds.append(('created < %s', [before]))
+ if after:
+ where_conds.append(('created > %s', [after]))
+ if ascending:
+ order_by = [('created', [])]
+ else:
+ order_by = [('created DESC', [])]
+
+ comments = self.GetComments(
+ cnxn, joins=[('Issue', [])], deleted_by=None, where=where_conds,
+ use_clause=use_clause, order_by=order_by, limit=num + 1)
+ return comments
+
+ def GetIssueIDsReportedByUser(self, cnxn, user_id):
+ """Get all issue IDs created by a user"""
+ rows = self.issue_tbl.Select(cnxn, cols=['id'], reporter_id=user_id,
+ limit=10000)
+ return [row[0] for row in rows]
+
+ def InsertIssue(self, cnxn, issue):
+ """Store the given issue in SQL.
+
+ Args:
+ cnxn: connection to SQL database.
+ issue: Issue PB to insert into the database.
+
+ Returns:
+ The int issue_id of the newly created issue.
+ """
+ status_id = self._config_service.LookupStatusID(
+ cnxn, issue.project_id, issue.status)
+ row = (issue.project_id, issue.local_id, status_id,
+ issue.owner_id or None,
+ issue.reporter_id,
+ issue.opened_timestamp,
+ issue.closed_timestamp,
+ issue.modified_timestamp,
+ issue.owner_modified_timestamp,
+ issue.status_modified_timestamp,
+ issue.component_modified_timestamp,
+ issue.derived_owner_id or None,
+ self._config_service.LookupStatusID(
+ cnxn, issue.project_id, issue.derived_status),
+ bool(issue.deleted),
+ issue.star_count, issue.attachment_count,
+ issue.is_spam)
+ # ISSUE_COLs[1:] to skip setting the ID
+ # Insert into the Primary DB.
+ generated_ids = self.issue_tbl.InsertRows(
+ cnxn, ISSUE_COLS[1:], [row], commit=False, return_generated_ids=True)
+ issue_id = generated_ids[0]
+ issue.issue_id = issue_id
+ self.issue_tbl.Update(
+ cnxn, {'shard': issue_id % settings.num_logical_shards},
+ id=issue.issue_id, commit=False)
+
+ self._UpdateIssuesSummary(cnxn, [issue], commit=False)
+ self._UpdateIssuesLabels(cnxn, [issue], commit=False)
+ self._UpdateIssuesFields(cnxn, [issue], commit=False)
+ self._UpdateIssuesComponents(cnxn, [issue], commit=False)
+ self._UpdateIssuesCc(cnxn, [issue], commit=False)
+ self._UpdateIssuesNotify(cnxn, [issue], commit=False)
+ self._UpdateIssuesRelation(cnxn, [issue], commit=False)
+ self._UpdateIssuesApprovals(cnxn, issue, commit=False)
+ self.chart_service.StoreIssueSnapshots(cnxn, [issue], commit=False)
+ cnxn.Commit()
+ self._config_service.InvalidateMemcache([issue])
+
+ return issue_id
+
+ def UpdateIssues(
+ self, cnxn, issues, update_cols=None, just_derived=False, commit=True,
+ invalidate=True):
+ """Update the given issues in SQL.
+
+ Args:
+ cnxn: connection to SQL database.
+ issues: list of issues to update, these must have been loaded with
+ use_cache=False so that issue.assume_stale is False.
+ update_cols: optional list of just the field names to update.
+ just_derived: set to True when only updating derived fields.
+ commit: set to False to skip the DB commit and do it in the caller.
+ invalidate: set to False to leave cache invalidatation to the caller.
+ """
+ if not issues:
+ return
+
+ for issue in issues: # slow, but mysql will not allow REPLACE rows.
+ assert not issue.assume_stale, (
+ 'issue2514: Storing issue that might be stale: %r' % issue)
+ delta = {
+ 'project_id': issue.project_id,
+ 'local_id': issue.local_id,
+ 'owner_id': issue.owner_id or None,
+ 'status_id': self._config_service.LookupStatusID(
+ cnxn, issue.project_id, issue.status) or None,
+ 'opened': issue.opened_timestamp,
+ 'closed': issue.closed_timestamp,
+ 'modified': issue.modified_timestamp,
+ 'owner_modified': issue.owner_modified_timestamp,
+ 'status_modified': issue.status_modified_timestamp,
+ 'component_modified': issue.component_modified_timestamp,
+ 'derived_owner_id': issue.derived_owner_id or None,
+ 'derived_status_id': self._config_service.LookupStatusID(
+ cnxn, issue.project_id, issue.derived_status) or None,
+ 'deleted': bool(issue.deleted),
+ 'star_count': issue.star_count,
+ 'attachment_count': issue.attachment_count,
+ 'is_spam': issue.is_spam,
+ }
+ if update_cols is not None:
+ delta = {key: val for key, val in delta.items()
+ if key in update_cols}
+ self.issue_tbl.Update(cnxn, delta, id=issue.issue_id, commit=False)
+
+ if not update_cols:
+ self._UpdateIssuesLabels(cnxn, issues, commit=False)
+ self._UpdateIssuesCc(cnxn, issues, commit=False)
+ self._UpdateIssuesFields(cnxn, issues, commit=False)
+ self._UpdateIssuesComponents(cnxn, issues, commit=False)
+ self._UpdateIssuesNotify(cnxn, issues, commit=False)
+ if not just_derived:
+ self._UpdateIssuesSummary(cnxn, issues, commit=False)
+ self._UpdateIssuesRelation(cnxn, issues, commit=False)
+
+ self.chart_service.StoreIssueSnapshots(cnxn, issues, commit=False)
+
+ iids_to_invalidate = [issue.issue_id for issue in issues]
+ if just_derived and invalidate:
+ self.issue_2lc.InvalidateAllKeys(cnxn, iids_to_invalidate)
+ elif invalidate:
+ self.issue_2lc.InvalidateKeys(cnxn, iids_to_invalidate)
+ if commit:
+ cnxn.Commit()
+ if invalidate:
+ self._config_service.InvalidateMemcache(issues)
+
+ def UpdateIssue(
+ self, cnxn, issue, update_cols=None, just_derived=False, commit=True,
+ invalidate=True):
+ """Update the given issue in SQL.
+
+ Args:
+ cnxn: connection to SQL database.
+ issue: the issue to update.
+ update_cols: optional list of just the field names to update.
+ just_derived: set to True when only updating derived fields.
+ commit: set to False to skip the DB commit and do it in the caller.
+ invalidate: set to False to leave cache invalidatation to the caller.
+ """
+ self.UpdateIssues(
+ cnxn, [issue], update_cols=update_cols, just_derived=just_derived,
+ commit=commit, invalidate=invalidate)
+
+ def _UpdateIssuesSummary(self, cnxn, issues, commit=True):
+ """Update the IssueSummary table rows for the given issues."""
+ self.issuesummary_tbl.InsertRows(
+ cnxn, ISSUESUMMARY_COLS,
+ [(issue.issue_id, issue.summary) for issue in issues],
+ replace=True, commit=commit)
+
+ def _UpdateIssuesLabels(self, cnxn, issues, commit=True):
+ """Update the Issue2Label table rows for the given issues."""
+ label_rows = []
+ for issue in issues:
+ issue_shard = issue.issue_id % settings.num_logical_shards
+ # TODO(jrobbins): If the user adds many novel labels in one issue update,
+ # that could be slow. Solution is to add all new labels in a batch first.
+ label_rows.extend(
+ (issue.issue_id,
+ self._config_service.LookupLabelID(cnxn, issue.project_id, label),
+ False,
+ issue_shard)
+ for label in issue.labels)
+ label_rows.extend(
+ (issue.issue_id,
+ self._config_service.LookupLabelID(cnxn, issue.project_id, label),
+ True,
+ issue_shard)
+ for label in issue.derived_labels)
+
+ self.issue2label_tbl.Delete(
+ cnxn, issue_id=[issue.issue_id for issue in issues],
+ commit=False)
+ self.issue2label_tbl.InsertRows(
+ cnxn, ISSUE2LABEL_COLS + ['issue_shard'],
+ label_rows, ignore=True, commit=commit)
+
+ def _UpdateIssuesFields(self, cnxn, issues, commit=True):
+ """Update the Issue2FieldValue table rows for the given issues."""
+ fieldvalue_rows = []
+ for issue in issues:
+ issue_shard = issue.issue_id % settings.num_logical_shards
+ for fv in issue.field_values:
+ fieldvalue_rows.append(
+ (issue.issue_id, fv.field_id, fv.int_value, fv.str_value,
+ fv.user_id or None, fv.date_value, fv.url_value, fv.derived,
+ fv.phase_id or None, issue_shard))
+
+ self.issue2fieldvalue_tbl.Delete(
+ cnxn, issue_id=[issue.issue_id for issue in issues], commit=False)
+ self.issue2fieldvalue_tbl.InsertRows(
+ cnxn, ISSUE2FIELDVALUE_COLS + ['issue_shard'],
+ fieldvalue_rows, commit=commit)
+
+ def _UpdateIssuesComponents(self, cnxn, issues, commit=True):
+ """Update the Issue2Component table rows for the given issues."""
+ issue2component_rows = []
+ for issue in issues:
+ issue_shard = issue.issue_id % settings.num_logical_shards
+ issue2component_rows.extend(
+ (issue.issue_id, component_id, False, issue_shard)
+ for component_id in issue.component_ids)
+ issue2component_rows.extend(
+ (issue.issue_id, component_id, True, issue_shard)
+ for component_id in issue.derived_component_ids)
+
+ self.issue2component_tbl.Delete(
+ cnxn, issue_id=[issue.issue_id for issue in issues], commit=False)
+ self.issue2component_tbl.InsertRows(
+ cnxn, ISSUE2COMPONENT_COLS + ['issue_shard'],
+ issue2component_rows, ignore=True, commit=commit)
+
+ def _UpdateIssuesCc(self, cnxn, issues, commit=True):
+ """Update the Issue2Cc table rows for the given issues."""
+ cc_rows = []
+ for issue in issues:
+ issue_shard = issue.issue_id % settings.num_logical_shards
+ cc_rows.extend(
+ (issue.issue_id, cc_id, False, issue_shard)
+ for cc_id in issue.cc_ids)
+ cc_rows.extend(
+ (issue.issue_id, cc_id, True, issue_shard)
+ for cc_id in issue.derived_cc_ids)
+
+ self.issue2cc_tbl.Delete(
+ cnxn, issue_id=[issue.issue_id for issue in issues], commit=False)
+ self.issue2cc_tbl.InsertRows(
+ cnxn, ISSUE2CC_COLS + ['issue_shard'],
+ cc_rows, ignore=True, commit=commit)
+
+ def _UpdateIssuesNotify(self, cnxn, issues, commit=True):
+ """Update the Issue2Notify table rows for the given issues."""
+ notify_rows = []
+ for issue in issues:
+ derived_rows = [[issue.issue_id, email]
+ for email in issue.derived_notify_addrs]
+ notify_rows.extend(derived_rows)
+
+ self.issue2notify_tbl.Delete(
+ cnxn, issue_id=[issue.issue_id for issue in issues], commit=False)
+ self.issue2notify_tbl.InsertRows(
+ cnxn, ISSUE2NOTIFY_COLS, notify_rows, ignore=True, commit=commit)
+
+ def _UpdateIssuesRelation(self, cnxn, issues, commit=True):
+ """Update the IssueRelation table rows for the given issues."""
+ relation_rows = []
+ blocking_rows = []
+ dangling_relation_rows = []
+ for issue in issues:
+ for i, dst_issue_id in enumerate(issue.blocked_on_iids):
+ rank = issue.blocked_on_ranks[i]
+ relation_rows.append((issue.issue_id, dst_issue_id, 'blockedon', rank))
+ for dst_issue_id in issue.blocking_iids:
+ blocking_rows.append((dst_issue_id, issue.issue_id, 'blockedon'))
+ for dst_ref in issue.dangling_blocked_on_refs:
+ if dst_ref.ext_issue_identifier:
+ dangling_relation_rows.append((
+ issue.issue_id, None, None,
+ dst_ref.ext_issue_identifier, 'blockedon'))
+ else:
+ dangling_relation_rows.append((
+ issue.issue_id, dst_ref.project, dst_ref.issue_id,
+ None, 'blockedon'))
+ for dst_ref in issue.dangling_blocking_refs:
+ if dst_ref.ext_issue_identifier:
+ dangling_relation_rows.append((
+ issue.issue_id, None, None,
+ dst_ref.ext_issue_identifier, 'blocking'))
+ else:
+ dangling_relation_rows.append((
+ issue.issue_id, dst_ref.project, dst_ref.issue_id,
+ dst_ref.ext_issue_identifier, 'blocking'))
+ if issue.merged_into:
+ relation_rows.append((
+ issue.issue_id, issue.merged_into, 'mergedinto', None))
+ if issue.merged_into_external:
+ dangling_relation_rows.append((
+ issue.issue_id, None, None,
+ issue.merged_into_external, 'mergedinto'))
+
+ old_blocking = self.issuerelation_tbl.Select(
+ cnxn, cols=ISSUERELATION_COLS[:-1],
+ dst_issue_id=[issue.issue_id for issue in issues], kind='blockedon')
+ relation_rows.extend([
+ (row + (0,)) for row in blocking_rows if row not in old_blocking])
+ delete_rows = [row for row in old_blocking if row not in blocking_rows]
+
+ for issue_id, dst_issue_id, kind in delete_rows:
+ self.issuerelation_tbl.Delete(cnxn, issue_id=issue_id,
+ dst_issue_id=dst_issue_id, kind=kind, commit=False)
+ self.issuerelation_tbl.Delete(
+ cnxn, issue_id=[issue.issue_id for issue in issues], commit=False)
+ self.issuerelation_tbl.InsertRows(
+ cnxn, ISSUERELATION_COLS, relation_rows, ignore=True, commit=commit)
+ self.danglingrelation_tbl.Delete(
+ cnxn, issue_id=[issue.issue_id for issue in issues], commit=False)
+ self.danglingrelation_tbl.InsertRows(
+ cnxn, DANGLINGRELATION_COLS, dangling_relation_rows, ignore=True,
+ commit=commit)
+
+ def _UpdateIssuesModified(
+ self, cnxn, iids, modified_timestamp=None, invalidate=True):
+ """Store a modified timestamp for each of the specified issues."""
+ if not iids:
+ return
+ delta = {'modified': modified_timestamp or int(time.time())}
+ self.issue_tbl.Update(cnxn, delta, id=iids, commit=False)
+ if invalidate:
+ self.InvalidateIIDs(cnxn, iids)
+
+ def _UpdateIssuesApprovals(self, cnxn, issue, commit=True):
+ """Update the Issue2ApprovalValue table rows for the given issue."""
+ self.issue2approvalvalue_tbl.Delete(
+ cnxn, issue_id=issue.issue_id, commit=commit)
+ av_rows = [(av.approval_id, issue.issue_id, av.phase_id,
+ av.status.name.lower(), av.setter_id, av.set_on) for
+ av in issue.approval_values]
+ self.issue2approvalvalue_tbl.InsertRows(
+ cnxn, ISSUE2APPROVALVALUE_COLS, av_rows, commit=commit)
+
+ approver_rows = []
+ for av in issue.approval_values:
+ approver_rows.extend([(av.approval_id, approver_id, issue.issue_id)
+ for approver_id in av.approver_ids])
+ self.issueapproval2approver_tbl.Delete(
+ cnxn, issue_id=issue.issue_id, commit=commit)
+ self.issueapproval2approver_tbl.InsertRows(
+ cnxn, ISSUEAPPROVAL2APPROVER_COLS, approver_rows, commit=commit)
+
+ def UpdateIssueStructure(self, cnxn, config, issue, template, reporter_id,
+ comment_content, commit=True, invalidate=True):
+ """Converts the phases and approvals structure of the issue into the
+ structure of the given template."""
+ # TODO(jojwang): Remove Field defs that belong to any removed approvals.
+ approval_defs_by_id = {ad.approval_id: ad for ad in config.approval_defs}
+ issue_avs_by_id = {av.approval_id: av for av in issue.approval_values}
+
+ new_approval_surveys = []
+ new_issue_approvals = []
+
+ for template_av in template.approval_values:
+ existing_issue_av = issue_avs_by_id.get(template_av.approval_id)
+
+ # Update all approval surveys so latest ApprovalDef survey changes
+ # appear in the converted issue's approval values.
+ ad = approval_defs_by_id.get(template_av.approval_id)
+ new_av_approver_ids = []
+ if ad:
+ new_av_approver_ids = ad.approver_ids
+ new_approval_surveys.append(
+ self._MakeIssueComment(
+ issue.project_id, reporter_id, ad.survey,
+ is_description=True, approval_id=ad.approval_id))
+ else:
+ logging.info('ApprovalDef not found for approval %r', template_av)
+
+ # Keep approval values as-is if it exists in issue and template
+ if existing_issue_av:
+ new_av = tracker_bizobj.MakeApprovalValue(
+ existing_issue_av.approval_id,
+ approver_ids=existing_issue_av.approver_ids,
+ status=existing_issue_av.status,
+ setter_id=existing_issue_av.setter_id,
+ set_on=existing_issue_av.set_on,
+ phase_id=template_av.phase_id)
+ new_issue_approvals.append(new_av)
+ else:
+ new_av = tracker_bizobj.MakeApprovalValue(
+ template_av.approval_id, approver_ids=new_av_approver_ids,
+ status=template_av.status, phase_id=template_av.phase_id)
+ new_issue_approvals.append(new_av)
+
+ template_phase_by_name = {
+ phase.name.lower(): phase for phase in template.phases}
+ issue_phase_by_id = {phase.phase_id: phase for phase in issue.phases}
+ updated_fvs = []
+ # Trim issue FieldValues or update FieldValue phase_ids
+ for fv in issue.field_values:
+ # If a fv's phase has the same name as a template's phase, update
+ # the fv's phase_id to that of the template phase's. Otherwise,
+ # remove the fv.
+ if fv.phase_id:
+ issue_phase = issue_phase_by_id.get(fv.phase_id)
+ if issue_phase and issue_phase.name:
+ template_phase = template_phase_by_name.get(issue_phase.name.lower())
+ # TODO(jojwang): monorail:4693, remove this after all 'stable-full'
+ # gates have been renamed to 'stable'.
+ if not template_phase:
+ template_phase = template_phase_by_name.get(
+ FLT_EQUIVALENT_GATES.get(issue_phase.name.lower()))
+ if template_phase:
+ fv.phase_id = template_phase.phase_id
+ updated_fvs.append(fv)
+ # keep all fvs that do not belong to phases.
+ else:
+ updated_fvs.append(fv)
+
+ fd_names_by_id = {fd.field_id: fd.field_name for fd in config.field_defs}
+ amendment = tracker_bizobj.MakeApprovalStructureAmendment(
+ [fd_names_by_id.get(av.approval_id) for av in new_issue_approvals],
+ [fd_names_by_id.get(av.approval_id) for av in issue.approval_values])
+
+ # Update issue structure in RAM.
+ issue.approval_values = new_issue_approvals
+ issue.phases = template.phases
+ issue.field_values = updated_fvs
+
+ # Update issue structure in DB.
+ for survey in new_approval_surveys:
+ survey.issue_id = issue.issue_id
+ self.InsertComment(cnxn, survey, commit=False)
+ self._UpdateIssuesApprovals(cnxn, issue, commit=False)
+ self._UpdateIssuesFields(cnxn, [issue], commit=False)
+ comment_pb = self.CreateIssueComment(
+ cnxn, issue, reporter_id, comment_content,
+ amendments=[amendment], commit=False)
+
+ if commit:
+ cnxn.Commit()
+
+ if invalidate:
+ self.InvalidateIIDs(cnxn, [issue.issue_id])
+
+ return comment_pb
+
+ def DeltaUpdateIssue(
+ self, cnxn, services, reporter_id, project_id,
+ config, issue, delta, index_now=False, comment=None, attachments=None,
+ iids_to_invalidate=None, rules=None, predicate_asts=None,
+ is_description=False, timestamp=None, kept_attachments=None,
+ importer_id=None, inbound_message=None):
+ """Update the issue in the database and return a set of update tuples.
+
+ Args:
+ cnxn: connection to SQL database.
+ services: connections to persistence layer.
+ reporter_id: user ID of the user making this change.
+ project_id: int ID for the current project.
+ config: ProjectIssueConfig PB for this project.
+ issue: Issue PB of issue to update.
+ delta: IssueDelta object of fields to update.
+ index_now: True if the issue should be updated in the full text index.
+ comment: This should be the content of the comment
+ corresponding to this change.
+ attachments: List [(filename, contents, mimetype),...] of attachments.
+ iids_to_invalidate: optional set of issue IDs that need to be invalidated.
+ If provided, affected issues will be accumulated here and, the caller
+ must call InvalidateIIDs() afterwards.
+ rules: optional list of preloaded FilterRule PBs for this project.
+ predicate_asts: optional list of QueryASTs for the rules. If rules are
+ provided, then predicate_asts should also be provided.
+ is_description: True if the comment is a new description for the issue.
+ timestamp: int timestamp set during testing, otherwise defaults to
+ int(time.time()).
+ kept_attachments: This should be a list of int attachment ids for
+ attachments kept from previous descriptions, if the comment is
+ a change to the issue description
+ importer_id: optional ID of user ID for an API client that is importing
+ issues and attributing them to other users.
+ inbound_message: optional string full text of an email that caused
+ this comment to be added.
+
+ Returns:
+ A tuple (amendments, comment_pb) with a list of Amendment PBs that
+ describe the set of metadata updates that the user made, and the
+ resulting IssueComment (or None if no comment was created).
+ """
+ timestamp = timestamp or int(time.time())
+ old_effective_owner = tracker_bizobj.GetOwnerId(issue)
+ old_effective_status = tracker_bizobj.GetStatus(issue)
+ old_components = set(issue.component_ids)
+
+ logging.info(
+ 'Bulk edit to project_id %s issue.local_id %s, comment %r',
+ project_id, issue.local_id, comment)
+ if iids_to_invalidate is None:
+ iids_to_invalidate = set([issue.issue_id])
+ invalidate = True
+ else:
+ iids_to_invalidate.add(issue.issue_id)
+ invalidate = False # Caller will do it.
+
+ # Store each updated value in the issue PB, and compute Update PBs
+ amendments, impacted_iids = tracker_bizobj.ApplyIssueDelta(
+ cnxn, self, issue, delta, config)
+ iids_to_invalidate.update(impacted_iids)
+
+ # If this was a no-op with no comment, bail out and don't save,
+ # invalidate, or re-index anything.
+ if (not amendments and (not comment or not comment.strip()) and
+ not attachments):
+ logging.info('No amendments, comment, attachments: this is a no-op.')
+ return [], None
+
+ # Note: no need to check for collisions when the user is doing a delta.
+
+ # update the modified_timestamp for any comment added, even if it was
+ # just a text comment with no issue fields changed.
+ issue.modified_timestamp = timestamp
+
+ # Update the closed timestamp before filter rules so that rules
+ # can test for closed_timestamp, and also after filter rules
+ # so that closed_timestamp will be set if the issue is closed by the rule.
+ tracker_helpers.UpdateClosedTimestamp(config, issue, old_effective_status)
+ if rules is None:
+ logging.info('Rules were not given')
+ rules = services.features.GetFilterRules(cnxn, config.project_id)
+ predicate_asts = filterrules_helpers.ParsePredicateASTs(
+ rules, config, [])
+
+ filterrules_helpers.ApplyGivenRules(
+ cnxn, services, issue, config, rules, predicate_asts)
+ tracker_helpers.UpdateClosedTimestamp(config, issue, old_effective_status)
+ if old_effective_owner != tracker_bizobj.GetOwnerId(issue):
+ issue.owner_modified_timestamp = timestamp
+ if old_effective_status != tracker_bizobj.GetStatus(issue):
+ issue.status_modified_timestamp = timestamp
+ if old_components != set(issue.component_ids):
+ issue.component_modified_timestamp = timestamp
+
+ # Store the issue in SQL.
+ self.UpdateIssue(cnxn, issue, commit=False, invalidate=False)
+
+ comment_pb = self.CreateIssueComment(
+ cnxn, issue, reporter_id, comment, amendments=amendments,
+ is_description=is_description, attachments=attachments, commit=False,
+ kept_attachments=kept_attachments, timestamp=timestamp,
+ importer_id=importer_id, inbound_message=inbound_message)
+ self._UpdateIssuesModified(
+ cnxn, iids_to_invalidate, modified_timestamp=issue.modified_timestamp,
+ invalidate=invalidate)
+
+ # Add a comment to the newly added issues saying they are now blocking
+ # this issue.
+ for add_issue in self.GetIssues(cnxn, delta.blocked_on_add):
+ self.CreateIssueComment(
+ cnxn, add_issue, reporter_id, content='',
+ amendments=[tracker_bizobj.MakeBlockingAmendment(
+ [(issue.project_name, issue.local_id)], [],
+ default_project_name=add_issue.project_name)],
+ timestamp=timestamp, importer_id=importer_id)
+ # Add a comment to the newly removed issues saying they are no longer
+ # blocking this issue.
+ for remove_issue in self.GetIssues(cnxn, delta.blocked_on_remove):
+ self.CreateIssueComment(
+ cnxn, remove_issue, reporter_id, content='',
+ amendments=[tracker_bizobj.MakeBlockingAmendment(
+ [], [(issue.project_name, issue.local_id)],
+ default_project_name=remove_issue.project_name)],
+ timestamp=timestamp, importer_id=importer_id)
+
+ # Add a comment to the newly added issues saying they are now blocked on
+ # this issue.
+ for add_issue in self.GetIssues(cnxn, delta.blocking_add):
+ self.CreateIssueComment(
+ cnxn, add_issue, reporter_id, content='',
+ amendments=[tracker_bizobj.MakeBlockedOnAmendment(
+ [(issue.project_name, issue.local_id)], [],
+ default_project_name=add_issue.project_name)],
+ timestamp=timestamp, importer_id=importer_id)
+ # Add a comment to the newly removed issues saying they are no longer
+ # blocked on this issue.
+ for remove_issue in self.GetIssues(cnxn, delta.blocking_remove):
+ self.CreateIssueComment(
+ cnxn, remove_issue, reporter_id, content='',
+ amendments=[tracker_bizobj.MakeBlockedOnAmendment(
+ [], [(issue.project_name, issue.local_id)],
+ default_project_name=remove_issue.project_name)],
+ timestamp=timestamp, importer_id=importer_id)
+
+ if not invalidate:
+ cnxn.Commit()
+
+ if index_now:
+ tracker_fulltext.IndexIssues(
+ cnxn, [issue], services.user_service, self, self._config_service)
+ else:
+ self.EnqueueIssuesForIndexing(cnxn, [issue.issue_id])
+
+ return amendments, comment_pb
+
+ def InvalidateIIDs(self, cnxn, iids_to_invalidate):
+ """Invalidate the specified issues in the Invalidate table and memcache."""
+ issues_to_invalidate = self.GetIssues(cnxn, iids_to_invalidate)
+ self.InvalidateIssues(cnxn, issues_to_invalidate)
+
+ def InvalidateIssues(self, cnxn, issues):
+ """Invalidate the specified issues in the Invalidate table and memcache."""
+ iids = [issue.issue_id for issue in issues]
+ self.issue_2lc.InvalidateKeys(cnxn, iids)
+ self._config_service.InvalidateMemcache(issues)
+
+ def RelateIssues(self, cnxn, issue_relation_dict, commit=True):
+ """Update the IssueRelation table rows for the given relationships.
+
+ issue_relation_dict is a mapping of 'source' issues to 'destination' issues,
+ paired with the kind of relationship connecting the two.
+ """
+ relation_rows = []
+ for src_iid, dests in issue_relation_dict.items():
+ for dst_iid, kind in dests:
+ if kind == 'blocking':
+ relation_rows.append((dst_iid, src_iid, 'blockedon', 0))
+ elif kind == 'blockedon':
+ relation_rows.append((src_iid, dst_iid, 'blockedon', 0))
+ elif kind == 'mergedinto':
+ relation_rows.append((src_iid, dst_iid, 'mergedinto', None))
+
+ self.issuerelation_tbl.InsertRows(
+ cnxn, ISSUERELATION_COLS, relation_rows, ignore=True, commit=commit)
+
+ def CopyIssues(self, cnxn, dest_project, issues, user_service, copier_id):
+ """Copy the given issues into the destination project."""
+ created_issues = []
+ iids_to_invalidate = set()
+
+ for target_issue in issues:
+ assert not target_issue.assume_stale, (
+ 'issue2514: Copying issue that might be stale: %r' % target_issue)
+ new_issue = tracker_pb2.Issue()
+ new_issue.project_id = dest_project.project_id
+ new_issue.project_name = dest_project.project_name
+ new_issue.summary = target_issue.summary
+ new_issue.labels.extend(target_issue.labels)
+ new_issue.field_values.extend(target_issue.field_values)
+ new_issue.reporter_id = copier_id
+
+ timestamp = int(time.time())
+ new_issue.opened_timestamp = timestamp
+ new_issue.modified_timestamp = timestamp
+
+ target_comments = self.GetCommentsForIssue(cnxn, target_issue.issue_id)
+ initial_summary_comment = target_comments[0]
+
+ # Note that blocking and merge_into are not copied.
+ if target_issue.blocked_on_iids:
+ blocked_on = target_issue.blocked_on_iids
+ iids_to_invalidate.update(blocked_on)
+ new_issue.blocked_on_iids = blocked_on
+
+ # Gather list of attachments from the target issue's summary comment.
+ # MakeIssueComments expects a list of [(filename, contents, mimetype),...]
+ attachments = []
+ for attachment in initial_summary_comment.attachments:
+ object_path = ('/' + app_identity.get_default_gcs_bucket_name() +
+ attachment.gcs_object_id)
+ with cloudstorage.open(object_path, 'r') as f:
+ content = f.read()
+ attachments.append(
+ [attachment.filename, content, attachment.mimetype])
+
+ if attachments:
+ new_issue.attachment_count = len(attachments)
+
+ # Create the same summary comment as the target issue.
+ comment = self._MakeIssueComment(
+ dest_project.project_id, copier_id, initial_summary_comment.content,
+ attachments=attachments, timestamp=timestamp, is_description=True)
+
+ new_issue.local_id = self.AllocateNextLocalID(
+ cnxn, dest_project.project_id)
+ issue_id = self.InsertIssue(cnxn, new_issue)
+ comment.issue_id = issue_id
+ self.InsertComment(cnxn, comment)
+
+ if permissions.HasRestrictions(new_issue, 'view'):
+ self._config_service.InvalidateMemcache(
+ [new_issue], key_prefix='nonviewable:')
+
+ tracker_fulltext.IndexIssues(
+ cnxn, [new_issue], user_service, self, self._config_service)
+ created_issues.append(new_issue)
+
+ # The referenced issues are all modified when the relationship is added.
+ self._UpdateIssuesModified(
+ cnxn, iids_to_invalidate, modified_timestamp=timestamp)
+
+ return created_issues
+
+ def MoveIssues(self, cnxn, dest_project, issues, user_service):
+ """Move the given issues into the destination project."""
+ old_location_rows = [
+ (issue.issue_id, issue.project_id, issue.local_id)
+ for issue in issues]
+ moved_back_iids = set()
+
+ former_locations_in_project = self.issueformerlocations_tbl.Select(
+ cnxn, cols=ISSUEFORMERLOCATIONS_COLS,
+ project_id=dest_project.project_id,
+ issue_id=[issue.issue_id for issue in issues])
+ former_locations = {
+ issue_id: local_id
+ for issue_id, project_id, local_id in former_locations_in_project}
+
+ # Remove the issue id from issue_id_2lc so that it does not stay
+ # around in cache and memcache.
+ # The Key of IssueIDTwoLevelCache is (project_id, local_id).
+ self.issue_id_2lc.InvalidateKeys(
+ cnxn, [(issue.project_id, issue.local_id) for issue in issues])
+ self.InvalidateIssues(cnxn, issues)
+
+ for issue in issues:
+ if issue.issue_id in former_locations:
+ dest_id = former_locations[issue.issue_id]
+ moved_back_iids.add(issue.issue_id)
+ else:
+ dest_id = self.AllocateNextLocalID(cnxn, dest_project.project_id)
+
+ issue.local_id = dest_id
+ issue.project_id = dest_project.project_id
+ issue.project_name = dest_project.project_name
+
+ # Rewrite each whole issue so that status and label IDs are looked up
+ # in the context of the destination project.
+ self.UpdateIssues(cnxn, issues)
+
+ # Comments also have the project_id because it is needed for an index.
+ self.comment_tbl.Update(
+ cnxn, {'project_id': dest_project.project_id},
+ issue_id=[issue.issue_id for issue in issues], commit=False)
+
+ # Record old locations so that we can offer links if the user looks there.
+ self.issueformerlocations_tbl.InsertRows(
+ cnxn, ISSUEFORMERLOCATIONS_COLS, old_location_rows, ignore=True,
+ commit=False)
+ cnxn.Commit()
+
+ tracker_fulltext.IndexIssues(
+ cnxn, issues, user_service, self, self._config_service)
+
+ return moved_back_iids
+
+ def ExpungeFormerLocations(self, cnxn, project_id):
+ """Delete history of issues that were in this project but moved out."""
+ self.issueformerlocations_tbl.Delete(cnxn, project_id=project_id)
+
+ def ExpungeIssues(self, cnxn, issue_ids):
+ """Completely delete the specified issues from the database."""
+ logging.info('expunging the issues %r', issue_ids)
+ tracker_fulltext.UnindexIssues(issue_ids)
+
+ remaining_iids = issue_ids[:]
+
+ # Note: these are purposely not done in a transaction to allow
+ # incremental progress in what might be a very large change.
+ # We are not concerned about non-atomic deletes because all
+ # this data will be gone eventually anyway.
+ while remaining_iids:
+ iids_in_chunk = remaining_iids[:CHUNK_SIZE]
+ remaining_iids = remaining_iids[CHUNK_SIZE:]
+ self.issuesummary_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.issue2label_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.issue2component_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.issue2cc_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.issue2notify_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.issueupdate_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.attachment_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.comment_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.issuerelation_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.issuerelation_tbl.Delete(cnxn, dst_issue_id=iids_in_chunk)
+ self.danglingrelation_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.issueformerlocations_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.reindexqueue_tbl.Delete(cnxn, issue_id=iids_in_chunk)
+ self.issue_tbl.Delete(cnxn, id=iids_in_chunk)
+
+ def SoftDeleteIssue(self, cnxn, project_id, local_id, deleted, user_service):
+ """Set the deleted boolean on the indicated issue and store it.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int project ID for the current project.
+ local_id: int local ID of the issue to freeze/unfreeze.
+ deleted: boolean, True to soft-delete, False to undelete.
+ user_service: persistence layer for users, used to lookup user IDs.
+ """
+ issue = self.GetIssueByLocalID(cnxn, project_id, local_id, use_cache=False)
+ issue.deleted = deleted
+ self.UpdateIssue(cnxn, issue, update_cols=['deleted'])
+ tracker_fulltext.IndexIssues(
+ cnxn, [issue], user_service, self, self._config_service)
+
+ def DeleteComponentReferences(self, cnxn, component_id):
+ """Delete any references to the specified component."""
+ # TODO(jrobbins): add tasks to re-index any affected issues.
+ # Note: if this call fails, some data could be left
+ # behind, but it would not be displayed, and it could always be
+ # GC'd from the DB later.
+ self.issue2component_tbl.Delete(cnxn, component_id=component_id)
+
+ ### Local ID generation
+
+ def InitializeLocalID(self, cnxn, project_id):
+ """Initialize the local ID counter for the specified project to zero.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project.
+ """
+ self.localidcounter_tbl.InsertRow(
+ cnxn, project_id=project_id, used_local_id=0, used_spam_id=0)
+
+ def SetUsedLocalID(self, cnxn, project_id):
+ """Set the local ID counter based on existing issues.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project.
+ """
+ highest_id = self.GetHighestLocalID(cnxn, project_id)
+ self.localidcounter_tbl.InsertRow(
+ cnxn, replace=True, used_local_id=highest_id, project_id=project_id)
+ return highest_id
+
+ def AllocateNextLocalID(self, cnxn, project_id):
+ """Return the next available issue ID in the specified project.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project.
+
+ Returns:
+ The next local ID.
+ """
+ try:
+ next_local_id = self.localidcounter_tbl.IncrementCounterValue(
+ cnxn, 'used_local_id', project_id=project_id)
+ except AssertionError as e:
+ logging.info('exception incrementing local_id counter: %s', e)
+ next_local_id = self.SetUsedLocalID(cnxn, project_id) + 1
+ return next_local_id
+
+ def GetHighestLocalID(self, cnxn, project_id):
+ """Return the highest used issue ID in the specified project.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project.
+
+ Returns:
+ The highest local ID for an active or moved issues.
+ """
+ highest = self.issue_tbl.SelectValue(
+ cnxn, 'MAX(local_id)', project_id=project_id)
+ highest = highest or 0 # It will be None if the project has no issues.
+ highest_former = self.issueformerlocations_tbl.SelectValue(
+ cnxn, 'MAX(local_id)', project_id=project_id)
+ highest_former = highest_former or 0
+ return max(highest, highest_former)
+
+ def GetAllLocalIDsInProject(self, cnxn, project_id, min_local_id=None):
+ """Return the list of local IDs only, not the actual issues.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: the ID of the project to which the issue belongs.
+ min_local_id: point to start at.
+
+ Returns:
+ A range object of local IDs from 1 to N, or from min_local_id to N. It
+ may be the case that some of those local IDs are no longer used, e.g.,
+ if some issues were moved out of this project.
+ """
+ if not min_local_id:
+ min_local_id = 1
+ highest_local_id = self.GetHighestLocalID(cnxn, project_id)
+ return list(range(min_local_id, highest_local_id + 1))
+
+ def ExpungeLocalIDCounters(self, cnxn, project_id):
+ """Delete history of local ids that were in this project."""
+ self.localidcounter_tbl.Delete(cnxn, project_id=project_id)
+
+ ### Comments
+
+ def _UnpackComment(
+ self, comment_row, content_dict, inbound_message_dict, approval_dict,
+ importer_dict):
+ """Partially construct a Comment PB from a DB row."""
+ (comment_id, issue_id, created, project_id, commenter_id,
+ deleted_by, is_spam, is_description, commentcontent_id) = comment_row
+ comment = tracker_pb2.IssueComment()
+ comment.id = comment_id
+ comment.issue_id = issue_id
+ comment.timestamp = created
+ comment.project_id = project_id
+ comment.user_id = commenter_id
+ comment.content = content_dict.get(commentcontent_id, '')
+ comment.inbound_message = inbound_message_dict.get(commentcontent_id, '')
+ comment.deleted_by = deleted_by or 0
+ comment.is_spam = bool(is_spam)
+ comment.is_description = bool(is_description)
+ comment.approval_id = approval_dict.get(comment_id)
+ comment.importer_id = importer_dict.get(comment_id)
+ return comment
+
+ def _UnpackAmendment(self, amendment_row):
+ """Construct an Amendment PB from a DB row."""
+ (_id, _issue_id, comment_id, field_name,
+ old_value, new_value, added_user_id, removed_user_id,
+ custom_field_name) = amendment_row
+ amendment = tracker_pb2.Amendment()
+ field_enum = tracker_pb2.FieldID(field_name.upper())
+ amendment.field = field_enum
+
+ # TODO(jrobbins): display old values in more cases.
+ if new_value is not None:
+ amendment.newvalue = new_value
+ if old_value is not None:
+ amendment.oldvalue = old_value
+ if added_user_id:
+ amendment.added_user_ids.append(added_user_id)
+ if removed_user_id:
+ amendment.removed_user_ids.append(removed_user_id)
+ if custom_field_name:
+ amendment.custom_field_name = custom_field_name
+ return amendment, comment_id
+
+ def _ConsolidateAmendments(self, amendments):
+ """Consoliodate amendments of the same field in one comment into one
+ amendment PB."""
+
+ fields_dict = {}
+ result = []
+
+ for amendment in amendments:
+ key = amendment.field, amendment.custom_field_name
+ fields_dict.setdefault(key, []).append(amendment)
+ for (field, _custom_name), sorted_amendments in sorted(fields_dict.items()):
+ new_amendment = tracker_pb2.Amendment()
+ new_amendment.field = field
+ for amendment in sorted_amendments:
+ if amendment.newvalue is not None:
+ if new_amendment.newvalue is not None:
+ # NOTE: see crbug/monorail/8272. BLOCKEDON and BLOCKING changes
+ # are all stored in newvalue e.g. (newvalue = -b/123 b/124) and
+ # external bugs and monorail bugs are stored in separate amendments.
+ # Without this, the values of external bug amendments and monorail
+ # blocker bug amendments may overwrite each other.
+ new_amendment.newvalue += (' ' + amendment.newvalue)
+ else:
+ new_amendment.newvalue = amendment.newvalue
+ if amendment.oldvalue is not None:
+ new_amendment.oldvalue = amendment.oldvalue
+ if amendment.added_user_ids:
+ new_amendment.added_user_ids.extend(amendment.added_user_ids)
+ if amendment.removed_user_ids:
+ new_amendment.removed_user_ids.extend(amendment.removed_user_ids)
+ if amendment.custom_field_name:
+ new_amendment.custom_field_name = amendment.custom_field_name
+ result.append(new_amendment)
+ return result
+
+ def _UnpackAttachment(self, attachment_row):
+ """Construct an Attachment PB from a DB row."""
+ (attachment_id, _issue_id, comment_id, filename, filesize, mimetype,
+ deleted, gcs_object_id) = attachment_row
+ attach = tracker_pb2.Attachment()
+ attach.attachment_id = attachment_id
+ attach.filename = filename
+ attach.filesize = filesize
+ attach.mimetype = mimetype
+ attach.deleted = bool(deleted)
+ attach.gcs_object_id = gcs_object_id
+ return attach, comment_id
+
+ def _DeserializeComments(
+ self, comment_rows, commentcontent_rows, amendment_rows, attachment_rows,
+ approval_rows, importer_rows):
+ """Turn rows into IssueComment PBs."""
+ results = [] # keep objects in the same order as the rows
+ results_dict = {} # for fast access when joining.
+
+ content_dict = dict(
+ (commentcontent_id, content) for
+ commentcontent_id, content, _ in commentcontent_rows)
+ inbound_message_dict = dict(
+ (commentcontent_id, inbound_message) for
+ commentcontent_id, _, inbound_message in commentcontent_rows)
+ approval_dict = dict(
+ (comment_id, approval_id) for approval_id, comment_id in
+ approval_rows)
+ importer_dict = dict(importer_rows)
+
+ for comment_row in comment_rows:
+ comment = self._UnpackComment(
+ comment_row, content_dict, inbound_message_dict, approval_dict,
+ importer_dict)
+ results.append(comment)
+ results_dict[comment.id] = comment
+
+ for amendment_row in amendment_rows:
+ amendment, comment_id = self._UnpackAmendment(amendment_row)
+ try:
+ results_dict[comment_id].amendments.extend([amendment])
+ except KeyError:
+ logging.error('Found amendment for missing comment: %r', comment_id)
+
+ for attachment_row in attachment_rows:
+ attach, comment_id = self._UnpackAttachment(attachment_row)
+ try:
+ results_dict[comment_id].attachments.append(attach)
+ except KeyError:
+ logging.error('Found attachment for missing comment: %r', comment_id)
+
+ for c in results:
+ c.amendments = self._ConsolidateAmendments(c.amendments)
+
+ return results
+
+ # TODO(jrobbins): make this a private method and expose just the interface
+ # needed by activities.py.
+ def GetComments(
+ self, cnxn, where=None, order_by=None, content_only=False, **kwargs):
+ """Retrieve comments from SQL."""
+ shard_id = sql.RandomShardID()
+ order_by = order_by or [('created', [])]
+ comment_rows = self.comment_tbl.Select(
+ cnxn, cols=COMMENT_COLS, where=where,
+ order_by=order_by, shard_id=shard_id, **kwargs)
+ cids = [row[0] for row in comment_rows]
+ commentcontent_ids = [row[-1] for row in comment_rows]
+ content_rows = self.commentcontent_tbl.Select(
+ cnxn, cols=COMMENTCONTENT_COLS, id=commentcontent_ids,
+ shard_id=shard_id)
+ approval_rows = self.issueapproval2comment_tbl.Select(
+ cnxn, cols=ISSUEAPPROVAL2COMMENT_COLS, comment_id=cids)
+ amendment_rows = []
+ attachment_rows = []
+ importer_rows = []
+ if not content_only:
+ amendment_rows = self.issueupdate_tbl.Select(
+ cnxn, cols=ISSUEUPDATE_COLS, comment_id=cids, shard_id=shard_id)
+ attachment_rows = self.attachment_tbl.Select(
+ cnxn, cols=ATTACHMENT_COLS, comment_id=cids, shard_id=shard_id)
+ importer_rows = self.commentimporter_tbl.Select(
+ cnxn, cols=COMMENTIMPORTER_COLS, comment_id=cids, shard_id=shard_id)
+
+ comments = self._DeserializeComments(
+ comment_rows, content_rows, amendment_rows, attachment_rows,
+ approval_rows, importer_rows)
+ return comments
+
+ def GetComment(self, cnxn, comment_id):
+ """Get the requested comment, or raise an exception."""
+ comments = self.GetComments(cnxn, id=comment_id)
+ try:
+ return comments[0]
+ except IndexError:
+ raise exceptions.NoSuchCommentException()
+
+ def GetCommentsForIssue(self, cnxn, issue_id):
+ """Return all IssueComment PBs for the specified issue.
+
+ Args:
+ cnxn: connection to SQL database.
+ issue_id: int global ID of the issue.
+
+ Returns:
+ A list of the IssueComment protocol buffers for the description
+ and comments on this issue.
+ """
+ comments = self.GetComments(cnxn, issue_id=[issue_id])
+ for i, comment in enumerate(comments):
+ comment.sequence = i
+
+ return comments
+
+
+ def GetCommentsByID(self, cnxn, comment_ids, sequences, use_cache=True,
+ shard_id=None):
+ """Return all IssueComment PBs by comment ids.
+
+ Args:
+ cnxn: connection to SQL database.
+ comment_ids: a list of comment ids.
+ sequences: sequence of the comments.
+ use_cache: optional boolean to enable the cache.
+ shard_id: optional int shard_id to limit retrieval.
+
+ Returns:
+ A list of the IssueComment protocol buffers for comment_ids.
+ """
+ # Try loading issue comments from a random shard to reduce load on
+ # primary DB.
+ if shard_id is None:
+ shard_id = sql.RandomShardID()
+
+ comment_dict, _missed_comments = self.comment_2lc.GetAll(cnxn, comment_ids,
+ use_cache=use_cache, shard_id=shard_id)
+
+ comments = sorted(list(comment_dict.values()), key=lambda x: x.timestamp)
+
+ for i in range(len(comment_ids)):
+ comments[i].sequence = sequences[i]
+
+ return comments
+
+ # TODO(jrobbins): remove this method because it is too slow when an issue
+ # has a huge number of comments.
+ def GetCommentsForIssues(self, cnxn, issue_ids, content_only=False):
+ """Return all IssueComment PBs for each issue ID in the given list.
+
+ Args:
+ cnxn: connection to SQL database.
+ issue_ids: list of integer global issue IDs.
+ content_only: optional boolean, set true for faster loading of
+ comment content without attachments and amendments.
+
+ Returns:
+ Dict {issue_id: [IssueComment, ...]} with IssueComment protocol
+ buffers for the description and comments on each issue.
+ """
+ comments = self.GetComments(
+ cnxn, issue_id=issue_ids, content_only=content_only)
+
+ comments_dict = collections.defaultdict(list)
+ for comment in comments:
+ comment.sequence = len(comments_dict[comment.issue_id])
+ comments_dict[comment.issue_id].append(comment)
+
+ return comments_dict
+
+ def InsertComment(self, cnxn, comment, commit=True):
+ """Store the given issue comment in SQL.
+
+ Args:
+ cnxn: connection to SQL database.
+ comment: IssueComment PB to insert into the database.
+ commit: set to False to avoid doing the commit for now.
+ """
+ commentcontent_id = self.commentcontent_tbl.InsertRow(
+ cnxn, content=comment.content,
+ inbound_message=comment.inbound_message, commit=False)
+ comment_id = self.comment_tbl.InsertRow(
+ cnxn, issue_id=comment.issue_id, created=comment.timestamp,
+ project_id=comment.project_id,
+ commenter_id=comment.user_id,
+ deleted_by=comment.deleted_by or None,
+ is_spam=comment.is_spam, is_description=comment.is_description,
+ commentcontent_id=commentcontent_id,
+ commit=False)
+ comment.id = comment_id
+ if comment.importer_id:
+ self.commentimporter_tbl.InsertRow(
+ cnxn, comment_id=comment_id, importer_id=comment.importer_id)
+
+ amendment_rows = []
+ for amendment in comment.amendments:
+ field_enum = str(amendment.field).lower()
+ if (amendment.get_assigned_value('newvalue') is not None and
+ not amendment.added_user_ids and not amendment.removed_user_ids):
+ amendment_rows.append((
+ comment.issue_id, comment_id, field_enum,
+ amendment.oldvalue, amendment.newvalue,
+ None, None, amendment.custom_field_name))
+ for added_user_id in amendment.added_user_ids:
+ amendment_rows.append((
+ comment.issue_id, comment_id, field_enum, None, None,
+ added_user_id, None, amendment.custom_field_name))
+ for removed_user_id in amendment.removed_user_ids:
+ amendment_rows.append((
+ comment.issue_id, comment_id, field_enum, None, None,
+ None, removed_user_id, amendment.custom_field_name))
+ # ISSUEUPDATE_COLS[1:] to skip id column.
+ self.issueupdate_tbl.InsertRows(
+ cnxn, ISSUEUPDATE_COLS[1:], amendment_rows, commit=False)
+
+ attachment_rows = []
+ for attach in comment.attachments:
+ attachment_rows.append([
+ comment.issue_id, comment.id, attach.filename, attach.filesize,
+ attach.mimetype, attach.deleted, attach.gcs_object_id])
+ self.attachment_tbl.InsertRows(
+ cnxn, ATTACHMENT_COLS[1:], attachment_rows, commit=False)
+
+ if comment.approval_id:
+ self.issueapproval2comment_tbl.InsertRows(
+ cnxn, ISSUEAPPROVAL2COMMENT_COLS,
+ [(comment.approval_id, comment_id)], commit=False)
+
+ if commit:
+ cnxn.Commit()
+
+ def _UpdateComment(self, cnxn, comment, update_cols=None):
+ """Update the given issue comment in SQL.
+
+ Args:
+ cnxn: connection to SQL database.
+ comment: IssueComment PB to update in the database.
+ update_cols: optional list of just the field names to update.
+ """
+ delta = {
+ 'commenter_id': comment.user_id,
+ 'deleted_by': comment.deleted_by or None,
+ 'is_spam': comment.is_spam,
+ }
+ if update_cols is not None:
+ delta = {key: val for key, val in delta.items()
+ if key in update_cols}
+
+ self.comment_tbl.Update(cnxn, delta, id=comment.id)
+ self.comment_2lc.InvalidateKeys(cnxn, [comment.id])
+
+ def _MakeIssueComment(
+ self, project_id, user_id, content, inbound_message=None,
+ amendments=None, attachments=None, kept_attachments=None, timestamp=None,
+ is_spam=False, is_description=False, approval_id=None, importer_id=None):
+ """Create in IssueComment protocol buffer in RAM.
+
+ Args:
+ project_id: Project with the issue.
+ user_id: the user ID of the user who entered the comment.
+ content: string body of the comment.
+ inbound_message: optional string full text of an email that
+ caused this comment to be added.
+ amendments: list of Amendment PBs describing the
+ metadata changes that the user made along w/ comment.
+ attachments: [(filename, contents, mimetype),...] attachments uploaded at
+ the time the comment was made.
+ kept_attachments: list of Attachment PBs for attachments kept from
+ previous descriptions, if the comment is a description
+ timestamp: time at which the comment was made, defaults to now.
+ is_spam: True if the comment was classified as spam.
+ is_description: True if the comment is a description for the issue.
+ approval_id: id, if any, of the APPROVAL_TYPE FieldDef this comment
+ belongs to.
+ importer_id: optional User ID of script that imported the comment on
+ behalf of a user.
+
+ Returns:
+ The new IssueComment protocol buffer.
+
+ The content may have some markup done during input processing.
+
+ Any attachments are immediately stored.
+ """
+ comment = tracker_pb2.IssueComment()
+ comment.project_id = project_id
+ comment.user_id = user_id
+ comment.content = content or ''
+ comment.is_spam = is_spam
+ comment.is_description = is_description
+ if not timestamp:
+ timestamp = int(time.time())
+ comment.timestamp = int(timestamp)
+ if inbound_message:
+ comment.inbound_message = inbound_message
+ if amendments:
+ logging.info('amendments is %r', amendments)
+ comment.amendments.extend(amendments)
+ if approval_id:
+ comment.approval_id = approval_id
+
+ if attachments:
+ for filename, body, mimetype in attachments:
+ gcs_object_id = gcs_helpers.StoreObjectInGCS(
+ body, mimetype, project_id, filename=filename)
+ attach = tracker_pb2.Attachment()
+ # attachment id is determined later by the SQL DB.
+ attach.filename = filename
+ attach.filesize = len(body)
+ attach.mimetype = mimetype
+ attach.gcs_object_id = gcs_object_id
+ comment.attachments.extend([attach])
+ logging.info("Save attachment with object_id: %s" % gcs_object_id)
+
+ if kept_attachments:
+ for kept_attach in kept_attachments:
+ (filename, filesize, mimetype, deleted,
+ gcs_object_id) = kept_attach[3:]
+ new_attach = tracker_pb2.Attachment(
+ filename=filename, filesize=filesize, mimetype=mimetype,
+ deleted=bool(deleted), gcs_object_id=gcs_object_id)
+ comment.attachments.append(new_attach)
+ logging.info("Copy attachment with object_id: %s" % gcs_object_id)
+
+ if importer_id:
+ comment.importer_id = importer_id
+
+ return comment
+
+ def CreateIssueComment(
+ self, cnxn, issue, user_id, content, inbound_message=None,
+ amendments=None, attachments=None, kept_attachments=None, timestamp=None,
+ is_spam=False, is_description=False, approval_id=None, commit=True,
+ importer_id=None):
+ """Create and store a new comment on the specified issue.
+
+ Args:
+ cnxn: connection to SQL database.
+ issue: the issue on which to add the comment, must be loaded from
+ database with use_cache=False so that assume_stale == False.
+ user_id: the user ID of the user who entered the comment.
+ content: string body of the comment.
+ inbound_message: optional string full text of an email that caused
+ this comment to be added.
+ amendments: list of Amendment PBs describing the
+ metadata changes that the user made along w/ comment.
+ attachments: [(filename, contents, mimetype),...] attachments uploaded at
+ the time the comment was made.
+ kept_attachments: list of attachment ids for attachments kept from
+ previous descriptions, if the comment is an update to the description
+ timestamp: time at which the comment was made, defaults to now.
+ is_spam: True if the comment is classified as spam.
+ is_description: True if the comment is a description for the issue.
+ approval_id: id, if any, of the APPROVAL_TYPE FieldDef this comment
+ belongs to.
+ commit: set to False to not commit to DB yet.
+ importer_id: user ID of an API client that is importing issues.
+
+ Returns:
+ The new IssueComment protocol buffer.
+
+ Note that we assume that the content is safe to echo out
+ again. The content may have some markup done during input
+ processing.
+ """
+ if is_description:
+ kept_attachments = self.GetAttachmentsByID(cnxn, kept_attachments)
+ else:
+ kept_attachments = []
+
+ comment = self._MakeIssueComment(
+ issue.project_id, user_id, content, amendments=amendments,
+ inbound_message=inbound_message, attachments=attachments,
+ timestamp=timestamp, is_spam=is_spam, is_description=is_description,
+ kept_attachments=kept_attachments, approval_id=approval_id,
+ importer_id=importer_id)
+ comment.issue_id = issue.issue_id
+
+ if attachments or kept_attachments:
+ issue.attachment_count = (
+ issue.attachment_count + len(attachments) + len(kept_attachments))
+ self.UpdateIssue(cnxn, issue, update_cols=['attachment_count'])
+
+ self.comment_creations.increment()
+ self.InsertComment(cnxn, comment, commit=commit)
+
+ return comment
+
+ def SoftDeleteComment(
+ self, cnxn, issue, issue_comment, deleted_by_user_id,
+ user_service, delete=True, reindex=False, is_spam=False):
+ """Mark comment as un/deleted, which shows/hides it from average users."""
+ # Update number of attachments
+ attachments = 0
+ if issue_comment.attachments:
+ for attachment in issue_comment.attachments:
+ if not attachment.deleted:
+ attachments += 1
+
+ # Delete only if it's not in deleted state
+ if delete:
+ if not issue_comment.deleted_by:
+ issue_comment.deleted_by = deleted_by_user_id
+ issue.attachment_count = issue.attachment_count - attachments
+
+ # Undelete only if it's in deleted state
+ elif issue_comment.deleted_by:
+ issue_comment.deleted_by = 0
+ issue.attachment_count = issue.attachment_count + attachments
+
+ issue_comment.is_spam = is_spam
+ self._UpdateComment(
+ cnxn, issue_comment, update_cols=['deleted_by', 'is_spam'])
+ self.UpdateIssue(cnxn, issue, update_cols=['attachment_count'])
+
+ # Reindex the issue to take the comment deletion/undeletion into account.
+ if reindex:
+ tracker_fulltext.IndexIssues(
+ cnxn, [issue], user_service, self, self._config_service)
+ else:
+ self.EnqueueIssuesForIndexing(cnxn, [issue.issue_id])
+
+ ### Approvals
+
+ def GetIssueApproval(self, cnxn, issue_id, approval_id, use_cache=True):
+ """Retrieve the specified approval for the specified issue."""
+ issue = self.GetIssue(cnxn, issue_id, use_cache=use_cache)
+ approval = tracker_bizobj.FindApprovalValueByID(
+ approval_id, issue.approval_values)
+ if approval:
+ return issue, approval
+ raise exceptions.NoSuchIssueApprovalException()
+
+ def DeltaUpdateIssueApproval(
+ self, cnxn, modifier_id, config, issue, approval, approval_delta,
+ comment_content=None, is_description=False, attachments=None,
+ commit=True, kept_attachments=None):
+ """Update the issue's approval in the database."""
+ amendments = []
+
+ # Update status in RAM and DB and create status amendment.
+ if approval_delta.status:
+ approval.status = approval_delta.status
+ approval.set_on = approval_delta.set_on or int(time.time())
+ approval.setter_id = modifier_id
+ status_amendment = tracker_bizobj.MakeApprovalStatusAmendment(
+ approval_delta.status)
+ amendments.append(status_amendment)
+
+ self._UpdateIssueApprovalStatus(
+ cnxn, issue.issue_id, approval.approval_id, approval.status,
+ approval.setter_id, approval.set_on)
+
+ # Update approver_ids in RAM and DB and create approver amendment.
+ approvers_add = [approver for approver in approval_delta.approver_ids_add
+ if approver not in approval.approver_ids]
+ approvers_remove = [approver for approver in
+ approval_delta.approver_ids_remove
+ if approver in approval.approver_ids]
+ if approvers_add or approvers_remove:
+ approver_ids = [approver for approver in
+ list(approval.approver_ids) + approvers_add
+ if approver not in approvers_remove]
+ approval.approver_ids = approver_ids
+ approvers_amendment = tracker_bizobj.MakeApprovalApproversAmendment(
+ approvers_add, approvers_remove)
+ amendments.append(approvers_amendment)
+
+ self._UpdateIssueApprovalApprovers(
+ cnxn, issue.issue_id, approval.approval_id, approver_ids)
+
+ fv_amendments = tracker_bizobj.ApplyFieldValueChanges(
+ issue, config, approval_delta.subfield_vals_add,
+ approval_delta.subfield_vals_remove, approval_delta.subfields_clear)
+ amendments.extend(fv_amendments)
+ if fv_amendments:
+ self._UpdateIssuesFields(cnxn, [issue], commit=False)
+
+ label_amendment = tracker_bizobj.ApplyLabelChanges(
+ issue, config, approval_delta.labels_add, approval_delta.labels_remove)
+ if label_amendment:
+ amendments.append(label_amendment)
+ self._UpdateIssuesLabels(cnxn, [issue], commit=False)
+
+ comment_pb = self.CreateIssueComment(
+ cnxn, issue, modifier_id, comment_content, amendments=amendments,
+ approval_id=approval.approval_id, is_description=is_description,
+ attachments=attachments, commit=False,
+ kept_attachments=kept_attachments)
+
+ if commit:
+ cnxn.Commit()
+ self.issue_2lc.InvalidateKeys(cnxn, [issue.issue_id])
+
+ return comment_pb
+
+ def _UpdateIssueApprovalStatus(
+ self, cnxn, issue_id, approval_id, status, setter_id, set_on):
+ """Update the approvalvalue for the given issue_id's issue."""
+ set_on = set_on or int(time.time())
+ delta = {
+ 'status': status.name.lower(),
+ 'setter_id': setter_id,
+ 'set_on': set_on,
+ }
+ self.issue2approvalvalue_tbl.Update(
+ cnxn, delta, approval_id=approval_id, issue_id=issue_id,
+ commit=False)
+
+ def _UpdateIssueApprovalApprovers(
+ self, cnxn, issue_id, approval_id, approver_ids):
+ """Update the list of approvers allowed to approve an issue's approval."""
+ self.issueapproval2approver_tbl.Delete(
+ cnxn, issue_id=issue_id, approval_id=approval_id, commit=False)
+ self.issueapproval2approver_tbl.InsertRows(
+ cnxn, ISSUEAPPROVAL2APPROVER_COLS, [(approval_id, approver_id, issue_id)
+ for approver_id in approver_ids],
+ commit=False)
+
+ ### Attachments
+
+ def GetAttachmentAndContext(self, cnxn, attachment_id):
+ """Load a IssueAttachment from database, and its comment ID and IID.
+
+ Args:
+ cnxn: connection to SQL database.
+ attachment_id: long integer unique ID of desired issue attachment.
+
+ Returns:
+ An Attachment protocol buffer that contains metadata about the attached
+ file, or None if it doesn't exist. Also, the comment ID and issue IID
+ of the comment and issue that contain this attachment.
+
+ Raises:
+ NoSuchAttachmentException: the attachment was not found.
+ """
+ if attachment_id is None:
+ raise exceptions.NoSuchAttachmentException()
+
+ attachment_row = self.attachment_tbl.SelectRow(
+ cnxn, cols=ATTACHMENT_COLS, id=attachment_id)
+ if attachment_row:
+ (attach_id, issue_id, comment_id, filename, filesize, mimetype,
+ deleted, gcs_object_id) = attachment_row
+ if not deleted:
+ attachment = tracker_pb2.Attachment(
+ attachment_id=attach_id, filename=filename, filesize=filesize,
+ mimetype=mimetype, deleted=bool(deleted),
+ gcs_object_id=gcs_object_id)
+ return attachment, comment_id, issue_id
+
+ raise exceptions.NoSuchAttachmentException()
+
+ def GetAttachmentsByID(self, cnxn, attachment_ids):
+ """Return all Attachment PBs by attachment ids.
+
+ Args:
+ cnxn: connection to SQL database.
+ attachment_ids: a list of comment ids.
+
+ Returns:
+ A list of the Attachment protocol buffers for the attachments with
+ these ids.
+ """
+ attachment_rows = self.attachment_tbl.Select(
+ cnxn, cols=ATTACHMENT_COLS, id=attachment_ids)
+
+ return attachment_rows
+
+ def _UpdateAttachment(self, cnxn, comment, attach, update_cols=None):
+ """Update attachment metadata in the DB.
+
+ Args:
+ cnxn: connection to SQL database.
+ comment: IssueComment PB to invalidate in the cache.
+ attach: IssueAttachment PB to update in the DB.
+ update_cols: optional list of just the field names to update.
+ """
+ delta = {
+ 'filename': attach.filename,
+ 'filesize': attach.filesize,
+ 'mimetype': attach.mimetype,
+ 'deleted': bool(attach.deleted),
+ }
+ if update_cols is not None:
+ delta = {key: val for key, val in delta.items()
+ if key in update_cols}
+
+ self.attachment_tbl.Update(cnxn, delta, id=attach.attachment_id)
+ self.comment_2lc.InvalidateKeys(cnxn, [comment.id])
+
+ def SoftDeleteAttachment(
+ self, cnxn, issue, issue_comment, attach_id, user_service, delete=True,
+ index_now=False):
+ """Mark attachment as un/deleted, which shows/hides it from avg users."""
+ attachment = None
+ for attach in issue_comment.attachments:
+ if attach.attachment_id == attach_id:
+ attachment = attach
+
+ if not attachment:
+ logging.warning(
+ 'Tried to (un)delete non-existent attachment #%s in project '
+ '%s issue %s', attach_id, issue.project_id, issue.local_id)
+ return
+
+ if not issue_comment.deleted_by:
+ # Decrement attachment count only if it's not in deleted state
+ if delete:
+ if not attachment.deleted:
+ issue.attachment_count = issue.attachment_count - 1
+
+ # Increment attachment count only if it's in deleted state
+ elif attachment.deleted:
+ issue.attachment_count = issue.attachment_count + 1
+
+ logging.info('attachment.deleted was %s', attachment.deleted)
+
+ attachment.deleted = delete
+
+ logging.info('attachment.deleted is %s', attachment.deleted)
+
+ self._UpdateAttachment(
+ cnxn, issue_comment, attachment, update_cols=['deleted'])
+ self.UpdateIssue(cnxn, issue, update_cols=['attachment_count'])
+
+ if index_now:
+ tracker_fulltext.IndexIssues(
+ cnxn, [issue], user_service, self, self._config_service)
+ else:
+ self.EnqueueIssuesForIndexing(cnxn, [issue.issue_id])
+
+ ### Reindex queue
+
+ def EnqueueIssuesForIndexing(self, cnxn, issue_ids, commit=True):
+ # type: (MonorailConnection, Collection[int], Optional[bool]) -> None
+ """Add the given issue IDs to the ReindexQueue table."""
+ reindex_rows = [(issue_id,) for issue_id in issue_ids]
+ self.reindexqueue_tbl.InsertRows(
+ cnxn, ['issue_id'], reindex_rows, ignore=True, commit=commit)
+
+ def ReindexIssues(self, cnxn, num_to_reindex, user_service):
+ """Reindex some issues specified in the IndexQueue table."""
+ rows = self.reindexqueue_tbl.Select(
+ cnxn, order_by=[('created', [])], limit=num_to_reindex)
+ issue_ids = [row[0] for row in rows]
+
+ if issue_ids:
+ issues = self.GetIssues(cnxn, issue_ids)
+ tracker_fulltext.IndexIssues(
+ cnxn, issues, user_service, self, self._config_service)
+ self.reindexqueue_tbl.Delete(cnxn, issue_id=issue_ids)
+
+ return len(issue_ids)
+
+ ### Search functions
+
+ def RunIssueQuery(
+ self, cnxn, left_joins, where, order_by, shard_id=None, limit=None):
+ """Run a SQL query to find matching issue IDs.
+
+ Args:
+ cnxn: connection to SQL database.
+ left_joins: list of SQL LEFT JOIN clauses.
+ where: list of SQL WHERE clauses.
+ order_by: list of SQL ORDER BY clauses.
+ shard_id: int shard ID to focus the search.
+ limit: int maximum number of results, defaults to
+ settings.search_limit_per_shard.
+
+ Returns:
+ (issue_ids, capped) where issue_ids is a list of the result issue IDs,
+ and capped is True if the number of results reached the limit.
+ """
+ limit = limit or settings.search_limit_per_shard
+ where = where + [('Issue.deleted = %s', [False])]
+ rows = self.issue_tbl.Select(
+ cnxn, shard_id=shard_id, distinct=True, cols=['Issue.id'],
+ left_joins=left_joins, where=where, order_by=order_by,
+ limit=limit)
+ issue_ids = [row[0] for row in rows]
+ capped = len(issue_ids) >= limit
+ return issue_ids, capped
+
+ def GetIIDsByLabelIDs(self, cnxn, label_ids, project_id, shard_id):
+ """Return a list of IIDs for issues with any of the given label IDs."""
+ if not label_ids:
+ return []
+ where = []
+ if shard_id is not None:
+ slice_term = ('shard = %s', [shard_id])
+ where.append(slice_term)
+
+ rows = self.issue_tbl.Select(
+ cnxn, shard_id=shard_id, cols=['id'],
+ left_joins=[('Issue2Label ON Issue.id = Issue2Label.issue_id', [])],
+ label_id=label_ids, project_id=project_id, where=where)
+ return [row[0] for row in rows]
+
+ def GetIIDsByParticipant(self, cnxn, user_ids, project_ids, shard_id):
+ """Return IIDs for issues where any of the given users participate."""
+ iids = []
+ where = []
+ if shard_id is not None:
+ where.append(('shard = %s', [shard_id]))
+ if project_ids:
+ cond_str = 'Issue.project_id IN (%s)' % sql.PlaceHolders(project_ids)
+ where.append((cond_str, project_ids))
+
+ # TODO(jrobbins): Combine these 3 queries into one with ORs. It currently
+ # is not the bottleneck.
+ rows = self.issue_tbl.Select(
+ cnxn, cols=['id'], reporter_id=user_ids,
+ where=where, shard_id=shard_id)
+ for row in rows:
+ iids.append(row[0])
+
+ rows = self.issue_tbl.Select(
+ cnxn, cols=['id'], owner_id=user_ids,
+ where=where, shard_id=shard_id)
+ for row in rows:
+ iids.append(row[0])
+
+ rows = self.issue_tbl.Select(
+ cnxn, cols=['id'], derived_owner_id=user_ids,
+ where=where, shard_id=shard_id)
+ for row in rows:
+ iids.append(row[0])
+
+ rows = self.issue_tbl.Select(
+ cnxn, cols=['id'],
+ left_joins=[('Issue2Cc ON Issue2Cc.issue_id = Issue.id', [])],
+ cc_id=user_ids,
+ where=where + [('cc_id IS NOT NULL', [])],
+ shard_id=shard_id)
+ for row in rows:
+ iids.append(row[0])
+
+ rows = self.issue_tbl.Select(
+ cnxn, cols=['Issue.id'],
+ left_joins=[
+ ('Issue2FieldValue ON Issue.id = Issue2FieldValue.issue_id', []),
+ ('FieldDef ON Issue2FieldValue.field_id = FieldDef.id', [])],
+ user_id=user_ids, grants_perm='View',
+ where=where + [('user_id IS NOT NULL', [])],
+ shard_id=shard_id)
+ for row in rows:
+ iids.append(row[0])
+
+ return iids
+
+ ### Issue Dependency Rankings
+
+ def SortBlockedOn(self, cnxn, issue, blocked_on_iids):
+ """Sort blocked_on dependencies by rank and dst_issue_id.
+
+ Args:
+ cnxn: connection to SQL database.
+ issue: the issue being blocked.
+ blocked_on_iids: the iids of all the issue's blockers
+
+ Returns:
+ a tuple (ids, ranks), where ids is the sorted list of
+ blocked_on_iids and ranks is the list of corresponding ranks
+ """
+ rows = self.issuerelation_tbl.Select(
+ cnxn, cols=ISSUERELATION_COLS, issue_id=issue.issue_id,
+ dst_issue_id=blocked_on_iids, kind='blockedon',
+ order_by=[('rank DESC', []), ('dst_issue_id', [])])
+ ids = [row[1] for row in rows]
+ ids.extend([iid for iid in blocked_on_iids if iid not in ids])
+ ranks = [row[3] for row in rows]
+ ranks.extend([0] * (len(blocked_on_iids) - len(ranks)))
+ return ids, ranks
+
+ def ApplyIssueRerank(
+ self, cnxn, parent_id, relations_to_change, commit=True, invalidate=True):
+ """Updates rankings of blocked on issue relations to new values
+
+ Args:
+ cnxn: connection to SQL database.
+ parent_id: the global ID of the blocked issue to update
+ relations_to_change: This should be a list of
+ [(blocker_id, new_rank),...] of relations that need to be changed
+ commit: set to False to skip the DB commit and do it in the caller.
+ invalidate: set to False to leave cache invalidatation to the caller.
+ """
+ blocker_ids = [blocker for (blocker, rank) in relations_to_change]
+ self.issuerelation_tbl.Delete(
+ cnxn, issue_id=parent_id, dst_issue_id=blocker_ids, commit=False)
+ insert_rows = [(parent_id, blocker, 'blockedon', rank)
+ for (blocker, rank) in relations_to_change]
+ self.issuerelation_tbl.InsertRows(
+ cnxn, cols=ISSUERELATION_COLS, row_values=insert_rows, commit=commit)
+ if invalidate:
+ self.InvalidateIIDs(cnxn, [parent_id])
+
+ # Expunge Users from Issues system.
+ def ExpungeUsersInIssues(self, cnxn, user_ids_by_email, limit=None):
+ """Removes all references to given users from issue DB tables.
+
+ This method will not commit the operations. This method will
+ not make changes to in-memory data.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_ids_by_email: dict of {email: user_id} of all users we want
+ to expunge.
+ limit: Optional, the limit for each operation.
+
+ Returns:
+ A list of issue_ids that need to be reindexed.
+ """
+ commit = False
+ user_ids = list(user_ids_by_email.values())
+ user_emails = list(user_ids_by_email.keys())
+ # Track issue_ids for issues that will have different search documents
+ # as a result of removing users.
+ affected_issue_ids = []
+
+ # Reassign commenter_id and delete inbound_messages.
+ shard_id = sql.RandomShardID()
+ comment_content_id_rows = self.comment_tbl.Select(
+ cnxn, cols=['Comment.id', 'Comment.issue_id', 'commentcontent_id'],
+ commenter_id=user_ids, shard_id=shard_id, limit=limit)
+ comment_ids = [row[0] for row in comment_content_id_rows]
+ commentcontent_ids = [row[2] for row in comment_content_id_rows]
+ if commentcontent_ids:
+ self.commentcontent_tbl.Update(
+ cnxn, {'inbound_message': None}, id=commentcontent_ids, commit=commit)
+ if comment_ids:
+ self.comment_tbl.Update(
+ cnxn, {'commenter_id': framework_constants.DELETED_USER_ID},
+ id=comment_ids,
+ commit=commit)
+ affected_issue_ids.extend([row[1] for row in comment_content_id_rows])
+
+ # Reassign deleted_by comments deleted_by.
+ self.comment_tbl.Update(
+ cnxn,
+ {'deleted_by': framework_constants.DELETED_USER_ID},
+ deleted_by=user_ids,
+ commit=commit, limit=limit)
+
+ # Remove users in field values.
+ fv_issue_id_rows = self.issue2fieldvalue_tbl.Select(
+ cnxn, cols=['issue_id'], user_id=user_ids, limit=limit)
+ fv_issue_ids = [row[0] for row in fv_issue_id_rows]
+ self.issue2fieldvalue_tbl.Delete(
+ cnxn, user_id=user_ids, limit=limit, commit=commit)
+ affected_issue_ids.extend(fv_issue_ids)
+
+ # Remove users in approval values.
+ self.issueapproval2approver_tbl.Delete(
+ cnxn, approver_id=user_ids, commit=commit, limit=limit)
+ self.issue2approvalvalue_tbl.Update(
+ cnxn,
+ {'setter_id': framework_constants.DELETED_USER_ID},
+ setter_id=user_ids,
+ commit=commit, limit=limit)
+
+ # Remove users in issue Ccs.
+ cc_issue_id_rows = self.issue2cc_tbl.Select(
+ cnxn, cols=['issue_id'], cc_id=user_ids, limit=limit)
+ cc_issue_ids = [row[0] for row in cc_issue_id_rows]
+ self.issue2cc_tbl.Delete(
+ cnxn, cc_id=user_ids, limit=limit, commit=commit)
+ affected_issue_ids.extend(cc_issue_ids)
+
+ # Remove users in issue owners.
+ owner_issue_id_rows = self.issue_tbl.Select(
+ cnxn, cols=['id'], owner_id=user_ids, limit=limit)
+ owner_issue_ids = [row[0] for row in owner_issue_id_rows]
+ if owner_issue_ids:
+ self.issue_tbl.Update(
+ cnxn, {'owner_id': None}, id=owner_issue_ids, commit=commit)
+ affected_issue_ids.extend(owner_issue_ids)
+ derived_owner_issue_id_rows = self.issue_tbl.Select(
+ cnxn, cols=['id'], derived_owner_id=user_ids, limit=limit)
+ derived_owner_issue_ids = [row[0] for row in derived_owner_issue_id_rows]
+ if derived_owner_issue_ids:
+ self.issue_tbl.Update(
+ cnxn, {'derived_owner_id': None},
+ id=derived_owner_issue_ids,
+ commit=commit)
+ affected_issue_ids.extend(derived_owner_issue_ids)
+
+ # Remove users in issue reporters.
+ reporter_issue_id_rows = self.issue_tbl.Select(
+ cnxn, cols=['id'], reporter_id=user_ids, limit=limit)
+ reporter_issue_ids = [row[0] for row in reporter_issue_id_rows]
+ if reporter_issue_ids:
+ self.issue_tbl.Update(
+ cnxn, {'reporter_id': framework_constants.DELETED_USER_ID},
+ id=reporter_issue_ids,
+ commit=commit)
+ affected_issue_ids.extend(reporter_issue_ids)
+
+ # Note: issueupdate_tbl's and issue2notify's user_id columns do not
+ # reference the User table. So all values need to updated here before
+ # User rows can be deleted safely. No limit will be applied.
+
+ # Remove users in issue updates.
+ self.issueupdate_tbl.Update(
+ cnxn,
+ {'added_user_id': framework_constants.DELETED_USER_ID},
+ added_user_id=user_ids,
+ commit=commit)
+ self.issueupdate_tbl.Update(
+ cnxn,
+ {'removed_user_id': framework_constants.DELETED_USER_ID},
+ removed_user_id=user_ids,
+ commit=commit)
+
+ # Remove users in issue notify.
+ self.issue2notify_tbl.Delete(
+ cnxn, email=user_emails, commit=commit)
+
+ # Remove users in issue snapshots.
+ self.issuesnapshot_tbl.Update(
+ cnxn,
+ {'owner_id': framework_constants.DELETED_USER_ID},
+ owner_id=user_ids,
+ commit=commit, limit=limit)
+ self.issuesnapshot_tbl.Update(
+ cnxn,
+ {'reporter_id': framework_constants.DELETED_USER_ID},
+ reporter_id=user_ids,
+ commit=commit, limit=limit)
+ self.issuesnapshot2cc_tbl.Delete(
+ cnxn, cc_id=user_ids, commit=commit, limit=limit)
+
+ return list(set(affected_issue_ids))
diff --git a/services/ml_helpers.py b/services/ml_helpers.py
new file mode 100644
index 0000000..c4650b4
--- /dev/null
+++ b/services/ml_helpers.py
@@ -0,0 +1,181 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""
+Helper functions for spam and component classification. These are mostly for
+feature extraction, so that the serving code and training code both use the same
+set of features.
+"""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import csv
+import hashlib
+import httplib2
+import logging
+import re
+import sys
+
+from six import text_type
+
+from apiclient.discovery import build
+from apiclient.errors import Error as ApiClientError
+from oauth2client.client import GoogleCredentials
+from oauth2client.client import Error as Oauth2ClientError
+
+
+SPAM_COLUMNS = ['verdict', 'subject', 'content', 'email']
+LEGACY_CSV_COLUMNS = ['verdict', 'subject', 'content']
+DELIMITERS = ['\s', '\,', '\.', '\?', '!', '\:', '\(', '\)']
+
+# Must be identical to settings.spam_feature_hashes.
+SPAM_FEATURE_HASHES = 500
+# Must be identical to settings.component_features.
+COMPONENT_FEATURES = 5000
+
+
+def _ComponentFeatures(content, num_features, top_words):
+ """
+ This uses the most common words in the entire dataset as features.
+ The count of common words in the issue comments makes up the features.
+ """
+
+ features = [0] * num_features
+ for blob in content:
+ words = blob.split()
+ for word in words:
+ if word in top_words:
+ features[top_words[word]] += 1
+
+ return features
+
+
+def _SpamHashFeatures(content, num_features):
+ """
+ Feature hashing is a fast and compact way to turn a string of text into a
+ vector of feature values for classification and training.
+ See also: https://en.wikipedia.org/wiki/Feature_hashing
+ This is a simple implementation that doesn't try to minimize collisions
+ or anything else fancy.
+ """
+ features = [0] * num_features
+ total = 0.0
+ for blob in content:
+ words = re.split('|'.join(DELIMITERS), blob)
+ for word in words:
+ encoded_word = word
+ # If we've been passed real unicode strings, convert them to bytestrings.
+ if isinstance(word, text_type):
+ encoded_word = word.encode('utf-8')
+ feature_index = int(
+ int(hashlib.sha1(encoded_word).hexdigest(), 16) % num_features)
+ features[feature_index] += 1.0
+ total += 1.0
+
+ if total > 0:
+ features = [ f / total for f in features ]
+
+ return features
+
+
+def GenerateFeaturesRaw(content, num_features, top_words=None):
+ """Generates a vector of features for a given issue or comment.
+
+ Args:
+ content: The content of the issue's description and comments.
+ num_features: The number of features to generate.
+ """
+ if top_words:
+ return { 'word_features': _ComponentFeatures(content,
+ num_features,
+ top_words)}
+
+ return { 'word_hashes': _SpamHashFeatures(content, num_features)}
+
+
+def transform_spam_csv_to_features(csv_training_data):
+ X = []
+ y = []
+
+ # Handle if the list is double-wrapped.
+ if csv_training_data and len(csv_training_data[0]) > 4:
+ csv_training_data = csv_training_data[0]
+
+ for row in csv_training_data:
+ if len(row) == 4:
+ verdict, subject, content, _email = row
+ else:
+ verdict, subject, content = row
+ X.append(GenerateFeaturesRaw([str(subject), str(content)],
+ SPAM_FEATURE_HASHES))
+ y.append(1 if verdict == 'spam' else 0)
+ return X, y
+
+
+def transform_component_csv_to_features(csv_training_data, top_list):
+ X = []
+ y = []
+ top_words = {}
+
+ for i in range(len(top_list)):
+ top_words[top_list[i]] = i
+
+ component_to_index = {}
+ index_to_component = {}
+ component_index = 0
+
+ for row in csv_training_data:
+ component, content = row
+ component = str(component).split(",")[0]
+
+ if component not in component_to_index:
+ component_to_index[component] = component_index
+ index_to_component[component_index] = component
+ component_index += 1
+
+ X.append(GenerateFeaturesRaw([content],
+ COMPONENT_FEATURES,
+ top_words))
+ y.append(component_to_index[component])
+
+ return X, y, index_to_component
+
+
+def spam_from_file(f):
+ """Reads a training data file and returns an array."""
+ rows = []
+ skipped_rows = 0
+ for row in csv.reader(f):
+ if len(row) == len(SPAM_COLUMNS):
+ # Throw out email field.
+ rows.append(row[:3])
+ elif len(row) == len(LEGACY_CSV_COLUMNS):
+ rows.append(row)
+ else:
+ skipped_rows += 1
+ return rows, skipped_rows
+
+
+def component_from_file(f):
+ """Reads a training data file and returns an array."""
+ rows = []
+ csv.field_size_limit(sys.maxsize)
+ for row in csv.reader(f):
+ rows.append(row)
+
+ return rows
+
+
+def setup_ml_engine():
+ """Sets up an instance of ml engine for ml classes."""
+ try:
+ credentials = GoogleCredentials.get_application_default()
+ ml_engine = build('ml', 'v1', http=httplib2.Http(), credentials=credentials)
+ return ml_engine
+
+ except (Oauth2ClientError, ApiClientError):
+ logging.error("Error setting up ML Engine API: %s" % sys.exc_info()[0])
diff --git a/services/project_svc.py b/services/project_svc.py
new file mode 100644
index 0000000..e92f6a9
--- /dev/null
+++ b/services/project_svc.py
@@ -0,0 +1,799 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of functions that provide persistence for projects.
+
+This module provides functions to get, update, create, and (in some
+cases) delete each type of project business object. It provides
+a logical persistence layer on top of the database.
+
+Business objects are described in project_pb2.py.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import logging
+import time
+
+import settings
+from framework import exceptions
+from framework import framework_constants
+from framework import framework_helpers
+from framework import permissions
+from framework import sql
+from services import caches
+from project import project_helpers
+from proto import project_pb2
+
+
+PROJECT_TABLE_NAME = 'Project'
+USER2PROJECT_TABLE_NAME = 'User2Project'
+EXTRAPERM_TABLE_NAME = 'ExtraPerm'
+MEMBERNOTES_TABLE_NAME = 'MemberNotes'
+USERGROUPPROJECTS_TABLE_NAME = 'Group2Project'
+AUTOCOMPLETEEXCLUSION_TABLE_NAME = 'AutocompleteExclusion'
+
+PROJECT_COLS = [
+ 'project_id', 'project_name', 'summary', 'description', 'state', 'access',
+ 'read_only_reason', 'state_reason', 'delete_time', 'issue_notify_address',
+ 'attachment_bytes_used', 'attachment_quota', 'cached_content_timestamp',
+ 'recent_activity_timestamp', 'moved_to', 'process_inbound_email',
+ 'only_owners_remove_restrictions', 'only_owners_see_contributors',
+ 'revision_url_format', 'home_page', 'docs_url', 'source_url', 'logo_gcs_id',
+ 'logo_file_name', 'issue_notify_always_detailed'
+]
+USER2PROJECT_COLS = ['project_id', 'user_id', 'role_name']
+EXTRAPERM_COLS = ['project_id', 'user_id', 'perm']
+MEMBERNOTES_COLS = ['project_id', 'user_id', 'notes']
+AUTOCOMPLETEEXCLUSION_COLS = [
+ 'project_id', 'user_id', 'ac_exclude', 'no_expand']
+
+RECENT_ACTIVITY_THRESHOLD = framework_constants.SECS_PER_HOUR
+
+
+class ProjectTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage both RAM and memcache for Project PBs."""
+
+ def __init__(self, cachemanager, project_service):
+ super(ProjectTwoLevelCache, self).__init__(
+ cachemanager, 'project', 'project:', project_pb2.Project)
+ self.project_service = project_service
+
+ def _DeserializeProjects(
+ self, project_rows, role_rows, extraperm_rows):
+ """Convert database rows into a dictionary of Project PB keyed by ID."""
+ project_dict = {}
+
+ for project_row in project_rows:
+ (
+ project_id, project_name, summary, description, state_name,
+ access_name, read_only_reason, state_reason, delete_time,
+ issue_notify_address, attachment_bytes_used, attachment_quota, cct,
+ recent_activity_timestamp, moved_to, process_inbound_email, oorr,
+ oosc, revision_url_format, home_page, docs_url, source_url,
+ logo_gcs_id, logo_file_name,
+ issue_notify_always_detailed) = project_row
+ project = project_pb2.Project()
+ project.project_id = project_id
+ project.project_name = project_name
+ project.summary = summary
+ project.description = description
+ project.state = project_pb2.ProjectState(state_name.upper())
+ project.state_reason = state_reason or ''
+ project.access = project_pb2.ProjectAccess(access_name.upper())
+ project.read_only_reason = read_only_reason or ''
+ project.issue_notify_address = issue_notify_address or ''
+ project.attachment_bytes_used = attachment_bytes_used or 0
+ project.attachment_quota = attachment_quota
+ project.recent_activity = recent_activity_timestamp or 0
+ project.cached_content_timestamp = cct or 0
+ project.delete_time = delete_time or 0
+ project.moved_to = moved_to or ''
+ project.process_inbound_email = bool(process_inbound_email)
+ project.only_owners_remove_restrictions = bool(oorr)
+ project.only_owners_see_contributors = bool(oosc)
+ project.revision_url_format = revision_url_format or ''
+ project.home_page = home_page or ''
+ project.docs_url = docs_url or ''
+ project.source_url = source_url or ''
+ project.logo_gcs_id = logo_gcs_id or ''
+ project.logo_file_name = logo_file_name or ''
+ project.issue_notify_always_detailed = bool(issue_notify_always_detailed)
+ project_dict[project_id] = project
+
+ for project_id, user_id, role_name in role_rows:
+ project = project_dict[project_id]
+ if role_name == 'owner':
+ project.owner_ids.append(user_id)
+ elif role_name == 'committer':
+ project.committer_ids.append(user_id)
+ elif role_name == 'contributor':
+ project.contributor_ids.append(user_id)
+
+ perms = {}
+ for project_id, user_id, perm in extraperm_rows:
+ perms.setdefault(project_id, {}).setdefault(user_id, []).append(perm)
+
+ for project_id, perms_by_user in perms.items():
+ project = project_dict[project_id]
+ for user_id, extra_perms in sorted(perms_by_user.items()):
+ project.extra_perms.append(project_pb2.Project.ExtraPerms(
+ member_id=user_id, perms=extra_perms))
+
+ return project_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database to get missing projects."""
+ project_rows = self.project_service.project_tbl.Select(
+ cnxn, cols=PROJECT_COLS, project_id=keys)
+ role_rows = self.project_service.user2project_tbl.Select(
+ cnxn, cols=['project_id', 'user_id', 'role_name'],
+ project_id=keys)
+ extraperm_rows = self.project_service.extraperm_tbl.Select(
+ cnxn, cols=EXTRAPERM_COLS, project_id=keys)
+ retrieved_dict = self._DeserializeProjects(
+ project_rows, role_rows, extraperm_rows)
+ return retrieved_dict
+
+
+class UserToProjectIdTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage both RAM and memcache for project_ids.
+
+ Keys for this cache are int, user_ids, which might correspond to a group.
+ This cache should be used to fetch a set of project_ids that the user_id
+ is a member of.
+ """
+
+ def __init__(self, cachemanager, project_service):
+ # type: cachemanager_svc.CacheManager, ProjectService -> None
+ super(UserToProjectIdTwoLevelCache, self).__init__(
+ cachemanager, 'project_id', 'project_id:', pb_class=None)
+ self.project_service = project_service
+
+ # Store the last time the table was fetched for rate limit purposes.
+ self.last_fetched = 0
+
+ def FetchItems(self, cnxn, keys):
+ # type MonorailConnection, Collection[int] -> Mapping[int, Collection[int]]
+ """On RAM and memcache miss, hit the database to get missing user_ids."""
+
+ # Unlike with other caches, we fetch and store the entire table.
+ # Thus, for cache misses we limit the rate we re-fetch the table to 60s.
+ now = self._GetCurrentTime()
+ result_dict = collections.defaultdict(set)
+
+ if (now - self.last_fetched) > 60:
+ project_to_user_rows = self.project_service.user2project_tbl.Select(
+ cnxn, cols=['project_id', 'user_id'])
+ self.last_fetched = now
+ # Cache the whole User2Project table.
+ for project_id, user_id in project_to_user_rows:
+ result_dict[user_id].add(project_id)
+
+ # Assume any requested user missing from result is not in any project.
+ result_dict.update(
+ (user_id, set()) for user_id in keys if user_id not in result_dict)
+
+ return result_dict
+
+ def _GetCurrentTime(self):
+ """ Returns the current time. We made a separate method for this to make it
+ easier to unit test. This was a better solution than @mock.patch because
+ the test had several unrelated time.time() calls. Modifying those calls
+ would be more onerous, having to fix calls for this test.
+ """
+ return time.time()
+
+
+class ProjectService(object):
+ """The persistence layer for project data."""
+
+ def __init__(self, cache_manager):
+ """Initialize this module so that it is ready to use.
+
+ Args:
+ cache_manager: local cache with distributed invalidation.
+ """
+ self.project_tbl = sql.SQLTableManager(PROJECT_TABLE_NAME)
+ self.user2project_tbl = sql.SQLTableManager(USER2PROJECT_TABLE_NAME)
+ self.extraperm_tbl = sql.SQLTableManager(EXTRAPERM_TABLE_NAME)
+ self.membernotes_tbl = sql.SQLTableManager(MEMBERNOTES_TABLE_NAME)
+ self.usergroupprojects_tbl = sql.SQLTableManager(
+ USERGROUPPROJECTS_TABLE_NAME)
+ self.acexclusion_tbl = sql.SQLTableManager(
+ AUTOCOMPLETEEXCLUSION_TABLE_NAME)
+
+ # Like a dictionary {project_id: project}
+ self.project_2lc = ProjectTwoLevelCache(cache_manager, self)
+ # A dictionary of user_id to a set of project ids.
+ # Mapping[int, Collection[int]]
+ self.user_to_project_2lc = UserToProjectIdTwoLevelCache(cache_manager, self)
+
+ # The project name to ID cache can never be invalidated by individual
+ # project changes because it is keyed by strings instead of ints. In
+ # the case of rare operations like deleting a project (or a future
+ # project renaming feature), we just InvalidateAll().
+ self.project_names_to_ids = caches.RamCache(cache_manager, 'project')
+
+ ### Creating projects
+
+ def CreateProject(
+ self, cnxn, project_name, owner_ids, committer_ids, contributor_ids,
+ summary, description, state=project_pb2.ProjectState.LIVE,
+ access=None, read_only_reason=None, home_page=None, docs_url=None,
+ source_url=None, logo_gcs_id=None, logo_file_name=None):
+ """Create and store a Project with the given attributes.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_name: a valid project name, all lower case.
+ owner_ids: a list of user IDs for the project owners.
+ committer_ids: a list of user IDs for the project members.
+ contributor_ids: a list of user IDs for the project contributors.
+ summary: one-line explanation of the project.
+ description: one-page explanation of the project.
+ state: a project state enum defined in project_pb2.
+ access: optional project access enum defined in project.proto.
+ read_only_reason: if given, provides a status message and marks
+ the project as read-only.
+ home_page: home page of the project
+ docs_url: url to redirect to for wiki/documentation links
+ source_url: url to redirect to for source browser links
+ logo_gcs_id: google storage object id of the project's logo
+ logo_file_name: uploaded file name of the project's logo
+
+ Returns:
+ The int project_id of the new project.
+
+ Raises:
+ ProjectAlreadyExists: if a project with that name already exists.
+ """
+ assert project_helpers.IsValidProjectName(project_name)
+ if self.LookupProjectIDs(cnxn, [project_name]):
+ raise exceptions.ProjectAlreadyExists()
+
+ project = project_pb2.MakeProject(
+ project_name, state=state, access=access,
+ description=description, summary=summary,
+ owner_ids=owner_ids, committer_ids=committer_ids,
+ contributor_ids=contributor_ids, read_only_reason=read_only_reason,
+ home_page=home_page, docs_url=docs_url, source_url=source_url,
+ logo_gcs_id=logo_gcs_id, logo_file_name=logo_file_name)
+
+ project.project_id = self._InsertProject(cnxn, project)
+ return project.project_id
+
+ def _InsertProject(self, cnxn, project):
+ """Insert the given project into the database."""
+ # Note: project_id is not specified because it is auto_increment.
+ project_id = self.project_tbl.InsertRow(
+ cnxn, project_name=project.project_name,
+ summary=project.summary, description=project.description,
+ state=str(project.state), access=str(project.access),
+ home_page=project.home_page, docs_url=project.docs_url,
+ source_url=project.source_url,
+ logo_gcs_id=project.logo_gcs_id, logo_file_name=project.logo_file_name)
+ logging.info('stored project was given project_id %d', project_id)
+
+ self.user2project_tbl.InsertRows(
+ cnxn, ['project_id', 'user_id', 'role_name'],
+ [(project_id, user_id, 'owner')
+ for user_id in project.owner_ids] +
+ [(project_id, user_id, 'committer')
+ for user_id in project.committer_ids] +
+ [(project_id, user_id, 'contributor')
+ for user_id in project.contributor_ids])
+
+ return project_id
+
+ ### Lookup project names and IDs
+
+ def LookupProjectIDs(self, cnxn, project_names):
+ """Return a list of project IDs for the specified projects."""
+ id_dict, missed_names = self.project_names_to_ids.GetAll(project_names)
+ if missed_names:
+ rows = self.project_tbl.Select(
+ cnxn, cols=['project_name', 'project_id'], project_name=missed_names)
+ retrieved_dict = dict(rows)
+ self.project_names_to_ids.CacheAll(retrieved_dict)
+ id_dict.update(retrieved_dict)
+
+ return id_dict
+
+ def LookupProjectNames(self, cnxn, project_ids):
+ """Lookup the names of the projects with the given IDs."""
+ projects_dict = self.GetProjects(cnxn, project_ids)
+ return {p.project_id: p.project_name
+ for p in projects_dict.values()}
+
+ ### Retrieving projects
+
+ def GetAllProjects(self, cnxn, use_cache=True):
+ """Return A dict mapping IDs to all live project PBs."""
+ project_rows = self.project_tbl.Select(
+ cnxn, cols=['project_id'], state=project_pb2.ProjectState.LIVE)
+ project_ids = [row[0] for row in project_rows]
+ projects_dict = self.GetProjects(cnxn, project_ids, use_cache=use_cache)
+
+ return projects_dict
+
+ def GetVisibleLiveProjects(
+ self, cnxn, logged_in_user, effective_ids, domain=None, use_cache=True):
+ """Return all user visible live project ids.
+
+ Args:
+ cnxn: connection to SQL database.
+ logged_in_user: protocol buffer of the logged in user. Can be None.
+ effective_ids: set of user IDs for this user. Can be None.
+ domain: optional string with HTTP request hostname.
+ use_cache: pass False to force database query to find Project protocol
+ buffers.
+
+ Returns:
+ A list of project ids of user visible live projects sorted by the names
+ of the projects. If host was provided, only projects with that host
+ as their branded domain will be returned.
+ """
+ project_rows = self.project_tbl.Select(
+ cnxn, cols=['project_id'], state=project_pb2.ProjectState.LIVE)
+ project_ids = [row[0] for row in project_rows]
+ projects_dict = self.GetProjects(cnxn, project_ids, use_cache=use_cache)
+ projects_on_host = {
+ project_id: project for project_id, project in projects_dict.items()
+ if not framework_helpers.GetNeededDomain(project.project_name, domain)}
+ visible_projects = []
+ for project in projects_on_host.values():
+ if permissions.UserCanViewProject(logged_in_user, effective_ids, project):
+ visible_projects.append(project)
+ visible_projects.sort(key=lambda p: p.project_name)
+
+ return [project.project_id for project in visible_projects]
+
+ def GetProjects(self, cnxn, project_ids, use_cache=True):
+ """Load all the Project PBs for the given projects.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_ids: list of int project IDs
+ use_cache: pass False to force database query.
+
+ Returns:
+ A dict mapping IDs to the corresponding Project protocol buffers.
+
+ Raises:
+ NoSuchProjectException: if any of the projects was not found.
+ """
+ project_dict, missed_ids = self.project_2lc.GetAll(
+ cnxn, project_ids, use_cache=use_cache)
+
+ # Also, update the project name cache.
+ self.project_names_to_ids.CacheAll(
+ {p.project_name: p.project_id for p in project_dict.values()})
+
+ if missed_ids:
+ raise exceptions.NoSuchProjectException()
+
+ return project_dict
+
+ def GetProject(self, cnxn, project_id, use_cache=True):
+ """Load the specified project from the database."""
+ project_id_dict = self.GetProjects(cnxn, [project_id], use_cache=use_cache)
+ return project_id_dict[project_id]
+
+ def GetProjectsByName(self, cnxn, project_names, use_cache=True):
+ """Load all the Project PBs for the given projects.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_names: list of project names.
+ use_cache: specifify False to force database query.
+
+ Returns:
+ A dict mapping names to the corresponding Project protocol buffers.
+ """
+ project_ids = list(self.LookupProjectIDs(cnxn, project_names).values())
+ projects = self.GetProjects(cnxn, project_ids, use_cache=use_cache)
+ return {p.project_name: p for p in projects.values()}
+
+ def GetProjectByName(self, cnxn, project_name, use_cache=True):
+ """Load the specified project from the database, None if does not exist."""
+ project_dict = self.GetProjectsByName(
+ cnxn, [project_name], use_cache=use_cache)
+ return project_dict.get(project_name)
+
+ ### Deleting projects
+
+ def ExpungeProject(self, cnxn, project_id):
+ """Wipes a project from the system."""
+ logging.info('expunging project %r', project_id)
+ self.user2project_tbl.Delete(cnxn, project_id=project_id)
+ self.usergroupprojects_tbl.Delete(cnxn, project_id=project_id)
+ self.extraperm_tbl.Delete(cnxn, project_id=project_id)
+ self.membernotes_tbl.Delete(cnxn, project_id=project_id)
+ self.acexclusion_tbl.Delete(cnxn, project_id=project_id)
+ self.project_tbl.Delete(cnxn, project_id=project_id)
+
+ ### Updating projects
+
+ def UpdateProject(
+ self,
+ cnxn,
+ project_id,
+ summary=None,
+ description=None,
+ state=None,
+ state_reason=None,
+ access=None,
+ issue_notify_address=None,
+ attachment_bytes_used=None,
+ attachment_quota=None,
+ moved_to=None,
+ process_inbound_email=None,
+ only_owners_remove_restrictions=None,
+ read_only_reason=None,
+ cached_content_timestamp=None,
+ only_owners_see_contributors=None,
+ delete_time=None,
+ recent_activity=None,
+ revision_url_format=None,
+ home_page=None,
+ docs_url=None,
+ source_url=None,
+ logo_gcs_id=None,
+ logo_file_name=None,
+ issue_notify_always_detailed=None,
+ commit=True):
+ """Update the DB with the given project information."""
+ exists = self.project_tbl.SelectValue(
+ cnxn, 'project_name', project_id=project_id)
+ if not exists:
+ raise exceptions.NoSuchProjectException()
+
+ delta = {}
+ if summary is not None:
+ delta['summary'] = summary
+ if description is not None:
+ delta['description'] = description
+ if state is not None:
+ delta['state'] = str(state).lower()
+ if state is not None:
+ delta['state_reason'] = state_reason
+ if access is not None:
+ delta['access'] = str(access).lower()
+ if read_only_reason is not None:
+ delta['read_only_reason'] = read_only_reason
+ if issue_notify_address is not None:
+ delta['issue_notify_address'] = issue_notify_address
+ if attachment_bytes_used is not None:
+ delta['attachment_bytes_used'] = attachment_bytes_used
+ if attachment_quota is not None:
+ delta['attachment_quota'] = attachment_quota
+ if moved_to is not None:
+ delta['moved_to'] = moved_to
+ if process_inbound_email is not None:
+ delta['process_inbound_email'] = process_inbound_email
+ if only_owners_remove_restrictions is not None:
+ delta['only_owners_remove_restrictions'] = (
+ only_owners_remove_restrictions)
+ if only_owners_see_contributors is not None:
+ delta['only_owners_see_contributors'] = only_owners_see_contributors
+ if delete_time is not None:
+ delta['delete_time'] = delete_time
+ if recent_activity is not None:
+ delta['recent_activity_timestamp'] = recent_activity
+ if revision_url_format is not None:
+ delta['revision_url_format'] = revision_url_format
+ if home_page is not None:
+ delta['home_page'] = home_page
+ if docs_url is not None:
+ delta['docs_url'] = docs_url
+ if source_url is not None:
+ delta['source_url'] = source_url
+ if logo_gcs_id is not None:
+ delta['logo_gcs_id'] = logo_gcs_id
+ if logo_file_name is not None:
+ delta['logo_file_name'] = logo_file_name
+ if issue_notify_always_detailed is not None:
+ delta['issue_notify_always_detailed'] = issue_notify_always_detailed
+ if cached_content_timestamp is not None:
+ delta['cached_content_timestamp'] = cached_content_timestamp
+ self.project_tbl.Update(cnxn, delta, project_id=project_id, commit=False)
+ self.project_2lc.InvalidateKeys(cnxn, [project_id])
+ if commit:
+ cnxn.Commit()
+
+ def UpdateCachedContentTimestamp(self, cnxn, project_id, now=None):
+ now = now or int(time.time())
+ self.project_tbl.Update(
+ cnxn, {'cached_content_timestamp': now},
+ project_id=project_id, commit=False)
+ return now
+
+ def UpdateProjectRoles(
+ self, cnxn, project_id, owner_ids, committer_ids, contributor_ids,
+ now=None):
+ """Store the project's roles in the DB and set cached_content_timestamp."""
+ exists = self.project_tbl.SelectValue(
+ cnxn, 'project_name', project_id=project_id)
+ if not exists:
+ raise exceptions.NoSuchProjectException()
+
+ self.UpdateCachedContentTimestamp(cnxn, project_id, now=now)
+
+ self.user2project_tbl.Delete(
+ cnxn, project_id=project_id, role_name='owner', commit=False)
+ self.user2project_tbl.Delete(
+ cnxn, project_id=project_id, role_name='committer', commit=False)
+ self.user2project_tbl.Delete(
+ cnxn, project_id=project_id, role_name='contributor', commit=False)
+
+ self.user2project_tbl.InsertRows(
+ cnxn, ['project_id', 'user_id', 'role_name'],
+ [(project_id, user_id, 'owner') for user_id in owner_ids],
+ commit=False)
+ self.user2project_tbl.InsertRows(
+ cnxn, ['project_id', 'user_id', 'role_name'],
+ [(project_id, user_id, 'committer')
+ for user_id in committer_ids], commit=False)
+
+ self.user2project_tbl.InsertRows(
+ cnxn, ['project_id', 'user_id', 'role_name'],
+ [(project_id, user_id, 'contributor')
+ for user_id in contributor_ids], commit=False)
+
+ cnxn.Commit()
+ self.project_2lc.InvalidateKeys(cnxn, [project_id])
+ updated_user_ids = owner_ids + committer_ids + contributor_ids
+ self.user_to_project_2lc.InvalidateKeys(cnxn, updated_user_ids)
+
+ def MarkProjectDeletable(self, cnxn, project_id, config_service):
+ """Update the project's state to make it DELETABLE and free up the name.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the project that will be deleted soon.
+ config_service: issue tracker configuration persistence service, needed
+ to invalidate cached issue tracker results.
+ """
+ generated_name = 'DELETABLE_%d' % project_id
+ delta = {'project_name': generated_name, 'state': 'deletable'}
+ self.project_tbl.Update(cnxn, delta, project_id=project_id)
+
+ self.project_2lc.InvalidateKeys(cnxn, [project_id])
+ # We cannot invalidate a specific part of the name->proj cache by name,
+ # So, tell every job to just drop the whole cache. It should refill
+ # efficiently and incrementally from memcache.
+ self.project_2lc.InvalidateAllRamEntries(cnxn)
+ self.user_to_project_2lc.InvalidateAllRamEntries(cnxn)
+ config_service.InvalidateMemcacheForEntireProject(project_id)
+
+ def UpdateRecentActivity(self, cnxn, project_id, now=None):
+ """Set the project's recent_activity to the current time."""
+ now = now or int(time.time())
+ project = self.GetProject(cnxn, project_id)
+ if now > project.recent_activity + RECENT_ACTIVITY_THRESHOLD:
+ self.UpdateProject(cnxn, project_id, recent_activity=now)
+
+ ### Roles, memberships, and extra perms
+
+ def GetUserRolesInAllProjects(self, cnxn, effective_ids):
+ """Return three sets of project IDs where the user has a role."""
+ owned_project_ids = set()
+ membered_project_ids = set()
+ contrib_project_ids = set()
+
+ rows = []
+ if effective_ids:
+ rows = self.user2project_tbl.Select(
+ cnxn, cols=['project_id', 'role_name'], user_id=effective_ids)
+
+ for project_id, role_name in rows:
+ if role_name == 'owner':
+ owned_project_ids.add(project_id)
+ elif role_name == 'committer':
+ membered_project_ids.add(project_id)
+ elif role_name == 'contributor':
+ contrib_project_ids.add(project_id)
+ else:
+ logging.warn('Unexpected role name %r', role_name)
+
+ return owned_project_ids, membered_project_ids, contrib_project_ids
+
+ def GetProjectMemberships(self, cnxn, effective_ids, use_cache=True):
+ # type: MonorailConnection, Collection[int], Optional[bool] ->
+ # Mapping[int, Collection[int]]
+ """Return a list of project IDs where the user has a membership."""
+ project_id_dict, missed_ids = self.user_to_project_2lc.GetAll(
+ cnxn, effective_ids, use_cache=use_cache)
+
+ # Users that were missed are assumed to not have any projects.
+ assert not missed_ids
+
+ return project_id_dict
+
+ def UpdateExtraPerms(
+ self, cnxn, project_id, member_id, extra_perms, now=None):
+ """Load the project, update the member's extra perms, and store.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+ member_id: int user id of the user that was edited.
+ extra_perms: list of strings for perms that the member
+ should have over-and-above what their role gives them.
+ now: fake int(time.time()) value passed in during unit testing.
+ """
+ # This will be a newly constructed object, not from the cache and not
+ # shared with any other thread.
+ project = self.GetProject(cnxn, project_id, use_cache=False)
+
+ idx, member_extra_perms = permissions.FindExtraPerms(project, member_id)
+ if not member_extra_perms and not extra_perms:
+ return
+ if member_extra_perms and list(member_extra_perms.perms) == extra_perms:
+ return
+ # Either project is None or member_id is not a member of the project.
+ if idx is None:
+ return
+
+ if member_extra_perms:
+ member_extra_perms.perms = extra_perms
+ else:
+ member_extra_perms = project_pb2.Project.ExtraPerms(
+ member_id=member_id, perms=extra_perms)
+ # Keep the list of extra_perms sorted by member id.
+ project.extra_perms.insert(idx, member_extra_perms)
+
+ self.extraperm_tbl.Delete(
+ cnxn, project_id=project_id, user_id=member_id, commit=False)
+ self.extraperm_tbl.InsertRows(
+ cnxn, EXTRAPERM_COLS,
+ [(project_id, member_id, perm) for perm in extra_perms],
+ commit=False)
+ project.cached_content_timestamp = self.UpdateCachedContentTimestamp(
+ cnxn, project_id, now=now)
+ cnxn.Commit()
+
+ self.project_2lc.InvalidateKeys(cnxn, [project_id])
+
+ ### Project Commitments
+
+ def GetProjectCommitments(self, cnxn, project_id):
+ """Get the project commitments (notes) from the DB.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int project ID.
+
+ Returns:
+ A the specified project's ProjectCommitments instance, or an empty one,
+ if the project doesn't exist, or has not documented member
+ commitments.
+ """
+ # Get the notes. Don't get the project_id column
+ # since we already know that value.
+ notes_rows = self.membernotes_tbl.Select(
+ cnxn, cols=['user_id', 'notes'], project_id=project_id)
+ notes_dict = dict(notes_rows)
+
+ project_commitments = project_pb2.ProjectCommitments()
+ project_commitments.project_id = project_id
+ for user_id in notes_dict.keys():
+ commitment = project_pb2.ProjectCommitments.MemberCommitment(
+ member_id=user_id,
+ notes=notes_dict.get(user_id, ''))
+ project_commitments.commitments.append(commitment)
+
+ return project_commitments
+
+ def _StoreProjectCommitments(self, cnxn, project_commitments):
+ """Store an updated set of project commitments in the DB.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_commitments: ProjectCommitments PB
+ """
+ project_id = project_commitments.project_id
+ notes_rows = []
+ for commitment in project_commitments.commitments:
+ notes_rows.append(
+ (project_id, commitment.member_id, commitment.notes))
+
+ # TODO(jrobbins): this should be in a transaction.
+ self.membernotes_tbl.Delete(cnxn, project_id=project_id)
+ self.membernotes_tbl.InsertRows(
+ cnxn, MEMBERNOTES_COLS, notes_rows, ignore=True)
+
+ def UpdateCommitments(self, cnxn, project_id, member_id, notes):
+ """Update the member's commitments in the specified project.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+ member_id: int user ID of the user that was edited.
+ notes: further notes on the member's expected involvment
+ in the project.
+ """
+ project_commitments = self.GetProjectCommitments(cnxn, project_id)
+
+ commitment = None
+ for c in project_commitments.commitments:
+ if c.member_id == member_id:
+ commitment = c
+ break
+ else:
+ commitment = project_pb2.ProjectCommitments.MemberCommitment(
+ member_id=member_id)
+ project_commitments.commitments.append(commitment)
+
+ dirty = False
+
+ if commitment.notes != notes:
+ commitment.notes = notes
+ dirty = True
+
+ if dirty:
+ self._StoreProjectCommitments(cnxn, project_commitments)
+
+ def GetProjectAutocompleteExclusion(self, cnxn, project_id):
+ """Get user ids who are excluded from autocomplete list.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+
+ Returns:
+ A pair containing: a list of user IDs who are excluded from the
+ autocomplete list for given project, and a list of group IDs to
+ not expand.
+ """
+ ac_exclusion_rows = self.acexclusion_tbl.Select(
+ cnxn, cols=['user_id'], project_id=project_id, ac_exclude=True)
+ ac_exclusion_ids = [row[0] for row in ac_exclusion_rows]
+ no_expand_rows = self.acexclusion_tbl.Select(
+ cnxn, cols=['user_id'], project_id=project_id, no_expand=True)
+ no_expand_ids = [row[0] for row in no_expand_rows]
+ return ac_exclusion_ids, no_expand_ids
+
+ def UpdateProjectAutocompleteExclusion(
+ self, cnxn, project_id, member_id, ac_exclude, no_expand):
+ """Update autocomplete exclusion for given user.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+ member_id: int user ID of the user that was edited.
+ ac_exclude: Whether this user should be excluded.
+ no_expand: Whether this group should not be expanded.
+ """
+ if ac_exclude or no_expand:
+ self.acexclusion_tbl.InsertRows(
+ cnxn, AUTOCOMPLETEEXCLUSION_COLS,
+ [(project_id, member_id, ac_exclude, no_expand)],
+ replace=True)
+ else:
+ self.acexclusion_tbl.Delete(
+ cnxn, project_id=project_id, user_id=member_id)
+
+ self.UpdateCachedContentTimestamp(cnxn, project_id)
+ cnxn.Commit()
+
+ self.project_2lc.InvalidateKeys(cnxn, [project_id])
+
+ def ExpungeUsersInProjects(self, cnxn, user_ids, limit=None):
+ """Wipes the given users from the projects system.
+
+ This method will not commit the operation. This method will
+ not make changes to in-memory data.
+ """
+ self.extraperm_tbl.Delete(cnxn, user_id=user_ids, limit=limit, commit=False)
+ self.acexclusion_tbl.Delete(
+ cnxn, user_id=user_ids, limit=limit, commit=False)
+ self.membernotes_tbl.Delete(
+ cnxn, user_id=user_ids, limit=limit, commit=False)
+ self.user2project_tbl.Delete(
+ cnxn, user_id=user_ids, limit=limit, commit=False)
diff --git a/services/secrets_svc.py b/services/secrets_svc.py
new file mode 100644
index 0000000..7b861ce
--- /dev/null
+++ b/services/secrets_svc.py
@@ -0,0 +1,87 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of functions that provide persistence for secret keys.
+
+These keys are used in generating XSRF tokens, calling the CAPTCHA API,
+and validating that inbound emails are replies to notifications that
+we sent.
+
+Unlike other data stored in Monorail, this is kept in the GAE
+datastore rather than SQL because (1) it never needs to be used in
+combination with other SQL data, and (2) we may want to replicate
+issue content for various off-line reporting functionality, but we
+will never want to do that with these keys. A copy is also kept in
+memcache for faster access.
+
+When no secrets are found, a new Secrets entity is created and initialized
+with randomly generated values for XSRF and email keys.
+
+If these secret values ever need to change:
+(1) Make the change on the Google Cloud Console in the Cloud Datastore tab.
+(2) Flush memcache.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+
+from google.appengine.api import memcache
+from google.appengine.ext import ndb
+
+import settings
+from framework import framework_helpers
+
+
+GLOBAL_KEY = 'secrets_singleton_key'
+
+
+class Secrets(ndb.Model):
+ """Model for representing secret keys."""
+ # Keys we use to generate tokens.
+ xsrf_key = ndb.StringProperty(required=True)
+ email_key = ndb.StringProperty(required=True)
+ pagination_key = ndb.StringProperty(required=True)
+
+
+def MakeSecrets():
+ """Make a new Secrets model with random values for keys."""
+ secrets = Secrets(id=GLOBAL_KEY)
+ secrets.xsrf_key = framework_helpers.MakeRandomKey()
+ secrets.email_key = framework_helpers.MakeRandomKey()
+ secrets.pagination_key = framework_helpers.MakeRandomKey()
+ return secrets
+
+
+def GetSecrets():
+ """Get secret keys from memcache or datastore. Or, make new ones."""
+ secrets = memcache.get(GLOBAL_KEY)
+ if secrets:
+ return secrets
+
+ secrets = Secrets.get_by_id(GLOBAL_KEY)
+ if not secrets:
+ secrets = MakeSecrets()
+ secrets.put()
+
+ memcache.set(GLOBAL_KEY, secrets)
+ return secrets
+
+
+def GetXSRFKey():
+ """Return a secret key string used to generate XSRF tokens."""
+ return GetSecrets().xsrf_key
+
+
+def GetEmailKey():
+ """Return a secret key string used to generate email tokens."""
+ return GetSecrets().email_key
+
+
+def GetPaginationKey():
+ """Return a secret key string used to generate pagination tokens."""
+ return GetSecrets().pagination_key
+
diff --git a/services/service_manager.py b/services/service_manager.py
new file mode 100644
index 0000000..1cb886a
--- /dev/null
+++ b/services/service_manager.py
@@ -0,0 +1,84 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Service manager to initialize all services."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+from features import autolink
+from services import cachemanager_svc
+from services import chart_svc
+from services import config_svc
+from services import features_svc
+from services import issue_svc
+from services import project_svc
+from services import spam_svc
+from services import star_svc
+from services import template_svc
+from services import user_svc
+from services import usergroup_svc
+
+
+svcs = None
+
+
+class Services(object):
+ """A simple container for widely-used service objects."""
+
+ def __init__(
+ self, project=None, user=None, issue=None, config=None,
+ usergroup=None, cache_manager=None, autolink_obj=None,
+ user_star=None, project_star=None, issue_star=None, features=None,
+ spam=None, hotlist_star=None, chart=None, template=None):
+ # Persistence services
+ self.project = project
+ self.user = user
+ self.usergroup = usergroup
+ self.issue = issue
+ self.config = config
+ self.user_star = user_star
+ self.project_star = project_star
+ self.hotlist_star = hotlist_star
+ self.issue_star = issue_star
+ self.features = features
+ self.template = template
+
+ # Misc. services
+ self.cache_manager = cache_manager
+ self.autolink = autolink_obj
+ self.spam = spam
+ self.chart = chart
+
+
+def set_up_services():
+ """Set up all services."""
+
+ global svcs
+ if svcs is None:
+ # Sorted as: cache_manager first, everything which depends on it,
+ # issue (which depends on project and config), things with no deps.
+ cache_manager = cachemanager_svc.CacheManager()
+ config = config_svc.ConfigService(cache_manager)
+ features = features_svc.FeaturesService(cache_manager, config)
+ hotlist_star = star_svc.HotlistStarService(cache_manager)
+ issue_star = star_svc.IssueStarService(cache_manager)
+ project = project_svc.ProjectService(cache_manager)
+ project_star = star_svc.ProjectStarService(cache_manager)
+ user = user_svc.UserService(cache_manager)
+ user_star = star_svc.UserStarService(cache_manager)
+ usergroup = usergroup_svc.UserGroupService(cache_manager)
+ chart = chart_svc.ChartService(config)
+ issue = issue_svc.IssueService(project, config, cache_manager, chart)
+ autolink_obj = autolink.Autolink()
+ spam = spam_svc.SpamService()
+ template = template_svc.TemplateService(cache_manager)
+ svcs = Services(
+ cache_manager=cache_manager, config=config, features=features,
+ issue_star=issue_star, project=project, project_star=project_star,
+ user=user, user_star=user_star, usergroup=usergroup, issue=issue,
+ autolink_obj=autolink_obj, spam=spam, hotlist_star=hotlist_star,
+ chart=chart, template=template)
+ return svcs
diff --git a/services/spam_svc.py b/services/spam_svc.py
new file mode 100644
index 0000000..9a62cb9
--- /dev/null
+++ b/services/spam_svc.py
@@ -0,0 +1,697 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+""" Set of functions for detaling with spam reports.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import logging
+import settings
+import sys
+
+from collections import defaultdict
+from features import filterrules_helpers
+from framework import sql
+from framework import framework_constants
+from infra_libs import ts_mon
+from services import ml_helpers
+
+
+SPAMREPORT_TABLE_NAME = 'SpamReport'
+SPAMVERDICT_TABLE_NAME = 'SpamVerdict'
+ISSUE_TABLE = 'Issue'
+
+REASON_MANUAL = 'manual'
+REASON_THRESHOLD = 'threshold'
+REASON_CLASSIFIER = 'classifier'
+REASON_FAIL_OPEN = 'fail_open'
+SPAM_CLASS_LABEL = '1'
+
+SPAMREPORT_ISSUE_COLS = ['issue_id', 'reported_user_id', 'user_id']
+SPAMVERDICT_ISSUE_COL = ['created', 'content_created', 'user_id',
+ 'reported_user_id', 'comment_id', 'issue_id']
+MANUALVERDICT_ISSUE_COLS = ['user_id', 'issue_id', 'is_spam', 'reason',
+ 'project_id']
+THRESHVERDICT_ISSUE_COLS = ['issue_id', 'is_spam', 'reason', 'project_id']
+
+SPAMREPORT_COMMENT_COLS = ['comment_id', 'reported_user_id', 'user_id']
+MANUALVERDICT_COMMENT_COLS = ['user_id', 'comment_id', 'is_spam', 'reason',
+ 'project_id']
+THRESHVERDICT_COMMENT_COLS = ['comment_id', 'is_spam', 'reason', 'project_id']
+
+
+class SpamService(object):
+ """The persistence layer for spam reports."""
+ issue_actions = ts_mon.CounterMetric(
+ 'monorail/spam_svc/issue', 'Count of things that happen to issues.', [
+ ts_mon.StringField('type'),
+ ts_mon.StringField('reporter_id'),
+ ts_mon.StringField('issue')
+ ])
+ comment_actions = ts_mon.CounterMetric(
+ 'monorail/spam_svc/comment', 'Count of things that happen to comments.', [
+ ts_mon.StringField('type'),
+ ts_mon.StringField('reporter_id'),
+ ts_mon.StringField('issue'),
+ ts_mon.StringField('comment_id')
+ ])
+ ml_engine_failures = ts_mon.CounterMetric(
+ 'monorail/spam_svc/ml_engine_failure',
+ 'Failures calling the ML Engine API',
+ None)
+
+ def __init__(self):
+ self.report_tbl = sql.SQLTableManager(SPAMREPORT_TABLE_NAME)
+ self.verdict_tbl = sql.SQLTableManager(SPAMVERDICT_TABLE_NAME)
+ self.issue_tbl = sql.SQLTableManager(ISSUE_TABLE)
+
+ # ML Engine library is lazy loaded below.
+ self.ml_engine = None
+
+ def LookupIssuesFlaggers(self, cnxn, issue_ids):
+ """Returns users who've reported the issues or their comments as spam.
+
+ Returns a dictionary {issue_id: (issue_reporters, comment_reporters)}
+ issue_reportes is a list of users who flagged the issue;
+ comment_reporters element is a dictionary {comment_id: [user_ids]} where
+ user_ids are the users who flagged that comment.
+ """
+ rows = self.report_tbl.Select(
+ cnxn, cols=['issue_id', 'user_id', 'comment_id'],
+ issue_id=issue_ids)
+
+ reporters = collections.defaultdict(
+ # Return a tuple of (issue_reporters, comment_reporters) as described
+ # above.
+ lambda: ([], collections.defaultdict(list)))
+
+ for row in rows:
+ issue_id = int(row[0])
+ user_id = row[1]
+ if row[2]:
+ comment_id = row[2]
+ reporters[issue_id][1][comment_id].append(user_id)
+ else:
+ reporters[issue_id][0].append(user_id)
+
+ return reporters
+
+ def LookupIssueFlaggers(self, cnxn, issue_id):
+ """Returns users who've reported the issue or its comments as spam.
+
+ Returns a tuple. First element is a list of users who flagged the issue;
+ second element is a dictionary of comment id to a list of users who flagged
+ that comment.
+ """
+ return self.LookupIssuesFlaggers(cnxn, [issue_id])[issue_id]
+
+ def LookupIssueFlagCounts(self, cnxn, issue_ids):
+ """Returns a map of issue_id to flag counts"""
+ rows = self.report_tbl.Select(cnxn, cols=['issue_id', 'COUNT(*)'],
+ issue_id=issue_ids, group_by=['issue_id'])
+ counts = {}
+ for row in rows:
+ counts[int(row[0])] = row[1]
+ return counts
+
+ def LookupIssueVerdicts(self, cnxn, issue_ids):
+ """Returns a map of issue_id to most recent spam verdicts"""
+ rows = self.verdict_tbl.Select(cnxn,
+ cols=['issue_id', 'reason', 'MAX(created)'],
+ issue_id=issue_ids, comment_id=None,
+ group_by=['issue_id'])
+ counts = {}
+ for row in rows:
+ counts[int(row[0])] = row[1]
+ return counts
+
+ def LookupIssueVerdictHistory(self, cnxn, issue_ids):
+ """Returns a map of issue_id to most recent spam verdicts"""
+ rows = self.verdict_tbl.Select(cnxn, cols=[
+ 'issue_id', 'reason', 'created', 'is_spam', 'classifier_confidence',
+ 'user_id', 'overruled'],
+ issue_id=issue_ids, order_by=[('issue_id', []), ('created', [])])
+
+ # TODO: group by issue_id, make class instead of dict for verdict.
+ verdicts = []
+ for row in rows:
+ verdicts.append({
+ 'issue_id': row[0],
+ 'reason': row[1],
+ 'created': row[2],
+ 'is_spam': row[3],
+ 'classifier_confidence': row[4],
+ 'user_id': row[5],
+ 'overruled': row[6],
+ })
+
+ return verdicts
+
+ def LookupCommentVerdictHistory(self, cnxn, comment_ids):
+ """Returns a map of issue_id to most recent spam verdicts"""
+ rows = self.verdict_tbl.Select(cnxn, cols=[
+ 'comment_id', 'reason', 'created', 'is_spam', 'classifier_confidence',
+ 'user_id', 'overruled'],
+ comment_id=comment_ids, order_by=[('comment_id', []), ('created', [])])
+
+ # TODO: group by comment_id, make class instead of dict for verdict.
+ verdicts = []
+ for row in rows:
+ verdicts.append({
+ 'comment_id': row[0],
+ 'reason': row[1],
+ 'created': row[2],
+ 'is_spam': row[3],
+ 'classifier_confidence': row[4],
+ 'user_id': row[5],
+ 'overruled': row[6],
+ })
+
+ return verdicts
+
+ def FlagIssues(self, cnxn, issue_service, issues, reporting_user_id,
+ flagged_spam):
+ """Creates or deletes a spam report on an issue."""
+ verdict_updates = []
+ if flagged_spam:
+ rows = [(issue.issue_id, issue.reporter_id, reporting_user_id)
+ for issue in issues]
+ self.report_tbl.InsertRows(cnxn, SPAMREPORT_ISSUE_COLS, rows,
+ ignore=True)
+ else:
+ issue_ids = [issue.issue_id for issue in issues]
+ self.report_tbl.Delete(
+ cnxn, issue_id=issue_ids, user_id=reporting_user_id,
+ comment_id=None)
+
+ project_id = issues[0].project_id
+
+ # Now record new verdicts and update issue.is_spam, if they've changed.
+ ids = [issue.issue_id for issue in issues]
+ counts = self.LookupIssueFlagCounts(cnxn, ids)
+ previous_verdicts = self.LookupIssueVerdicts(cnxn, ids)
+
+ for issue_id in counts:
+ # If the flag counts changed enough to toggle the is_spam bit, need to
+ # record a new verdict and update the Issue.
+
+ # No number of user spam flags can overturn an admin's verdict.
+ if previous_verdicts.get(issue_id) == REASON_MANUAL:
+ continue
+
+ # If enough spam flags come in, mark the issue as spam.
+ if (flagged_spam and counts[issue_id] >= settings.spam_flag_thresh):
+ verdict_updates.append(issue_id)
+
+ if len(verdict_updates) == 0:
+ return
+
+ # Some of the issues may have exceed the flag threshold, so issue verdicts
+ # and mark as spam in those cases.
+ rows = [(issue_id, flagged_spam, REASON_THRESHOLD, project_id)
+ for issue_id in verdict_updates]
+ self.verdict_tbl.InsertRows(cnxn, THRESHVERDICT_ISSUE_COLS, rows,
+ ignore=True)
+ update_issues = []
+ for issue in issues:
+ if issue.issue_id in verdict_updates:
+ issue.is_spam = flagged_spam
+ update_issues.append(issue)
+
+ if flagged_spam:
+ for issue in update_issues:
+ issue_ref = '%s:%s' % (issue.project_name, issue.local_id)
+ self.issue_actions.increment(
+ {
+ 'type': 'flag',
+ 'reporter_id': str(reporting_user_id),
+ 'issue': issue_ref
+ })
+
+ issue_service.UpdateIssues(cnxn, update_issues, update_cols=['is_spam'])
+
+ def FlagComment(
+ self, cnxn, issue, comment_id, reported_user_id, reporting_user_id,
+ flagged_spam):
+ """Creates or deletes a spam report on a comment."""
+ # TODO(seanmccullough): Bulk comment flagging? There's no UI for that.
+ if flagged_spam:
+ self.report_tbl.InsertRow(
+ cnxn,
+ ignore=True,
+ issue_id=issue.issue_id,
+ comment_id=comment_id,
+ reported_user_id=reported_user_id,
+ user_id=reporting_user_id)
+ issue_ref = '%s:%s' % (issue.project_name, issue.local_id)
+ self.comment_actions.increment(
+ {
+ 'type': 'flag',
+ 'reporter_id': str(reporting_user_id),
+ 'issue': issue_ref,
+ 'comment_id': str(comment_id)
+ })
+ else:
+ self.report_tbl.Delete(
+ cnxn,
+ issue_id=issue.issue_id,
+ comment_id=comment_id,
+ user_id=reporting_user_id)
+
+ def RecordClassifierIssueVerdict(self, cnxn, issue, is_spam, confidence,
+ fail_open):
+ reason = REASON_FAIL_OPEN if fail_open else REASON_CLASSIFIER
+ self.verdict_tbl.InsertRow(cnxn, issue_id=issue.issue_id, is_spam=is_spam,
+ reason=reason, classifier_confidence=confidence,
+ project_id=issue.project_id)
+ if is_spam:
+ issue_ref = '%s:%s' % (issue.project_name, issue.local_id)
+ self.issue_actions.increment(
+ {
+ 'type': 'classifier',
+ 'reporter_id': 'classifier',
+ 'issue': issue_ref
+ })
+ # This is called at issue creation time, so there's nothing else to do here.
+
+ def RecordManualIssueVerdicts(self, cnxn, issue_service, issues, user_id,
+ is_spam):
+ rows = [(user_id, issue.issue_id, is_spam, REASON_MANUAL, issue.project_id)
+ for issue in issues]
+ issue_ids = [issue.issue_id for issue in issues]
+
+ # Overrule all previous verdicts.
+ self.verdict_tbl.Update(cnxn, {'overruled': True}, [
+ ('issue_id IN (%s)' % sql.PlaceHolders(issue_ids), issue_ids)
+ ], commit=False)
+
+ self.verdict_tbl.InsertRows(cnxn, MANUALVERDICT_ISSUE_COLS, rows,
+ ignore=True)
+
+ for issue in issues:
+ issue.is_spam = is_spam
+
+ if is_spam:
+ for issue in issues:
+ issue_ref = '%s:%s' % (issue.project_name, issue.local_id)
+ self.issue_actions.increment(
+ {
+ 'type': 'manual',
+ 'reporter_id': str(user_id),
+ 'issue': issue_ref
+ })
+ else:
+ issue_service.AllocateNewLocalIDs(cnxn, issues)
+
+ # This will commit the transaction.
+ issue_service.UpdateIssues(cnxn, issues, update_cols=['is_spam'])
+
+ def RecordManualCommentVerdict(self, cnxn, issue_service, user_service,
+ comment_id, user_id, is_spam):
+ # TODO(seanmccullough): Bulk comment verdicts? There's no UI for that.
+ self.verdict_tbl.InsertRow(cnxn, ignore=True,
+ user_id=user_id, comment_id=comment_id, is_spam=is_spam,
+ reason=REASON_MANUAL)
+ comment = issue_service.GetComment(cnxn, comment_id)
+ comment.is_spam = is_spam
+ issue = issue_service.GetIssue(cnxn, comment.issue_id, use_cache=False)
+ issue_service.SoftDeleteComment(
+ cnxn, issue, comment, user_id, user_service, is_spam, True, is_spam)
+ if is_spam:
+ issue_ref = '%s:%s' % (issue.project_name, issue.local_id)
+ self.comment_actions.increment(
+ {
+ 'type': 'manual',
+ 'reporter_id': str(user_id),
+ 'issue': issue_ref,
+ 'comment_id': str(comment_id)
+ })
+
+ def RecordClassifierCommentVerdict(
+ self, cnxn, issue_service, comment, is_spam, confidence, fail_open):
+ reason = REASON_FAIL_OPEN if fail_open else REASON_CLASSIFIER
+ self.verdict_tbl.InsertRow(cnxn, comment_id=comment.id, is_spam=is_spam,
+ reason=reason, classifier_confidence=confidence,
+ project_id=comment.project_id)
+ if is_spam:
+ issue = issue_service.GetIssue(cnxn, comment.issue_id, use_cache=False)
+ issue_ref = '%s:%s' % (issue.project_name, issue.local_id)
+ self.comment_actions.increment(
+ {
+ 'type': 'classifier',
+ 'reporter_id': 'classifier',
+ 'issue': issue_ref,
+ 'comment_id': str(comment.id)
+ })
+
+ def _predict(self, instance):
+ """Requests a prediction from the ML Engine API.
+
+ Sample API response:
+ {'predictions': [{
+ 'classes': ['0', '1'],
+ 'scores': [0.4986788034439087, 0.5013211965560913]
+ }]}
+
+ This hits the default model.
+
+ Returns:
+ A floating point number representing the confidence
+ the instance is spam.
+ """
+ model_name = 'projects/%s/models/%s' % (
+ settings.classifier_project_id, settings.spam_model_name)
+ body = {'instances': [{"inputs": instance["word_hashes"]}]}
+
+ if not self.ml_engine:
+ self.ml_engine = ml_helpers.setup_ml_engine()
+
+ request = self.ml_engine.projects().predict(name=model_name, body=body)
+ response = request.execute()
+ logging.info('ML Engine API response: %r' % response)
+ prediction = response['predictions'][0]
+
+ # Ensure the class confidence we return is for the spam, not the ham label.
+ # The spam label, '1', is usually at index 1 but I'm not sure of any
+ # guarantees around label order.
+ if prediction['classes'][1] == SPAM_CLASS_LABEL:
+ return prediction['scores'][1]
+ elif prediction['classes'][0] == SPAM_CLASS_LABEL:
+ return prediction['scores'][0]
+ else:
+ raise Exception('No predicted classes found.')
+
+ def _IsExempt(self, author, is_project_member):
+ """Return True if the user is exempt from spam checking."""
+ if author.email is not None and author.email.endswith(
+ settings.spam_allowlisted_suffixes):
+ logging.info('%s allowlisted from spam filtering', author.email)
+ return True
+
+ if is_project_member:
+ logging.info('%s is a project member, assuming ham', author.email)
+ return True
+
+ return False
+
+ def ClassifyIssue(self, issue, firstComment, reporter, is_project_member):
+ """Classify an issue as either spam or ham.
+
+ Args:
+ issue: the Issue.
+ firstComment: the first Comment on issue.
+ reporter: User PB for the Issue reporter.
+ is_project_member: True if reporter is a member of issue's project.
+
+ Returns a JSON dict of classifier prediction results from
+ the ML Engine API.
+ """
+ instance = ml_helpers.GenerateFeaturesRaw(
+ [issue.summary, firstComment.content],
+ settings.spam_feature_hashes)
+ return self._classify(instance, reporter, is_project_member)
+
+ def ClassifyComment(self, comment_content, commenter, is_project_member=True):
+ """Classify a comment as either spam or ham.
+
+ Args:
+ comment: the comment text.
+ commenter: User PB for the user who authored the comment.
+
+ Returns a JSON dict of classifier prediction results from
+ the ML Engine API.
+ """
+ instance = ml_helpers.GenerateFeaturesRaw(
+ ['', comment_content],
+ settings.spam_feature_hashes)
+ return self._classify(instance, commenter, is_project_member)
+
+
+ def _classify(self, instance, author, is_project_member):
+ # Fail-safe: not spam.
+ result = self.ham_classification()
+
+ if self._IsExempt(author, is_project_member):
+ return result
+
+ if not self.ml_engine:
+ self.ml_engine = ml_helpers.setup_ml_engine()
+
+ # If setup_ml_engine returns None, it failed to init.
+ if not self.ml_engine:
+ logging.error("ML Engine not initialized.")
+ self.ml_engine_failures.increment()
+ result['failed_open'] = True
+ return result
+
+ remaining_retries = 3
+ while remaining_retries > 0:
+ try:
+ result['confidence_is_spam'] = self._predict(instance)
+ result['failed_open'] = False
+ return result
+ except Exception as ex:
+ remaining_retries = remaining_retries - 1
+ self.ml_engine_failures.increment()
+ logging.error('Error calling ML Engine API: %s' % ex)
+
+ result['failed_open'] = True
+ return result
+
+ def ham_classification(self):
+ return {'confidence_is_spam': 0.0,
+ 'failed_open': False}
+
+ def GetIssueClassifierQueue(
+ self, cnxn, _issue_service, project_id, offset=0, limit=10):
+ """Returns list of recent issues with spam verdicts,
+ ranked in ascending order of confidence (so uncertain items are first).
+ """
+ # TODO(seanmccullough): Optimize pagination. This query probably gets
+ # slower as the number of SpamVerdicts grows, regardless of offset
+ # and limit values used here. Using offset,limit in general may not
+ # be the best way to do this.
+ issue_results = self.verdict_tbl.Select(
+ cnxn,
+ cols=[
+ 'issue_id', 'is_spam', 'reason', 'classifier_confidence', 'created'
+ ],
+ where=[
+ ('project_id = %s', [project_id]),
+ (
+ 'classifier_confidence <= %s',
+ [settings.classifier_moderation_thresh]),
+ ('overruled = %s', [False]),
+ ('issue_id IS NOT NULL', []),
+ ],
+ order_by=[
+ ('classifier_confidence ASC', []),
+ ('created ASC', []),
+ ],
+ group_by=['issue_id'],
+ offset=offset,
+ limit=limit,
+ )
+
+ ret = []
+ for row in issue_results:
+ ret.append(
+ ModerationItem(
+ issue_id=int(row[0]),
+ is_spam=row[1] == 1,
+ reason=row[2],
+ classifier_confidence=row[3],
+ verdict_time='%s' % row[4],
+ ))
+
+ count = self.verdict_tbl.SelectValue(
+ cnxn,
+ col='COUNT(*)',
+ where=[
+ ('project_id = %s', [project_id]),
+ (
+ 'classifier_confidence <= %s',
+ [settings.classifier_moderation_thresh]),
+ ('overruled = %s', [False]),
+ ('issue_id IS NOT NULL', []),
+ ])
+
+ return ret, count
+
+ def GetIssueFlagQueue(
+ self, cnxn, _issue_service, project_id, offset=0, limit=10):
+ """Returns list of recent issues that have been flagged by users"""
+ issue_flags = self.report_tbl.Select(
+ cnxn,
+ cols=[
+ "Issue.project_id", "Report.issue_id", "count(*) as count",
+ "max(Report.created) as latest",
+ "count(distinct Report.user_id) as users"
+ ],
+ left_joins=["Issue ON Issue.id = Report.issue_id"],
+ where=[
+ ('Report.issue_id IS NOT NULL', []),
+ ("Issue.project_id == %v", [project_id])
+ ],
+ order_by=[('count DESC', [])],
+ group_by=['Report.issue_id'],
+ offset=offset,
+ limit=limit)
+ ret = []
+ for row in issue_flags:
+ ret.append(
+ ModerationItem(
+ project_id=row[0],
+ issue_id=row[1],
+ count=row[2],
+ latest_report=row[3],
+ num_users=row[4],
+ ))
+
+ count = self.verdict_tbl.SelectValue(
+ cnxn,
+ col='COUNT(DISTINCT Report.issue_id)',
+ where=[('Issue.project_id = %s', [project_id])],
+ left_joins=["Issue ON Issue.id = SpamReport.issue_id"])
+ return ret, count
+
+
+ def GetCommentClassifierQueue(
+ self, cnxn, _issue_service, project_id, offset=0, limit=10):
+ """Returns list of recent comments with spam verdicts,
+ ranked in ascending order of confidence (so uncertain items are first).
+ """
+ # TODO(seanmccullough): Optimize pagination. This query probably gets
+ # slower as the number of SpamVerdicts grows, regardless of offset
+ # and limit values used here. Using offset,limit in general may not
+ # be the best way to do this.
+ comment_results = self.verdict_tbl.Select(
+ cnxn,
+ cols=[
+ 'issue_id', 'is_spam', 'reason', 'classifier_confidence', 'created'
+ ],
+ where=[
+ ('project_id = %s', [project_id]),
+ (
+ 'classifier_confidence <= %s',
+ [settings.classifier_moderation_thresh]),
+ ('overruled = %s', [False]),
+ ('comment_id IS NOT NULL', []),
+ ],
+ order_by=[
+ ('classifier_confidence ASC', []),
+ ('created ASC', []),
+ ],
+ group_by=['comment_id'],
+ offset=offset,
+ limit=limit,
+ )
+
+ ret = []
+ for row in comment_results:
+ ret.append(
+ ModerationItem(
+ comment_id=int(row[0]),
+ is_spam=row[1] == 1,
+ reason=row[2],
+ classifier_confidence=row[3],
+ verdict_time='%s' % row[4],
+ ))
+
+ count = self.verdict_tbl.SelectValue(
+ cnxn,
+ col='COUNT(*)',
+ where=[
+ ('project_id = %s', [project_id]),
+ (
+ 'classifier_confidence <= %s',
+ [settings.classifier_moderation_thresh]),
+ ('overruled = %s', [False]),
+ ('comment_id IS NOT NULL', []),
+ ])
+
+ return ret, count
+
+
+ def GetTrainingIssues(self, cnxn, issue_service, since, offset=0, limit=100):
+ """Returns list of recent issues with human-labeled spam/ham verdicts.
+ """
+
+ # get all of the manual verdicts in the past day.
+ results = self.verdict_tbl.Select(cnxn,
+ cols=['issue_id'],
+ where=[
+ ('overruled = %s', [False]),
+ ('reason = %s', ['manual']),
+ ('issue_id IS NOT NULL', []),
+ ('created > %s', [since.isoformat()]),
+ ],
+ offset=offset,
+ limit=limit,
+ )
+
+ issue_ids = [int(row[0]) for row in results if row[0]]
+ issues = issue_service.GetIssues(cnxn, issue_ids)
+ comments = issue_service.GetCommentsForIssues(cnxn, issue_ids)
+ first_comments = {}
+ for issue in issues:
+ first_comments[issue.issue_id] = (comments[issue.issue_id][0].content
+ if issue.issue_id in comments else "[Empty]")
+
+ count = self.verdict_tbl.SelectValue(cnxn,
+ col='COUNT(*)',
+ where=[
+ ('overruled = %s', [False]),
+ ('reason = %s', ['manual']),
+ ('issue_id IS NOT NULL', []),
+ ('created > %s', [since.isoformat()]),
+ ])
+
+ return issues, first_comments, count
+
+ def GetTrainingComments(self, cnxn, issue_service, since, offset=0,
+ limit=100):
+ """Returns list of recent comments with human-labeled spam/ham verdicts.
+ """
+
+ # get all of the manual verdicts in the past day.
+ results = self.verdict_tbl.Select(
+ cnxn,
+ distinct=True,
+ cols=['comment_id'],
+ where=[
+ ('overruled = %s', [False]),
+ ('reason = %s', ['manual']),
+ ('comment_id IS NOT NULL', []),
+ ('created > %s', [since.isoformat()]),
+ ],
+ offset=offset,
+ limit=limit,
+ )
+
+ comment_ids = [int(row[0]) for row in results if row[0]]
+ # Don't care about sequence numbers in this context yet.
+ comments = issue_service.GetCommentsByID(cnxn, comment_ids,
+ defaultdict(int))
+ return comments
+
+ def ExpungeUsersInSpam(self, cnxn, user_ids):
+ """Removes all references to given users from Spam DB tables.
+
+ This method will not commit the operations. This method will
+ not make changes to in-memory data.
+ """
+ commit = False
+ self.report_tbl.Delete(cnxn, reported_user_id=user_ids, commit=commit)
+ self.report_tbl.Delete(cnxn, user_id=user_ids, commit=commit)
+ self.verdict_tbl.Delete(cnxn, user_id=user_ids, commit=commit)
+
+
+class ModerationItem:
+ def __init__(self, **kwargs):
+ self.__dict__ = kwargs
diff --git a/services/star_svc.py b/services/star_svc.py
new file mode 100644
index 0000000..bb92e73
--- /dev/null
+++ b/services/star_svc.py
@@ -0,0 +1,264 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of functions that provide persistence for stars.
+
+Stars can be on users, projects, or issues.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+
+import settings
+from features import filterrules_helpers
+from framework import sql
+from services import caches
+
+
+USERSTAR_TABLE_NAME = 'UserStar'
+PROJECTSTAR_TABLE_NAME = 'ProjectStar'
+ISSUESTAR_TABLE_NAME = 'IssueStar'
+HOTLISTSTAR_TABLE_NAME = 'HotlistStar'
+
+# TODO(jrobbins): Consider adding memcache here if performance testing shows
+# that stars are a bottleneck. Keep in mind that issue star counts are
+# already denormalized and stored in the Issue, which is cached in memcache.
+
+
+class AbstractStarService(object):
+ """The persistence layer for any kind of star data."""
+
+ def __init__(self, cache_manager, tbl, item_col, user_col, cache_kind):
+ """Constructor.
+
+ Args:
+ cache_manager: local cache with distributed invalidation.
+ tbl: SQL table that stores star data.
+ item_col: string SQL column name that holds int item IDs.
+ user_col: string SQL column name that holds int user IDs
+ of the user who starred the item.
+ cache_kind: string saying the kind of RAM cache.
+ """
+ self.tbl = tbl
+ self.item_col = item_col
+ self.user_col = user_col
+
+ # Items starred by users, keyed by user who did the starring.
+ self.star_cache = caches.RamCache(cache_manager, 'user')
+ # Users that starred an item, keyed by item ID.
+ self.starrer_cache = caches.RamCache(cache_manager, cache_kind)
+ # Counts of the users that starred an item, keyed by item ID.
+ self.star_count_cache = caches.RamCache(cache_manager, cache_kind)
+
+ def ExpungeStars(self, cnxn, item_id, commit=True, limit=None):
+ """Wipes an item's stars from the system."""
+ self.tbl.Delete(
+ cnxn, commit=commit, limit=limit, **{self.item_col: item_id})
+
+ def ExpungeStarsByUsers(self, cnxn, user_ids, limit=None):
+ """Wipes a user's stars from the system.
+ This method will not commit the operation. This method will
+ not make changes to in-memory data.
+ """
+ self.tbl.Delete(cnxn, user_id=user_ids, commit=False, limit=limit)
+
+ def LookupItemStarrers(self, cnxn, item_id):
+ """Returns list of users having stars on the specified item."""
+ starrer_list_dict = self.LookupItemsStarrers(cnxn, [item_id])
+ return starrer_list_dict[item_id]
+
+ def LookupItemsStarrers(self, cnxn, items_ids):
+ """Returns {item_id: [uid, ...]} of users who starred these items."""
+ starrer_list_dict, missed_ids = self.starrer_cache.GetAll(items_ids)
+
+ if missed_ids:
+ rows = self.tbl.Select(
+ cnxn, cols=[self.item_col, self.user_col],
+ **{self.item_col: missed_ids})
+ # Ensure that every requested item_id has an entry so that even
+ # zero-star items get cached.
+ retrieved_starrers = {item_id: [] for item_id in missed_ids}
+ for item_id, starrer_id in rows:
+ retrieved_starrers[item_id].append(starrer_id)
+ starrer_list_dict.update(retrieved_starrers)
+ self.starrer_cache.CacheAll(retrieved_starrers)
+
+ return starrer_list_dict
+
+ def LookupStarredItemIDs(self, cnxn, starrer_user_id):
+ """Returns list of item IDs that were starred by the specified user."""
+ if not starrer_user_id:
+ return [] # Anon user cannot star anything.
+
+ cached_item_ids = self.star_cache.GetItem(starrer_user_id)
+ if cached_item_ids is not None:
+ return cached_item_ids
+
+ rows = self.tbl.Select(cnxn, cols=[self.item_col], user_id=starrer_user_id)
+ starred_ids = [row[0] for row in rows]
+ self.star_cache.CacheItem(starrer_user_id, starred_ids)
+ return starred_ids
+
+ def IsItemStarredBy(self, cnxn, item_id, starrer_user_id):
+ """Return True if the given issue is starred by the given user."""
+ starred_ids = self.LookupStarredItemIDs(cnxn, starrer_user_id)
+ return item_id in starred_ids
+
+ def CountItemStars(self, cnxn, item_id):
+ """Returns the number of stars on the specified item."""
+ count_dict = self.CountItemsStars(cnxn, [item_id])
+ return count_dict.get(item_id, 0)
+
+ def CountItemsStars(self, cnxn, item_ids):
+ """Get a dict {item_id: count} for the given items."""
+ item_count_dict, missed_ids = self.star_count_cache.GetAll(item_ids)
+
+ if missed_ids:
+ rows = self.tbl.Select(
+ cnxn, cols=[self.item_col, 'COUNT(%s)' % self.user_col],
+ group_by=[self.item_col],
+ **{self.item_col: missed_ids})
+ # Ensure that every requested item_id has an entry so that even
+ # zero-star items get cached.
+ retrieved_counts = {item_id: 0 for item_id in missed_ids}
+ retrieved_counts.update(rows)
+ item_count_dict.update(retrieved_counts)
+ self.star_count_cache.CacheAll(retrieved_counts)
+
+ return item_count_dict
+
+ def _SetStarsBatch(
+ self, cnxn, item_id, starrer_user_ids, starred, commit=True):
+ """Sets or unsets stars for the specified item and users."""
+ if starred:
+ rows = [(item_id, user_id) for user_id in starrer_user_ids]
+ self.tbl.InsertRows(
+ cnxn, [self.item_col, self.user_col], rows, ignore=True,
+ commit=commit)
+ else:
+ self.tbl.Delete(
+ cnxn, commit=commit,
+ **{self.item_col: item_id, self.user_col: starrer_user_ids})
+
+ self.star_cache.InvalidateKeys(cnxn, starrer_user_ids)
+ self.starrer_cache.Invalidate(cnxn, item_id)
+ self.star_count_cache.Invalidate(cnxn, item_id)
+
+ def SetStarsBatch(
+ self, cnxn, item_id, starrer_user_ids, starred, commit=True):
+ """Sets or unsets stars for the specified item and users."""
+ self._SetStarsBatch(
+ cnxn, item_id, starrer_user_ids, starred, commit=commit)
+
+ def SetStar(self, cnxn, item_id, starrer_user_id, starred):
+ """Sets or unsets a star for the specified item and user."""
+ self._SetStarsBatch(cnxn, item_id, [starrer_user_id], starred)
+
+
+
+class UserStarService(AbstractStarService):
+ """Star service for stars on users."""
+
+ def __init__(self, cache_manager):
+ tbl = sql.SQLTableManager(USERSTAR_TABLE_NAME)
+ super(UserStarService, self).__init__(
+ cache_manager, tbl, 'starred_user_id', 'user_id', 'user')
+
+
+class ProjectStarService(AbstractStarService):
+ """Star service for stars on projects."""
+
+ def __init__(self, cache_manager):
+ tbl = sql.SQLTableManager(PROJECTSTAR_TABLE_NAME)
+ super(ProjectStarService, self).__init__(
+ cache_manager, tbl, 'project_id', 'user_id', 'project')
+
+
+class HotlistStarService(AbstractStarService):
+ """Star service for stars on hotlists."""
+
+ def __init__(self, cache_manager):
+ tbl = sql.SQLTableManager(HOTLISTSTAR_TABLE_NAME)
+ super(HotlistStarService, self).__init__(
+ cache_manager, tbl, 'hotlist_id', 'user_id', 'hotlist')
+
+
+class IssueStarService(AbstractStarService):
+ """Star service for stars on issues."""
+
+ def __init__(self, cache_manager):
+ tbl = sql.SQLTableManager(ISSUESTAR_TABLE_NAME)
+ super(IssueStarService, self).__init__(
+ cache_manager, tbl, 'issue_id', 'user_id', 'issue')
+
+ # pylint: disable=arguments-differ
+ def SetStar(
+ self, cnxn, services, config, issue_id, starrer_user_id, starred):
+ """Add or remove a star on the given issue for the given user.
+
+ Args:
+ cnxn: connection to SQL database.
+ services: connections to persistence layer.
+ config: ProjectIssueConfig PB for the project containing the issue.
+ issue_id: integer global ID of an issue.
+ starrer_user_id: user ID of the user who starred the issue.
+ starred: boolean True for adding a star, False when removing one.
+ """
+ self.SetStarsBatch(
+ cnxn, services, config, issue_id, [starrer_user_id], starred)
+
+ # pylint: disable=arguments-differ
+ def SetStarsBatch(
+ self, cnxn, services, config, issue_id, starrer_user_ids, starred):
+ """Add or remove a star on the given issue for the given users.
+
+ Args:
+ cnxn: connection to SQL database.
+ services: connections to persistence layer.
+ config: ProjectIssueConfig PB for the project containing the issue.
+ issue_id: integer global ID of an issue.
+ starrer_user_id: user ID of the user who starred the issue.
+ starred: boolean True for adding a star, False when removing one.
+ """
+ logging.info(
+ 'SetStarsBatch:%r, %r, %r', issue_id, starrer_user_ids, starred)
+ super(IssueStarService, self).SetStarsBatch(
+ cnxn, issue_id, starrer_user_ids, starred)
+
+ # Because we will modify issues, load from DB rather than cache.
+ issue = services.issue.GetIssue(cnxn, issue_id, use_cache=False)
+ issue.star_count = self.CountItemStars(cnxn, issue_id)
+ filterrules_helpers.ApplyFilterRules(cnxn, services, issue, config)
+ # Note: only star_count could change due to the starring, but any
+ # field could have changed as a result of filter rules.
+ services.issue.UpdateIssue(cnxn, issue)
+
+ self.star_cache.InvalidateKeys(cnxn, starrer_user_ids)
+ self.starrer_cache.Invalidate(cnxn, issue_id)
+
+ # TODO(crbug.com/monorail/8098): This method should replace SetStarsBatch.
+ # New code should be calling SetStarsBatch_SkipIssueUpdate.
+ # SetStarsBatch, does issue.star_count updating that should be done
+ # in the business logic layer instead. E.g. We can create a
+ # WorkEnv.BatchSetStars() that includes the star_count updating work.
+ def SetStarsBatch_SkipIssueUpdate(
+ self, cnxn, issue_id, starrer_user_ids, starred, commit=True):
+ # type: (MonorailConnection, int, Sequence[int], bool, Optional[bool])
+ # -> None
+ """Add or remove a star on the given issue for the given users.
+
+ Note: unlike SetStarsBatch above, does not make any updates to the
+ the issue itself e.g. updating issue.star_count.
+
+ """
+ logging.info(
+ 'SetStarsBatch:%r, %r, %r', issue_id, starrer_user_ids, starred)
+ super(IssueStarService, self).SetStarsBatch(
+ cnxn, issue_id, starrer_user_ids, starred, commit=commit)
+
+ self.star_cache.InvalidateKeys(cnxn, starrer_user_ids)
+ self.starrer_cache.Invalidate(cnxn, issue_id)
diff --git a/services/template_svc.py b/services/template_svc.py
new file mode 100644
index 0000000..edfde05
--- /dev/null
+++ b/services/template_svc.py
@@ -0,0 +1,550 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""The TemplateService class providing methods for template persistence."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import logging
+
+import settings
+
+from framework import exceptions
+from framework import sql
+from proto import tracker_pb2
+from services import caches
+from services import project_svc
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+
+TEMPLATE_COLS = [
+ 'id', 'project_id', 'name', 'content', 'summary', 'summary_must_be_edited',
+ 'owner_id', 'status', 'members_only', 'owner_defaults_to_member',
+ 'component_required']
+TEMPLATE2LABEL_COLS = ['template_id', 'label']
+TEMPLATE2COMPONENT_COLS = ['template_id', 'component_id']
+TEMPLATE2ADMIN_COLS = ['template_id', 'admin_id']
+TEMPLATE2FIELDVALUE_COLS = [
+ 'template_id', 'field_id', 'int_value', 'str_value', 'user_id',
+ 'date_value', 'url_value']
+ISSUEPHASEDEF_COLS = ['id', 'name', 'rank']
+TEMPLATE2APPROVALVALUE_COLS = [
+ 'approval_id', 'template_id', 'phase_id', 'status']
+
+
+TEMPLATE_TABLE_NAME = 'Template'
+TEMPLATE2LABEL_TABLE_NAME = 'Template2Label'
+TEMPLATE2ADMIN_TABLE_NAME = 'Template2Admin'
+TEMPLATE2COMPONENT_TABLE_NAME = 'Template2Component'
+TEMPLATE2FIELDVALUE_TABLE_NAME = 'Template2FieldValue'
+ISSUEPHASEDEF_TABLE_NAME = 'IssuePhaseDef'
+TEMPLATE2APPROVALVALUE_TABLE_NAME = 'Template2ApprovalValue'
+
+
+class TemplateSetTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for templates.
+
+ Holds a dictionary of {project_id: templateset} key value pairs,
+ where a templateset is a list of all templates in a project.
+ """
+
+ def __init__(self, cache_manager, template_service):
+ super(TemplateSetTwoLevelCache, self).__init__(
+ cache_manager, 'project', prefix='templateset:', pb_class=None)
+ self.template_service = template_service
+
+ def _MakeCache(self, cache_manager, kind, max_size=None):
+ """Make the RAM cache and register it with the cache_manager."""
+ return caches.RamCache(cache_manager, kind, max_size=max_size)
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database."""
+ template_set_dict = {}
+
+ for project_id in keys:
+ template_set_dict.setdefault(project_id, [])
+ template_rows = self.template_service.template_tbl.Select(
+ cnxn, cols=TEMPLATE_COLS, project_id=project_id,
+ order_by=[('name', [])])
+ for (template_id, _project_id, template_name, _content, _summary,
+ _summary_must_be_edited, _owner_id, _status, members_only,
+ _owner_defaults_to_member, _component_required) in template_rows:
+ template_set_row = (template_id, template_name, members_only)
+ template_set_dict[project_id].append(template_set_row)
+
+ return template_set_dict
+
+
+class TemplateDefTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for individual TemplateDef.
+
+ Holds a dictionary of {template_id: TemplateDef} key value pairs.
+ """
+ def __init__(self, cache_manager, template_service):
+ super(TemplateDefTwoLevelCache, self).__init__(
+ cache_manager,
+ 'template',
+ prefix='templatedef:',
+ pb_class=tracker_pb2.TemplateDef)
+ self.template_service = template_service
+
+ def _MakeCache(self, cache_manager, kind, max_size=None):
+ """Make the RAM cache and register it with the cache_manager."""
+ return caches.RamCache(cache_manager, kind, max_size=max_size)
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database.
+
+ Args:
+ cnxn: A MonorailConnection.
+ keys: A list of template IDs (ints).
+
+ Returns:
+ A dict of {template_id: TemplateDef}.
+ """
+ template_dict = {}
+
+ # Fetch template rows and relations.
+ template_rows = self.template_service.template_tbl.Select(
+ cnxn, cols=TEMPLATE_COLS, id=keys,
+ order_by=[('name', [])])
+
+ template2label_rows = self.template_service.\
+ template2label_tbl.Select(
+ cnxn, cols=TEMPLATE2LABEL_COLS, template_id=keys)
+ template2component_rows = self.template_service.\
+ template2component_tbl.Select(
+ cnxn, cols=TEMPLATE2COMPONENT_COLS, template_id=keys)
+ template2admin_rows = self.template_service.template2admin_tbl.Select(
+ cnxn, cols=TEMPLATE2ADMIN_COLS, template_id=keys)
+ template2fieldvalue_rows = self.template_service.\
+ template2fieldvalue_tbl.Select(
+ cnxn, cols=TEMPLATE2FIELDVALUE_COLS, template_id=keys)
+ template2approvalvalue_rows = self.template_service.\
+ template2approvalvalue_tbl.Select(
+ cnxn, cols=TEMPLATE2APPROVALVALUE_COLS, template_id=keys)
+ phase_ids = [av_row[2] for av_row in template2approvalvalue_rows]
+ phase_rows = []
+ if phase_ids:
+ phase_rows = self.template_service.issuephasedef_tbl.Select(
+ cnxn, cols=ISSUEPHASEDEF_COLS, id=list(set(phase_ids)))
+
+ # Build TemplateDef with all related data.
+ for template_row in template_rows:
+ template = UnpackTemplate(template_row)
+ template_dict[template.template_id] = template
+
+ for template2label_row in template2label_rows:
+ template_id, label = template2label_row
+ template = template_dict.get(template_id)
+ if template:
+ template.labels.append(label)
+
+ for template2component_row in template2component_rows:
+ template_id, component_id = template2component_row
+ template = template_dict.get(template_id)
+ if template:
+ template.component_ids.append(component_id)
+
+ for template2admin_row in template2admin_rows:
+ template_id, admin_id = template2admin_row
+ template = template_dict.get(template_id)
+ if template:
+ template.admin_ids.append(admin_id)
+
+ for fv_row in template2fieldvalue_rows:
+ (template_id, field_id, int_value, str_value, user_id,
+ date_value, url_value) = fv_row
+ fv = tracker_bizobj.MakeFieldValue(
+ field_id, int_value, str_value, user_id, date_value, url_value,
+ False)
+ template = template_dict.get(template_id)
+ if template:
+ template.field_values.append(fv)
+
+ phases_by_id = {}
+ for phase_row in phase_rows:
+ (phase_id, name, rank) = phase_row
+ phase = tracker_pb2.Phase(
+ phase_id=phase_id, name=name, rank=rank)
+ phases_by_id[phase_id] = phase
+
+ # Note: there is no templateapproval2approver_tbl.
+ for av_row in template2approvalvalue_rows:
+ (approval_id, template_id, phase_id, status) = av_row
+ approval_value = tracker_pb2.ApprovalValue(
+ approval_id=approval_id, phase_id=phase_id,
+ status=tracker_pb2.ApprovalStatus(status.upper()))
+ template = template_dict.get(template_id)
+ if template:
+ template.approval_values.append(approval_value)
+ phase = phases_by_id.get(phase_id)
+ if phase and phase not in template.phases:
+ template_dict.get(template_id).phases.append(phase)
+
+ return template_dict
+
+
+class TemplateService(object):
+
+ def __init__(self, cache_manager):
+ self.template_tbl = sql.SQLTableManager(TEMPLATE_TABLE_NAME)
+ self.template2label_tbl = sql.SQLTableManager(TEMPLATE2LABEL_TABLE_NAME)
+ self.template2component_tbl = sql.SQLTableManager(
+ TEMPLATE2COMPONENT_TABLE_NAME)
+ self.template2admin_tbl = sql.SQLTableManager(TEMPLATE2ADMIN_TABLE_NAME)
+ self.template2fieldvalue_tbl = sql.SQLTableManager(
+ TEMPLATE2FIELDVALUE_TABLE_NAME)
+ self.issuephasedef_tbl = sql.SQLTableManager(
+ ISSUEPHASEDEF_TABLE_NAME)
+ self.template2approvalvalue_tbl = sql.SQLTableManager(
+ TEMPLATE2APPROVALVALUE_TABLE_NAME)
+
+ self.template_set_2lc = TemplateSetTwoLevelCache(cache_manager, self)
+ self.template_def_2lc = TemplateDefTwoLevelCache(cache_manager, self)
+
+ def CreateDefaultProjectTemplates(self, cnxn, project_id):
+ """Create the default templates for a project.
+
+ Used only when creating a new project.
+
+ Args:
+ cnxn: A MonorailConnection instance.
+ project_id: The project ID under which to create the templates.
+ """
+ for tpl in tracker_constants.DEFAULT_TEMPLATES:
+ tpl = tracker_bizobj.ConvertDictToTemplate(tpl)
+ self.CreateIssueTemplateDef(cnxn, project_id, tpl.name, tpl.content,
+ tpl.summary, tpl.summary_must_be_edited, tpl.status, tpl.members_only,
+ tpl.owner_defaults_to_member, tpl.component_required, tpl.owner_id,
+ tpl.labels, tpl.component_ids, tpl.admin_ids, tpl.field_values,
+ tpl.phases)
+
+ def GetTemplateByName(self, cnxn, template_name, project_id):
+ """Retrieves a template by name and project_id.
+
+ Args:
+ template_name (string): name of template.
+ project_id (int): ID of project template is under.
+
+ Returns:
+ A Template PB if found, otherwise None.
+ """
+ template_set = self.GetTemplateSetForProject(cnxn, project_id)
+ for tpl_id, name, _members_only in template_set:
+ if template_name == name:
+ return self.GetTemplateById(cnxn, tpl_id)
+
+ def GetTemplateById(self, cnxn, template_id):
+ """Retrieves one template.
+
+ Args:
+ template_id (int): ID of the template.
+
+ Returns:
+ A TemplateDef PB if found, otherwise None.
+ """
+ result_dict, _ = self.template_def_2lc.GetAll(cnxn, [template_id])
+ try:
+ return result_dict[template_id]
+ except KeyError:
+ return None
+
+ def GetTemplatesById(self, cnxn, template_ids):
+ """Retrieves one or more templates by ID.
+
+ Args:
+ template_id (list<int>): IDs of the templates.
+
+ Returns:
+ A list containing any found TemplateDef PBs.
+ """
+ result_dict, _ = self.template_def_2lc.GetAll(cnxn, template_ids)
+ return list(result_dict.values())
+
+ def GetTemplateSetForProject(self, cnxn, project_id):
+ """Get the TemplateSet for a project."""
+ result_dict, _ = self.template_set_2lc.GetAll(cnxn, [project_id])
+ return result_dict[project_id]
+
+ def GetProjectTemplates(self, cnxn, project_id):
+ """Gets all templates in a given project.
+
+ Args:
+ cnxn: A MonorailConnection instance.
+ project_id: All templates for this project will be returned.
+
+ Returns:
+ A list of TemplateDefs.
+ """
+ template_set = self.GetTemplateSetForProject(cnxn, project_id)
+ template_ids = [row[0] for row in template_set]
+ return self.GetTemplatesById(cnxn, template_ids)
+
+ def TemplatesWithComponent(self, cnxn, component_id):
+ """Returns all templates with the specified component.
+
+ Args:
+ cnxn: connection to SQL database.
+ component_id: int component id.
+
+ Returns:
+ A list of TemplateDefs.
+ """
+ template2component_rows = self.template2component_tbl.Select(
+ cnxn, cols=['template_id'], component_id=component_id)
+ template_ids = [r[0] for r in template2component_rows]
+ return self.GetTemplatesById(cnxn, template_ids)
+
+ def CreateIssueTemplateDef(
+ self, cnxn, project_id, name, content, summary, summary_must_be_edited,
+ status, members_only, owner_defaults_to_member, component_required,
+ owner_id=None, labels=None, component_ids=None, admin_ids=None,
+ field_values=None, phases=None, approval_values=None):
+ """Create a new issue template definition with the given info.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+ name: name of the new issue template.
+ content: string content of the issue template.
+ summary: string summary of the issue template.
+ summary_must_be_edited: True if the summary must be edited when this
+ issue template is used to make a new issue.
+ status: string default status of a new issue created with this template.
+ members_only: True if only members can view this issue template.
+ owner_defaults_to_member: True is issue owner should be set to member
+ creating the issue.
+ component_required: True if a component is required.
+ owner_id: user_id of default owner, if any.
+ labels: list of string labels for the new issue, if any.
+ component_ids: list of component_ids, if any.
+ admin_ids: list of admin_ids, if any.
+ field_values: list of FieldValue PBs, if any.
+ phases: list of Phase PBs, if any.
+ approval_values: list of ApprovalValue PBs, if any.
+
+ Returns:
+ Integer template_id of the new issue template definition.
+ """
+ template_id = self.template_tbl.InsertRow(
+ cnxn, project_id=project_id, name=name, content=content,
+ summary=summary, summary_must_be_edited=summary_must_be_edited,
+ owner_id=owner_id, status=status, members_only=members_only,
+ owner_defaults_to_member=owner_defaults_to_member,
+ component_required=component_required, commit=False)
+
+ if labels:
+ self.template2label_tbl.InsertRows(
+ cnxn, TEMPLATE2LABEL_COLS, [(template_id, label) for label in labels],
+ commit=False)
+ if component_ids:
+ self.template2component_tbl.InsertRows(
+ cnxn, TEMPLATE2COMPONENT_COLS, [(template_id, c_id) for
+ c_id in component_ids], commit=False)
+ if admin_ids:
+ self.template2admin_tbl.InsertRows(
+ cnxn, TEMPLATE2ADMIN_COLS, [(template_id, admin_id) for
+ admin_id in admin_ids], commit=False)
+ if field_values:
+ self.template2fieldvalue_tbl.InsertRows(
+ cnxn, TEMPLATE2FIELDVALUE_COLS, [
+ (template_id, fv.field_id, fv.int_value, fv.str_value, fv.user_id,
+ fv.date_value, fv.url_value) for fv in field_values],
+ commit=False)
+
+ # current phase_ids in approval_values and phases are temporary and were
+ # assigned based on the order of the phases. These temporary phase_ids are
+ # used to keep track of which approvals belong to which phases and are
+ # updated once all phases have their real phase_ids returned from InsertRow.
+ phase_id_by_tmp = {}
+ if phases:
+ for phase in phases:
+ phase_id = self.issuephasedef_tbl.InsertRow(
+ cnxn, name=phase.name, rank=phase.rank, commit=False)
+ phase_id_by_tmp[phase.phase_id] = phase_id
+
+ if approval_values:
+ self.template2approvalvalue_tbl.InsertRows(
+ cnxn, TEMPLATE2APPROVALVALUE_COLS,
+ [(av.approval_id, template_id,
+ phase_id_by_tmp.get(av.phase_id), av.status.name.lower())
+ for av in approval_values],
+ commit=False)
+
+ cnxn.Commit()
+ self.template_set_2lc.InvalidateKeys(cnxn, [project_id])
+ return template_id
+
+ def UpdateIssueTemplateDef(
+ self, cnxn, project_id, template_id, name=None, content=None,
+ summary=None, summary_must_be_edited=None, status=None, members_only=None,
+ owner_defaults_to_member=None, component_required=None, owner_id=None,
+ labels=None, component_ids=None, admin_ids=None, field_values=None,
+ phases=None, approval_values=None):
+ """Update an existing issue template definition with the given info.
+
+ Args:
+ cnxn: connection to SQL database.
+ project_id: int ID of the current project.
+ template_id: int ID of the issue template to update.
+ name: updated name of the new issue template.
+ content: updated string content of the issue template.
+ summary: updated string summary of the issue template.
+ summary_must_be_edited: True if the summary must be edited when this
+ issue template is used to make a new issue.
+ status: updated string default status of a new issue created with this
+ template.
+ members_only: True if only members can view this issue template.
+ owner_defaults_to_member: True is issue owner should be set to member
+ creating the issue.
+ component_required: True if a component is required.
+ owner_id: updated user_id of default owner, if any.
+ labels: updated list of string labels for the new issue, if any.
+ component_ids: updated list of component_ids, if any.
+ admin_ids: updated list of admin_ids, if any.
+ field_values: updated list of FieldValue PBs, if any.
+ phases: updated list of Phase PBs, if any.
+ approval_values: updated list of ApprovalValue PBs, if any.
+ """
+ new_values = {}
+ if name is not None:
+ new_values['name'] = name
+ if content is not None:
+ new_values['content'] = content
+ if summary is not None:
+ new_values['summary'] = summary
+ if summary_must_be_edited is not None:
+ new_values['summary_must_be_edited'] = bool(summary_must_be_edited)
+ if status is not None:
+ new_values['status'] = status
+ if members_only is not None:
+ new_values['members_only'] = bool(members_only)
+ if owner_defaults_to_member is not None:
+ new_values['owner_defaults_to_member'] = bool(owner_defaults_to_member)
+ if component_required is not None:
+ new_values['component_required'] = bool(component_required)
+ if owner_id is not None:
+ new_values['owner_id'] = owner_id
+
+ self.template_tbl.Update(cnxn, new_values, id=template_id, commit=False)
+
+ if labels is not None:
+ self.template2label_tbl.Delete(
+ cnxn, template_id=template_id, commit=False)
+ self.template2label_tbl.InsertRows(
+ cnxn, TEMPLATE2LABEL_COLS, [(template_id, label) for label in labels],
+ commit=False)
+ if component_ids is not None:
+ self.template2component_tbl.Delete(
+ cnxn, template_id=template_id, commit=False)
+ self.template2component_tbl.InsertRows(
+ cnxn, TEMPLATE2COMPONENT_COLS, [(template_id, c_id) for
+ c_id in component_ids],
+ commit=False)
+ if admin_ids is not None:
+ self.template2admin_tbl.Delete(
+ cnxn, template_id=template_id, commit=False)
+ self.template2admin_tbl.InsertRows(
+ cnxn, TEMPLATE2ADMIN_COLS, [(template_id, admin_id) for
+ admin_id in admin_ids],
+ commit=False)
+ if field_values is not None:
+ self.template2fieldvalue_tbl.Delete(
+ cnxn, template_id=template_id, commit=False)
+ self.template2fieldvalue_tbl.InsertRows(
+ cnxn, TEMPLATE2FIELDVALUE_COLS, [
+ (template_id, fv.field_id, fv.int_value, fv.str_value, fv.user_id,
+ fv.date_value, fv.url_value) for fv in field_values],
+ commit=False)
+
+ # we need to keep track of tmp phase_ids created at the servlet.
+ phase_id_by_tmp = {}
+ if phases is not None:
+ self.template2approvalvalue_tbl.Delete(
+ cnxn, template_id=template_id, commit=False)
+ for phase in phases:
+ phase_id = self.issuephasedef_tbl.InsertRow(
+ cnxn, name=phase.name, rank=phase.rank, commit=False)
+ phase_id_by_tmp[phase.phase_id] = phase_id
+
+ self.template2approvalvalue_tbl.InsertRows(
+ cnxn, TEMPLATE2APPROVALVALUE_COLS,
+ [(av.approval_id, template_id,
+ phase_id_by_tmp.get(av.phase_id), av.status.name.lower())
+ for av in approval_values], commit=False)
+
+ cnxn.Commit()
+ self.template_set_2lc.InvalidateKeys(cnxn, [project_id])
+ self.template_def_2lc.InvalidateKeys(cnxn, [template_id])
+
+ def DeleteIssueTemplateDef(self, cnxn, project_id, template_id):
+ """Delete the specified issue template definition."""
+ self.template2label_tbl.Delete(cnxn, template_id=template_id, commit=False)
+ self.template2component_tbl.Delete(
+ cnxn, template_id=template_id, commit=False)
+ self.template2admin_tbl.Delete(cnxn, template_id=template_id, commit=False)
+ self.template2fieldvalue_tbl.Delete(
+ cnxn, template_id=template_id, commit=False)
+ self.template2approvalvalue_tbl.Delete(
+ cnxn, template_id=template_id, commit=False)
+ # We do not delete issuephasedef rows becuase these rows will be used by
+ # issues that were created with this template. template2approvalvalue rows
+ # can be deleted because those rows are copied over to issue2approvalvalue
+ # during issue creation.
+ self.template_tbl.Delete(cnxn, id=template_id, commit=False)
+
+ cnxn.Commit()
+ self.template_set_2lc.InvalidateKeys(cnxn, [project_id])
+ self.template_def_2lc.InvalidateKeys(cnxn, [template_id])
+
+ def ExpungeProjectTemplates(self, cnxn, project_id):
+ template_id_rows = self.template_tbl.Select(
+ cnxn, cols=['id'], project_id=project_id)
+ template_ids = [row[0] for row in template_id_rows]
+ self.template2label_tbl.Delete(cnxn, template_id=template_ids)
+ self.template2component_tbl.Delete(cnxn, template_id=template_ids)
+ # TODO(3816): Delete all other relations here.
+ self.template_tbl.Delete(cnxn, project_id=project_id)
+
+ def ExpungeUsersInTemplates(self, cnxn, user_ids, limit=None):
+ """Wipes a user from the templates system.
+
+ This method will not commit the operation. This method will
+ not make changes to in-memory data.
+ """
+ self.template2admin_tbl.Delete(
+ cnxn, admin_id=user_ids, commit=False, limit=limit)
+ self.template2fieldvalue_tbl.Delete(
+ cnxn, user_id=user_ids, commit=False, limit=limit)
+ # template_tbl's owner_id does not reference User. All appropriate rows
+ # should be deleted before rows can be safely deleted from User. No limit
+ # will be applied.
+ self.template_tbl.Update(
+ cnxn, {'owner_id': None}, owner_id=user_ids, commit=False)
+
+
+def UnpackTemplate(template_row):
+ """Partially construct a template object using info from a DB row."""
+ (template_id, _project_id, name, content, summary,
+ summary_must_be_edited, owner_id, status,
+ members_only, owner_defaults_to_member, component_required) = template_row
+ template = tracker_pb2.TemplateDef()
+ template.template_id = template_id
+ template.name = name
+ template.content = content
+ template.summary = summary
+ template.summary_must_be_edited = bool(
+ summary_must_be_edited)
+ template.owner_id = owner_id or 0
+ template.status = status
+ template.members_only = bool(members_only)
+ template.owner_defaults_to_member = bool(owner_defaults_to_member)
+ template.component_required = bool(component_required)
+
+ return template
diff --git a/services/test/__init__.py b/services/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/test/__init__.py
diff --git a/services/test/api_pb2_v1_helpers_test.py b/services/test/api_pb2_v1_helpers_test.py
new file mode 100644
index 0000000..460f5c3
--- /dev/null
+++ b/services/test/api_pb2_v1_helpers_test.py
@@ -0,0 +1,786 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the API v1 helpers."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import datetime
+import mock
+import unittest
+
+from framework import framework_constants
+from framework import permissions
+from framework import profiler
+from services import api_pb2_v1_helpers
+from services import service_manager
+from proto import api_pb2_v1
+from proto import project_pb2
+from proto import tracker_pb2
+from proto import usergroup_pb2
+from testing import fake
+from tracker import tracker_bizobj
+
+
+def MakeTemplate(prefix):
+ return tracker_pb2.TemplateDef(
+ name='%s-template' % prefix,
+ content='%s-content' % prefix,
+ summary='%s-summary' % prefix,
+ summary_must_be_edited=True,
+ status='New',
+ labels=['%s-label1' % prefix, '%s-label2' % prefix],
+ members_only=True,
+ owner_defaults_to_member=True,
+ component_required=True,
+ )
+
+
+def MakeLabel(prefix):
+ return tracker_pb2.LabelDef(
+ label='%s-label' % prefix,
+ label_docstring='%s-description' % prefix
+ )
+
+
+def MakeStatus(prefix):
+ return tracker_pb2.StatusDef(
+ status='%s-New' % prefix,
+ means_open=True,
+ status_docstring='%s-status' % prefix
+ )
+
+
+def MakeProjectIssueConfig(prefix):
+ return tracker_pb2.ProjectIssueConfig(
+ restrict_to_known=True,
+ default_col_spec='ID Type Priority Summary',
+ default_sort_spec='ID Priority',
+ well_known_statuses=[
+ MakeStatus('%s-status1' % prefix),
+ MakeStatus('%s-status2' % prefix),
+ ],
+ well_known_labels=[
+ MakeLabel('%s-label1' % prefix),
+ MakeLabel('%s-label2' % prefix),
+ ],
+ default_template_for_developers=1,
+ default_template_for_users=2
+ )
+
+
+def MakeProject(prefix):
+ return project_pb2.MakeProject(
+ project_name='%s-project' % prefix,
+ summary='%s-summary' % prefix,
+ description='%s-description' % prefix,
+ )
+
+
+class ApiV1HelpersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ issue=fake.IssueService(),
+ project=fake.ProjectService(),
+ config=fake.ConfigService(),
+ issue_star=fake.IssueStarService())
+ self.services.user.TestAddUser('user@example.com', 111)
+ self.person_1 = api_pb2_v1_helpers.convert_person(111, None, self.services)
+
+ def testConvertTemplate(self):
+ """Test convert_template."""
+ template = MakeTemplate('test')
+ prompt = api_pb2_v1_helpers.convert_template(template)
+ self.assertEqual(template.name, prompt.name)
+ self.assertEqual(template.summary, prompt.title)
+ self.assertEqual(template.content, prompt.description)
+ self.assertEqual(template.summary_must_be_edited, prompt.titleMustBeEdited)
+ self.assertEqual(template.status, prompt.status)
+ self.assertEqual(template.labels, prompt.labels)
+ self.assertEqual(template.members_only, prompt.membersOnly)
+ self.assertEqual(template.owner_defaults_to_member, prompt.defaultToMember)
+ self.assertEqual(template.component_required, prompt.componentRequired)
+
+ def testConvertLabel(self):
+ """Test convert_label."""
+ labeldef = MakeLabel('test')
+ label = api_pb2_v1_helpers.convert_label(labeldef)
+ self.assertEqual(labeldef.label, label.label)
+ self.assertEqual(labeldef.label_docstring, label.description)
+
+ def testConvertStatus(self):
+ """Test convert_status."""
+ statusdef = MakeStatus('test')
+ status = api_pb2_v1_helpers.convert_status(statusdef)
+ self.assertEqual(statusdef.status, status.status)
+ self.assertEqual(statusdef.means_open, status.meansOpen)
+ self.assertEqual(statusdef.status_docstring, status.description)
+
+ def testConvertProjectIssueConfig(self):
+ """Test convert_project_config."""
+ prefix = 'test'
+ config = MakeProjectIssueConfig(prefix)
+ templates = [
+ MakeTemplate('%s-template1' % prefix),
+ MakeTemplate('%s-template2' % prefix),
+ ]
+ config_api = api_pb2_v1_helpers.convert_project_config(config, templates)
+ self.assertEqual(config.restrict_to_known, config_api.restrictToKnown)
+ self.assertEqual(config.default_col_spec.split(), config_api.defaultColumns)
+ self.assertEqual(
+ config.default_sort_spec.split(), config_api.defaultSorting)
+ self.assertEqual(2, len(config_api.statuses))
+ self.assertEqual(2, len(config_api.labels))
+ self.assertEqual(2, len(config_api.prompts))
+ self.assertEqual(
+ config.default_template_for_developers,
+ config_api.defaultPromptForMembers)
+ self.assertEqual(
+ config.default_template_for_users,
+ config_api.defaultPromptForNonMembers)
+
+ def testConvertProject(self):
+ """Test convert_project."""
+ project = MakeProject('testprj')
+ prefix = 'testconfig'
+ config = MakeProjectIssueConfig(prefix)
+ role = api_pb2_v1.Role.owner
+ templates = [
+ MakeTemplate('%s-template1' % prefix),
+ MakeTemplate('%s-template2' % prefix),
+ ]
+ project_api = api_pb2_v1_helpers.convert_project(project, config, role,
+ templates)
+ self.assertEqual(project.project_name, project_api.name)
+ self.assertEqual(project.project_name, project_api.externalId)
+ self.assertEqual('/p/%s/' % project.project_name, project_api.htmlLink)
+ self.assertEqual(project.summary, project_api.summary)
+ self.assertEqual(project.description, project_api.description)
+ self.assertEqual(role, project_api.role)
+ self.assertIsInstance(
+ project_api.issuesConfig, api_pb2_v1.ProjectIssueConfig)
+
+ def testConvertPerson(self):
+ """Test convert_person."""
+ result = api_pb2_v1_helpers.convert_person(111, None, self.services)
+ self.assertIsInstance(result, api_pb2_v1.AtomPerson)
+ self.assertEqual('user@example.com', result.name)
+
+ none_user = api_pb2_v1_helpers.convert_person(None, '', self.services)
+ self.assertIsNone(none_user)
+
+ deleted_user = api_pb2_v1_helpers.convert_person(
+ framework_constants.DELETED_USER_ID, '', self.services)
+ self.assertEqual(
+ deleted_user,
+ api_pb2_v1.AtomPerson(
+ kind='monorail#issuePerson',
+ name=framework_constants.DELETED_USER_NAME))
+
+ def testConvertIssueIDs(self):
+ """Test convert_issue_ids."""
+ issue1 = fake.MakeTestIssue(789, 1, 'one', 'New', 111)
+ self.services.issue.TestAddIssue(issue1)
+ issue_ids = [100001]
+ mar = mock.Mock()
+ mar.cnxn = None
+ mar.project_name = 'test-project'
+ result = api_pb2_v1_helpers.convert_issue_ids(issue_ids, mar, self.services)
+ self.assertEqual(1, len(result))
+ self.assertEqual(1, result[0].issueId)
+
+ def testConvertIssueRef(self):
+ """Test convert_issueref_pbs."""
+ issue1 = fake.MakeTestIssue(12345, 1, 'one', 'New', 111)
+ self.services.issue.TestAddIssue(issue1)
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[2],
+ project_id=12345)
+ mar = mock.Mock()
+ mar.cnxn = None
+ mar.project_name = 'test-project'
+ mar.project_id = 12345
+ ir = api_pb2_v1.IssueRef(
+ issueId=1,
+ projectId='test-project'
+ )
+ result = api_pb2_v1_helpers.convert_issueref_pbs([ir], mar, self.services)
+ self.assertEqual(1, len(result))
+ self.assertEqual(100001, result[0])
+
+ def testConvertIssue(self):
+ """Convert an internal Issue PB to an IssueWrapper API PB."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[2], project_id=12345)
+ self.services.user.TestAddUser('user@example.com', 111)
+
+ mar = mock.Mock()
+ mar.cnxn = None
+ mar.project_name = 'test-project'
+ mar.project_id = 12345
+ mar.auth.effective_ids = {111}
+ mar.perms = permissions.READ_ONLY_PERMISSIONSET
+ mar.profiler = profiler.Profiler()
+ mar.config = tracker_bizobj.MakeDefaultProjectIssueConfig(12345)
+ mar.config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 1, 12345, 'EstDays', tracker_pb2.FieldTypes.INT_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False, approval_id=2),
+ tracker_bizobj.MakeFieldDef(
+ 2, 12345, 'DesignReview', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 3, 12345, 'StringField', tracker_pb2.FieldTypes.STR_TYPE, None,
+ None, False, False, False, None, None, None, False, None, None,
+ None, None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 12345, 'DressReview', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, None, False, False, False, None, None, None, False, None,
+ None, None, None, 'doc', False),
+ ]
+ self.services.config.StoreConfig(mar.cnxn, mar.config)
+
+ now = 1472067725
+ now_dt = datetime.datetime.fromtimestamp(now)
+
+ fvs = [
+ tracker_bizobj.MakeFieldValue(
+ 1, 4, None, None, None, None, False, phase_id=4),
+ tracker_bizobj.MakeFieldValue(
+ 3, None, 'string', None, None, None, False, phase_id=4),
+ # missing phase
+ tracker_bizobj.MakeFieldValue(
+ 3, None, u'\xe2\x9d\xa4\xef\xb8\x8f', None, None, None, False,
+ phase_id=2),
+ ]
+ phases = [
+ tracker_pb2.Phase(phase_id=3, name="JustAPhase", rank=4),
+ tracker_pb2.Phase(phase_id=4, name="NotAPhase", rank=9)
+ ]
+ approval_values = [
+ tracker_pb2.ApprovalValue(
+ approval_id=2, phase_id=3, approver_ids=[111]),
+ tracker_pb2.ApprovalValue(approval_id=4, approver_ids=[111])
+ ]
+ issue = fake.MakeTestIssue(
+ 12345, 1, 'one', 'New', 111, field_values=fvs,
+ approval_values=approval_values, phases=phases)
+ issue.opened_timestamp = now
+ issue.owner_modified_timestamp = now
+ issue.status_modified_timestamp = now
+ issue.component_modified_timestamp = now
+ # TODO(jrobbins): set up a lot more fields.
+
+ for cls in [api_pb2_v1.IssueWrapper, api_pb2_v1.IssuesGetInsertResponse]:
+ result = api_pb2_v1_helpers.convert_issue(cls, issue, mar, self.services)
+ self.assertEqual(1, result.id)
+ self.assertEqual('one', result.title)
+ self.assertEqual('one', result.summary)
+ self.assertEqual(now_dt, result.published)
+ self.assertEqual(now_dt, result.owner_modified)
+ self.assertEqual(now_dt, result.status_modified)
+ self.assertEqual(now_dt, result.component_modified)
+ self.assertEqual(
+ result.fieldValues, [
+ api_pb2_v1.FieldValue(
+ fieldName='EstDays',
+ fieldValue='4',
+ approvalName='DesignReview',
+ derived=False),
+ api_pb2_v1.FieldValue(
+ fieldName='StringField',
+ fieldValue='string',
+ phaseName="NotAPhase",
+ derived=False),
+ api_pb2_v1.FieldValue(
+ fieldName='StringField',
+ fieldValue=u'\xe2\x9d\xa4\xef\xb8\x8f',
+ derived=False),
+ ])
+ self.assertEqual(
+ result.approvalValues,
+ [api_pb2_v1.Approval(
+ approvalName="DesignReview",
+ approvers=[self.person_1],
+ status=api_pb2_v1.ApprovalStatus.notSet,
+ phaseName="JustAPhase",
+ ),
+ api_pb2_v1.Approval(
+ approvalName="DressReview",
+ approvers=[self.person_1],
+ status=api_pb2_v1.ApprovalStatus.notSet,
+ )]
+ )
+ self.assertEqual(
+ result.phases,
+ [api_pb2_v1.Phase(phaseName="JustAPhase", rank=4),
+ api_pb2_v1.Phase(phaseName="NotAPhase", rank=9)
+ ])
+
+ # TODO(jrobbins): check a lot more fields.
+
+ def testConvertAttachment(self):
+ """Test convert_attachment."""
+
+ attachment = tracker_pb2.Attachment(
+ attachment_id=1,
+ filename='stats.txt',
+ filesize=12345,
+ mimetype='text/plain',
+ deleted=False)
+
+ result = api_pb2_v1_helpers.convert_attachment(attachment)
+ self.assertEqual(attachment.attachment_id, result.attachmentId)
+ self.assertEqual(attachment.filename, result.fileName)
+ self.assertEqual(attachment.filesize, result.fileSize)
+ self.assertEqual(attachment.mimetype, result.mimetype)
+ self.assertEqual(attachment.deleted, result.isDeleted)
+
+ def testConvertAmendments(self):
+ """Test convert_amendments."""
+ self.services.user.TestAddUser('user2@example.com', 222)
+ mar = mock.Mock()
+ mar.cnxn = None
+ issue = mock.Mock()
+ issue.project_name = 'test-project'
+
+ amendment_summary = tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID.SUMMARY,
+ newvalue='new summary')
+ amendment_status = tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID.STATUS,
+ newvalue='new status')
+ amendment_owner = tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID.OWNER,
+ added_user_ids=[111])
+ amendment_labels = tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID.LABELS,
+ newvalue='label1 -label2')
+ amendment_cc_add = tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID.CC,
+ added_user_ids=[111])
+ amendment_cc_remove = tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID.CC,
+ removed_user_ids=[222])
+ amendment_blockedon = tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID.BLOCKEDON,
+ newvalue='1')
+ amendment_blocking = tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID.BLOCKING,
+ newvalue='other:2 -3')
+ amendment_mergedinto = tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID.MERGEDINTO,
+ newvalue='4')
+ amendments = [
+ amendment_summary, amendment_status, amendment_owner,
+ amendment_labels, amendment_cc_add, amendment_cc_remove,
+ amendment_blockedon, amendment_blocking, amendment_mergedinto]
+
+ result = api_pb2_v1_helpers.convert_amendments(
+ issue, amendments, mar, self.services)
+ self.assertEqual(amendment_summary.newvalue, result.summary)
+ self.assertEqual(amendment_status.newvalue, result.status)
+ self.assertEqual('user@example.com', result.owner)
+ self.assertEqual(['label1', '-label2'], result.labels)
+ self.assertEqual(['user@example.com', '-user2@example.com'], result.cc)
+ self.assertEqual(['test-project:1'], result.blockedOn)
+ self.assertEqual(['other:2', '-test-project:3'], result.blocking)
+ self.assertEqual(amendment_mergedinto.newvalue, result.mergedInto)
+
+ def testConvertApprovalAmendments(self):
+ """Test convert_approval_comment."""
+ self.services.user.TestAddUser('user1@example.com', 111)
+ self.services.user.TestAddUser('user2@example.com', 222)
+ self.services.user.TestAddUser('user3@example.com', 333)
+ mar = mock.Mock()
+ mar.cnxn = None
+ amendment_status = tracker_bizobj.MakeApprovalStatusAmendment(
+ tracker_pb2.ApprovalStatus.APPROVED)
+ amendment_approvers = tracker_bizobj.MakeApprovalApproversAmendment(
+ [111, 222], [333])
+ amendments = [amendment_status, amendment_approvers]
+ result = api_pb2_v1_helpers.convert_approval_amendments(
+ amendments, mar, self.services)
+ self.assertEqual(amendment_status.newvalue, result.status)
+ self.assertEqual(
+ ['user1@example.com', 'user2@example.com', '-user3@example.com'],
+ result.approvers)
+
+ def testConvertComment(self):
+ """Test convert_comment."""
+ mar = mock.Mock()
+ mar.cnxn = None
+ mar.perms = permissions.PermissionSet([])
+ issue = fake.MakeTestIssue(project_id=12345, local_id=1, summary='sum',
+ status='New', owner_id=1001)
+
+ comment = tracker_pb2.IssueComment(
+ user_id=111,
+ content='test content',
+ sequence=1,
+ deleted_by=111,
+ timestamp=1437700000,
+ )
+ result = api_pb2_v1_helpers.convert_comment(
+ issue, comment, mar, self.services, None)
+ self.assertEqual('user@example.com', result.author.name)
+ self.assertEqual(comment.content, result.content)
+ self.assertEqual('user@example.com', result.deletedBy.name)
+ self.assertEqual(1, result.id)
+ # Ensure that the published timestamp falls in a timestamp range to account
+ # for the test being run in different timezones.
+ # Using "Fri, 23 Jul 2015 00:00:00" and "Fri, 25 Jul 2015 00:00:00".
+ self.assertTrue(
+ datetime.datetime(2015, 7, 23, 0, 0, 0) <= result.published <=
+ datetime.datetime(2015, 7, 25, 0, 0, 0))
+ self.assertEqual(result.kind, 'monorail#issueComment')
+
+ def testConvertApprovalComment(self):
+ """Test convert_approval_comment."""
+ mar = mock.Mock()
+ mar.cnxn = None
+ mar.perms = permissions.PermissionSet([])
+ issue = fake.MakeTestIssue(project_id=12345, local_id=1, summary='sum',
+ status='New', owner_id=1001)
+ comment = tracker_pb2.IssueComment(
+ user_id=111,
+ content='test content',
+ sequence=1,
+ deleted_by=111,
+ timestamp=1437700000,
+ )
+ result = api_pb2_v1_helpers.convert_approval_comment(
+ issue, comment, mar, self.services, None)
+ self.assertEqual('user@example.com', result.author.name)
+ self.assertEqual(comment.content, result.content)
+ self.assertEqual('user@example.com', result.deletedBy.name)
+ self.assertEqual(1, result.id)
+ # Ensure that the published timestamp falls in a timestamp range to account
+ # for the test being run in different timezones.
+ # Using "Fri, 23 Jul 2015 00:00:00" and "Fri, 25 Jul 2015 00:00:00".
+ self.assertTrue(
+ datetime.datetime(2015, 7, 23, 0, 0, 0) <= result.published <=
+ datetime.datetime(2015, 7, 25, 0, 0, 0))
+ self.assertEqual(result.kind, 'monorail#approvalComment')
+
+
+ def testGetUserEmail(self):
+ email = api_pb2_v1_helpers._get_user_email(self.services.user, '', 111)
+ self.assertEqual('user@example.com', email)
+
+ no_user_found = api_pb2_v1_helpers._get_user_email(
+ self.services.user, '', 222)
+ self.assertEqual(framework_constants.USER_NOT_FOUND_NAME, no_user_found)
+
+ deleted = api_pb2_v1_helpers._get_user_email(
+ self.services.user, '', framework_constants.DELETED_USER_ID)
+ self.assertEqual(framework_constants.DELETED_USER_NAME, deleted)
+
+ none_user_id = api_pb2_v1_helpers._get_user_email(
+ self.services.user, '', None)
+ self.assertEqual(framework_constants.NO_USER_NAME, none_user_id)
+
+ def testSplitRemoveAdd(self):
+ """Test split_remove_add."""
+
+ items = ['1', '-2', '-3', '4']
+ list_to_add, list_to_remove = api_pb2_v1_helpers.split_remove_add(items)
+
+ self.assertEqual(['1', '4'], list_to_add)
+ self.assertEqual(['2', '3'], list_to_remove)
+
+ def testIssueGlobalIDs(self):
+ """Test issue_global_ids."""
+ issue1 = fake.MakeTestIssue(12345, 1, 'one', 'New', 111)
+ self.services.issue.TestAddIssue(issue1)
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[2],
+ project_id=12345)
+ mar = mock.Mock()
+ mar.cnxn = None
+ mar.project_name = 'test-project'
+ mar.project_id = 12345
+ pairs = ['test-project:1']
+ result = api_pb2_v1_helpers.issue_global_ids(
+ pairs, 12345, mar, self.services)
+ self.assertEqual(100001, result[0])
+
+ def testConvertGroupSettings(self):
+ """Test convert_group_settings."""
+
+ setting = usergroup_pb2.MakeSettings('owners', 'mdb', 0)
+ result = api_pb2_v1_helpers.convert_group_settings('test-group', setting)
+ self.assertEqual('test-group', result.groupName)
+ self.assertEqual(setting.who_can_view_members, result.who_can_view_members)
+ self.assertEqual(setting.ext_group_type, result.ext_group_type)
+ self.assertEqual(setting.last_sync_time, result.last_sync_time)
+
+ def testConvertComponentDef(self):
+ pass # TODO(jrobbins): Fill in this test.
+
+ def testConvertComponentIDs(self):
+ pass # TODO(jrobbins): Fill in this test.
+
+ def testConvertFieldValues_Empty(self):
+ """The client's request might not have any field edits."""
+ mar = mock.Mock()
+ mar.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ field_values = []
+ actual = api_pb2_v1_helpers.convert_field_values(
+ field_values, mar, self.services)
+ (fv_list_add, fv_list_remove, fv_list_clear,
+ label_list_add, label_list_remove) = actual
+ self.assertEqual([], fv_list_add)
+ self.assertEqual([], fv_list_remove)
+ self.assertEqual([], fv_list_clear)
+ self.assertEqual([], label_list_add)
+ self.assertEqual([], label_list_remove)
+
+ def testConvertFieldValues_Normal(self):
+ """The client wants to edit a custom field."""
+ mar = mock.Mock()
+ mar.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ mar.config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 1, 789, 'Priority', tracker_pb2.FieldTypes.ENUM_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 2, 789, 'EstDays', tracker_pb2.FieldTypes.INT_TYPE, None, None,
+ False, False, False, 0, 99, None, False, None, None, None,
+ None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'Nickname', tracker_pb2.FieldTypes.STR_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 4, 789, 'Verifier', tracker_pb2.FieldTypes.USER_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 5, 789, 'Deadline', tracker_pb2.FieldTypes.DATE_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 6, 789, 'Homepage', tracker_pb2.FieldTypes.URL_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False),
+ ]
+ field_values = [
+ api_pb2_v1.FieldValue(fieldName='Priority', fieldValue='High'),
+ api_pb2_v1.FieldValue(fieldName='EstDays', fieldValue='4'),
+ api_pb2_v1.FieldValue(fieldName='Nickname', fieldValue='Scout'),
+ api_pb2_v1.FieldValue(
+ fieldName='Verifier', fieldValue='user@example.com'),
+ api_pb2_v1.FieldValue(fieldName='Deadline', fieldValue='2017-12-06'),
+ api_pb2_v1.FieldValue(
+ fieldName='Homepage', fieldValue='http://example.com'),
+ ]
+ actual = api_pb2_v1_helpers.convert_field_values(
+ field_values, mar, self.services)
+ (fv_list_add, fv_list_remove, fv_list_clear,
+ label_list_add, label_list_remove) = actual
+ self.assertEqual(
+ [
+ tracker_bizobj.MakeFieldValue(2, 4, None, None, None, None, False),
+ tracker_bizobj.MakeFieldValue(
+ 3, None, 'Scout', None, None, None, False),
+ tracker_bizobj.MakeFieldValue(
+ 4, None, None, 111, None, None, False),
+ tracker_bizobj.MakeFieldValue(
+ 5, None, None, None, 1512518400, None, False),
+ tracker_bizobj.MakeFieldValue(
+ 6, None, None, None, None, 'http://example.com', False),
+ ], fv_list_add)
+ self.assertEqual([], fv_list_remove)
+ self.assertEqual([], fv_list_clear)
+ self.assertEqual(['Priority-High'], label_list_add)
+ self.assertEqual([], label_list_remove)
+
+ def testConvertFieldValues_ClearAndRemove(self):
+ """The client wants to clear and remove some custom fields."""
+ mar = mock.Mock()
+ mar.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ mar.config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 1, 789, 'Priority', tracker_pb2.FieldTypes.ENUM_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 11, 789, 'OS', tracker_pb2.FieldTypes.ENUM_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 2, 789, 'EstDays', tracker_pb2.FieldTypes.INT_TYPE, None, None,
+ False, False, False, 0, 99, None, False, None, None, None,
+ None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'Nickname', tracker_pb2.FieldTypes.STR_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False),
+ ]
+ field_values = [
+ api_pb2_v1.FieldValue(
+ fieldName='Priority', fieldValue='High',
+ operator=api_pb2_v1.FieldValueOperator.remove),
+ api_pb2_v1.FieldValue(
+ fieldName='OS', operator=api_pb2_v1.FieldValueOperator.clear),
+ api_pb2_v1.FieldValue(
+ fieldName='EstDays', operator=api_pb2_v1.FieldValueOperator.clear),
+ api_pb2_v1.FieldValue(
+ fieldName='Nickname', fieldValue='Scout',
+ operator=api_pb2_v1.FieldValueOperator.remove),
+ ]
+ actual = api_pb2_v1_helpers.convert_field_values(
+ field_values, mar, self.services)
+ (fv_list_add, fv_list_remove, fv_list_clear,
+ label_list_add, label_list_remove) = actual
+ self.assertEqual([], fv_list_add)
+ self.assertEqual(
+ [
+ tracker_bizobj.MakeFieldValue(
+ 3, None, 'Scout', None, None, None, False)
+ ], fv_list_remove)
+ self.assertEqual([11, 2], fv_list_clear)
+ self.assertEqual([], label_list_add)
+ self.assertEqual(['Priority-High'], label_list_remove)
+
+ def testConvertFieldValues_Errors(self):
+ """We don't crash on bad requests."""
+ mar = mock.Mock()
+ mar.config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ mar.config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 2, 789, 'EstDays', tracker_pb2.FieldTypes.INT_TYPE, None, None,
+ False, False, False, 0, 99, None, False, None, None, None,
+ None, 'doc', False),
+ ]
+ field_values = [
+ api_pb2_v1.FieldValue(
+ fieldName='Unknown', operator=api_pb2_v1.FieldValueOperator.clear),
+ ]
+ actual = api_pb2_v1_helpers.convert_field_values(
+ field_values, mar, self.services)
+ (fv_list_add, fv_list_remove, fv_list_clear,
+ label_list_add, label_list_remove) = actual
+ self.assertEqual([], fv_list_add)
+ self.assertEqual([], fv_list_remove)
+ self.assertEqual([], fv_list_clear)
+ self.assertEqual([], label_list_add)
+ self.assertEqual([], label_list_remove)
+
+ def testConvertApprovals(self):
+ """Test we can convert ApprovalValues."""
+ cnxn = None
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 1, 789, 'DesignReview', tracker_pb2.FieldTypes.APPROVAL_TYPE, None,
+ None, False, False, False, None, None, None, False, None, None,
+ None, None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 2, 789, 'PrivacyReview', tracker_pb2.FieldTypes.APPROVAL_TYPE, None,
+ None, False, False, False, 0, 99, None, False, None, None, None,
+ None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 5, 789, 'UXReview', tracker_pb2.FieldTypes.APPROVAL_TYPE, None,
+ None, False, False, False, None, None, None, False, None, None,
+ None, None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 6, 789, 'Homepage', tracker_pb2.FieldTypes.URL_TYPE, None, None,
+ False, False, False, None, None, None, False, None, None, None,
+ None, 'doc', False),
+ ]
+ phases = [
+ tracker_pb2.Phase(phase_id=1),
+ tracker_pb2.Phase(phase_id=2, name="JustAPhase", rank=3),
+ ]
+ ts = 1536260059
+ expected = [
+ api_pb2_v1.Approval(
+ approvalName="DesignReview",
+ approvers=[self.person_1],
+ setter=self.person_1,
+ status=api_pb2_v1.ApprovalStatus.needsReview,
+ setOn=datetime.datetime.fromtimestamp(ts),
+ ),
+ api_pb2_v1.Approval(
+ approvalName="UXReview",
+ approvers=[self.person_1],
+ status=api_pb2_v1.ApprovalStatus.notSet,
+ phaseName="JustAPhase",
+ ),
+ ]
+ avs = [
+ tracker_pb2.ApprovalValue(
+ approval_id=1, approver_ids=[111], setter_id=111,
+ status=tracker_pb2.ApprovalStatus.NEEDS_REVIEW, set_on=ts),
+ tracker_pb2.ApprovalValue(
+ approval_id=5, approver_ids=[111], phase_id=2)
+ ]
+ actual = api_pb2_v1_helpers.convert_approvals(
+ cnxn, avs, self.services, config, phases)
+
+ self.assertEqual(actual, expected)
+
+ def testConvertApprovals_errors(self):
+ """we dont crash on bad requests."""
+ cnxn = None
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.field_defs = [
+ tracker_bizobj.MakeFieldDef(
+ 1, 789, 'DesignReview', tracker_pb2.FieldTypes.APPROVAL_TYPE, None,
+ None, False, False, False, None, None, None, False, None, None,
+ None, None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 5, 789, 'UXReview', tracker_pb2.FieldTypes.APPROVAL_TYPE, None,
+ None, False, False, False, None, None, None, False, None, None,
+ None, None, 'doc', False),
+ tracker_bizobj.MakeFieldDef(
+ 3, 789, 'DesignDoc', tracker_pb2.FieldTypes.URL_TYPE, None, None,
+ False, False, False, 0, 99, None, False, None, None, None,
+ None, 'doc', False),
+ ]
+ phases = []
+ avs = [
+ tracker_pb2.ApprovalValue(approval_id=1, approver_ids=[111]),
+ # phase does not exist
+ tracker_pb2.ApprovalValue(approval_id=2, phase_id=2),
+ tracker_pb2.ApprovalValue(approval_id=3), # field 3 is not an approval
+ tracker_pb2.ApprovalValue(approval_id=4), # field 4 does not exist
+ ]
+ expected = [
+ api_pb2_v1.Approval(
+ approvalName="DesignReview",
+ approvers=[self.person_1],
+ status=api_pb2_v1.ApprovalStatus.notSet)
+ ]
+
+ actual = api_pb2_v1_helpers.convert_approvals(
+ cnxn, avs, self.services, config, phases)
+ self.assertEqual(actual, expected)
+
+ def testConvertPhases(self):
+ """We can convert Phases."""
+ phases = [
+ tracker_pb2.Phase(name="JustAPhase", rank=1),
+ tracker_pb2.Phase(name="Can'tPhaseMe", rank=4),
+ tracker_pb2.Phase(phase_id=11, rank=5),
+ tracker_pb2.Phase(rank=3),
+ tracker_pb2.Phase(name="Phase"),
+ ]
+ expected = [
+ api_pb2_v1.Phase(phaseName="JustAPhase", rank=1),
+ api_pb2_v1.Phase(phaseName="Can'tPhaseMe", rank=4),
+ api_pb2_v1.Phase(phaseName="Phase"),
+ ]
+ actual = api_pb2_v1_helpers.convert_phases(phases)
+ self.assertEqual(actual, expected)
diff --git a/services/test/api_svc_v1_test.py b/services/test/api_svc_v1_test.py
new file mode 100644
index 0000000..b7cd9b1
--- /dev/null
+++ b/services/test/api_svc_v1_test.py
@@ -0,0 +1,1898 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the API v1."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import datetime
+import endpoints
+import logging
+from mock import Mock, patch, ANY
+import time
+import unittest
+import webtest
+
+from google.appengine.api import oauth
+from protorpc import messages
+from protorpc import message_types
+
+from features import send_notifications
+from framework import authdata
+from framework import exceptions
+from framework import framework_constants
+from framework import permissions
+from framework import profiler
+from framework import template_helpers
+from proto import api_pb2_v1
+from proto import project_pb2
+from proto import tracker_pb2
+from search import frontendsearchpipeline
+from services import api_svc_v1
+from services import service_manager
+from services import template_svc
+from services import tracker_fulltext
+from testing import fake
+from testing import testing_helpers
+from testing_utils import testing
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+
+def MakeFakeServiceManager():
+ return service_manager.Services(
+ user=fake.UserService(),
+ usergroup=fake.UserGroupService(),
+ project=fake.ProjectService(),
+ config=fake.ConfigService(),
+ issue=fake.IssueService(),
+ issue_star=fake.IssueStarService(),
+ features=fake.FeaturesService(),
+ template=Mock(spec=template_svc.TemplateService),
+ cache_manager=fake.CacheManager())
+
+
+class FakeMonorailApiRequest(object):
+
+ def __init__(self, request, services, perms=None):
+ self.profiler = profiler.Profiler()
+ self.cnxn = None
+ self.auth = authdata.AuthData.FromEmail(
+ self.cnxn, request['requester'], services)
+ self.me_user_id = self.auth.user_id
+ self.project_name = None
+ self.project = None
+ self.viewed_username = None
+ self.viewed_user_auth = None
+ self.config = None
+ if 'userId' in request:
+ self.viewed_username = request['userId']
+ self.viewed_user_auth = authdata.AuthData.FromEmail(
+ self.cnxn, self.viewed_username, services)
+ else:
+ assert 'groupName' in request
+ self.viewed_username = request['groupName']
+ try:
+ self.viewed_user_auth = authdata.AuthData.FromEmail(
+ self.cnxn, self.viewed_username, services)
+ except exceptions.NoSuchUserException:
+ self.viewed_user_auth = None
+ if 'projectId' in request:
+ self.project_name = request['projectId']
+ self.project = services.project.GetProjectByName(
+ self.cnxn, self.project_name)
+ self.config = services.config.GetProjectConfig(
+ self.cnxn, self.project_id)
+ self.perms = perms or permissions.GetPermissions(
+ self.auth.user_pb, self.auth.effective_ids, self.project)
+ self.granted_perms = set()
+
+ self.params = {
+ 'can': request.get('can', 1),
+ 'start': request.get('startIndex', 0),
+ 'num': request.get('maxResults', 100),
+ 'q': request.get('q', ''),
+ 'sort': request.get('sort', ''),
+ 'groupby': '',
+ 'projects': request.get('additionalProject', []) + [self.project_name]}
+ self.use_cached_searches = True
+ self.errors = template_helpers.EZTError()
+ self.mode = None
+
+ self.query_project_names = self.GetParam('projects')
+ self.group_by_spec = self.GetParam('groupby')
+ self.sort_spec = self.GetParam('sort')
+ self.query = self.GetParam('q')
+ self.can = self.GetParam('can')
+ self.start = self.GetParam('start')
+ self.num = self.GetParam('num')
+ self.warnings = []
+
+ def CleanUp(self):
+ self.cnxn = None
+
+ @property
+ def project_id(self):
+ return self.project.project_id if self.project else None
+
+ def GetParam(self, query_param_name, default_value=None,
+ _antitamper_re=None):
+ return self.params.get(query_param_name, default_value)
+
+
+class FakeFrontendSearchPipeline(object):
+
+ def __init__(self):
+ issue1 = fake.MakeTestIssue(
+ project_id=12345, local_id=1, owner_id=222, status='New', summary='sum')
+ issue2 = fake.MakeTestIssue(
+ project_id=12345, local_id=2, owner_id=222, status='New', summary='sum')
+ self.allowed_results = [issue1, issue2]
+ self.visible_results = [issue1]
+ self.total_count = len(self.allowed_results)
+ self.config = None
+ self.projectId = 0
+
+ def SearchForIIDs(self):
+ pass
+
+ def MergeAndSortIssues(self):
+ pass
+
+ def Paginate(self):
+ pass
+
+
+class MonorailApiBadAuthTest(testing.EndpointsTestCase):
+
+ api_service_cls = api_svc_v1.MonorailApi
+
+ def setUp(self):
+ super(MonorailApiBadAuthTest, self).setUp()
+ self.requester = RequesterMock(email='requester@example.com')
+ self.mock(endpoints, 'get_current_user', lambda: None)
+ self.request = {'userId': 'user@example.com'}
+
+ def testUsersGet_BadOAuth(self):
+ """The requester's token is invalid, e.g., because it expired."""
+ oauth.get_current_user = Mock(
+ return_value=RequesterMock(email='test@example.com'))
+ oauth.get_current_user.side_effect = oauth.Error()
+ with self.assertRaises(webtest.AppError) as cm:
+ self.call_api('users_get', self.request)
+ self.assertTrue(cm.exception.message.startswith('Bad response: 401'))
+
+
+class MonorailApiTest(testing.EndpointsTestCase):
+
+ api_service_cls = api_svc_v1.MonorailApi
+
+ def setUp(self):
+ super(MonorailApiTest, self).setUp()
+ # Load queue.yaml.
+ self.requester = RequesterMock(email='requester@example.com')
+ self.mock(endpoints, 'get_current_user', lambda: self.requester)
+ self.config = None
+ self.services = MakeFakeServiceManager()
+ self.mock(api_svc_v1.MonorailApi, '_services', self.services)
+ self.services.user.TestAddUser('requester@example.com', 111)
+ self.services.user.TestAddUser('user@example.com', 222)
+ self.services.user.TestAddUser('group@example.com', 123)
+ self.services.usergroup.TestAddGroupSettings(123, 'group@example.com')
+ self.request = {
+ 'userId': 'user@example.com',
+ 'ownerProjectsOnly': False,
+ 'requester': 'requester@example.com',
+ 'projectId': 'test-project',
+ 'issueId': 1}
+ self.mock(api_svc_v1.MonorailApi, 'mar_factory',
+ lambda x, y, z: FakeMonorailApiRequest(
+ self.request, self.services))
+
+ # api_base_checks is tested in AllBaseChecksTest,
+ # so mock it to reduce noise.
+ self.mock(api_svc_v1, 'api_base_checks',
+ lambda x, y, z, u, v, w: ('id', 'email'))
+
+ self.mock(tracker_fulltext, 'IndexIssues', lambda x, y, z, u, v: None)
+
+ def SetUpComponents(
+ self, project_id, component_id, component_name, component_doc='doc',
+ deprecated=False, admin_ids=None, cc_ids=None, created=100000,
+ creator=111):
+ admin_ids = admin_ids or []
+ cc_ids = cc_ids or []
+ self.config = self.services.config.GetProjectConfig(
+ 'fake cnxn', project_id)
+ self.services.config.StoreConfig('fake cnxn', self.config)
+ cd = tracker_bizobj.MakeComponentDef(
+ component_id, project_id, component_name, component_doc, deprecated,
+ admin_ids, cc_ids, created, creator, modifier_id=creator)
+ self.config.component_defs.append(cd)
+
+ def SetUpFieldDefs(
+ self, field_id, project_id, field_name, field_type_int,
+ min_value=0, max_value=100, needs_member=False, docstring='doc',
+ approval_id=None, is_phase_field=False):
+ self.config = self.services.config.GetProjectConfig(
+ 'fake cnxn', project_id)
+ self.services.config.StoreConfig('fake cnxn', self.config)
+ fd = tracker_bizobj.MakeFieldDef(
+ field_id, project_id, field_name, field_type_int, '',
+ '', False, False, False, min_value, max_value, None, needs_member,
+ None, '', tracker_pb2.NotifyTriggers.NEVER, 'no_action', docstring,
+ False, approval_id=approval_id, is_phase_field=is_phase_field)
+ self.config.field_defs.append(fd)
+
+ def testUsersGet_NoProject(self):
+ """The viewed user has no projects."""
+
+ self.services.project.TestAddProject(
+ 'public-project', owner_ids=[111])
+ resp = self.call_api('users_get', self.request).json_body
+ expected = {
+ 'id': '222',
+ 'kind': 'monorail#user'}
+ self.assertEqual(expected, resp)
+
+ def testUsersGet_PublicProject(self):
+ """The viewed user has one public project."""
+ self.services.template.GetProjectTemplates.return_value = \
+ testing_helpers.DefaultTemplates()
+ self.services.project.TestAddProject(
+ 'public-project', owner_ids=[222])
+ resp = self.call_api('users_get', self.request).json_body
+
+ self.assertEqual(1, len(resp['projects']))
+ self.assertEqual('public-project', resp['projects'][0]['name'])
+
+ def testUsersGet_PrivateProject(self):
+ """The viewed user has one project but the requester cannot view."""
+
+ self.services.project.TestAddProject(
+ 'private-project', owner_ids=[222],
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY)
+ resp = self.call_api('users_get', self.request).json_body
+ self.assertNotIn('projects', resp)
+
+ def testUsersGet_OwnerProjectOnly(self):
+ """The viewed user has different roles of projects."""
+ self.services.template.GetProjectTemplates.return_value = \
+ testing_helpers.DefaultTemplates()
+ self.services.project.TestAddProject(
+ 'owner-project', owner_ids=[222])
+ self.services.project.TestAddProject(
+ 'member-project', owner_ids=[111], committer_ids=[222])
+ resp = self.call_api('users_get', self.request).json_body
+ self.assertEqual(2, len(resp['projects']))
+
+ self.request['ownerProjectsOnly'] = True
+ resp = self.call_api('users_get', self.request).json_body
+ self.assertEqual(1, len(resp['projects']))
+ self.assertEqual('owner-project', resp['projects'][0]['name'])
+
+ def testIssuesGet_GetIssue(self):
+ """Get the requested issue."""
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ self.SetUpComponents(12345, 1, 'API')
+ self.SetUpFieldDefs(1, 12345, 'Field1', tracker_pb2.FieldTypes.INT_TYPE)
+
+ fv = tracker_pb2.FieldValue(
+ field_id=1,
+ int_value=11)
+ issue1 = fake.MakeTestIssue(
+ project_id=12345, local_id=1, owner_id=222, reporter_id=111,
+ status='New', summary='sum', component_ids=[1], field_values=[fv])
+ self.services.issue.TestAddIssue(issue1)
+
+ resp = self.call_api('issues_get', self.request).json_body
+ self.assertEqual(1, resp['id'])
+ self.assertEqual('New', resp['status'])
+ self.assertEqual('open', resp['state'])
+ self.assertFalse(resp['canEdit'])
+ self.assertTrue(resp['canComment'])
+ self.assertEqual('requester@example.com', resp['author']['name'])
+ self.assertEqual('user@example.com', resp['owner']['name'])
+ self.assertEqual('API', resp['components'][0])
+ self.assertEqual('Field1', resp['fieldValues'][0]['fieldName'])
+ self.assertEqual('11', resp['fieldValues'][0]['fieldValue'])
+
+ def testIssuesInsert_BadRequest(self):
+ """The request does not specify summary or status."""
+
+ with self.assertRaises(webtest.AppError):
+ self.call_api('issues_insert', self.request)
+
+ issue_dict = {
+ 'status': 'New',
+ 'summary': 'Test issue',
+ 'owner': {'name': 'notexist@example.com'}}
+ self.request.update(issue_dict)
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ with self.call_should_fail(400):
+ self.call_api('issues_insert', self.request)
+
+ # Invalid field value
+ self.SetUpFieldDefs(1, 12345, 'Field1', tracker_pb2.FieldTypes.INT_TYPE)
+ issue_dict = {
+ 'status': 'New',
+ 'summary': 'Test issue',
+ 'owner': {'name': 'requester@example.com'},
+ 'fieldValues': [{'fieldName': 'Field1', 'fieldValue': '111'}]}
+ self.request.update(issue_dict)
+ with self.call_should_fail(400):
+ self.call_api('issues_insert', self.request)
+
+ def testIssuesInsert_NoPermission(self):
+ """The requester has no permission to create issues."""
+
+ issue_dict = {
+ 'status': 'New',
+ 'summary': 'Test issue'}
+ self.request.update(issue_dict)
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY,
+ project_id=12345)
+ with self.call_should_fail(403):
+ self.call_api('issues_insert', self.request)
+
+ @patch('framework.cloud_tasks_helpers.create_task')
+ def testIssuesInsert_CreateIssue(self, _create_task_mock):
+ """Create an issue as requested."""
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], committer_ids=[111], project_id=12345)
+ self.SetUpFieldDefs(1, 12345, 'Field1', tracker_pb2.FieldTypes.INT_TYPE)
+
+ issue1 = fake.MakeTestIssue(
+ project_id=12345, local_id=1, owner_id=222, reporter_id=111,
+ status='New', summary='Test issue')
+ self.services.issue.TestAddIssue(issue1)
+
+ issue_dict = {
+ 'blockedOn': [{'issueId': 1}],
+ 'cc': [{'name': 'user@example.com'}, {'name': ''}, {'name': ' '}],
+ 'description': 'description',
+ 'labels': ['label1', 'label2'],
+ 'owner': {'name': 'requester@example.com'},
+ 'status': 'New',
+ 'summary': 'Test issue',
+ 'fieldValues': [{'fieldName': 'Field1', 'fieldValue': '11'}]}
+ self.request.update(issue_dict)
+
+ resp = self.call_api('issues_insert', self.request).json_body
+ self.assertEqual('New', resp['status'])
+ self.assertEqual('requester@example.com', resp['author']['name'])
+ self.assertEqual('requester@example.com', resp['owner']['name'])
+ self.assertEqual('user@example.com', resp['cc'][0]['name'])
+ self.assertEqual(1, resp['blockedOn'][0]['issueId'])
+ self.assertEqual([u'label1', u'label2'], resp['labels'])
+ self.assertEqual('Test issue', resp['summary'])
+ self.assertEqual('Field1', resp['fieldValues'][0]['fieldName'])
+ self.assertEqual('11', resp['fieldValues'][0]['fieldValue'])
+
+ new_issue = self.services.issue.GetIssueByLocalID(
+ 'fake cnxn', 12345, resp['id'])
+
+ starrers = self.services.issue_star.LookupItemStarrers(
+ 'fake cnxn', new_issue.issue_id)
+ self.assertIn(111, starrers)
+
+ @patch('framework.cloud_tasks_helpers.create_task')
+ def testIssuesInsert_EmptyOwnerCcNames(self, _create_task_mock):
+ """Create an issue as requested."""
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ self.SetUpFieldDefs(1, 12345, 'Field1', tracker_pb2.FieldTypes.INT_TYPE)
+
+ issue_dict = {
+ 'cc': [{'name': 'user@example.com'}, {'name': ''}],
+ 'description': 'description',
+ 'owner': {'name': ''},
+ 'status': 'New',
+ 'summary': 'Test issue'}
+ self.request.update(issue_dict)
+
+ resp = self.call_api('issues_insert', self.request).json_body
+ self.assertEqual('New', resp['status'])
+ self.assertEqual('requester@example.com', resp['author']['name'])
+ self.assertTrue('owner' not in resp)
+ self.assertEqual('user@example.com', resp['cc'][0]['name'])
+ self.assertEqual(len(resp['cc']), 1)
+ self.assertEqual('Test issue', resp['summary'])
+
+ new_issue = self.services.issue.GetIssueByLocalID(
+ 'fake cnxn', 12345, resp['id'])
+ self.assertEqual(new_issue.owner_id, 0)
+
+ def testIssuesList_NoPermission(self):
+ """No permission for additional projects."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+
+ self.services.project.TestAddProject(
+ 'test-project2', owner_ids=[222],
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY,
+ project_id=123456)
+ self.request['additionalProject'] = ['test-project2']
+ with self.call_should_fail(403):
+ self.call_api('issues_list', self.request)
+
+ def testIssuesList_SearchIssues(self):
+ """Find issues of one project."""
+
+ self.mock(
+ frontendsearchpipeline,
+ 'FrontendSearchPipeline', lambda cnxn, serv, auth, me, q, q_proj_names,
+ num, start, can, group_spec, sort_spec, warnings, errors, use_cache,
+ profiler, project: FakeFrontendSearchPipeline())
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111], # requester
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY,
+ project_id=12345)
+ resp = self.call_api('issues_list', self.request).json_body
+ self.assertEqual(2, int(resp['totalResults']))
+ self.assertEqual(1, len(resp['items']))
+ self.assertEqual(1, resp['items'][0]['id'])
+
+ def testIssuesCommentsList_GetComments(self):
+ """Get comments of requested issue."""
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+
+ issue1 = fake.MakeTestIssue(
+ project_id=12345, local_id=1, summary='test summary', status='New',
+ issue_id=10001, owner_id=222, reporter_id=111)
+ self.services.issue.TestAddIssue(issue1)
+
+ comment = tracker_pb2.IssueComment(
+ id=123, issue_id=10001,
+ project_id=12345, user_id=222,
+ content='this is a comment',
+ timestamp=1437700000)
+ self.services.issue.TestAddComment(comment, 1)
+
+ resp = self.call_api('issues_comments_list', self.request).json_body
+ self.assertEqual(2, resp['totalResults'])
+ comment1 = resp['items'][0]
+ comment2 = resp['items'][1]
+ self.assertEqual('requester@example.com', comment1['author']['name'])
+ self.assertEqual('test summary', comment1['content'])
+ self.assertEqual('user@example.com', comment2['author']['name'])
+ self.assertEqual('this is a comment', comment2['content'])
+
+ def testParseImportedReporter_Normal(self):
+ """Normal attempt to post a comment under the requester's name."""
+ mar = FakeMonorailApiRequest(self.request, self.services)
+ container = api_pb2_v1.ISSUES_COMMENTS_INSERT_REQUEST_RESOURCE_CONTAINER
+ request = container.body_message_class()
+
+ monorail_api = self.api_service_cls()
+ monorail_api._set_services(self.services)
+ reporter_id, timestamp = monorail_api.parse_imported_reporter(mar, request)
+ self.assertEqual(111, reporter_id)
+ self.assertIsNone(timestamp)
+
+ # API users should not need to specify anything for author when posting
+ # as the signed-in user, but it is OK if they specify their own email.
+ request.author = api_pb2_v1.AtomPerson(name='requester@example.com')
+ request.published = datetime.datetime.now() # Ignored
+ monorail_api = self.api_service_cls()
+ monorail_api._set_services(self.services)
+ reporter_id, timestamp = monorail_api.parse_imported_reporter(mar, request)
+ self.assertEqual(111, reporter_id)
+ self.assertIsNone(timestamp)
+
+ def testParseImportedReporter_Import_Allowed(self):
+ """User is importing a comment posted by a different user."""
+ project = self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], contrib_ids=[111],
+ project_id=12345)
+ project.extra_perms = [project_pb2.Project.ExtraPerms(
+ member_id=111, perms=['ImportComment'])]
+ mar = FakeMonorailApiRequest(self.request, self.services)
+ container = api_pb2_v1.ISSUES_COMMENTS_INSERT_REQUEST_RESOURCE_CONTAINER
+ request = container.body_message_class()
+ request.author = api_pb2_v1.AtomPerson(name='user@example.com')
+ NOW = 1234567890
+ request.published = datetime.datetime.utcfromtimestamp(NOW)
+ monorail_api = self.api_service_cls()
+ monorail_api._set_services(self.services)
+
+ reporter_id, timestamp = monorail_api.parse_imported_reporter(mar, request)
+
+ self.assertEqual(222, reporter_id) # that is user@
+ self.assertEqual(NOW, timestamp)
+
+ def testParseImportedReporter_Import_NotAllowed(self):
+ """User is importing a comment posted by a different user without perm."""
+ mar = FakeMonorailApiRequest(self.request, self.services)
+ container = api_pb2_v1.ISSUES_COMMENTS_INSERT_REQUEST_RESOURCE_CONTAINER
+ request = container.body_message_class()
+ request.author = api_pb2_v1.AtomPerson(name='user@example.com')
+ NOW = 1234567890
+ request.published = datetime.datetime.fromtimestamp(NOW)
+ monorail_api = self.api_service_cls()
+ monorail_api._set_services(self.services)
+
+ with self.assertRaises(permissions.PermissionException):
+ monorail_api.parse_imported_reporter(mar, request)
+
+ def testIssuesCommentsInsert_ApprovalFields(self):
+ """Attempts to update approval field values are blocked."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY,
+ project_id=12345)
+
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 2, issue_id=1234501)
+ self.services.issue.TestAddIssue(issue1)
+
+ self.SetUpFieldDefs(
+ 1, 12345, 'Field_int', tracker_pb2.FieldTypes.INT_TYPE)
+ self.SetUpFieldDefs(
+ 2, 12345, 'ApprovalChild', tracker_pb2.FieldTypes.STR_TYPE,
+ approval_id=1)
+
+ self.request['updates'] = {
+ 'fieldValues': [{'fieldName': 'Field_int', 'fieldValue': '11'},
+ {'fieldName': 'ApprovalChild', 'fieldValue': 'str'}]}
+
+ with self.call_should_fail(403):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsInsert_NoCommentPermission(self):
+ """No permission to comment an issue."""
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY,
+ project_id=12345)
+
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 2)
+ self.services.issue.TestAddIssue(issue1)
+
+ with self.call_should_fail(403):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsInsert_CommentPermissionOnly(self):
+ """User has permission to comment, even though they cannot edit."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[], project_id=12345)
+
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222)
+ self.services.issue.TestAddIssue(issue1)
+
+ self.request['content'] = 'This is just a comment'
+ resp = self.call_api('issues_comments_insert', self.request).json_body
+ self.assertEqual('requester@example.com', resp['author']['name'])
+ self.assertEqual('This is just a comment', resp['content'])
+
+ def testIssuesCommentsInsert_TooLongComment(self):
+ """Too long of a comment to add."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[], project_id=12345)
+
+ issue1 = fake.MakeTestIssue(12345, 1, 'Issue 1', 'New', 222)
+ self.services.issue.TestAddIssue(issue1)
+
+ long_comment = ' ' + 'c' * tracker_constants.MAX_COMMENT_CHARS + ' '
+ self.request['content'] = long_comment
+ with self.call_should_fail(400):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsInsert_Amendments_Normal(self):
+ """Insert comments with amendments."""
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111],
+ project_id=12345)
+
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, project_name='test-project')
+ issue2 = fake.MakeTestIssue(
+ 12345, 2, 'Issue 2', 'New', 222, project_name='test-project')
+ issue3 = fake.MakeTestIssue(
+ 12345, 3, 'Issue 3', 'New', 222, project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+ self.services.issue.TestAddIssue(issue2)
+ self.services.issue.TestAddIssue(issue3)
+
+ self.request['updates'] = {
+ 'summary': 'new summary',
+ 'status': 'Started',
+ 'owner': 'requester@example.com',
+ 'cc': ['user@example.com'],
+ 'labels': ['add_label', '-remove_label'],
+ 'blockedOn': ['2'],
+ 'blocking': ['3'],
+ }
+ resp = self.call_api('issues_comments_insert', self.request).json_body
+ self.assertEqual('requester@example.com', resp['author']['name'])
+ self.assertEqual('Started', resp['updates']['status'])
+ self.assertEqual(0, issue1.merged_into)
+
+ def testIssuesCommentsInsert_Amendments_NoPerms(self):
+ """Can't insert comments using account that lacks permissions."""
+
+ project1 = self.services.project.TestAddProject(
+ 'test-project', owner_ids=[], project_id=12345)
+
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+
+ self.request['updates'] = {
+ 'summary': 'new summary',
+ }
+ with self.call_should_fail(403):
+ self.call_api('issues_comments_insert', self.request)
+
+ project1.contributor_ids = [1] # Does not grant edit perm.
+ with self.call_should_fail(403):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsInsert_Amendments_BadOwner(self):
+ """Can't set owner to someone who is not a project member."""
+
+ _project1 = self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111], project_id=12345)
+
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+
+ self.request['updates'] = {
+ 'owner': 'user@example.com',
+ }
+ with self.call_should_fail(400):
+ self.call_api('issues_comments_insert', self.request)
+
+ @patch('framework.cloud_tasks_helpers.create_task')
+ def testIssuesCommentsInsert_MergeInto(self, _create_task_mock):
+ """Insert comment that merges an issue into another issue."""
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], committer_ids=[111],
+ project_id=12345)
+
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, project_name='test-project')
+ issue2 = fake.MakeTestIssue(
+ 12345, 2, 'Issue 2', 'New', 222, project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+ self.services.issue.TestAddIssue(issue2)
+ self.services.issue_star.SetStarsBatch(
+ 'cnxn', 'service', 'config', issue1.issue_id, [111, 222, 333], True)
+ self.services.issue_star.SetStarsBatch(
+ 'cnxn', 'service', 'config', issue2.issue_id, [555], True)
+
+ self.request['updates'] = {
+ 'summary': 'new summary',
+ 'status': 'Duplicate',
+ 'owner': 'requester@example.com',
+ 'cc': ['user@example.com'],
+ 'labels': ['add_label', '-remove_label'],
+ 'mergedInto': '2',
+ }
+ resp = self.call_api('issues_comments_insert', self.request).json_body
+ self.assertEqual('requester@example.com', resp['author']['name'])
+ self.assertEqual('Duplicate', resp['updates']['status'])
+ self.assertEqual(issue2.issue_id, issue1.merged_into)
+ issue2_comments = self.services.issue.GetCommentsForIssue(
+ 'cnxn', issue2.issue_id)
+ self.assertEqual(2, len(issue2_comments)) # description and merge
+ source_starrers = self.services.issue_star.LookupItemStarrers(
+ 'cnxn', issue1.issue_id)
+ self.assertItemsEqual([111, 222, 333], source_starrers)
+ target_starrers = self.services.issue_star.LookupItemStarrers(
+ 'cnxn', issue2.issue_id)
+ self.assertItemsEqual([111, 222, 333, 555], target_starrers)
+
+ def testIssuesCommentsInsert_CustomFields(self):
+ """Update custom field values."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111],
+ project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222,
+ project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+ self.SetUpFieldDefs(
+ 1, 12345, 'Field_int', tracker_pb2.FieldTypes.INT_TYPE)
+ self.SetUpFieldDefs(
+ 2, 12345, 'Field_enum', tracker_pb2.FieldTypes.ENUM_TYPE)
+
+ self.request['updates'] = {
+ 'fieldValues': [{'fieldName': 'Field_int', 'fieldValue': '11'},
+ {'fieldName': 'Field_enum', 'fieldValue': 'str'}]}
+ resp = self.call_api('issues_comments_insert', self.request).json_body
+ self.assertEqual(
+ {'fieldName': 'Field_int', 'fieldValue': '11'},
+ resp['updates']['fieldValues'][0])
+
+ def testIssuesCommentsInsert_IsDescription(self):
+ """Add a new issue description."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111], project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+ # Note: the initially issue description will be "Issue 1".
+
+ self.request['content'] = 'new desc'
+ self.request['updates'] = {'is_description': True}
+ resp = self.call_api('issues_comments_insert', self.request).json_body
+ self.assertEqual('new desc', resp['content'])
+ comments = self.services.issue.GetCommentsForIssue('cnxn', issue1.issue_id)
+ self.assertEqual(2, len(comments))
+ self.assertTrue(comments[1].is_description)
+ self.assertEqual('new desc', comments[1].content)
+
+ def testIssuesCommentsInsert_MoveToProject_NoPermsSrc(self):
+ """Don't move issue when user has no perms to edit issue."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[], project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, labels=[],
+ project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+ self.services.project.TestAddProject(
+ 'test-project2', owner_ids=[111], project_id=12346)
+
+ # The user has no permission in test-project.
+ self.request['projectId'] = 'test-project'
+ self.request['updates'] = {
+ 'moveToProject': 'test-project2'}
+ with self.call_should_fail(403):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsInsert_MoveToProject_NoPermsDest(self):
+ """Don't move issue to a different project where user has no perms."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111], project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, labels=[],
+ project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+ self.services.project.TestAddProject(
+ 'test-project2', owner_ids=[], project_id=12346)
+
+ # The user has no permission in test-project2.
+ self.request['projectId'] = 'test-project'
+ self.request['updates'] = {
+ 'moveToProject': 'test-project2'}
+ with self.call_should_fail(400):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsInsert_MoveToProject_NoSuchProject(self):
+ """Don't move issue to a different project that does not exist."""
+ project1 = self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, labels=[],
+ project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+
+ # Project doesn't exist.
+ project1.owner_ids = [111, 222]
+ self.request['updates'] = {
+ 'moveToProject': 'not exist'}
+ with self.call_should_fail(400):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsInsert_MoveToProject_SameProject(self):
+ """Don't move issue to the project it is already in."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111], project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, labels=[],
+ project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+
+ # The issue is already in destination
+ self.request['updates'] = {
+ 'moveToProject': 'test-project'}
+ with self.call_should_fail(400):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsInsert_MoveToProject_Restricted(self):
+ """Don't move restricted issue to a different project."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111], project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, labels=['Restrict-View-Google'],
+ project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+ self.services.project.TestAddProject(
+ 'test-project2', owner_ids=[111],
+ project_id=12346)
+
+ # Issue has restrict labels, so it cannot move.
+ self.request['projectId'] = 'test-project'
+ self.request['updates'] = {
+ 'moveToProject': 'test-project2'}
+ with self.call_should_fail(400):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsInsert_MoveToProject_Normal(self):
+ """Move issue."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111, 222],
+ project_id=12345)
+ self.services.project.TestAddProject(
+ 'test-project2', owner_ids=[111, 222],
+ project_id=12346)
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+ issue2 = fake.MakeTestIssue(
+ 12346, 1, 'Issue 1', 'New', 222, project_name='test-project2')
+ self.services.issue.TestAddIssue(issue2)
+
+ self.request['updates'] = {
+ 'moveToProject': 'test-project2'}
+ resp = self.call_api('issues_comments_insert', self.request).json_body
+
+ self.assertEqual(
+ 'Moved issue test-project:1 to now be issue test-project2:2.',
+ resp['content'])
+
+ def testIssuesCommentsInsert_Import_Allowed(self):
+ """Post a comment attributed to another user, with permission."""
+ project = self.services.project.TestAddProject(
+ 'test-project', committer_ids=[111, 222], project_id=12345)
+ project.extra_perms = [project_pb2.Project.ExtraPerms(
+ member_id=111, perms=['ImportComment'])]
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+
+ self.request['author'] = {'name': 'user@example.com'} # 222
+ self.request['content'] = 'a comment'
+ self.request['updates'] = {
+ 'owner': 'user@example.com',
+ }
+
+ resp = self.call_api('issues_comments_insert', self.request).json_body
+
+ self.assertEqual('a comment', resp['content'])
+ comments = self.services.issue.GetCommentsForIssue('cnxn', issue1.issue_id)
+ self.assertEqual(2, len(comments))
+ self.assertEqual(222, comments[1].user_id)
+ self.assertEqual('a comment', comments[1].content)
+
+
+ def testIssuesCommentsInsert_Import_Self(self):
+ """Specifying the comment author is OK if it is the requester."""
+ self.services.project.TestAddProject(
+ 'test-project', committer_ids=[111, 222], project_id=12345)
+ # Note: No ImportComment permission has been granted.
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+
+ self.request['author'] = {'name': 'requester@example.com'} # 111
+ self.request['content'] = 'a comment'
+ self.request['updates'] = {
+ 'owner': 'user@example.com',
+ }
+
+ resp = self.call_api('issues_comments_insert', self.request).json_body
+
+ self.assertEqual('a comment', resp['content'])
+ comments = self.services.issue.GetCommentsForIssue('cnxn', issue1.issue_id)
+ self.assertEqual(2, len(comments))
+ self.assertEqual(111, comments[1].user_id)
+ self.assertEqual('a comment', comments[1].content)
+
+ def testIssuesCommentsInsert_Import_Denied(self):
+ """Cannot post a comment attributed to another user without permission."""
+ self.services.project.TestAddProject(
+ 'test-project', committer_ids=[111, 222], project_id=12345)
+ # Note: No ImportComment permission has been granted.
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, project_name='test-project')
+ self.services.issue.TestAddIssue(issue1)
+
+ self.request['author'] = {'name': 'user@example.com'} # 222
+ self.request['content'] = 'a comment'
+ self.request['updates'] = {
+ 'owner': 'user@example.com',
+ }
+
+ with self.call_should_fail(403):
+ self.call_api('issues_comments_insert', self.request)
+
+ def testIssuesCommentsDelete_NoComment(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ project_id=12345, local_id=1, summary='test summary',
+ issue_id=10001, status='New', owner_id=222, reporter_id=222)
+ self.services.issue.TestAddIssue(issue1)
+ self.request['commentId'] = 1
+ with self.call_should_fail(404):
+ self.call_api('issues_comments_delete', self.request)
+
+ def testIssuesCommentsDelete_NoDeletePermission(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ project_id=12345, local_id=1, summary='test summary',
+ issue_id=10001, status='New', owner_id=222, reporter_id=222)
+ self.services.issue.TestAddIssue(issue1)
+ self.request['commentId'] = 0
+ with self.call_should_fail(403):
+ self.call_api('issues_comments_delete', self.request)
+
+ def testIssuesCommentsDelete_DeleteUndelete(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ issue1 = fake.MakeTestIssue(
+ project_id=12345, local_id=1, summary='test summary',
+ issue_id=10001, status='New', owner_id=222, reporter_id=111)
+ self.services.issue.TestAddIssue(issue1)
+ comment = tracker_pb2.IssueComment(
+ id=123, issue_id=10001,
+ project_id=12345, user_id=111,
+ content='this is a comment',
+ timestamp=1437700000)
+ self.services.issue.TestAddComment(comment, 1)
+ self.request['commentId'] = 1
+
+ comments = self.services.issue.GetCommentsForIssue(None, 10001)
+
+ self.call_api('issues_comments_delete', self.request)
+ self.assertEqual(111, comments[1].deleted_by)
+
+ self.call_api('issues_comments_undelete', self.request)
+ self.assertIsNone(comments[1].deleted_by)
+
+ def approvalRequest(self, approval, request_fields=None, comment=None,
+ issue_labels=None):
+ request = {'userId': 'user@example.com',
+ 'requester': 'requester@example.com',
+ 'projectId': 'test-project',
+ 'issueId': 1,
+ 'approvalName': 'Legal-Review',
+ 'sendEmail': False,
+ }
+ if request_fields:
+ request.update(request_fields)
+
+ self.SetUpFieldDefs(
+ 1, 12345, 'Legal-Review', tracker_pb2.FieldTypes.APPROVAL_TYPE)
+
+ issue1 = fake.MakeTestIssue(
+ 12345, 1, 'Issue 1', 'New', 222, approval_values=[approval],
+ labels=issue_labels)
+ self.services.issue.TestAddIssue(issue1)
+
+ self.services.issue.DeltaUpdateIssueApproval = Mock(return_value=comment)
+
+ self.mock(api_svc_v1.MonorailApi, 'mar_factory',
+ lambda x, y, z: FakeMonorailApiRequest(
+ request, self.services))
+ return request, issue1
+
+ def getFakeComments(self):
+ return [
+ tracker_pb2.IssueComment(
+ id=123, issue_id=1234501, project_id=12345, user_id=111,
+ content='1st comment', timestamp=1437700000, approval_id=1),
+ tracker_pb2.IssueComment(
+ id=223, issue_id=1234501, project_id=12345, user_id=111,
+ content='2nd comment', timestamp=1437700000, approval_id=2),
+ tracker_pb2.IssueComment(
+ id=323, issue_id=1234501, project_id=12345, user_id=111,
+ content='3rd comment', timestamp=1437700000, approval_id=1,
+ is_description=True),
+ tracker_pb2.IssueComment(
+ id=423, issue_id=1234501, project_id=12345, user_id=111,
+ content='4th comment', timestamp=1437700000)]
+
+ def testApprovalsCommentsList_NoViewPermission(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+
+ approval = tracker_pb2.ApprovalValue(approval_id=1)
+ request, _issue = self.approvalRequest(
+ approval, issue_labels=['Restrict-View-Google'])
+
+ with self.call_should_fail(403):
+ self.call_api('approvals_comments_list', request)
+
+ def testApprovalsCommentsList_NoApprovalFound(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+
+ approval = tracker_pb2.ApprovalValue(approval_id=1)
+ request, _issue = self.approvalRequest(approval)
+ self.config.field_defs = [] # empty field_defs of approval fd
+
+ with self.call_should_fail(400):
+ self.call_api('approvals_comments_list', request)
+
+ def testApprovalsCommentsList(self):
+ """Get comments of requested issue approval."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], project_id=12345)
+ self.services.issue.GetCommentsForIssue = Mock(
+ return_value=self.getFakeComments())
+
+ approval = tracker_pb2.ApprovalValue(approval_id=1)
+ request, _issue = self.approvalRequest(approval)
+
+ response = self.call_api('approvals_comments_list', request).json_body
+ self.assertEqual(response['kind'], 'monorail#approvalCommentList')
+ self.assertEqual(response['totalResults'], 2)
+ self.assertEqual(len(response['items']), 2)
+
+ def testApprovalsCommentsList_MaxResults(self):
+ """get comments of requested issue approval with maxResults."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], project_id=12345)
+ self.services.issue.GetCommentsForIssue = Mock(
+ return_value=self.getFakeComments())
+
+ approval = tracker_pb2.ApprovalValue(approval_id=1)
+ request, _issue = self.approvalRequest(
+ approval, request_fields={'maxResults': 1})
+
+ response = self.call_api('approvals_comments_list', request).json_body
+ self.assertEqual(response['kind'], 'monorail#approvalCommentList')
+ self.assertEqual(response['totalResults'], 2)
+ self.assertEqual(len(response['items']), 1)
+ self.assertEqual(response['items'][0]['content'], '1st comment')
+
+ @patch('testing.fake.IssueService.GetCommentsForIssue')
+ def testApprovalsCommentsList_StartIndex(self, mockGetComments):
+ """get comments of requested issue approval with maxResults."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], project_id=12345)
+ mockGetComments.return_value = self.getFakeComments()
+
+ approval = tracker_pb2.ApprovalValue(approval_id=1)
+ request, _issue = self.approvalRequest(
+ approval, request_fields={'startIndex': 1})
+
+ response = self.call_api('approvals_comments_list', request).json_body
+ self.assertEqual(response['kind'], 'monorail#approvalCommentList')
+ self.assertEqual(response['totalResults'], 2)
+ self.assertEqual(len(response['items']), 1)
+ self.assertEqual(response['items'][0]['content'], '3rd comment')
+
+ def testApprovalsCommentsInsert_NoCommentPermission(self):
+ """No permission to comment on an issue, including approvals."""
+
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY,
+ project_id=12345)
+
+ approval = tracker_pb2.ApprovalValue(approval_id=1)
+ request, _issue = self.approvalRequest(approval)
+
+ with self.call_should_fail(403):
+ self.call_api('approvals_comments_insert', request)
+
+ def testApprovalsCommentsInsert_TooLongComment(self):
+ """Too long of a comment when comments on approvals."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], project_id=12345)
+
+ approval = tracker_pb2.ApprovalValue(approval_id=1)
+ request, _issue = self.approvalRequest(approval)
+
+ long_comment = ' ' + 'c' * tracker_constants.MAX_COMMENT_CHARS + ' '
+ request['content'] = long_comment
+ with self.call_should_fail(400):
+ self.call_api('approvals_comments_insert', request)
+
+ def testApprovalsCommentsInsert_NoApprovalDefFound(self):
+ """No approval with approvalName found."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+
+ approval = tracker_pb2.ApprovalValue(approval_id=1)
+ request, _issue = self.approvalRequest(approval)
+ self.config.field_defs = []
+
+ with self.call_should_fail(400):
+ self.call_api('approvals_comments_insert', request)
+
+ # Test wrong field_type is also caught.
+ self.SetUpFieldDefs(
+ 1, 12345, 'Legal-Review', tracker_pb2.FieldTypes.STR_TYPE)
+ with self.call_should_fail(400):
+ self.call_api('approvals_comments_insert', request)
+
+ def testApprovalscommentsInsert_NoIssueFound(self):
+ """No issue found in project."""
+ request = {'userId': 'user@example.com',
+ 'requester': 'requester@example.com',
+ 'projectId': 'test-project',
+ 'issueId': 1,
+ 'approvalName': 'Legal-Review',
+ }
+ # No issue created.
+ with self.call_should_fail(400):
+ self.call_api('approvals_comments_insert', request)
+
+ def testApprovalsCommentsInsert_NoIssueApprovalFound(self):
+ """No approval with the given name found in the issue."""
+
+ request = {'userId': 'user@example.com',
+ 'requester': 'requester@example.com',
+ 'projectId': 'test-project',
+ 'issueId': 1,
+ 'approvalName': 'Legal-Review',
+ 'sendEmail': False,
+ }
+
+ self.SetUpFieldDefs(
+ 1, 12345, 'Legal-Review', tracker_pb2.FieldTypes.APPROVAL_TYPE)
+
+ # issue 1 does not contain the Legal-Review approval.
+ issue1 = fake.MakeTestIssue(12345, 1, 'Issue 1', 'New', 222)
+ self.services.issue.TestAddIssue(issue1)
+
+ with self.call_should_fail(400):
+ self.call_api('approvals_comments_insert', request)
+
+ def testApprovalsCommentsInsert_FieldValueChanges_NotFound(self):
+ """Approval's subfield value not found."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ approval = tracker_pb2.ApprovalValue(approval_id=1)
+
+ request, _issue = self.approvalRequest(
+ approval,
+ request_fields={
+ 'approvalUpdates': {
+ 'fieldValues': [
+ {'fieldName': 'DoesNotExist', 'fieldValue': 'cow'}]
+ },
+ })
+ with self.call_should_fail(400):
+ self.call_api('approvals_comments_insert', request)
+
+ # Test field belongs to another approval
+ self.config.field_defs.append(
+ tracker_bizobj.MakeFieldDef(
+ 2, 12345, 'DoesNotExist', tracker_pb2.FieldTypes.STR_TYPE,
+ '', '', False, False, False, None, None, None, False,
+ None, '', tracker_pb2.NotifyTriggers.NEVER, 'no_action',
+ 'parent approval is wrong', False, approval_id=4))
+ with self.call_should_fail(400):
+ self.call_api('approvals_comments_insert', request)
+
+ @patch('time.time')
+ def testApprovalCommentsInsert_FieldValueChanges(self, mock_time):
+ """Field value changes are properly processed."""
+ test_time = 6789
+ mock_time.return_value = test_time
+ comment = tracker_pb2.IssueComment(
+ id=123, issue_id=10001,
+ project_id=12345, user_id=111,
+ content='cows moo',
+ timestamp=143770000)
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], project_id=12345)
+ approval = tracker_pb2.ApprovalValue(
+ approval_id=1, approver_ids=[444])
+
+ request, issue = self.approvalRequest(
+ approval,
+ request_fields={'approvalUpdates': {
+ 'fieldValues': [
+ {'fieldName': 'CowLayerName', 'fieldValue': 'cow'},
+ {'fieldName': 'CowType', 'fieldValue': 'skim'},
+ {'fieldName': 'CowType', 'fieldValue': 'milk'},
+ {'fieldName': 'CowType', 'fieldValue': 'chocolate',
+ 'operator': 'remove'}]
+ }},
+ comment=comment)
+ self.config.field_defs.extend(
+ [tracker_bizobj.MakeFieldDef(
+ 2, 12345, 'CowLayerName', tracker_pb2.FieldTypes.STR_TYPE,
+ '', '', False, False, False, None, None, None, False,
+ None, '', tracker_pb2.NotifyTriggers.NEVER, 'no_action',
+ 'sub field value of approval 1', False, approval_id=1),
+ tracker_bizobj.MakeFieldDef(
+ 3, 12345, 'CowType', tracker_pb2.FieldTypes.ENUM_TYPE,
+ '', '', False, False, True, None, None, None, False,
+ None, '', tracker_pb2.NotifyTriggers.NEVER, 'no_action',
+ 'enum sub field value of approval 1', False, approval_id=1)])
+
+ response = self.call_api('approvals_comments_insert', request).json_body
+ fvs_add = [tracker_bizobj.MakeFieldValue(
+ 2, None, 'cow', None, None, None, False)]
+ labels_add = ['CowType-skim', 'CowType-milk']
+ labels_remove = ['CowType-chocolate']
+ approval_delta = tracker_bizobj.MakeApprovalDelta(
+ None, 111, [], [], fvs_add, [], [],
+ labels_add, labels_remove, set_on=test_time)
+ self.services.issue.DeltaUpdateIssueApproval.assert_called_with(
+ None, 111, self.config, issue, approval, approval_delta,
+ comment_content=None, is_description=None)
+ self.assertEqual(response['content'], comment.content)
+
+ @patch('time.time')
+ def testApprovalsCommentsInsert_StatusChanges_Normal(self, mock_time):
+ test_time = 6789
+ mock_time.return_value = test_time
+ comment = tracker_pb2.IssueComment(
+ id=123, issue_id=10001,
+ project_id=12345, user_id=111, # requester
+ content='this is a comment',
+ timestamp=1437700000,
+ amendments=[tracker_bizobj.MakeApprovalStatusAmendment(
+ tracker_pb2.ApprovalStatus.REVIEW_REQUESTED)])
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222], project_id=12345)
+ approval = tracker_pb2.ApprovalValue(
+ approval_id=1, approver_ids=[444],
+ status=tracker_pb2.ApprovalStatus.NOT_SET)
+
+ request, issue = self.approvalRequest(
+ approval,
+ request_fields={'approvalUpdates': {'status': 'reviewRequested'}},
+ comment=comment)
+ response = self.call_api('approvals_comments_insert', request).json_body
+ approval_delta = tracker_bizobj.MakeApprovalDelta(
+ tracker_pb2.ApprovalStatus.REVIEW_REQUESTED, 111, [], [], [], [], [],
+ [], [], set_on=test_time)
+ self.services.issue.DeltaUpdateIssueApproval.assert_called_with(
+ None, 111, self.config, issue, approval, approval_delta,
+ comment_content=None, is_description=None)
+
+ self.assertEqual(response['author']['name'], 'requester@example.com')
+ self.assertEqual(response['content'], comment.content)
+ self.assertTrue(response['canDelete'])
+ self.assertEqual(response['approvalUpdates'],
+ {'kind': 'monorail#approvalCommentUpdate',
+ 'status': 'reviewRequested'})
+
+ def testApprovalsCommentsInsert_StatusChanges_NoPerms(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ approval = tracker_pb2.ApprovalValue(
+ approval_id=1, approver_ids=[444],
+ status=tracker_pb2.ApprovalStatus.NOT_SET)
+ request, _issue = self.approvalRequest(
+ approval,
+ request_fields={'approvalUpdates': {'status': 'approved'}})
+ with self.call_should_fail(403):
+ self.call_api('approvals_comments_insert', request)
+
+ @patch('time.time')
+ def testApprovalsCommentsInsert_StatusChanges_ApproverPerms(self, mock_time):
+ test_time = 6789
+ mock_time.return_value = test_time
+ comment = tracker_pb2.IssueComment(
+ id=123, issue_id=1234501,
+ project_id=12345, user_id=111,
+ content='this is a comment',
+ timestamp=1437700000,
+ amendments=[tracker_bizobj.MakeApprovalStatusAmendment(
+ tracker_pb2.ApprovalStatus.NOT_APPROVED)])
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ approval = tracker_pb2.ApprovalValue(
+ approval_id=1, approver_ids=[111], # requester
+ status=tracker_pb2.ApprovalStatus.NOT_SET)
+ request, issue = self.approvalRequest(
+ approval,
+ request_fields={'approvalUpdates': {'status': 'notApproved'}},
+ comment=comment)
+ response = self.call_api('approvals_comments_insert', request).json_body
+
+ approval_delta = tracker_bizobj.MakeApprovalDelta(
+ tracker_pb2.ApprovalStatus.NOT_APPROVED, 111, [], [], [], [], [],
+ [], [], set_on=test_time)
+ self.services.issue.DeltaUpdateIssueApproval.assert_called_with(
+ None, 111, self.config, issue, approval, approval_delta,
+ comment_content=None, is_description=None)
+ self.assertEqual(response['author']['name'], 'requester@example.com')
+ self.assertEqual(response['content'], comment.content)
+ self.assertTrue(response['canDelete'])
+ self.assertEqual(response['approvalUpdates'],
+ {'kind': 'monorail#approvalCommentUpdate',
+ 'status': 'notApproved'})
+
+ def testApprovalsCommentsInsert_ApproverChanges_NoPerms(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+
+ approval = tracker_pb2.ApprovalValue(
+ approval_id=1, approver_ids=[444],
+ status=tracker_pb2.ApprovalStatus.NOT_SET)
+ request, _issue = self.approvalRequest(
+ approval,
+ request_fields={'approvalUpdates': {'approvers': 'someone@test.com'}})
+ with self.call_should_fail(403):
+ self.call_api('approvals_comments_insert', request)
+
+ @patch('time.time')
+ def testApprovalsCommentsInsert_ApproverChanges_ApproverPerms(
+ self, mock_time):
+ test_time = 6789
+ mock_time.return_value = test_time
+ comment = tracker_pb2.IssueComment(
+ id=123, issue_id=1234501,
+ project_id=12345, user_id=111,
+ content='this is a comment',
+ timestamp=1437700000,
+ amendments=[tracker_bizobj.MakeApprovalApproversAmendment(
+ [222], [123])])
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+
+ approval = tracker_pb2.ApprovalValue(
+ approval_id=1, approver_ids=[111], # requester
+ status=tracker_pb2.ApprovalStatus.NOT_SET)
+ request, issue = self.approvalRequest(
+ approval,
+ request_fields={
+ 'approvalUpdates':
+ {'approvers': ['user@example.com', '-group@example.com']}},
+ comment=comment)
+ response = self.call_api('approvals_comments_insert', request).json_body
+
+ approval_delta = tracker_bizobj.MakeApprovalDelta(
+ None, 111, [222], [123], [], [], [], [], [], set_on=test_time)
+ self.services.issue.DeltaUpdateIssueApproval.assert_called_with(
+ None, 111, self.config, issue, approval, approval_delta,
+ comment_content=None, is_description=None)
+ self.assertEqual(response['author']['name'], 'requester@example.com')
+ self.assertEqual(response['content'], comment.content)
+ self.assertTrue(response['canDelete'])
+ self.assertEqual(response['approvalUpdates'],
+ {'kind': 'monorail#approvalCommentUpdate',
+ 'approvers': ['user@example.com', '-group@example.com']})
+
+ @patch('time.time')
+ def testApprovalsCommentsInsert_IsSurvey(self, mock_time):
+ test_time = 6789
+ mock_time.return_value = test_time
+ comment = tracker_pb2.IssueComment(
+ id=123, issue_id=10001,
+ project_id=12345, user_id=111,
+ content='this is a comment',
+ timestamp=1437700000)
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+
+ approval = tracker_pb2.ApprovalValue(
+ approval_id=1, approver_ids=[111], # requester
+ status=tracker_pb2.ApprovalStatus.NOT_SET)
+ request, issue = self.approvalRequest(
+ approval,
+ request_fields={'content': 'updated survey', 'is_description': True},
+ comment=comment)
+ response = self.call_api('approvals_comments_insert', request).json_body
+
+ approval_delta = tracker_bizobj.MakeApprovalDelta(
+ None, 111, [], [], [], [], [], [], [], set_on=test_time)
+ self.services.issue.DeltaUpdateIssueApproval.assert_called_with(
+ None, 111, self.config, issue, approval, approval_delta,
+ comment_content='updated survey', is_description=True)
+ self.assertEqual(response['author']['name'], 'requester@example.com')
+ self.assertTrue(response['canDelete'])
+
+ @patch('time.time')
+ @patch('features.send_notifications.PrepareAndSendApprovalChangeNotification')
+ def testApprovalsCommentsInsert_SendEmail(
+ self, mockPrepareAndSend, mock_time,):
+ test_time = 6789
+ mock_time.return_value = test_time
+ comment = tracker_pb2.IssueComment(
+ id=123, issue_id=10001,
+ project_id=12345, user_id=111,
+ content='this is a comment',
+ timestamp=1437700000)
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+
+ approval = tracker_pb2.ApprovalValue(
+ approval_id=1, approver_ids=[111], # requester
+ status=tracker_pb2.ApprovalStatus.NOT_SET)
+ request, issue = self.approvalRequest(
+ approval,
+ request_fields={'content': comment.content, 'sendEmail': True},
+ comment=comment)
+
+ response = self.call_api('approvals_comments_insert', request).json_body
+
+ mockPrepareAndSend.assert_called_with(
+ issue.issue_id, approval.approval_id, ANY, comment.id, send_email=True)
+
+ approval_delta = tracker_bizobj.MakeApprovalDelta(
+ None, 111, [], [], [], [], [], [], [], set_on=test_time)
+ self.services.issue.DeltaUpdateIssueApproval.assert_called_with(
+ None, 111, self.config, issue, approval, approval_delta,
+ comment_content=comment.content, is_description=None)
+ self.assertEqual(response['author']['name'], 'requester@example.com')
+ self.assertTrue(response['canDelete'])
+
+ def testGroupsSettingsList_AllSettings(self):
+ resp = self.call_api('groups_settings_list', self.request).json_body
+ all_settings = resp['groupSettings']
+ self.assertEqual(1, len(all_settings))
+ self.assertEqual('group@example.com', all_settings[0]['groupName'])
+
+ def testGroupsSettingsList_ImportedSettings(self):
+ self.services.user.TestAddUser('imported@example.com', 234)
+ self.services.usergroup.TestAddGroupSettings(
+ 234, 'imported@example.com', external_group_type='mdb')
+ self.request['importedGroupsOnly'] = True
+ resp = self.call_api('groups_settings_list', self.request).json_body
+ all_settings = resp['groupSettings']
+ self.assertEqual(1, len(all_settings))
+ self.assertEqual('imported@example.com', all_settings[0]['groupName'])
+
+ def testGroupsCreate_NoPermission(self):
+ self.request['groupName'] = 'group'
+ with self.call_should_fail(403):
+ self.call_api('groups_create', self.request)
+
+ def SetUpGroupRequest(self, group_name, who_can_view_members='MEMBERS',
+ ext_group_type=None, perms=None,
+ requester='requester@example.com'):
+ request = {
+ 'groupName': group_name,
+ 'requester': requester,
+ 'who_can_view_members': who_can_view_members,
+ 'ext_group_type': ext_group_type}
+ self.request.pop("userId", None)
+ self.mock(api_svc_v1.MonorailApi, 'mar_factory',
+ lambda x, y, z: FakeMonorailApiRequest(
+ request, self.services, perms=perms))
+ return request
+
+ def testGroupsCreate_Normal(self):
+ request = self.SetUpGroupRequest('newgroup@example.com', 'MEMBERS',
+ 'MDB', permissions.ADMIN_PERMISSIONSET)
+
+ resp = self.call_api('groups_create', request).json_body
+ self.assertIn('groupID', resp)
+
+ def testGroupsGet_NoPermission(self):
+ request = self.SetUpGroupRequest('group@example.com')
+ with self.call_should_fail(403):
+ self.call_api('groups_get', request)
+
+ def testGroupsGet_Normal(self):
+ request = self.SetUpGroupRequest('group@example.com',
+ perms=permissions.ADMIN_PERMISSIONSET)
+ self.services.usergroup.TestAddMembers(123, [111], 'member')
+ self.services.usergroup.TestAddMembers(123, [222], 'owner')
+ resp = self.call_api('groups_get', request).json_body
+ self.assertEqual(123, resp['groupID'])
+ self.assertEqual(['requester@example.com'], resp['groupMembers'])
+ self.assertEqual(['user@example.com'], resp['groupOwners'])
+ self.assertEqual('group@example.com', resp['groupSettings']['groupName'])
+
+ def testGroupsUpdate_NoPermission(self):
+ request = self.SetUpGroupRequest('group@example.com')
+ with self.call_should_fail(403):
+ self.call_api('groups_update', request)
+
+ def testGroupsUpdate_Normal(self):
+ request = self.SetUpGroupRequest('group@example.com')
+ request = self.SetUpGroupRequest('group@example.com',
+ perms=permissions.ADMIN_PERMISSIONSET)
+ request['last_sync_time'] = 123456789
+ request['groupOwners'] = ['requester@example.com']
+ request['groupMembers'] = ['user@example.com']
+ resp = self.call_api('groups_update', request).json_body
+ self.assertFalse(resp.get('error'))
+
+ def testComponentsList(self):
+ """Get components for a project."""
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ self.SetUpComponents(12345, 1, 'API')
+ resp = self.call_api('components_list', self.request).json_body
+
+ self.assertEqual(1, len(resp['components']))
+ cd = resp['components'][0]
+ self.assertEqual(1, cd['componentId'])
+ self.assertEqual('API', cd['componentPath'])
+ self.assertEqual(1, cd['componentId'])
+ self.assertEqual('test-project', cd['projectName'])
+
+ def testComponentsCreate_NoPermission(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ self.SetUpComponents(12345, 1, 'API')
+
+ cd_dict = {
+ 'componentName': 'Test'}
+ self.request.update(cd_dict)
+
+ with self.call_should_fail(403):
+ self.call_api('components_create', self.request)
+
+ def testComponentsCreate_Invalid(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111],
+ project_id=12345)
+ self.SetUpComponents(12345, 1, 'API')
+
+ # Component with invalid name
+ cd_dict = {
+ 'componentName': 'c>d>e'}
+ self.request.update(cd_dict)
+ with self.call_should_fail(400):
+ self.call_api('components_create', self.request)
+
+ # Name already in use
+ cd_dict = {
+ 'componentName': 'API'}
+ self.request.update(cd_dict)
+ with self.call_should_fail(400):
+ self.call_api('components_create', self.request)
+
+ # Parent component does not exist
+ cd_dict = {
+ 'componentName': 'test',
+ 'parentPath': 'NotExist'}
+ self.request.update(cd_dict)
+ with self.call_should_fail(404):
+ self.call_api('components_create', self.request)
+
+
+ def testComponentsCreate_Normal(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111],
+ project_id=12345)
+ self.SetUpComponents(12345, 1, 'API')
+
+ cd_dict = {
+ 'componentName': 'Test',
+ 'description': 'test comp',
+ 'cc': ['requester@example.com', '']
+ }
+ self.request.update(cd_dict)
+
+ resp = self.call_api('components_create', self.request).json_body
+ self.assertEqual('test comp', resp['description'])
+ self.assertEqual('requester@example.com', resp['creator'])
+ self.assertEqual([u'requester@example.com'], resp['cc'])
+ self.assertEqual('Test', resp['componentPath'])
+
+ cd_dict = {
+ 'componentName': 'TestChild',
+ 'parentPath': 'API'}
+ self.request.update(cd_dict)
+ resp = self.call_api('components_create', self.request).json_body
+
+ self.assertEqual('API>TestChild', resp['componentPath'])
+
+ def testComponentsDelete_Invalid(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ self.SetUpComponents(12345, 1, 'API')
+
+ # Fail to delete a non-existent component
+ cd_dict = {
+ 'componentPath': 'NotExist'}
+ self.request.update(cd_dict)
+ with self.call_should_fail(404):
+ self.call_api('components_delete', self.request)
+
+ # The user has no permission to delete component
+ cd_dict = {
+ 'componentPath': 'API'}
+ self.request.update(cd_dict)
+ with self.call_should_fail(403):
+ self.call_api('components_delete', self.request)
+
+ # The user tries to delete component that had subcomponents
+ self.services.project.TestAddProject(
+ 'test-project2', owner_ids=[111],
+ project_id=123456)
+ self.SetUpComponents(123456, 1, 'Parent')
+ self.SetUpComponents(123456, 2, 'Parent>Child')
+ cd_dict = {
+ 'componentPath': 'Parent',
+ 'projectId': 'test-project2',}
+ self.request.update(cd_dict)
+ with self.call_should_fail(403):
+ self.call_api('components_delete', self.request)
+
+ def testComponentsDelete_Normal(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111],
+ project_id=12345)
+ self.SetUpComponents(12345, 1, 'API')
+
+ cd_dict = {
+ 'componentPath': 'API'}
+ self.request.update(cd_dict)
+ _ = self.call_api('components_delete', self.request).json_body
+ self.assertEqual(0, len(self.config.component_defs))
+
+ def testComponentsUpdate_Invalid(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[222],
+ project_id=12345)
+ self.SetUpComponents(12345, 1, 'API')
+ self.SetUpComponents(12345, 2, 'Test', admin_ids=[111])
+
+ # Fail to update a non-existent component
+ cd_dict = {
+ 'componentPath': 'NotExist'}
+ self.request.update(cd_dict)
+ with self.call_should_fail(404):
+ self.call_api('components_update', self.request)
+
+ # The user has no permission to edit component
+ cd_dict = {
+ 'componentPath': 'API'}
+ self.request.update(cd_dict)
+ with self.call_should_fail(403):
+ self.call_api('components_update', self.request)
+
+ # The user tries an invalid component name
+ cd_dict = {
+ 'componentPath': 'Test',
+ 'updates': [{'field': 'LEAF_NAME', 'leafName': 'c>e'}]}
+ self.request.update(cd_dict)
+ with self.call_should_fail(400):
+ self.call_api('components_update', self.request)
+
+ # The user tries a name already in use
+ cd_dict = {
+ 'componentPath': 'Test',
+ 'updates': [{'field': 'LEAF_NAME', 'leafName': 'API'}]}
+ self.request.update(cd_dict)
+ with self.call_should_fail(400):
+ self.call_api('components_update', self.request)
+
+ def testComponentsUpdate_Normal(self):
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111],
+ project_id=12345)
+ self.SetUpComponents(12345, 1, 'API')
+ self.SetUpComponents(12345, 2, 'Parent')
+ self.SetUpComponents(12345, 3, 'Parent>Child')
+
+ cd_dict = {
+ 'componentPath': 'API',
+ 'updates': [
+ {'field': 'DESCRIPTION', 'description': ''},
+ {'field': 'CC', 'cc': [
+ 'requester@example.com', 'user@example.com', '', ' ']},
+ {'field': 'DEPRECATED', 'deprecated': True}]}
+ self.request.update(cd_dict)
+ _ = self.call_api('components_update', self.request).json_body
+ component_def = tracker_bizobj.FindComponentDef(
+ 'API', self.config)
+ self.assertIsNotNone(component_def)
+ self.assertEqual('', component_def.docstring)
+ self.assertItemsEqual([111, 222], component_def.cc_ids)
+ self.assertTrue(component_def.deprecated)
+
+ cd_dict = {
+ 'componentPath': 'Parent',
+ 'updates': [
+ {'field': 'LEAF_NAME', 'leafName': 'NewParent'}]}
+ self.request.update(cd_dict)
+ _ = self.call_api('components_update', self.request).json_body
+ cd_parent = tracker_bizobj.FindComponentDef(
+ 'NewParent', self.config)
+ cd_child = tracker_bizobj.FindComponentDef(
+ 'NewParent>Child', self.config)
+ self.assertIsNotNone(cd_parent)
+ self.assertIsNotNone(cd_child)
+
+
+class RequestMock(object):
+
+ def __init__(self):
+ self.projectId = None
+ self.issueId = None
+
+
+class RequesterMock(object):
+
+ def __init__(self, email=None):
+ self._email = email
+
+ def email(self):
+ return self._email
+
+
+class AllBaseChecksTest(unittest.TestCase):
+
+ def setUp(self):
+ self.services = MakeFakeServiceManager()
+ self.services.user.TestAddUser('test@example.com', 111)
+ self.user_2 = self.services.user.TestAddUser('test@google.com', 222)
+ self.services.project.TestAddProject(
+ 'test-project', owner_ids=[111], project_id=123,
+ access=project_pb2.ProjectAccess.MEMBERS_ONLY)
+ self.auth_client_ids = ['123456789.apps.googleusercontent.com']
+ oauth.get_client_id = Mock(return_value=self.auth_client_ids[0])
+ oauth.get_current_user = Mock(
+ return_value=RequesterMock(email='test@example.com'))
+ oauth.get_authorized_scopes = Mock()
+
+ def testUnauthorizedRequester(self):
+ with self.assertRaises(endpoints.UnauthorizedException):
+ api_svc_v1.api_base_checks(None, None, None, None, [], [])
+
+ def testNoUser(self):
+ requester = RequesterMock(email='notexist@example.com')
+ with self.assertRaises(exceptions.NoSuchUserException):
+ api_svc_v1.api_base_checks(
+ None, requester, self.services, None, self.auth_client_ids, [])
+
+ def testAllowedDomain_MonorailScope(self):
+ oauth.get_authorized_scopes.return_value = [
+ framework_constants.MONORAIL_SCOPE]
+ oauth.get_current_user.return_value = RequesterMock(
+ email=self.user_2.email)
+ allowlisted_client_ids = []
+ allowlisted_emails = []
+ client_id, email = api_svc_v1.api_base_checks(
+ None, None, self.services, None, allowlisted_client_ids,
+ allowlisted_emails)
+ self.assertEqual(client_id, self.auth_client_ids[0])
+ self.assertEqual(email, self.user_2.email)
+
+ def testAllowedDomain_NoMonorailScope(self):
+ oauth.get_authorized_scopes.return_value = []
+ oauth.get_current_user.return_value = RequesterMock(
+ email=self.user_2.email)
+ allowlisted_client_ids = []
+ allowlisted_emails = []
+ with self.assertRaises(endpoints.UnauthorizedException):
+ api_svc_v1.api_base_checks(
+ None, None, self.services, None, allowlisted_client_ids,
+ allowlisted_emails)
+
+ def testAllowedDomain_BadEmail(self):
+ oauth.get_authorized_scopes.return_value = [
+ framework_constants.MONORAIL_SCOPE]
+ oauth.get_current_user.return_value = RequesterMock(
+ email='chicken@chicken.test')
+ allowlisted_client_ids = []
+ allowlisted_emails = []
+ self.services.user.TestAddUser('chicken@chicken.test', 333)
+ with self.assertRaises(endpoints.UnauthorizedException):
+ api_svc_v1.api_base_checks(
+ None, None, self.services, None, allowlisted_client_ids,
+ allowlisted_emails)
+
+ def testNoOauthUser(self):
+ oauth.get_current_user.side_effect = oauth.Error()
+ with self.assertRaises(endpoints.UnauthorizedException):
+ api_svc_v1.api_base_checks(
+ None, None, self.services, None, [], [])
+
+ def testBannedUser(self):
+ banned_email = 'banned@example.com'
+ self.services.user.TestAddUser(banned_email, 222, banned=True)
+ requester = RequesterMock(email=banned_email)
+ with self.assertRaises(permissions.BannedUserException):
+ api_svc_v1.api_base_checks(
+ None, requester, self.services, None, self.auth_client_ids, [])
+
+ def testNoProject(self):
+ request = RequestMock()
+ request.projectId = 'notexist-project'
+ requester = RequesterMock(email='test@example.com')
+ with self.assertRaises(exceptions.NoSuchProjectException):
+ api_svc_v1.api_base_checks(
+ request, requester, self.services, None, self.auth_client_ids, [])
+
+ def testNonLiveProject(self):
+ archived_project = 'archived-project'
+ self.services.project.TestAddProject(
+ archived_project, owner_ids=[111],
+ state=project_pb2.ProjectState.ARCHIVED)
+ request = RequestMock()
+ request.projectId = archived_project
+ requester = RequesterMock(email='test@example.com')
+ with self.assertRaises(permissions.PermissionException):
+ api_svc_v1.api_base_checks(
+ request, requester, self.services, None, self.auth_client_ids, [])
+
+ def testNoViewProjectPermission(self):
+ nonmember_email = 'nonmember@example.com'
+ self.services.user.TestAddUser(nonmember_email, 222)
+ requester = RequesterMock(email=nonmember_email)
+ request = RequestMock()
+ request.projectId = 'test-project'
+ with self.assertRaises(permissions.PermissionException):
+ api_svc_v1.api_base_checks(
+ request, requester, self.services, None, self.auth_client_ids, [])
+
+ def testAllPass(self):
+ requester = RequesterMock(email='test@example.com')
+ request = RequestMock()
+ request.projectId = 'test-project'
+ api_svc_v1.api_base_checks(
+ request, requester, self.services, None, self.auth_client_ids, [])
+
+ def testNoIssue(self):
+ requester = RequesterMock(email='test@example.com')
+ request = RequestMock()
+ request.projectId = 'test-project'
+ request.issueId = 12345
+ with self.assertRaises(exceptions.NoSuchIssueException):
+ api_svc_v1.api_base_checks(
+ request, requester, self.services, None, self.auth_client_ids, [])
+
+ def testNoViewIssuePermission(self):
+ requester = RequesterMock(email='test@example.com')
+ request = RequestMock()
+ request.projectId = 'test-project'
+ request.issueId = 1
+ issue1 = fake.MakeTestIssue(
+ project_id=123, local_id=1, summary='test summary',
+ status='New', owner_id=111, reporter_id=111)
+ issue1.deleted = True
+ self.services.issue.TestAddIssue(issue1)
+ with self.assertRaises(permissions.PermissionException):
+ api_svc_v1.api_base_checks(
+ request, requester, self.services, None, self.auth_client_ids, [])
+
+ def testAnonymousClients(self):
+ # Some clients specifically pass "anonymous" as the client ID.
+ oauth.get_client_id = Mock(return_value='anonymous')
+ requester = RequesterMock(email='test@example.com')
+ request = RequestMock()
+ request.projectId = 'test-project'
+ api_svc_v1.api_base_checks(
+ request, requester, self.services, None, [], ['test@example.com'])
+
+ # Any client_id is OK if the email is allowlisted.
+ oauth.get_client_id = Mock(return_value='anything')
+ api_svc_v1.api_base_checks(
+ request, requester, self.services, None, [], ['test@example.com'])
+
+ # Reject request when neither client ID nor email is allowlisted.
+ with self.assertRaises(endpoints.UnauthorizedException):
+ api_svc_v1.api_base_checks(
+ request, requester, self.services, None, [], [])
diff --git a/services/test/cachemanager_svc_test.py b/services/test/cachemanager_svc_test.py
new file mode 100644
index 0000000..20956e0
--- /dev/null
+++ b/services/test/cachemanager_svc_test.py
@@ -0,0 +1,205 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the cachemanager service."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import mox
+
+from framework import sql
+from services import cachemanager_svc
+from services import caches
+from services import service_manager
+from testing import fake
+from testing import testing_helpers
+
+
+class CacheManagerServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cnxn = fake.MonorailConnection()
+ self.cache_manager = cachemanager_svc.CacheManager()
+ self.cache_manager.invalidate_tbl = self.mox.CreateMock(
+ sql.SQLTableManager)
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testRegisterCache(self):
+ ram_cache = 'fake ramcache'
+ self.cache_manager.RegisterCache(ram_cache, 'issue')
+ self.assertTrue(ram_cache in self.cache_manager.cache_registry['issue'])
+
+ def testRegisterCache_UnknownKind(self):
+ ram_cache = 'fake ramcache'
+ self.assertRaises(
+ AssertionError,
+ self.cache_manager.RegisterCache, ram_cache, 'foo')
+
+ def testProcessInvalidateRows_Empty(self):
+ rows = []
+ self.cache_manager._ProcessInvalidationRows(rows)
+ self.assertEqual(0, self.cache_manager.processed_invalidations_up_to)
+
+ def testProcessInvalidateRows_Some(self):
+ ram_cache = caches.RamCache(self.cache_manager, 'issue')
+ ram_cache.CacheAll({
+ 33: 'issue 33',
+ 34: 'issue 34',
+ })
+ rows = [(1, 'issue', 34),
+ (2, 'project', 789),
+ (3, 'issue', 39)]
+ self.cache_manager._ProcessInvalidationRows(rows)
+ self.assertEqual(3, self.cache_manager.processed_invalidations_up_to)
+ self.assertTrue(ram_cache.HasItem(33))
+ self.assertFalse(ram_cache.HasItem(34))
+
+ def testProcessInvalidateRows_All(self):
+ ram_cache = caches.RamCache(self.cache_manager, 'issue')
+ ram_cache.CacheAll({
+ 33: 'issue 33',
+ 34: 'issue 34',
+ })
+ rows = [(991, 'issue', 34),
+ (992, 'project', 789),
+ (993, 'issue', cachemanager_svc.INVALIDATE_ALL_KEYS)]
+ self.cache_manager._ProcessInvalidationRows(rows)
+ self.assertEqual(993, self.cache_manager.processed_invalidations_up_to)
+ self.assertEqual({}, ram_cache.cache)
+
+ def SetUpDoDistributedInvalidation(self, rows):
+ self.cache_manager.invalidate_tbl.Select(
+ self.cnxn, cols=['timestep', 'kind', 'cache_key'],
+ where=[('timestep > %s', [0])],
+ order_by=[('timestep DESC', [])],
+ limit=cachemanager_svc.MAX_INVALIDATE_ROWS_TO_CONSIDER
+ ).AndReturn(rows)
+
+ def testDoDistributedInvalidation_Empty(self):
+ rows = []
+ self.SetUpDoDistributedInvalidation(rows)
+ self.mox.ReplayAll()
+ self.cache_manager.DoDistributedInvalidation(self.cnxn)
+ self.mox.VerifyAll()
+ self.assertEqual(0, self.cache_manager.processed_invalidations_up_to)
+
+ def testDoDistributedInvalidation_Some(self):
+ ram_cache = caches.RamCache(self.cache_manager, 'issue')
+ ram_cache.CacheAll({
+ 33: 'issue 33',
+ 34: 'issue 34',
+ })
+ rows = [(1, 'issue', 34),
+ (2, 'project', 789),
+ (3, 'issue', 39)]
+ self.SetUpDoDistributedInvalidation(rows)
+ self.mox.ReplayAll()
+ self.cache_manager.DoDistributedInvalidation(self.cnxn)
+ self.mox.VerifyAll()
+ self.assertEqual(3, self.cache_manager.processed_invalidations_up_to)
+ self.assertTrue(ram_cache.HasItem(33))
+ self.assertFalse(ram_cache.HasItem(34))
+
+ def testDoDistributedInvalidation_Redundant(self):
+ ram_cache = caches.RamCache(self.cache_manager, 'issue')
+ ram_cache.CacheAll({
+ 33: 'issue 33',
+ 34: 'issue 34',
+ })
+ rows = [(1, 'issue', 34),
+ (2, 'project', 789),
+ (3, 'issue', 39),
+ (4, 'project', 789),
+ (5, 'issue', 39)]
+ self.SetUpDoDistributedInvalidation(rows)
+ self.mox.ReplayAll()
+ self.cache_manager.DoDistributedInvalidation(self.cnxn)
+ self.mox.VerifyAll()
+ self.assertEqual(5, self.cache_manager.processed_invalidations_up_to)
+ self.assertTrue(ram_cache.HasItem(33))
+ self.assertFalse(ram_cache.HasItem(34))
+
+ def testStoreInvalidateRows_UnknownKind(self):
+ self.assertRaises(
+ AssertionError,
+ self.cache_manager.StoreInvalidateRows, self.cnxn, 'foo', [1, 2])
+
+ def SetUpStoreInvalidateRows(self, rows):
+ self.cache_manager.invalidate_tbl.InsertRows(
+ self.cnxn, ['kind', 'cache_key'], rows)
+
+ def testStoreInvalidateRows(self):
+ rows = [('issue', 1), ('issue', 2)]
+ self.SetUpStoreInvalidateRows(rows)
+ self.mox.ReplayAll()
+ self.cache_manager.StoreInvalidateRows(self.cnxn, 'issue', [1, 2])
+ self.mox.VerifyAll()
+
+ def SetUpStoreInvalidateAll(self, kind):
+ self.cache_manager.invalidate_tbl.InsertRow(
+ self.cnxn, kind=kind, cache_key=cachemanager_svc.INVALIDATE_ALL_KEYS,
+ ).AndReturn(44)
+ self.cache_manager.invalidate_tbl.Delete(
+ self.cnxn, kind=kind, where=[('timestep < %s', [44])])
+
+ def testStoreInvalidateAll(self):
+ self.SetUpStoreInvalidateAll('issue')
+ self.mox.ReplayAll()
+ self.cache_manager.StoreInvalidateAll(self.cnxn, 'issue')
+ self.mox.VerifyAll()
+
+
+class RamCacheConsolidateTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cnxn = 'fake connection'
+ self.cache_manager = cachemanager_svc.CacheManager()
+ self.cache_manager.invalidate_tbl = self.mox.CreateMock(
+ sql.SQLTableManager)
+ self.services = service_manager.Services(
+ cache_manager=self.cache_manager)
+ self.servlet = cachemanager_svc.RamCacheConsolidate(
+ 'req', 'res', services=self.services)
+
+ def testHandleRequest_NothingToDo(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ self.cache_manager.invalidate_tbl.SelectValue(
+ mr.cnxn, 'COUNT(*)').AndReturn(112)
+ self.cache_manager.invalidate_tbl.SelectValue(
+ mr.cnxn, 'COUNT(*)').AndReturn(112)
+ self.mox.ReplayAll()
+
+ json_data = self.servlet.HandleRequest(mr)
+ self.mox.VerifyAll()
+ self.assertEqual(json_data['old_count'], 112)
+ self.assertEqual(json_data['new_count'], 112)
+
+ def testHandleRequest_Truncate(self):
+ mr = testing_helpers.MakeMonorailRequest()
+ self.cache_manager.invalidate_tbl.SelectValue(
+ mr.cnxn, 'COUNT(*)').AndReturn(4012)
+ self.cache_manager.invalidate_tbl.Select(
+ mr.cnxn, ['timestep'],
+ order_by=[('timestep DESC', [])],
+ limit=cachemanager_svc.MAX_INVALIDATE_ROWS_TO_CONSIDER
+ ).AndReturn([[3012]]) # Actual would be 1000 rows ending with 3012.
+ self.cache_manager.invalidate_tbl.Delete(
+ mr.cnxn, where=[('timestep < %s', [3012])])
+ self.cache_manager.invalidate_tbl.SelectValue(
+ mr.cnxn, 'COUNT(*)').AndReturn(1000)
+ self.mox.ReplayAll()
+
+ json_data = self.servlet.HandleRequest(mr)
+ self.mox.VerifyAll()
+ self.assertEqual(json_data['old_count'], 4012)
+ self.assertEqual(json_data['new_count'], 1000)
diff --git a/services/test/caches_test.py b/services/test/caches_test.py
new file mode 100644
index 0000000..4ced369
--- /dev/null
+++ b/services/test/caches_test.py
@@ -0,0 +1,418 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the cache classes."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import fakeredis
+import unittest
+
+from google.appengine.api import memcache
+from google.appengine.ext import testbed
+
+import settings
+from services import caches
+from testing import fake
+
+
+class RamCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = 'fake connection'
+ self.cache_manager = fake.CacheManager()
+ self.ram_cache = caches.RamCache(self.cache_manager, 'issue', max_size=3)
+
+ def testInit(self):
+ self.assertEqual('issue', self.ram_cache.kind)
+ self.assertEqual(3, self.ram_cache.max_size)
+ self.assertEqual(
+ [self.ram_cache],
+ self.cache_manager.cache_registry['issue'])
+
+ def testCacheItem(self):
+ self.ram_cache.CacheItem(123, 'foo')
+ self.assertEqual('foo', self.ram_cache.cache[123])
+
+ def testCacheItem_DropsOldItems(self):
+ self.ram_cache.CacheItem(123, 'foo')
+ self.ram_cache.CacheItem(234, 'foo')
+ self.ram_cache.CacheItem(345, 'foo')
+ self.ram_cache.CacheItem(456, 'foo')
+ # The cache does not get bigger than its limit.
+ self.assertEqual(3, len(self.ram_cache.cache))
+ # An old value is dropped, not the newly added one.
+ self.assertIn(456, self.ram_cache.cache)
+
+ def testCacheAll(self):
+ self.ram_cache.CacheAll({123: 'foo'})
+ self.assertEqual('foo', self.ram_cache.cache[123])
+
+ def testCacheAll_DropsOldItems(self):
+ self.ram_cache.CacheAll({1: 'a', 2: 'b', 3: 'c'})
+ self.ram_cache.CacheAll({4: 'x', 5: 'y'})
+ # The cache does not get bigger than its limit.
+ self.assertEqual(3, len(self.ram_cache.cache))
+ # An old value is dropped, not the newly added one.
+ self.assertIn(4, self.ram_cache.cache)
+ self.assertIn(5, self.ram_cache.cache)
+ self.assertEqual('y', self.ram_cache.cache[5])
+
+ def testHasItem(self):
+ self.ram_cache.CacheItem(123, 'foo')
+ self.assertTrue(self.ram_cache.HasItem(123))
+ self.assertFalse(self.ram_cache.HasItem(999))
+
+ def testGetItem(self):
+ self.ram_cache.CacheItem(123, 'foo')
+ self.assertEqual('foo', self.ram_cache.GetItem(123))
+ self.assertEqual(None, self.ram_cache.GetItem(456))
+
+ def testGetAll(self):
+ self.ram_cache.CacheItem(123, 'foo')
+ self.ram_cache.CacheItem(124, 'bar')
+ hits, misses = self.ram_cache.GetAll([123, 124, 999])
+ self.assertEqual({123: 'foo', 124: 'bar'}, hits)
+ self.assertEqual([999], misses)
+
+ def testLocalInvalidate(self):
+ self.ram_cache.CacheAll({123: 'a', 124: 'b', 125: 'c'})
+ self.ram_cache.LocalInvalidate(124)
+ self.assertEqual(2, len(self.ram_cache.cache))
+ self.assertNotIn(124, self.ram_cache.cache)
+
+ self.ram_cache.LocalInvalidate(999)
+ self.assertEqual(2, len(self.ram_cache.cache))
+
+ def testInvalidate(self):
+ self.ram_cache.CacheAll({123: 'a', 124: 'b', 125: 'c'})
+ self.ram_cache.Invalidate(self.cnxn, 124)
+ self.assertEqual(2, len(self.ram_cache.cache))
+ self.assertNotIn(124, self.ram_cache.cache)
+ self.assertEqual(self.cache_manager.last_call,
+ ('StoreInvalidateRows', self.cnxn, 'issue', [124]))
+
+ def testInvalidateKeys(self):
+ self.ram_cache.CacheAll({123: 'a', 124: 'b', 125: 'c'})
+ self.ram_cache.InvalidateKeys(self.cnxn, [124])
+ self.assertEqual(2, len(self.ram_cache.cache))
+ self.assertNotIn(124, self.ram_cache.cache)
+ self.assertEqual(self.cache_manager.last_call,
+ ('StoreInvalidateRows', self.cnxn, 'issue', [124]))
+
+ def testLocalInvalidateAll(self):
+ self.ram_cache.CacheAll({123: 'a', 124: 'b', 125: 'c'})
+ self.ram_cache.LocalInvalidateAll()
+ self.assertEqual(0, len(self.ram_cache.cache))
+
+ def testInvalidateAll(self):
+ self.ram_cache.CacheAll({123: 'a', 124: 'b', 125: 'c'})
+ self.ram_cache.InvalidateAll(self.cnxn)
+ self.assertEqual(0, len(self.ram_cache.cache))
+ self.assertEqual(self.cache_manager.last_call,
+ ('StoreInvalidateAll', self.cnxn, 'issue'))
+
+
+class ShardedRamCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = 'fake connection'
+ self.cache_manager = fake.CacheManager()
+ self.sharded_ram_cache = caches.ShardedRamCache(
+ self.cache_manager, 'issue', max_size=3, num_shards=3)
+
+ def testLocalInvalidate(self):
+ self.sharded_ram_cache.CacheAll({
+ (123, 0): 'a',
+ (123, 1): 'aa',
+ (123, 2): 'aaa',
+ (124, 0): 'b',
+ (124, 1): 'bb',
+ (124, 2): 'bbb',
+ })
+ self.sharded_ram_cache.LocalInvalidate(124)
+ self.assertEqual(3, len(self.sharded_ram_cache.cache))
+ self.assertNotIn((124, 0), self.sharded_ram_cache.cache)
+ self.assertNotIn((124, 1), self.sharded_ram_cache.cache)
+ self.assertNotIn((124, 2), self.sharded_ram_cache.cache)
+
+ self.sharded_ram_cache.LocalInvalidate(999)
+ self.assertEqual(3, len(self.sharded_ram_cache.cache))
+
+
+class TestableTwoLevelCache(caches.AbstractTwoLevelCache):
+
+ def __init__(
+ self,
+ cache_manager,
+ kind,
+ max_size=None,
+ use_redis=False,
+ redis_client=None):
+ super(TestableTwoLevelCache, self).__init__(
+ cache_manager,
+ kind,
+ 'testable:',
+ None,
+ max_size=max_size,
+ use_redis=use_redis,
+ redis_client=redis_client)
+
+ # pylint: disable=unused-argument
+ def FetchItems(self, cnxn, keys, **kwargs):
+ """On RAM and memcache miss, hit the database."""
+ return {key: key for key in keys if key < 900}
+
+
+class AbstractTwoLevelCacheTest_Memcache(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.cnxn = 'fake connection'
+ self.cache_manager = fake.CacheManager()
+ self.testable_2lc = TestableTwoLevelCache(self.cache_manager, 'issue')
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testCacheItem(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.assertEqual(12300, self.testable_2lc.cache.cache[123])
+
+ def testHasItem(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.assertTrue(self.testable_2lc.HasItem(123))
+ self.assertFalse(self.testable_2lc.HasItem(444))
+ self.assertFalse(self.testable_2lc.HasItem(999))
+
+ def testWriteToMemcache_Normal(self):
+ retrieved_dict = {123: 12300, 124: 12400}
+ self.testable_2lc._WriteToMemcache(retrieved_dict)
+ actual_123, _ = self.testable_2lc._ReadFromMemcache([123])
+ self.assertEqual(12300, actual_123[123])
+ actual_124, _ = self.testable_2lc._ReadFromMemcache([124])
+ self.assertEqual(12400, actual_124[124])
+
+ def testWriteToMemcache_String(self):
+ retrieved_dict = {123: 'foo', 124: 'bar'}
+ self.testable_2lc._WriteToMemcache(retrieved_dict)
+ actual_123, _ = self.testable_2lc._ReadFromMemcache([123])
+ self.assertEqual('foo', actual_123[123])
+ actual_124, _ = self.testable_2lc._ReadFromMemcache([124])
+ self.assertEqual('bar', actual_124[124])
+
+ def testWriteToMemcache_ProtobufInt(self):
+ self.testable_2lc.pb_class = int
+ retrieved_dict = {123: 12300, 124: 12400}
+ self.testable_2lc._WriteToMemcache(retrieved_dict)
+ actual_123, _ = self.testable_2lc._ReadFromMemcache([123])
+ self.assertEqual(12300, actual_123[123])
+ actual_124, _ = self.testable_2lc._ReadFromMemcache([124])
+ self.assertEqual(12400, actual_124[124])
+
+ def testWriteToMemcache_List(self):
+ retrieved_dict = {123: [1, 2, 3], 124: [1, 2, 4]}
+ self.testable_2lc._WriteToMemcache(retrieved_dict)
+ actual_123, _ = self.testable_2lc._ReadFromMemcache([123])
+ self.assertEqual([1, 2, 3], actual_123[123])
+ actual_124, _ = self.testable_2lc._ReadFromMemcache([124])
+ self.assertEqual([1, 2, 4], actual_124[124])
+
+ def testWriteToMemcache_Dict(self):
+ retrieved_dict = {123: {'ham': 2, 'spam': 3}, 124: {'eggs': 2, 'bean': 4}}
+ self.testable_2lc._WriteToMemcache(retrieved_dict)
+ actual_123, _ = self.testable_2lc._ReadFromMemcache([123])
+ self.assertEqual({'ham': 2, 'spam': 3}, actual_123[123])
+ actual_124, _ = self.testable_2lc._ReadFromMemcache([124])
+ self.assertEqual({'eggs': 2, 'bean': 4}, actual_124[124])
+
+ def testWriteToMemcache_HugeValue(self):
+ """If memcache refuses to store a huge value, we don't store any."""
+ self.testable_2lc._WriteToMemcache({124: 124999}) # Gets deleted.
+ huge_str = 'huge' * 260000
+ retrieved_dict = {123: huge_str, 124: 12400}
+ self.testable_2lc._WriteToMemcache(retrieved_dict)
+ actual_123 = memcache.get('testable:123')
+ self.assertEqual(None, actual_123)
+ actual_124 = memcache.get('testable:124')
+ self.assertEqual(None, actual_124)
+
+ def testGetAll_FetchGetsIt(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ # Clear the RAM cache so that we find items in memcache.
+ self.testable_2lc.cache.LocalInvalidateAll()
+ self.testable_2lc.CacheItem(125, 12500)
+ hits, misses = self.testable_2lc.GetAll(self.cnxn, [123, 124, 333, 444])
+ self.assertEqual({123: 12300, 124: 12400, 333: 333, 444: 444}, hits)
+ self.assertEqual([], misses)
+ # The RAM cache now has items found in memcache and DB.
+ self.assertItemsEqual(
+ [123, 124, 125, 333, 444], list(self.testable_2lc.cache.cache.keys()))
+
+ def testGetAll_FetchGetsItFromDB(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ hits, misses = self.testable_2lc.GetAll(self.cnxn, [123, 124, 333, 444])
+ self.assertEqual({123: 12300, 124: 12400, 333: 333, 444: 444}, hits)
+ self.assertEqual([], misses)
+
+ def testGetAll_FetchDoesNotFindIt(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ hits, misses = self.testable_2lc.GetAll(self.cnxn, [123, 124, 999])
+ self.assertEqual({123: 12300, 124: 12400}, hits)
+ self.assertEqual([999], misses)
+
+ def testInvalidateKeys(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ self.testable_2lc.CacheItem(125, 12500)
+ self.testable_2lc.InvalidateKeys(self.cnxn, [124])
+ self.assertEqual(2, len(self.testable_2lc.cache.cache))
+ self.assertNotIn(124, self.testable_2lc.cache.cache)
+ self.assertEqual(
+ self.cache_manager.last_call,
+ ('StoreInvalidateRows', self.cnxn, 'issue', [124]))
+
+ def testGetAllAlreadyInRam(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ hits, misses = self.testable_2lc.GetAllAlreadyInRam(
+ [123, 124, 333, 444, 999])
+ self.assertEqual({123: 12300, 124: 12400}, hits)
+ self.assertEqual([333, 444, 999], misses)
+
+ def testInvalidateAllRamEntries(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ self.testable_2lc.InvalidateAllRamEntries(self.cnxn)
+ self.assertFalse(self.testable_2lc.HasItem(123))
+ self.assertFalse(self.testable_2lc.HasItem(124))
+
+
+class AbstractTwoLevelCacheTest_Redis(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = 'fake connection'
+ self.cache_manager = fake.CacheManager()
+
+ self.server = fakeredis.FakeServer()
+ self.fake_redis_client = fakeredis.FakeRedis(server=self.server)
+ self.testable_2lc = TestableTwoLevelCache(
+ self.cache_manager,
+ 'issue',
+ use_redis=True,
+ redis_client=self.fake_redis_client)
+
+ def tearDown(self):
+ self.fake_redis_client.flushall()
+
+ def testCacheItem(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.assertEqual(12300, self.testable_2lc.cache.cache[123])
+
+ def testHasItem(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.assertTrue(self.testable_2lc.HasItem(123))
+ self.assertFalse(self.testable_2lc.HasItem(444))
+ self.assertFalse(self.testable_2lc.HasItem(999))
+
+ def testWriteToRedis_Normal(self):
+ retrieved_dict = {123: 12300, 124: 12400}
+ self.testable_2lc._WriteToRedis(retrieved_dict)
+ actual_123, _ = self.testable_2lc._ReadFromRedis([123])
+ self.assertEqual(12300, actual_123[123])
+ actual_124, _ = self.testable_2lc._ReadFromRedis([124])
+ self.assertEqual(12400, actual_124[124])
+
+ def testWriteToRedis_str(self):
+ retrieved_dict = {111: 'foo', 222: 'bar'}
+ self.testable_2lc._WriteToRedis(retrieved_dict)
+ actual_111, _ = self.testable_2lc._ReadFromRedis([111])
+ self.assertEqual('foo', actual_111[111])
+ actual_222, _ = self.testable_2lc._ReadFromRedis([222])
+ self.assertEqual('bar', actual_222[222])
+
+ def testWriteToRedis_ProtobufInt(self):
+ self.testable_2lc.pb_class = int
+ retrieved_dict = {123: 12300, 124: 12400}
+ self.testable_2lc._WriteToRedis(retrieved_dict)
+ actual_123, _ = self.testable_2lc._ReadFromRedis([123])
+ self.assertEqual(12300, actual_123[123])
+ actual_124, _ = self.testable_2lc._ReadFromRedis([124])
+ self.assertEqual(12400, actual_124[124])
+
+ def testWriteToRedis_List(self):
+ retrieved_dict = {123: [1, 2, 3], 124: [1, 2, 4]}
+ self.testable_2lc._WriteToRedis(retrieved_dict)
+ actual_123, _ = self.testable_2lc._ReadFromRedis([123])
+ self.assertEqual([1, 2, 3], actual_123[123])
+ actual_124, _ = self.testable_2lc._ReadFromRedis([124])
+ self.assertEqual([1, 2, 4], actual_124[124])
+
+ def testWriteToRedis_Dict(self):
+ retrieved_dict = {123: {'ham': 2, 'spam': 3}, 124: {'eggs': 2, 'bean': 4}}
+ self.testable_2lc._WriteToRedis(retrieved_dict)
+ actual_123, _ = self.testable_2lc._ReadFromRedis([123])
+ self.assertEqual({'ham': 2, 'spam': 3}, actual_123[123])
+ actual_124, _ = self.testable_2lc._ReadFromRedis([124])
+ self.assertEqual({'eggs': 2, 'bean': 4}, actual_124[124])
+
+ def testGetAll_FetchGetsIt(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ # Clear the RAM cache so that we find items in redis.
+ self.testable_2lc.cache.LocalInvalidateAll()
+ self.testable_2lc.CacheItem(125, 12500)
+ hits, misses = self.testable_2lc.GetAll(self.cnxn, [123, 124, 333, 444])
+ self.assertEqual({123: 12300, 124: 12400, 333: 333, 444: 444}, hits)
+ self.assertEqual([], misses)
+ # The RAM cache now has items found in redis and DB.
+ self.assertItemsEqual(
+ [123, 124, 125, 333, 444], list(self.testable_2lc.cache.cache.keys()))
+
+ def testGetAll_FetchGetsItFromDB(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ hits, misses = self.testable_2lc.GetAll(self.cnxn, [123, 124, 333, 444])
+ self.assertEqual({123: 12300, 124: 12400, 333: 333, 444: 444}, hits)
+ self.assertEqual([], misses)
+
+ def testGetAll_FetchDoesNotFindIt(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ hits, misses = self.testable_2lc.GetAll(self.cnxn, [123, 124, 999])
+ self.assertEqual({123: 12300, 124: 12400}, hits)
+ self.assertEqual([999], misses)
+
+ def testInvalidateKeys(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ self.testable_2lc.CacheItem(125, 12500)
+ self.testable_2lc.InvalidateKeys(self.cnxn, [124])
+ self.assertEqual(2, len(self.testable_2lc.cache.cache))
+ self.assertNotIn(124, self.testable_2lc.cache.cache)
+ self.assertEqual(self.cache_manager.last_call,
+ ('StoreInvalidateRows', self.cnxn, 'issue', [124]))
+
+ def testGetAllAlreadyInRam(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ hits, misses = self.testable_2lc.GetAllAlreadyInRam(
+ [123, 124, 333, 444, 999])
+ self.assertEqual({123: 12300, 124: 12400}, hits)
+ self.assertEqual([333, 444, 999], misses)
+
+ def testInvalidateAllRamEntries(self):
+ self.testable_2lc.CacheItem(123, 12300)
+ self.testable_2lc.CacheItem(124, 12400)
+ self.testable_2lc.InvalidateAllRamEntries(self.cnxn)
+ self.assertFalse(self.testable_2lc.HasItem(123))
+ self.assertFalse(self.testable_2lc.HasItem(124))
diff --git a/services/test/chart_svc_test.py b/services/test/chart_svc_test.py
new file mode 100644
index 0000000..fbd87df
--- /dev/null
+++ b/services/test/chart_svc_test.py
@@ -0,0 +1,713 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for chart_svc module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import datetime
+import mox
+import re
+import settings
+import unittest
+
+from google.appengine.ext import testbed
+
+from services import chart_svc
+from services import config_svc
+from services import service_manager
+from framework import permissions
+from framework import sql
+from proto import ast_pb2
+from proto import tracker_pb2
+from search import ast2select
+from search import search_helpers
+from testing import fake
+from tracker import tracker_bizobj
+
+
+def MakeChartService(my_mox, config):
+ chart_service = chart_svc.ChartService(config)
+ for table_var in ['issuesnapshot_tbl', 'issuesnapshot2label_tbl',
+ 'issuesnapshot2component_tbl', 'issuesnapshot2cctbl', 'labeldef_tbl']:
+ setattr(chart_service, table_var, my_mox.CreateMock(sql.SQLTableManager))
+ return chart_service
+
+
+class ChartServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.services = service_manager.Services()
+ self.config_service = fake.ConfigService()
+ self.services.config = self.config_service
+ self.services.chart = MakeChartService(self.mox, self.config_service)
+ self.services.issue = fake.IssueService()
+ self.mox.StubOutWithMock(self.services.chart, '_QueryToWhere')
+ self.mox.StubOutWithMock(search_helpers, 'GetPersonalAtRiskLabelIDs')
+ self.mox.StubOutWithMock(settings, 'num_logical_shards')
+ settings.num_logical_shards = 1
+ self.mox.StubOutWithMock(self.services.chart, '_currentTime')
+
+ self.defaultLeftJoins = [
+ ('Issue ON IssueSnapshot.issue_id = Issue.id', []),
+ ('Issue2Label AS Forbidden_label'
+ ' ON Issue.id = Forbidden_label.issue_id'
+ ' AND Forbidden_label.label_id IN (%s,%s)', [91, 81]),
+ ('Issue2Cc AS I2cc'
+ ' ON Issue.id = I2cc.issue_id'
+ ' AND I2cc.cc_id IN (%s,%s)', [10, 20]),
+ ]
+ self.defaultWheres = [
+ ('IssueSnapshot.period_start <= %s', [1514764800]),
+ ('IssueSnapshot.period_end > %s', [1514764800]),
+ ('Issue.is_spam = %s', [False]),
+ ('Issue.deleted = %s', [False]),
+ ('IssueSnapshot.project_id IN (%s)', [789]),
+ ('(Issue.reporter_id IN (%s,%s)'
+ ' OR Issue.owner_id IN (%s,%s)'
+ ' OR I2cc.cc_id IS NOT NULL'
+ ' OR Forbidden_label.label_id IS NULL)',
+ [10, 20, 10, 20]
+ ),
+ ]
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def _verifySQL(self, cols, left_joins, where, group_by=None):
+ for col in cols:
+ self.assertTrue(sql._IsValidColumnName(col))
+ for join_str, _ in left_joins:
+ self.assertTrue(sql._IsValidJoin(join_str))
+ for where_str, _ in where:
+ self.assertTrue(sql._IsValidWhereCond(where_str))
+ if group_by:
+ for groupby_str in group_by:
+ self.assertTrue(sql._IsValidGroupByTerm(groupby_str))
+
+ def testQueryIssueSnapshots_InvalidGroupBy(self):
+ """Make sure the `group_by` argument is checked."""
+ project = fake.Project(project_id=789)
+ perms = permissions.USER_PERMISSIONSET
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([91, 81])
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+
+ self.mox.ReplayAll()
+ with self.assertRaises(ValueError):
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, group_by='rutabaga', label_prefix='rutabaga')
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_NoLabelPrefix(self):
+ """Make sure the `label_prefix` argument is required."""
+ project = fake.Project(project_id=789)
+ perms = permissions.USER_PERMISSIONSET
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([91, 81])
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+
+ self.mox.ReplayAll()
+ with self.assertRaises(ValueError):
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, group_by='label')
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_Impossible(self):
+ """We give an error message when a query could never have results."""
+ project = fake.Project(project_id=789)
+ perms = permissions.USER_PERMISSIONSET
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndRaise(ast2select.NoPossibleResults())
+ self.mox.ReplayAll()
+ total, errors, limit_reached = self.services.chart.QueryIssueSnapshots(
+ self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, query='prefix=')
+ self.mox.VerifyAll()
+ self.assertEqual({}, total)
+ self.assertEqual(['Invalid query.'], errors)
+ self.assertFalse(limit_reached)
+
+ def testQueryIssueSnapshots_Components(self):
+ """Test a burndown query from a regular user grouping by component."""
+ project = fake.Project(project_id=789)
+ perms = permissions.PermissionSet(['BarPerm'])
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([91, 81])
+
+ cols = [
+ 'Comp.path',
+ 'COUNT(IssueSnapshot.issue_id)'
+ ]
+ left_joins = self.defaultLeftJoins + [
+ ('IssueSnapshot2Component AS Is2c'
+ ' ON Is2c.issuesnapshot_id = IssueSnapshot.id', []),
+ ('ComponentDef AS Comp ON Comp.id = Is2c.component_id', [])
+ ]
+ where = self.defaultWheres
+ group_by = ['Comp.path']
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where, group_by)
+
+ self.mox.ReplayAll()
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, group_by='component')
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_Labels(self):
+ """Test a burndown query from a regular user grouping by label."""
+ project = fake.Project(project_id=789)
+ perms = permissions.PermissionSet(['BarPerm'])
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([91, 81])
+
+ cols = [
+ 'Lab.label',
+ 'COUNT(IssueSnapshot.issue_id)',
+ ]
+ left_joins = self.defaultLeftJoins + [
+ ('IssueSnapshot2Label AS Is2l'
+ ' ON Is2l.issuesnapshot_id = IssueSnapshot.id', []),
+ ('LabelDef AS Lab ON Lab.id = Is2l.label_id', [])
+ ]
+ where = self.defaultWheres + [
+ ('LOWER(Lab.label) LIKE %s', ['foo-%']),
+ ]
+ group_by = ['Lab.label']
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where, group_by)
+
+ self.mox.ReplayAll()
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, group_by='label', label_prefix='Foo')
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_Open(self):
+ """Test a burndown query from a regular user grouping
+ by status is open or closed."""
+ project = fake.Project(project_id=789)
+ perms = permissions.PermissionSet(['BarPerm'])
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([91, 81])
+
+ cols = [
+ 'IssueSnapshot.is_open',
+ 'COUNT(IssueSnapshot.issue_id) AS issue_count',
+ ]
+
+ left_joins = self.defaultLeftJoins
+ where = self.defaultWheres
+ group_by = ['IssueSnapshot.is_open']
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where, group_by)
+
+ self.mox.ReplayAll()
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, group_by='open')
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_Status(self):
+ """Test a burndown query from a regular user grouping by open status."""
+ project = fake.Project(project_id=789)
+ perms = permissions.PermissionSet(['BarPerm'])
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([91, 81])
+
+ cols = [
+ 'Stats.status',
+ 'COUNT(IssueSnapshot.issue_id)',
+ ]
+ left_joins = self.defaultLeftJoins + [
+ ('StatusDef AS Stats ON ' \
+ 'Stats.id = IssueSnapshot.status_id', [])
+ ]
+ where = self.defaultWheres
+ group_by = ['Stats.status']
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where, group_by)
+
+ self.mox.ReplayAll()
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, group_by='status')
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_Hotlist(self):
+ """Test a QueryIssueSnapshots when a hotlist is passed."""
+ hotlist = fake.Hotlist('hotlist_rutabaga', 19191)
+ project = fake.Project(project_id=789)
+ perms = permissions.PermissionSet(['BarPerm'])
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([91, 81])
+
+ cols = [
+ 'IssueSnapshot.issue_id',
+ ]
+ left_joins = self.defaultLeftJoins + [
+ (('IssueSnapshot2Hotlist AS Is2h'
+ ' ON Is2h.issuesnapshot_id = IssueSnapshot.id'
+ ' AND Is2h.hotlist_id = %s'), [hotlist.hotlist_id]),
+ ]
+ where = self.defaultWheres + [
+ ('Is2h.hotlist_id = %s', [hotlist.hotlist_id]),
+ ]
+ group_by = []
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where, group_by)
+
+ self.mox.ReplayAll()
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, hotlist=hotlist)
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_Owner(self):
+ """Test a burndown query from a regular user grouping by owner."""
+ project = fake.Project(project_id=789)
+ perms = permissions.PermissionSet(['BarPerm'])
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([91, 81])
+ cols = [
+ 'IssueSnapshot.owner_id',
+ 'COUNT(IssueSnapshot.issue_id)',
+ ]
+ left_joins = self.defaultLeftJoins
+ where = self.defaultWheres
+ group_by = ['IssueSnapshot.owner_id']
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where, group_by)
+
+ self.mox.ReplayAll()
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, group_by='owner')
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_NoGroupBy(self):
+ """Test a burndown query from a regular user with no grouping."""
+ project = fake.Project(project_id=789)
+ perms = permissions.PermissionSet(['BarPerm'])
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([91, 81])
+
+ cols = [
+ 'IssueSnapshot.issue_id',
+ ]
+ left_joins = self.defaultLeftJoins
+ where = self.defaultWheres
+ group_by = None
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where)
+
+ self.mox.ReplayAll()
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, group_by=None, label_prefix='Foo')
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_LabelsNotLoggedInUser(self):
+ """Tests fetching burndown snapshot counts grouped by labels
+ for a user who is not logged in. Also no restricted labels are
+ present.
+ """
+ project = fake.Project(project_id=789)
+ perms = permissions.READ_ONLY_PERMISSIONSET
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, set([]), project,
+ perms).AndReturn([91, 81])
+
+ cols = [
+ 'Lab.label',
+ 'COUNT(IssueSnapshot.issue_id)',
+ ]
+ left_joins = [
+ ('Issue ON IssueSnapshot.issue_id = Issue.id', []),
+ ('Issue2Label AS Forbidden_label'
+ ' ON Issue.id = Forbidden_label.issue_id'
+ ' AND Forbidden_label.label_id IN (%s,%s)', [91, 81]),
+ ('IssueSnapshot2Label AS Is2l'
+ ' ON Is2l.issuesnapshot_id = IssueSnapshot.id', []),
+ ('LabelDef AS Lab ON Lab.id = Is2l.label_id', []),
+ ]
+ where = [
+ ('IssueSnapshot.period_start <= %s', [1514764800]),
+ ('IssueSnapshot.period_end > %s', [1514764800]),
+ ('Issue.is_spam = %s', [False]),
+ ('Issue.deleted = %s', [False]),
+ ('IssueSnapshot.project_id IN (%s)', [789]),
+ ('Forbidden_label.label_id IS NULL', []),
+ ('LOWER(Lab.label) LIKE %s', ['foo-%']),
+ ]
+ group_by = ['Lab.label']
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where, group_by)
+
+ self.mox.ReplayAll()
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=set([]), project=project,
+ perms=perms, group_by='label', label_prefix='Foo')
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_NoRestrictedLabels(self):
+ """Test a label burndown query when the project has no restricted labels."""
+ project = fake.Project(project_id=789)
+ perms = permissions.USER_PERMISSIONSET
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project,
+ perms).AndReturn([])
+
+ cols = [
+ 'Lab.label',
+ 'COUNT(IssueSnapshot.issue_id)',
+ ]
+ left_joins = [
+ ('Issue ON IssueSnapshot.issue_id = Issue.id', []),
+ ('Issue2Cc AS I2cc'
+ ' ON Issue.id = I2cc.issue_id'
+ ' AND I2cc.cc_id IN (%s,%s)', [10, 20]),
+ ('IssueSnapshot2Label AS Is2l'
+ ' ON Is2l.issuesnapshot_id = IssueSnapshot.id', []),
+ ('LabelDef AS Lab ON Lab.id = Is2l.label_id', []),
+ ]
+ where = [
+ ('IssueSnapshot.period_start <= %s', [1514764800]),
+ ('IssueSnapshot.period_end > %s', [1514764800]),
+ ('Issue.is_spam = %s', [False]),
+ ('Issue.deleted = %s', [False]),
+ ('IssueSnapshot.project_id IN (%s)', [789]),
+ ('(Issue.reporter_id IN (%s,%s)'
+ ' OR Issue.owner_id IN (%s,%s)'
+ ' OR I2cc.cc_id IS NOT NULL)',
+ [10, 20, 10, 20]
+ ),
+ ('LOWER(Lab.label) LIKE %s', ['foo-%']),
+ ]
+ group_by = ['Lab.label']
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn(([], [], []))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where, group_by)
+
+ self.mox.ReplayAll()
+ self.services.chart.QueryIssueSnapshots(self.cnxn, self.services,
+ unixtime=1514764800, effective_ids=[10, 20], project=project,
+ perms=perms, group_by='label', label_prefix='Foo')
+ self.mox.VerifyAll()
+
+ def SetUpStoreIssueSnapshots(self, replace_now=None,
+ project_id=789, owner_id=111,
+ component_ids=None, cc_rows=None):
+ """Set up all calls to mocks that StoreIssueSnapshots will call."""
+ now = self.services.chart._currentTime().AndReturn(replace_now or 12345678)
+
+ self.services.chart.issuesnapshot_tbl.Update(self.cnxn,
+ delta={'period_end': now},
+ where=[('IssueSnapshot.issue_id = %s', [78901]),
+ ('IssueSnapshot.period_end = %s',
+ [settings.maximum_snapshot_period_end])],
+ commit=False)
+
+ # Shard is 0 because len(shards) = 1 and 1 % 1 = 0.
+ shard = 0
+ self.services.chart.issuesnapshot_tbl.InsertRows(self.cnxn,
+ chart_svc.ISSUESNAPSHOT_COLS[1:],
+ [(78901, shard, project_id, 1, 111, owner_id, 1,
+ now, 4294967295, True)],
+ replace=True, commit=False, return_generated_ids=True).AndReturn([5678])
+
+ label_rows = [(5678, 1)]
+
+ self.services.chart.issuesnapshot2label_tbl.InsertRows(self.cnxn,
+ chart_svc.ISSUESNAPSHOT2LABEL_COLS,
+ label_rows,
+ replace=True, commit=False)
+
+ self.services.chart.issuesnapshot2cc_tbl.InsertRows(
+ self.cnxn, chart_svc.ISSUESNAPSHOT2CC_COLS,
+ [(5678, row[1]) for row in cc_rows],
+ replace=True, commit=False)
+
+ component_rows = [(5678, component_id) for component_id in component_ids]
+ self.services.chart.issuesnapshot2component_tbl.InsertRows(
+ self.cnxn, chart_svc.ISSUESNAPSHOT2COMPONENT_COLS,
+ component_rows,
+ replace=True, commit=False)
+
+ # Spacing of string must match.
+ self.cnxn.Execute((
+ '\n INSERT INTO IssueSnapshot2Hotlist '
+ '(issuesnapshot_id, hotlist_id)\n '
+ 'SELECT %s, hotlist_id FROM Hotlist2Issue '
+ 'WHERE issue_id = %s\n '
+ ), [5678, 78901])
+
+ def testStoreIssueSnapshots_NoChange(self):
+ """Test that StoreIssueSnapshots inserts and updates previous
+ issue snapshots correctly."""
+
+ now_1 = 1517599888
+ now_2 = 1517599999
+
+ issue = fake.MakeTestIssue(issue_id=78901,
+ project_id=789, local_id=1, reporter_id=111, owner_id=111,
+ summary='sum', status='Status1',
+ labels=['Type-Defect'],
+ component_ids=[11], assume_stale=False,
+ opened_timestamp=123456789, modified_timestamp=123456789,
+ star_count=12, cc_ids=[222, 333], derived_cc_ids=[888])
+
+ # Snapshot #1
+ cc_rows = [(5678, 222), (5678, 333), (5678, 888)]
+ self.SetUpStoreIssueSnapshots(replace_now=now_1,
+ component_ids=[11], cc_rows=cc_rows)
+
+ # Snapshot #2
+ self.SetUpStoreIssueSnapshots(replace_now=now_2,
+ component_ids=[11], cc_rows=cc_rows)
+
+ self.mox.ReplayAll()
+ self.services.chart.StoreIssueSnapshots(self.cnxn, [issue], commit=False)
+ self.services.chart.StoreIssueSnapshots(self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testStoreIssueSnapshots_AllFieldsChanged(self):
+ """Test that StoreIssueSnapshots inserts and updates previous
+ issue snapshots correctly. This tests that all relations (labels,
+ CCs, and components) are updated."""
+
+ now_1 = 1517599888
+ now_2 = 1517599999
+
+ issue_1 = fake.MakeTestIssue(issue_id=78901,
+ project_id=789, local_id=1, reporter_id=111, owner_id=111,
+ summary='sum', status='Status1',
+ labels=['Type-Defect'],
+ component_ids=[11, 12], assume_stale=False,
+ opened_timestamp=123456789, modified_timestamp=123456789,
+ star_count=12, cc_ids=[222, 333], derived_cc_ids=[888])
+
+ issue_2 = fake.MakeTestIssue(issue_id=78901,
+ project_id=123, local_id=1, reporter_id=111, owner_id=222,
+ summary='sum', status='Status2',
+ labels=['Type-Enhancement'],
+ component_ids=[13], assume_stale=False,
+ opened_timestamp=123456789, modified_timestamp=123456789,
+ star_count=12, cc_ids=[222, 444], derived_cc_ids=[888, 999])
+
+ # Snapshot #1
+ cc_rows_1 = [(5678, 222), (5678, 333), (5678, 888)]
+ self.SetUpStoreIssueSnapshots(replace_now=now_1,
+ component_ids=[11, 12], cc_rows=cc_rows_1)
+
+ # Snapshot #2
+ cc_rows_2 = [(5678, 222), (5678, 444), (5678, 888), (5678, 999)]
+ self.SetUpStoreIssueSnapshots(replace_now=now_2,
+ project_id=123, owner_id=222, component_ids=[13],
+ cc_rows=cc_rows_2)
+
+ self.mox.ReplayAll()
+ self.services.chart.StoreIssueSnapshots(self.cnxn, [issue_1], commit=False)
+ self.services.chart.StoreIssueSnapshots(self.cnxn, [issue_2], commit=False)
+ self.mox.VerifyAll()
+
+ def testQueryIssueSnapshots_WithQueryStringAndCannedQuery(self):
+ """Test the query param is parsed and used."""
+ project = fake.Project(project_id=789)
+ perms = permissions.USER_PERMISSIONSET
+ search_helpers.GetPersonalAtRiskLabelIDs(self.cnxn, None,
+ self.config_service, [10, 20], project, perms).AndReturn([])
+
+ cols = [
+ 'Lab.label',
+ 'COUNT(IssueSnapshot.issue_id)',
+ ]
+ left_joins = [
+ ('Issue ON IssueSnapshot.issue_id = Issue.id', []),
+ ('Issue2Cc AS I2cc'
+ ' ON Issue.id = I2cc.issue_id'
+ ' AND I2cc.cc_id IN (%s,%s)', [10, 20]),
+ ('IssueSnapshot2Label AS Is2l'
+ ' ON Is2l.issuesnapshot_id = IssueSnapshot.id', []),
+ ('LabelDef AS Lab ON Lab.id = Is2l.label_id', []),
+ ('IssueSnapshot2Label AS Cond0 '
+ 'ON IssueSnapshot.id = Cond0.issuesnapshot_id '
+ 'AND Cond0.label_id = %s', [15]),
+ ]
+ where = [
+ ('IssueSnapshot.period_start <= %s', [1514764800]),
+ ('IssueSnapshot.period_end > %s', [1514764800]),
+ ('Issue.is_spam = %s', [False]),
+ ('Issue.deleted = %s', [False]),
+ ('IssueSnapshot.project_id IN (%s)', [789]),
+ ('(Issue.reporter_id IN (%s,%s)'
+ ' OR Issue.owner_id IN (%s,%s)'
+ ' OR I2cc.cc_id IS NOT NULL)',
+ [10, 20, 10, 20]
+ ),
+ ('LOWER(Lab.label) LIKE %s', ['foo-%']),
+ ('Cond0.label_id IS NULL', []),
+ ('IssueSnapshot.is_open = %s', [True]),
+ ]
+ group_by = ['Lab.label']
+
+ query_left_joins = [(
+ 'IssueSnapshot2Label AS Cond0 '
+ 'ON IssueSnapshot.id = Cond0.issuesnapshot_id '
+ 'AND Cond0.label_id = %s', [15])]
+ query_where = [
+ ('Cond0.label_id IS NULL', []),
+ ('IssueSnapshot.is_open = %s', [True]),
+ ]
+
+ unsupported_field_names = ['ownerbouncing']
+
+ unsupported_conds = [
+ ast_pb2.Condition(op=ast_pb2.QueryOp(1), field_defs=[
+ tracker_pb2.FieldDef(field_name='ownerbouncing',
+ field_type=tracker_pb2.FieldTypes.BOOL_TYPE),
+ ])
+ ]
+
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols, where,
+ left_joins, group_by, shard_id=0)
+
+ self.services.chart._QueryToWhere(mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg(), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg()).AndReturn((query_left_joins, query_where,
+ unsupported_conds))
+ self.cnxn.Execute(stmt, stmt_args, shard_id=0).AndReturn([])
+
+ self._verifySQL(cols, left_joins, where, group_by)
+
+ self.mox.ReplayAll()
+ _, unsupported, limit_reached = self.services.chart.QueryIssueSnapshots(
+ self.cnxn, self.services, unixtime=1514764800,
+ effective_ids=[10, 20], project=project, perms=perms,
+ group_by='label', label_prefix='Foo',
+ query='-label:Performance%20is:ownerbouncing', canned_query='is:open')
+ self.mox.VerifyAll()
+
+ self.assertEqual(unsupported_field_names, unsupported)
+ self.assertFalse(limit_reached)
+
+ def testQueryToWhere_AddsShardId(self):
+ """Test that shards are handled correctly."""
+ cols = []
+ where = []
+ joins = []
+ group_by = []
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols=cols,
+ where=where, joins=joins, group_by=group_by, shard_id=9)
+
+ self.assertEqual(stmt, ('SELECT COUNT(results.issue_id) '
+ 'FROM (SELECT DISTINCT FROM IssueSnapshot\n'
+ 'WHERE IssueSnapshot.shard = %s\nLIMIT 10000) AS results'))
+ self.assertEqual(stmt_args, [9])
+
+ # Test that shard_id is still correct on second invocation.
+ stmt, stmt_args = self.services.chart._BuildSnapshotQuery(cols=cols,
+ where=where, joins=joins, group_by=group_by, shard_id=8)
+
+ self.assertEqual(stmt, ('SELECT COUNT(results.issue_id) '
+ 'FROM (SELECT DISTINCT FROM IssueSnapshot\n'
+ 'WHERE IssueSnapshot.shard = %s\nLIMIT 10000) AS results'))
+ self.assertEqual(stmt_args, [8])
+
+ # Test no parameters were modified.
+ self.assertEqual(cols, [])
+ self.assertEqual(where, [])
+ self.assertEqual(joins, [])
+ self.assertEqual(group_by, [])
diff --git a/services/test/client_config_svc_test.py b/services/test/client_config_svc_test.py
new file mode 100644
index 0000000..5e9b87a
--- /dev/null
+++ b/services/test/client_config_svc_test.py
@@ -0,0 +1,133 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the client config service."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import base64
+import unittest
+
+from services import client_config_svc
+
+
+class LoadApiClientConfigsTest(unittest.TestCase):
+
+ class FakeResponse(object):
+ def __init__(self, content):
+ self.content = content
+
+ def setUp(self):
+ self.handler = client_config_svc.LoadApiClientConfigs()
+
+ def testProcessResponse_InvalidJSON(self):
+ r = self.FakeResponse('}{')
+ with self.assertRaises(ValueError):
+ self.handler._process_response(r)
+
+ def testProcessResponse_NoContent(self):
+ r = self.FakeResponse('{"wrong-key": "some-value"}')
+ with self.assertRaises(KeyError):
+ self.handler._process_response(r)
+
+ def testProcessResponse_NotB64(self):
+ # 'asd' is not a valid base64-encoded string.
+ r = self.FakeResponse('{"content": "asd"}')
+ with self.assertRaises(TypeError):
+ self.handler._process_response(r)
+
+ def testProcessResponse_NotProto(self):
+ # 'asdf' is a valid base64-encoded string.
+ r = self.FakeResponse('{"content": "asdf"}')
+ with self.assertRaises(Exception):
+ self.handler._process_response(r)
+
+ def testProcessResponse_Success(self):
+ with open(client_config_svc.CONFIG_FILE_PATH) as f:
+ r = self.FakeResponse('{"content": "%s"}' % base64.b64encode(f.read()))
+ c = self.handler._process_response(r)
+ assert '123456789.apps.googleusercontent.com' in c
+
+
+class ClientConfigServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.client_config_svc = client_config_svc.GetClientConfigSvc()
+ self.client_email = '123456789@developer.gserviceaccount.com'
+ self.client_id = '123456789.apps.googleusercontent.com'
+ self.allowed_origins = {'chicken.test', 'cow.test', 'goat.test'}
+
+ def testGetDisplayNames(self):
+ display_names_map = self.client_config_svc.GetDisplayNames()
+ self.assertIn(self.client_email, display_names_map)
+ self.assertEqual('johndoe@example.com',
+ display_names_map[self.client_email])
+
+ def testGetQPMDict(self):
+ qpm_map = self.client_config_svc.GetQPM()
+ self.assertIn(self.client_email, qpm_map)
+ self.assertEqual(1, qpm_map[self.client_email])
+ self.assertNotIn('bugdroid1@chromium.org', qpm_map)
+
+ def testGetClientIDEmails(self):
+ auth_client_ids, auth_emails = self.client_config_svc.GetClientIDEmails()
+ self.assertIn(self.client_id, auth_client_ids)
+ self.assertIn(self.client_email, auth_emails)
+
+ def testGetAllowedOriginsSet(self):
+ origins = self.client_config_svc.GetAllowedOriginsSet()
+ self.assertEqual(self.allowed_origins, origins)
+
+ def testForceLoad(self):
+ EXPIRES_IN = client_config_svc.ClientConfigService.EXPIRES_IN
+ NOW = 1493007338
+ # First time it will always read the config
+ self.client_config_svc.load_time = NOW
+ self.client_config_svc.GetConfigs(use_cache=True)
+ self.assertNotEqual(NOW, self.client_config_svc.load_time)
+
+ # use_cache is false and it will read the config
+ self.client_config_svc.load_time = NOW
+ self.client_config_svc.GetConfigs(
+ use_cache=False, cur_time=NOW + 1)
+ self.assertNotEqual(NOW, self.client_config_svc.load_time)
+
+ # Cache expires after some time and it will read the config
+ self.client_config_svc.load_time = NOW
+ self.client_config_svc.GetConfigs(
+ use_cache=True, cur_time=NOW + EXPIRES_IN + 1)
+ self.assertNotEqual(NOW, self.client_config_svc.load_time)
+
+ # otherwise it should just use the cache
+ self.client_config_svc.load_time = NOW
+ self.client_config_svc.GetConfigs(
+ use_cache=True, cur_time=NOW + EXPIRES_IN - 1)
+ self.assertEqual(NOW, self.client_config_svc.load_time)
+
+
+class ClientConfigServiceFunctionsTest(unittest.TestCase):
+
+ def setUp(self):
+ self.client_email = '123456789@developer.gserviceaccount.com'
+ self.allowed_origins = {'chicken.test', 'cow.test', 'goat.test'}
+
+ def testGetServiceAccountMap(self):
+ service_account_map = client_config_svc.GetServiceAccountMap()
+ self.assertIn(self.client_email, service_account_map)
+ self.assertEqual(
+ 'johndoe@example.com',
+ service_account_map[self.client_email])
+ self.assertNotIn('bugdroid1@chromium.org', service_account_map)
+
+ def testGetQPMDict(self):
+ qpm_map = client_config_svc.GetQPMDict()
+ self.assertIn(self.client_email, qpm_map)
+ self.assertEqual(1, qpm_map[self.client_email])
+ self.assertNotIn('bugdroid1@chromium.org', qpm_map)
+
+ def testGetAllowedOriginsSet(self):
+ allowed_origins = client_config_svc.GetAllowedOriginsSet()
+ self.assertEqual(self.allowed_origins, allowed_origins)
diff --git a/services/test/config_svc_test.py b/services/test/config_svc_test.py
new file mode 100644
index 0000000..6d1d941
--- /dev/null
+++ b/services/test/config_svc_test.py
@@ -0,0 +1,1143 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for config_svc module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import re
+import unittest
+import logging
+import mock
+
+import mox
+
+from google.appengine.api import memcache
+from google.appengine.ext import testbed
+
+from framework import exceptions
+from framework import framework_constants
+from framework import sql
+from proto import tracker_pb2
+from services import config_svc
+from services import template_svc
+from testing import fake
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+LABEL_ROW_SHARDS = config_svc.LABEL_ROW_SHARDS
+
+
+def MakeConfigService(cache_manager, my_mox):
+ config_service = config_svc.ConfigService(cache_manager)
+ for table_var in ['projectissueconfig_tbl', 'statusdef_tbl', 'labeldef_tbl',
+ 'fielddef_tbl', 'fielddef2admin_tbl', 'fielddef2editor_tbl',
+ 'componentdef_tbl', 'component2admin_tbl',
+ 'component2cc_tbl', 'component2label_tbl',
+ 'approvaldef2approver_tbl', 'approvaldef2survey_tbl']:
+ setattr(config_service, table_var, my_mox.CreateMock(sql.SQLTableManager))
+
+ return config_service
+
+
+class LabelRowTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cnxn = 'fake connection'
+ self.cache_manager = fake.CacheManager()
+ self.config_service = MakeConfigService(self.cache_manager, self.mox)
+ self.label_row_2lc = self.config_service.label_row_2lc
+
+ self.rows = [(1, 789, 1, 'A', 'doc', False),
+ (2, 789, 2, 'B', 'doc', False),
+ (3, 678, 1, 'C', 'doc', True),
+ (4, 678, None, 'D', 'doc', False)]
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testDeserializeLabelRows_Empty(self):
+ label_row_dict = self.label_row_2lc._DeserializeLabelRows([])
+ self.assertEqual({}, label_row_dict)
+
+ def testDeserializeLabelRows_Normal(self):
+ label_rows_dict = self.label_row_2lc._DeserializeLabelRows(self.rows)
+ expected = {
+ (789, 1): [(1, 789, 1, 'A', 'doc', False)],
+ (789, 2): [(2, 789, 2, 'B', 'doc', False)],
+ (678, 3): [(3, 678, 1, 'C', 'doc', True)],
+ (678, 4): [(4, 678, None, 'D', 'doc', False)],
+ }
+ self.assertEqual(expected, label_rows_dict)
+
+ def SetUpFetchItems(self, keys, rows):
+ for (project_id, shard_id) in keys:
+ sharded_rows = [row for row in rows
+ if row[0] % LABEL_ROW_SHARDS == shard_id]
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=config_svc.LABELDEF_COLS, project_id=project_id,
+ where=[('id %% %s = %s', [LABEL_ROW_SHARDS, shard_id])]).AndReturn(
+ sharded_rows)
+
+ def testFetchItems(self):
+ keys = [(567, 0), (678, 0), (789, 0),
+ (567, 1), (678, 1), (789, 1),
+ (567, 2), (678, 2), (789, 2),
+ (567, 3), (678, 3), (789, 3),
+ (567, 4), (678, 4), (789, 4),
+ ]
+ self.SetUpFetchItems(keys, self.rows)
+ self.mox.ReplayAll()
+ label_rows_dict = self.label_row_2lc.FetchItems(self.cnxn, keys)
+ self.mox.VerifyAll()
+ expected = {
+ (567, 0): [],
+ (678, 0): [],
+ (789, 0): [],
+ (567, 1): [],
+ (678, 1): [],
+ (789, 1): [(1, 789, 1, 'A', 'doc', False)],
+ (567, 2): [],
+ (678, 2): [],
+ (789, 2): [(2, 789, 2, 'B', 'doc', False)],
+ (567, 3): [],
+ (678, 3): [(3, 678, 1, 'C', 'doc', True)],
+ (789, 3): [],
+ (567, 4): [],
+ (678, 4): [(4, 678, None, 'D', 'doc', False)],
+ (789, 4): [],
+ }
+ self.assertEqual(expected, label_rows_dict)
+
+
+class StatusRowTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cnxn = 'fake connection'
+ self.cache_manager = fake.CacheManager()
+ self.config_service = MakeConfigService(self.cache_manager, self.mox)
+ self.status_row_2lc = self.config_service.status_row_2lc
+
+ self.rows = [(1, 789, 1, 'A', True, 'doc', False),
+ (2, 789, 2, 'B', False, 'doc', False),
+ (3, 678, 1, 'C', True, 'doc', True),
+ (4, 678, None, 'D', True, 'doc', False)]
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testDeserializeStatusRows_Empty(self):
+ status_row_dict = self.status_row_2lc._DeserializeStatusRows([])
+ self.assertEqual({}, status_row_dict)
+
+ def testDeserializeStatusRows_Normal(self):
+ status_rows_dict = self.status_row_2lc._DeserializeStatusRows(self.rows)
+ expected = {
+ 678: [(3, 678, 1, 'C', True, 'doc', True),
+ (4, 678, None, 'D', True, 'doc', False)],
+ 789: [(1, 789, 1, 'A', True, 'doc', False),
+ (2, 789, 2, 'B', False, 'doc', False)],
+ }
+ self.assertEqual(expected, status_rows_dict)
+
+ def SetUpFetchItems(self, keys, rows):
+ self.config_service.statusdef_tbl.Select(
+ self.cnxn, cols=config_svc.STATUSDEF_COLS, project_id=keys,
+ order_by=[('rank DESC', []), ('status DESC', [])]).AndReturn(
+ rows)
+
+ def testFetchItems(self):
+ keys = [567, 678, 789]
+ self.SetUpFetchItems(keys, self.rows)
+ self.mox.ReplayAll()
+ status_rows_dict = self.status_row_2lc.FetchItems(self.cnxn, keys)
+ self.mox.VerifyAll()
+ expected = {
+ 567: [],
+ 678: [(3, 678, 1, 'C', True, 'doc', True),
+ (4, 678, None, 'D', True, 'doc', False)],
+ 789: [(1, 789, 1, 'A', True, 'doc', False),
+ (2, 789, 2, 'B', False, 'doc', False)],
+ }
+ self.assertEqual(expected, status_rows_dict)
+
+
+class ConfigRowTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cnxn = 'fake connection'
+ self.cache_manager = fake.CacheManager()
+ self.config_service = MakeConfigService(self.cache_manager, self.mox)
+ self.config_2lc = self.config_service.config_2lc
+
+ self.config_rows = [
+ (789, 'Duplicate', 'Pri Type', 1, 2,
+ 'Type Pri Summary', '-Pri', 'Mstone', 'Owner',
+ '', None)]
+ self.statusdef_rows = [(1, 789, 1, 'New', True, 'doc', False),
+ (2, 789, 2, 'Fixed', False, 'doc', False)]
+ self.labeldef_rows = [(1, 789, 1, 'Security', 'doc', False),
+ (2, 789, 2, 'UX', 'doc', False)]
+ self.fielddef_rows = [
+ (
+ 1, 789, None, 'Field', 'INT_TYPE', 'Defect', '', False, False,
+ False, 1, 99, None, '', '', None, 'NEVER', 'no_action', 'doc',
+ False, None, False, False)
+ ]
+ self.approvaldef2approver_rows = [(2, 101, 789), (2, 102, 789)]
+ self.approvaldef2survey_rows = [(2, 'Q1\nQ2\nQ3', 789)]
+ self.fielddef2admin_rows = [(1, 111), (1, 222)]
+ self.fielddef2editor_rows = [(1, 111), (1, 222), (1, 333)]
+ self.componentdef_rows = []
+ self.component2admin_rows = []
+ self.component2cc_rows = []
+ self.component2label_rows = []
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testDeserializeIssueConfigs_Empty(self):
+ config_dict = self.config_2lc._DeserializeIssueConfigs(
+ [], [], [], [], [], [], [], [], [], [], [], [])
+ self.assertEqual({}, config_dict)
+
+ def testDeserializeIssueConfigs_Normal(self):
+ config_dict = self.config_2lc._DeserializeIssueConfigs(
+ self.config_rows, self.statusdef_rows, self.labeldef_rows,
+ self.fielddef_rows, self.fielddef2admin_rows, self.fielddef2editor_rows,
+ self.componentdef_rows, self.component2admin_rows,
+ self.component2cc_rows, self.component2label_rows,
+ self.approvaldef2approver_rows, self.approvaldef2survey_rows)
+ self.assertItemsEqual([789], list(config_dict.keys()))
+ config = config_dict[789]
+ self.assertEqual(789, config.project_id)
+ self.assertEqual(['Duplicate'], config.statuses_offer_merge)
+ self.assertEqual(len(self.labeldef_rows), len(config.well_known_labels))
+ self.assertEqual(len(self.statusdef_rows), len(config.well_known_statuses))
+ self.assertEqual(len(self.fielddef_rows), len(config.field_defs))
+ self.assertEqual(len(self.componentdef_rows), len(config.component_defs))
+ self.assertEqual(
+ len(self.fielddef2admin_rows), len(config.field_defs[0].admin_ids))
+ self.assertEqual(
+ len(self.fielddef2editor_rows), len(config.field_defs[0].editor_ids))
+ self.assertEqual(len(self.approvaldef2approver_rows),
+ len(config.approval_defs[0].approver_ids))
+ self.assertEqual(config.approval_defs[0].survey, 'Q1\nQ2\nQ3')
+
+ def SetUpFetchConfigs(self, project_ids):
+ self.config_service.projectissueconfig_tbl.Select(
+ self.cnxn, cols=config_svc.PROJECTISSUECONFIG_COLS,
+ project_id=project_ids).AndReturn(self.config_rows)
+
+ self.config_service.statusdef_tbl.Select(
+ self.cnxn, cols=config_svc.STATUSDEF_COLS, project_id=project_ids,
+ where=[('rank IS NOT NULL', [])], order_by=[('rank', [])]).AndReturn(
+ self.statusdef_rows)
+
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=config_svc.LABELDEF_COLS, project_id=project_ids,
+ where=[('rank IS NOT NULL', [])], order_by=[('rank', [])]).AndReturn(
+ self.labeldef_rows)
+
+ self.config_service.approvaldef2approver_tbl.Select(
+ self.cnxn, cols=config_svc.APPROVALDEF2APPROVER_COLS,
+ project_id=project_ids).AndReturn(self.approvaldef2approver_rows)
+ self.config_service.approvaldef2survey_tbl.Select(
+ self.cnxn, cols=config_svc.APPROVALDEF2SURVEY_COLS,
+ project_id=project_ids).AndReturn(self.approvaldef2survey_rows)
+
+ self.config_service.fielddef_tbl.Select(
+ self.cnxn, cols=config_svc.FIELDDEF_COLS, project_id=project_ids,
+ order_by=[('field_name', [])]).AndReturn(self.fielddef_rows)
+ field_ids = [row[0] for row in self.fielddef_rows]
+ self.config_service.fielddef2admin_tbl.Select(
+ self.cnxn, cols=config_svc.FIELDDEF2ADMIN_COLS,
+ field_id=field_ids).AndReturn(self.fielddef2admin_rows)
+ self.config_service.fielddef2editor_tbl.Select(
+ self.cnxn, cols=config_svc.FIELDDEF2EDITOR_COLS,
+ field_id=field_ids).AndReturn(self.fielddef2editor_rows)
+
+ self.config_service.componentdef_tbl.Select(
+ self.cnxn, cols=config_svc.COMPONENTDEF_COLS, project_id=project_ids,
+ is_deleted=False,
+ order_by=[('path', [])]).AndReturn(self.componentdef_rows)
+
+ def testFetchConfigs(self):
+ keys = [789]
+ self.SetUpFetchConfigs(keys)
+ self.mox.ReplayAll()
+ config_dict = self.config_2lc._FetchConfigs(self.cnxn, keys)
+ self.mox.VerifyAll()
+ self.assertItemsEqual(keys, list(config_dict.keys()))
+
+ def testFetchItems(self):
+ keys = [678, 789]
+ self.SetUpFetchConfigs(keys)
+ self.mox.ReplayAll()
+ config_dict = self.config_2lc.FetchItems(self.cnxn, keys)
+ self.mox.VerifyAll()
+ self.assertItemsEqual(keys, list(config_dict.keys()))
+
+
+class ConfigServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.cache_manager = fake.CacheManager()
+ self.config_service = MakeConfigService(self.cache_manager, self.mox)
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ ### Label lookups
+
+ def testGetLabelDefRows_Hit(self):
+ self.config_service.label_row_2lc.CacheItem((789, 0), [])
+ self.config_service.label_row_2lc.CacheItem((789, 1), [])
+ self.config_service.label_row_2lc.CacheItem((789, 2), [])
+ self.config_service.label_row_2lc.CacheItem(
+ (789, 3), [(3, 678, 1, 'C', 'doc', True)])
+ self.config_service.label_row_2lc.CacheItem(
+ (789, 4), [(4, 678, None, 'D', 'doc', False)])
+ self.config_service.label_row_2lc.CacheItem((789, 5), [])
+ self.config_service.label_row_2lc.CacheItem((789, 6), [])
+ self.config_service.label_row_2lc.CacheItem((789, 7), [])
+ self.config_service.label_row_2lc.CacheItem((789, 8), [])
+ self.config_service.label_row_2lc.CacheItem((789, 9), [])
+ actual = self.config_service.GetLabelDefRows(self.cnxn, 789)
+ expected = [
+ (3, 678, 1, 'C', 'doc', True),
+ (4, 678, None, 'D', 'doc', False)]
+ self.assertEqual(expected, actual)
+
+ def SetUpGetLabelDefRowsAnyProject(self, rows):
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=config_svc.LABELDEF_COLS, where=None,
+ order_by=[('rank DESC', []), ('label DESC', [])]).AndReturn(
+ rows)
+
+ def testGetLabelDefRowsAnyProject(self):
+ rows = 'foo'
+ self.SetUpGetLabelDefRowsAnyProject(rows)
+ self.mox.ReplayAll()
+ actual = self.config_service.GetLabelDefRowsAnyProject(self.cnxn)
+ self.mox.VerifyAll()
+ self.assertEqual(rows, actual)
+
+ def testDeserializeLabels(self):
+ labeldef_rows = [(1, 789, 1, 'Security', 'doc', False),
+ (2, 789, 2, 'UX', 'doc', True)]
+ id_to_name, name_to_id = self.config_service._DeserializeLabels(
+ labeldef_rows)
+ self.assertEqual({1: 'Security', 2: 'UX'}, id_to_name)
+ self.assertEqual({'security': 1, 'ux': 2}, name_to_id)
+
+ def testEnsureLabelCacheEntry_Hit(self):
+ label_dicts = 'foo'
+ self.config_service.label_cache.CacheItem(789, label_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.config_service._EnsureLabelCacheEntry(self.cnxn, 789)
+ self.mox.VerifyAll()
+
+ def SetUpEnsureLabelCacheEntry_Miss(self, project_id, rows):
+ for shard_id in range(0, LABEL_ROW_SHARDS):
+ shard_rows = [row for row in rows
+ if row[0] % LABEL_ROW_SHARDS == shard_id]
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=config_svc.LABELDEF_COLS, project_id=project_id,
+ where=[('id %% %s = %s', [LABEL_ROW_SHARDS, shard_id])]).AndReturn(
+ shard_rows)
+
+ def testEnsureLabelCacheEntry_Miss(self):
+ labeldef_rows = [(1, 789, 1, 'Security', 'doc', False),
+ (2, 789, 2, 'UX', 'doc', True)]
+ self.SetUpEnsureLabelCacheEntry_Miss(789, labeldef_rows)
+ self.mox.ReplayAll()
+ self.config_service._EnsureLabelCacheEntry(self.cnxn, 789)
+ self.mox.VerifyAll()
+ label_dicts = {1: 'Security', 2: 'UX'}, {'security': 1, 'ux': 2}
+ self.assertEqual(label_dicts, self.config_service.label_cache.GetItem(789))
+
+ def testLookupLabel_Hit(self):
+ label_dicts = {1: 'Security', 2: 'UX'}, {'security': 1, 'ux': 2}
+ self.config_service.label_cache.CacheItem(789, label_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.assertEqual(
+ 'Security', self.config_service.LookupLabel(self.cnxn, 789, 1))
+ self.assertEqual(
+ 'UX', self.config_service.LookupLabel(self.cnxn, 789, 2))
+ self.mox.VerifyAll()
+
+ def testLookupLabelID_Hit(self):
+ label_dicts = {1: 'Security', 2: 'UX'}, {'security': 1, 'ux': 2}
+ self.config_service.label_cache.CacheItem(789, label_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.assertEqual(
+ 1, self.config_service.LookupLabelID(self.cnxn, 789, 'Security'))
+ self.assertEqual(
+ 2, self.config_service.LookupLabelID(self.cnxn, 789, 'UX'))
+ self.mox.VerifyAll()
+
+ def testLookupLabelID_MissAndDoubleCheck(self):
+ label_dicts = {1: 'Security', 2: 'UX'}, {'security': 1, 'ux': 2}
+ self.config_service.label_cache.CacheItem(789, label_dicts)
+
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=['id'], project_id=789,
+ where=[('LOWER(label) = %s', ['newlabel'])],
+ limit=1).AndReturn([(3,)])
+ self.mox.ReplayAll()
+ self.assertEqual(
+ 3, self.config_service.LookupLabelID(self.cnxn, 789, 'NewLabel'))
+ self.mox.VerifyAll()
+
+ def testLookupLabelID_MissAutocreate(self):
+ label_dicts = {1: 'Security', 2: 'UX'}, {'security': 1, 'ux': 2}
+ self.config_service.label_cache.CacheItem(789, label_dicts)
+
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=['id'], project_id=789,
+ where=[('LOWER(label) = %s', ['newlabel'])],
+ limit=1).AndReturn([])
+ self.config_service.labeldef_tbl.InsertRow(
+ self.cnxn, project_id=789, label='NewLabel').AndReturn(3)
+ self.mox.ReplayAll()
+ self.assertEqual(
+ 3, self.config_service.LookupLabelID(self.cnxn, 789, 'NewLabel'))
+ self.mox.VerifyAll()
+
+ def testLookupLabelID_MissDontAutocreate(self):
+ label_dicts = {1: 'Security', 2: 'UX'}, {'security': 1, 'ux': 2}
+ self.config_service.label_cache.CacheItem(789, label_dicts)
+
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=['id'], project_id=789,
+ where=[('LOWER(label) = %s', ['newlabel'])],
+ limit=1).AndReturn([])
+ self.mox.ReplayAll()
+ self.assertIsNone(self.config_service.LookupLabelID(
+ self.cnxn, 789, 'NewLabel', autocreate=False))
+ self.mox.VerifyAll()
+
+ def testLookupLabelIDs_Hit(self):
+ label_dicts = {1: 'Security', 2: 'UX'}, {'security': 1, 'ux': 2}
+ self.config_service.label_cache.CacheItem(789, label_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.assertEqual(
+ [1, 2],
+ self.config_service.LookupLabelIDs(self.cnxn, 789, ['Security', 'UX']))
+ self.mox.VerifyAll()
+
+ def testLookupIDsOfLabelsMatching_Hit(self):
+ label_dicts = {1: 'Security', 2: 'UX'}, {'security': 1, 'ux': 2}
+ self.config_service.label_cache.CacheItem(789, label_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.assertItemsEqual(
+ [1],
+ self.config_service.LookupIDsOfLabelsMatching(
+ self.cnxn, 789, re.compile('Sec.*')))
+ self.assertItemsEqual(
+ [1, 2],
+ self.config_service.LookupIDsOfLabelsMatching(
+ self.cnxn, 789, re.compile('.*')))
+ self.assertItemsEqual(
+ [],
+ self.config_service.LookupIDsOfLabelsMatching(
+ self.cnxn, 789, re.compile('Zzzzz.*')))
+ self.mox.VerifyAll()
+
+ def SetUpLookupLabelIDsAnyProject(self, label, id_rows):
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=['id'], label=label).AndReturn(id_rows)
+
+ def testLookupLabelIDsAnyProject(self):
+ self.SetUpLookupLabelIDsAnyProject('Security', [(1,)])
+ self.mox.ReplayAll()
+ actual = self.config_service.LookupLabelIDsAnyProject(
+ self.cnxn, 'Security')
+ self.mox.VerifyAll()
+ self.assertEqual([1], actual)
+
+ def SetUpLookupIDsOfLabelsMatchingAnyProject(self, id_label_rows):
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=['id', 'label']).AndReturn(id_label_rows)
+
+ def testLookupIDsOfLabelsMatchingAnyProject(self):
+ id_label_rows = [(1, 'Security'), (2, 'UX')]
+ self.SetUpLookupIDsOfLabelsMatchingAnyProject(id_label_rows)
+ self.mox.ReplayAll()
+ actual = self.config_service.LookupIDsOfLabelsMatchingAnyProject(
+ self.cnxn, re.compile('(Sec|Zzz).*'))
+ self.mox.VerifyAll()
+ self.assertEqual([1], actual)
+
+ ### Status lookups
+
+ def testGetStatusDefRows(self):
+ rows = 'foo'
+ self.config_service.status_row_2lc.CacheItem(789, rows)
+ actual = self.config_service.GetStatusDefRows(self.cnxn, 789)
+ self.assertEqual(rows, actual)
+
+ def SetUpGetStatusDefRowsAnyProject(self, rows):
+ self.config_service.statusdef_tbl.Select(
+ self.cnxn, cols=config_svc.STATUSDEF_COLS,
+ order_by=[('rank DESC', []), ('status DESC', [])]).AndReturn(
+ rows)
+
+ def testGetStatusDefRowsAnyProject(self):
+ rows = 'foo'
+ self.SetUpGetStatusDefRowsAnyProject(rows)
+ self.mox.ReplayAll()
+ actual = self.config_service.GetStatusDefRowsAnyProject(self.cnxn)
+ self.mox.VerifyAll()
+ self.assertEqual(rows, actual)
+
+ def testDeserializeStatuses(self):
+ statusdef_rows = [(1, 789, 1, 'New', True, 'doc', False),
+ (2, 789, 2, 'Fixed', False, 'doc', True)]
+ actual = self.config_service._DeserializeStatuses(statusdef_rows)
+ id_to_name, name_to_id, closed_ids = actual
+ self.assertEqual({1: 'New', 2: 'Fixed'}, id_to_name)
+ self.assertEqual({'new': 1, 'fixed': 2}, name_to_id)
+ self.assertEqual([2], closed_ids)
+
+ def testEnsureStatusCacheEntry_Hit(self):
+ status_dicts = 'foo'
+ self.config_service.status_cache.CacheItem(789, status_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.config_service._EnsureStatusCacheEntry(self.cnxn, 789)
+ self.mox.VerifyAll()
+
+ def SetUpEnsureStatusCacheEntry_Miss(self, keys, rows):
+ self.config_service.statusdef_tbl.Select(
+ self.cnxn, cols=config_svc.STATUSDEF_COLS, project_id=keys,
+ order_by=[('rank DESC', []), ('status DESC', [])]).AndReturn(
+ rows)
+
+ def testEnsureStatusCacheEntry_Miss(self):
+ statusdef_rows = [(1, 789, 1, 'New', True, 'doc', False),
+ (2, 789, 2, 'Fixed', False, 'doc', True)]
+ self.SetUpEnsureStatusCacheEntry_Miss([789], statusdef_rows)
+ self.mox.ReplayAll()
+ self.config_service._EnsureStatusCacheEntry(self.cnxn, 789)
+ self.mox.VerifyAll()
+ status_dicts = {1: 'New', 2: 'Fixed'}, {'new': 1, 'fixed': 2}, [2]
+ self.assertEqual(
+ status_dicts, self.config_service.status_cache.GetItem(789))
+
+ def testLookupStatus_Hit(self):
+ status_dicts = {1: 'New', 2: 'Fixed'}, {'new': 1, 'fixed': 2}, [2]
+ self.config_service.status_cache.CacheItem(789, status_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.assertEqual(
+ 'New', self.config_service.LookupStatus(self.cnxn, 789, 1))
+ self.assertEqual(
+ 'Fixed', self.config_service.LookupStatus(self.cnxn, 789, 2))
+ self.mox.VerifyAll()
+
+ def testLookupStatusID_Hit(self):
+ status_dicts = {1: 'New', 2: 'Fixed'}, {'new': 1, 'fixed': 2}, [2]
+ self.config_service.status_cache.CacheItem(789, status_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.assertEqual(
+ 1, self.config_service.LookupStatusID(self.cnxn, 789, 'New'))
+ self.assertEqual(
+ 2, self.config_service.LookupStatusID(self.cnxn, 789, 'Fixed'))
+ self.mox.VerifyAll()
+
+ def testLookupStatusIDs_Hit(self):
+ status_dicts = {1: 'New', 2: 'Fixed'}, {'new': 1, 'fixed': 2}, [2]
+ self.config_service.status_cache.CacheItem(789, status_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.assertEqual(
+ [1, 2],
+ self.config_service.LookupStatusIDs(self.cnxn, 789, ['New', 'Fixed']))
+ self.mox.VerifyAll()
+
+ def testLookupClosedStatusIDs_Hit(self):
+ status_dicts = {1: 'New', 2: 'Fixed'}, {'new': 1, 'fixed': 2}, [2]
+ self.config_service.status_cache.CacheItem(789, status_dicts)
+ # No mock calls set up because none are needed.
+ self.mox.ReplayAll()
+ self.assertEqual(
+ [2],
+ self.config_service.LookupClosedStatusIDs(self.cnxn, 789))
+ self.mox.VerifyAll()
+
+ def SetUpLookupClosedStatusIDsAnyProject(self, id_rows):
+ self.config_service.statusdef_tbl.Select(
+ self.cnxn, cols=['id'], means_open=False).AndReturn(
+ id_rows)
+
+ def testLookupClosedStatusIDsAnyProject(self):
+ self.SetUpLookupClosedStatusIDsAnyProject([(2,)])
+ self.mox.ReplayAll()
+ actual = self.config_service.LookupClosedStatusIDsAnyProject(self.cnxn)
+ self.mox.VerifyAll()
+ self.assertEqual([2], actual)
+
+ def SetUpLookupStatusIDsAnyProject(self, status, id_rows):
+ self.config_service.statusdef_tbl.Select(
+ self.cnxn, cols=['id'], status=status).AndReturn(id_rows)
+
+ def testLookupStatusIDsAnyProject(self):
+ self.SetUpLookupStatusIDsAnyProject('New', [(1,)])
+ self.mox.ReplayAll()
+ actual = self.config_service.LookupStatusIDsAnyProject(self.cnxn, 'New')
+ self.mox.VerifyAll()
+ self.assertEqual([1], actual)
+
+ ### Issue tracker configuration objects
+
+ def SetUpGetProjectConfigs(self, project_ids):
+ self.config_service.projectissueconfig_tbl.Select(
+ self.cnxn, cols=config_svc.PROJECTISSUECONFIG_COLS,
+ project_id=project_ids).AndReturn([])
+ self.config_service.statusdef_tbl.Select(
+ self.cnxn, cols=config_svc.STATUSDEF_COLS,
+ project_id=project_ids, where=[('rank IS NOT NULL', [])],
+ order_by=[('rank', [])]).AndReturn([])
+ self.config_service.labeldef_tbl.Select(
+ self.cnxn, cols=config_svc.LABELDEF_COLS,
+ project_id=project_ids, where=[('rank IS NOT NULL', [])],
+ order_by=[('rank', [])]).AndReturn([])
+ self.config_service.approvaldef2approver_tbl.Select(
+ self.cnxn, cols=config_svc.APPROVALDEF2APPROVER_COLS,
+ project_id=project_ids).AndReturn([])
+ self.config_service.approvaldef2survey_tbl.Select(
+ self.cnxn, cols=config_svc.APPROVALDEF2SURVEY_COLS,
+ project_id=project_ids).AndReturn([])
+ self.config_service.fielddef_tbl.Select(
+ self.cnxn, cols=config_svc.FIELDDEF_COLS,
+ project_id=project_ids, order_by=[('field_name', [])]).AndReturn([])
+ self.config_service.componentdef_tbl.Select(
+ self.cnxn, cols=config_svc.COMPONENTDEF_COLS,
+ is_deleted=False,
+ project_id=project_ids, order_by=[('path', [])]).AndReturn([])
+
+ def testGetProjectConfigs(self):
+ project_ids = [789, 679]
+ self.SetUpGetProjectConfigs(project_ids)
+
+ self.mox.ReplayAll()
+ config_dict = self.config_service.GetProjectConfigs(
+ self.cnxn, [789, 679], use_cache=False)
+ self.assertEqual(2, len(config_dict))
+ for pid in project_ids:
+ self.assertEqual(pid, config_dict[pid].project_id)
+ self.mox.VerifyAll()
+
+ def testGetProjectConfig_Hit(self):
+ project_id = 789
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(project_id)
+ self.config_service.config_2lc.CacheItem(project_id, config)
+
+ self.mox.ReplayAll()
+ actual = self.config_service.GetProjectConfig(self.cnxn, project_id)
+ self.assertEqual(config, actual)
+ self.mox.VerifyAll()
+
+ def testGetProjectConfig_Miss(self):
+ project_id = 789
+ self.SetUpGetProjectConfigs([project_id])
+
+ self.mox.ReplayAll()
+ config = self.config_service.GetProjectConfig(self.cnxn, project_id)
+ self.assertEqual(project_id, config.project_id)
+ self.mox.VerifyAll()
+
+ def SetUpStoreConfig_Default(self, project_id):
+ self.config_service.projectissueconfig_tbl.InsertRow(
+ self.cnxn, replace=True,
+ project_id=project_id,
+ statuses_offer_merge='Duplicate',
+ exclusive_label_prefixes='Type Priority Milestone',
+ default_template_for_developers=0,
+ default_template_for_users=0,
+ default_col_spec=tracker_constants.DEFAULT_COL_SPEC,
+ default_sort_spec='',
+ default_x_attr='',
+ default_y_attr='',
+ member_default_query='',
+ custom_issue_entry_url=None,
+ commit=False)
+
+ self.SetUpUpdateWellKnownLabels_Default(project_id)
+ self.SetUpUpdateWellKnownStatuses_Default(project_id)
+ self.cnxn.Commit()
+
+ def SetUpUpdateWellKnownLabels_JustCache(self, project_id):
+ by_id = {
+ idx + 1: label for idx, (label, _, _) in enumerate(
+ tracker_constants.DEFAULT_WELL_KNOWN_LABELS)}
+ by_name = {name.lower(): label_id
+ for label_id, name in by_id.items()}
+ label_dicts = by_id, by_name
+ self.config_service.label_cache.CacheAll({project_id: label_dicts})
+
+ def SetUpUpdateWellKnownLabels_Default(self, project_id):
+ self.SetUpUpdateWellKnownLabels_JustCache(project_id)
+ update_labeldef_rows = [
+ (idx + 1, project_id, idx, label, doc, deprecated)
+ for idx, (label, doc, deprecated) in enumerate(
+ tracker_constants.DEFAULT_WELL_KNOWN_LABELS)]
+ self.config_service.labeldef_tbl.Update(
+ self.cnxn, {'rank': None}, project_id=project_id, commit=False)
+ self.config_service.labeldef_tbl.InsertRows(
+ self.cnxn, config_svc.LABELDEF_COLS, update_labeldef_rows,
+ replace=True, commit=False)
+ self.config_service.labeldef_tbl.InsertRows(
+ self.cnxn, config_svc.LABELDEF_COLS[1:], [], commit=False)
+
+ def SetUpUpdateWellKnownStatuses_Default(self, project_id):
+ by_id = {
+ idx + 1: status for idx, (status, _, _, _) in enumerate(
+ tracker_constants.DEFAULT_WELL_KNOWN_STATUSES)}
+ by_name = {name.lower(): label_id
+ for label_id, name in by_id.items()}
+ closed_ids = [
+ idx + 1 for idx, (_, _, means_open, _) in enumerate(
+ tracker_constants.DEFAULT_WELL_KNOWN_STATUSES)
+ if not means_open]
+ status_dicts = by_id, by_name, closed_ids
+ self.config_service.status_cache.CacheAll({789: status_dicts})
+
+ update_statusdef_rows = [
+ (idx + 1, project_id, idx, status, means_open, doc, deprecated)
+ for idx, (status, doc, means_open, deprecated) in enumerate(
+ tracker_constants.DEFAULT_WELL_KNOWN_STATUSES)]
+ self.config_service.statusdef_tbl.Update(
+ self.cnxn, {'rank': None}, project_id=project_id, commit=False)
+ self.config_service.statusdef_tbl.InsertRows(
+ self.cnxn, config_svc.STATUSDEF_COLS, update_statusdef_rows,
+ replace=True, commit=False)
+ self.config_service.statusdef_tbl.InsertRows(
+ self.cnxn, config_svc.STATUSDEF_COLS[1:], [], commit=False)
+
+ def SetUpUpdateApprovals_Default(
+ self, approval_id, approver_rows, survey_row):
+ self.config_service.approvaldef2approver_tbl.Delete(
+ self.cnxn, approval_id=approval_id, commit=False)
+
+ self.config_service.approvaldef2approver_tbl.InsertRows(
+ self.cnxn,
+ config_svc.APPROVALDEF2APPROVER_COLS,
+ approver_rows,
+ commit=False)
+
+ approval_id, survey, project_id = survey_row
+ self.config_service.approvaldef2survey_tbl.Delete(
+ self.cnxn, approval_id=approval_id, commit=False)
+ self.config_service.approvaldef2survey_tbl.InsertRow(
+ self.cnxn,
+ approval_id=approval_id,
+ survey=survey,
+ project_id=project_id,
+ commit=False)
+
+ def testStoreConfig(self):
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ self.SetUpStoreConfig_Default(789)
+
+ self.mox.ReplayAll()
+ self.config_service.StoreConfig(self.cnxn, config)
+ self.mox.VerifyAll()
+
+ def testUpdateWellKnownLabels(self):
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ self.SetUpUpdateWellKnownLabels_Default(789)
+
+ self.mox.ReplayAll()
+ self.config_service._UpdateWellKnownLabels(self.cnxn, config)
+ self.mox.VerifyAll()
+
+ def testUpdateWellKnownLabels_Duplicate(self):
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ config.well_known_labels.append(config.well_known_labels[0])
+ self.SetUpUpdateWellKnownLabels_JustCache(789)
+
+ self.mox.ReplayAll()
+ with self.assertRaises(exceptions.InputException) as cm:
+ self.config_service._UpdateWellKnownLabels(self.cnxn, config)
+ self.mox.VerifyAll()
+ self.assertEqual(
+ 'Defined label "Type-Defect" twice',
+ cm.exception.message)
+
+ def testUpdateWellKnownStatuses(self):
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ self.SetUpUpdateWellKnownStatuses_Default(789)
+
+ self.mox.ReplayAll()
+ self.config_service._UpdateWellKnownStatuses(self.cnxn, config)
+ self.mox.VerifyAll()
+
+ def testUpdateApprovals(self):
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ approver_rows = [(123, 111, 789), (123, 222, 789)]
+ survey_row = (123, 'Q1\nQ2', 789)
+ first_approval = tracker_bizobj.MakeFieldDef(
+ 123, 789, 'FirstApproval', tracker_pb2.FieldTypes.APPROVAL_TYPE,
+ None, '', False, False, False, None, None, '', False, '', '',
+ tracker_pb2.NotifyTriggers.NEVER, 'no_action', 'the first one', False)
+ config.field_defs = [first_approval]
+ config.approval_defs = [tracker_pb2.ApprovalDef(
+ approval_id=123, approver_ids=[111, 222], survey='Q1\nQ2')]
+ self.SetUpUpdateApprovals_Default(123, approver_rows, survey_row)
+
+ self.mox.ReplayAll()
+ self.config_service._UpdateApprovals(self.cnxn, config)
+ self.mox.VerifyAll()
+
+ def testUpdateConfig(self):
+ pass # TODO(jrobbins): add a test for this
+
+ def SetUpExpungeConfig(self, project_id):
+ self.config_service.statusdef_tbl.Delete(self.cnxn, project_id=project_id)
+ self.config_service.labeldef_tbl.Delete(self.cnxn, project_id=project_id)
+ self.config_service.projectissueconfig_tbl.Delete(
+ self.cnxn, project_id=project_id)
+
+ self.config_service.config_2lc.InvalidateKeys(self.cnxn, [project_id])
+
+ def testExpungeConfig(self):
+ self.SetUpExpungeConfig(789)
+
+ self.mox.ReplayAll()
+ self.config_service.ExpungeConfig(self.cnxn, 789)
+ self.mox.VerifyAll()
+
+ def testExpungeUsersInConfigs(self):
+
+ self.config_service.component2admin_tbl.Delete = mock.Mock()
+ self.config_service.component2cc_tbl.Delete = mock.Mock()
+ self.config_service.componentdef_tbl.Update = mock.Mock()
+
+ self.config_service.fielddef2admin_tbl.Delete = mock.Mock()
+ self.config_service.fielddef2editor_tbl.Delete = mock.Mock()
+ self.config_service.approvaldef2approver_tbl.Delete = mock.Mock()
+
+ user_ids = [111, 222, 333]
+ self.config_service.ExpungeUsersInConfigs(self.cnxn, user_ids, limit=50)
+
+ self.config_service.component2admin_tbl.Delete.assert_called_once_with(
+ self.cnxn, admin_id=user_ids, commit=False, limit=50)
+ self.config_service.component2cc_tbl.Delete.assert_called_once_with(
+ self.cnxn, cc_id=user_ids, commit=False, limit=50)
+ cdef_calls = [
+ mock.call(
+ self.cnxn, {'creator_id': framework_constants.DELETED_USER_ID},
+ creator_id=user_ids, commit=False, limit=50),
+ mock.call(
+ self.cnxn, {'modifier_id': framework_constants.DELETED_USER_ID},
+ modifier_id=user_ids, commit=False, limit=50)]
+ self.config_service.componentdef_tbl.Update.assert_has_calls(cdef_calls)
+
+ self.config_service.fielddef2admin_tbl.Delete.assert_called_once_with(
+ self.cnxn, admin_id=user_ids, commit=False, limit=50)
+ self.config_service.fielddef2editor_tbl.Delete.assert_called_once_with(
+ self.cnxn, editor_id=user_ids, commit=False, limit=50)
+ self.config_service.approvaldef2approver_tbl.Delete.assert_called_once_with(
+ self.cnxn, approver_id=user_ids, commit=False, limit=50)
+
+ ### Custom field definitions
+
+ def SetUpCreateFieldDef(self, project_id):
+ self.config_service.fielddef_tbl.InsertRow(
+ self.cnxn,
+ project_id=project_id,
+ field_name='PercentDone',
+ field_type='int_type',
+ applicable_type='Defect',
+ applicable_predicate='',
+ is_required=False,
+ is_multivalued=False,
+ is_niche=False,
+ min_value=1,
+ max_value=100,
+ regex=None,
+ needs_member=None,
+ needs_perm=None,
+ grants_perm=None,
+ notify_on='never',
+ date_action='no_action',
+ docstring='doc',
+ approval_id=None,
+ is_phase_field=False,
+ is_restricted_field=True,
+ commit=False).AndReturn(1)
+ self.config_service.fielddef2admin_tbl.InsertRows(
+ self.cnxn, config_svc.FIELDDEF2ADMIN_COLS, [(1, 111)], commit=False)
+ self.config_service.fielddef2editor_tbl.InsertRows(
+ self.cnxn, config_svc.FIELDDEF2EDITOR_COLS, [(1, 222)], commit=False)
+ self.cnxn.Commit()
+
+ def testCreateFieldDef(self):
+ self.SetUpCreateFieldDef(789)
+
+ self.mox.ReplayAll()
+ field_id = self.config_service.CreateFieldDef(
+ self.cnxn,
+ 789,
+ 'PercentDone',
+ 'int_type',
+ 'Defect',
+ '',
+ False,
+ False,
+ False,
+ 1,
+ 100,
+ None,
+ None,
+ None,
+ None,
+ 0,
+ 'no_action',
+ 'doc', [111], [222],
+ is_restricted_field=True)
+ self.mox.VerifyAll()
+ self.assertEqual(1, field_id)
+
+ def SetUpSoftDeleteFieldDefs(self, field_ids):
+ self.config_service.fielddef_tbl.Update(
+ self.cnxn, {'is_deleted': True}, id=field_ids)
+
+ def testSoftDeleteFieldDefs(self):
+ self.SetUpSoftDeleteFieldDefs([1])
+
+ self.mox.ReplayAll()
+ self.config_service.SoftDeleteFieldDefs(self.cnxn, 789, [1])
+ self.mox.VerifyAll()
+
+ def SetUpUpdateFieldDef(self, field_id, new_values, admin_rows, editor_rows):
+ self.config_service.fielddef_tbl.Update(
+ self.cnxn, new_values, id=field_id, commit=False)
+ self.config_service.fielddef2admin_tbl.Delete(
+ self.cnxn, field_id=field_id, commit=False)
+ self.config_service.fielddef2admin_tbl.InsertRows(
+ self.cnxn, config_svc.FIELDDEF2ADMIN_COLS, admin_rows, commit=False)
+ self.config_service.fielddef2editor_tbl.Delete(
+ self.cnxn, field_id=field_id, commit=False)
+ self.config_service.fielddef2editor_tbl.InsertRows(
+ self.cnxn, config_svc.FIELDDEF2EDITOR_COLS, editor_rows, commit=False)
+ self.cnxn.Commit()
+
+ def testUpdateFieldDef_NoOp(self):
+ new_values = {}
+ self.SetUpUpdateFieldDef(1, new_values, [], [])
+
+ self.mox.ReplayAll()
+ self.config_service.UpdateFieldDef(
+ self.cnxn, 789, 1, admin_ids=[], editor_ids=[])
+ self.mox.VerifyAll()
+
+ def testUpdateFieldDef_Normal(self):
+ new_values = dict(
+ field_name='newname',
+ applicable_type='defect',
+ applicable_predicate='pri:1',
+ is_required=True,
+ is_niche=True,
+ is_multivalued=True,
+ min_value=32,
+ max_value=212,
+ regex='a.*b',
+ needs_member=True,
+ needs_perm='EditIssue',
+ grants_perm='DeleteIssue',
+ notify_on='any_comment',
+ docstring='new doc',
+ is_restricted_field=True)
+ self.SetUpUpdateFieldDef(1, new_values, [(1, 111)], [(1, 222)])
+
+ self.mox.ReplayAll()
+ new_values = new_values.copy()
+ new_values['notify_on'] = 1
+ self.config_service.UpdateFieldDef(
+ self.cnxn, 789, 1, admin_ids=[111], editor_ids=[222], **new_values)
+ self.mox.VerifyAll()
+
+ ### Component definitions
+
+ def SetUpFindMatchingComponentIDsAnyProject(self, _exact, rows):
+ # TODO(jrobbins): more details here.
+ self.config_service.componentdef_tbl.Select(
+ self.cnxn, cols=['id'], where=mox.IsA(list)).AndReturn(rows)
+
+ def testFindMatchingComponentIDsAnyProject_Rooted(self):
+ self.SetUpFindMatchingComponentIDsAnyProject(True, [(1,), (2,), (3,)])
+
+ self.mox.ReplayAll()
+ comp_ids = self.config_service.FindMatchingComponentIDsAnyProject(
+ self.cnxn, ['WindowManager', 'NetworkLayer'])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([1, 2, 3], comp_ids)
+
+ def testFindMatchingComponentIDsAnyProject_NonRooted(self):
+ self.SetUpFindMatchingComponentIDsAnyProject(False, [(1,), (2,), (3,)])
+
+ self.mox.ReplayAll()
+ comp_ids = self.config_service.FindMatchingComponentIDsAnyProject(
+ self.cnxn, ['WindowManager', 'NetworkLayer'], exact=False)
+ self.mox.VerifyAll()
+ self.assertItemsEqual([1, 2, 3], comp_ids)
+
+ def SetUpCreateComponentDef(self, comp_id):
+ self.config_service.componentdef_tbl.InsertRow(
+ self.cnxn, project_id=789, path='WindowManager',
+ docstring='doc', deprecated=False, commit=False,
+ created=0, creator_id=0).AndReturn(comp_id)
+ self.config_service.component2admin_tbl.InsertRows(
+ self.cnxn, config_svc.COMPONENT2ADMIN_COLS, [], commit=False)
+ self.config_service.component2cc_tbl.InsertRows(
+ self.cnxn, config_svc.COMPONENT2CC_COLS, [], commit=False)
+ self.config_service.component2label_tbl.InsertRows(
+ self.cnxn, config_svc.COMPONENT2LABEL_COLS, [], commit=False)
+ self.cnxn.Commit()
+
+ def testCreateComponentDef(self):
+ self.SetUpCreateComponentDef(1)
+
+ self.mox.ReplayAll()
+ comp_id = self.config_service.CreateComponentDef(
+ self.cnxn, 789, 'WindowManager', 'doc', False, [], [], 0, 0, [])
+ self.mox.VerifyAll()
+ self.assertEqual(1, comp_id)
+
+ def SetUpUpdateComponentDef(self, component_id):
+ self.config_service.component2admin_tbl.Delete(
+ self.cnxn, component_id=component_id, commit=False)
+ self.config_service.component2admin_tbl.InsertRows(
+ self.cnxn, config_svc.COMPONENT2ADMIN_COLS, [], commit=False)
+ self.config_service.component2cc_tbl.Delete(
+ self.cnxn, component_id=component_id, commit=False)
+ self.config_service.component2cc_tbl.InsertRows(
+ self.cnxn, config_svc.COMPONENT2CC_COLS, [], commit=False)
+ self.config_service.component2label_tbl.Delete(
+ self.cnxn, component_id=component_id, commit=False)
+ self.config_service.component2label_tbl.InsertRows(
+ self.cnxn, config_svc.COMPONENT2LABEL_COLS, [], commit=False)
+
+ self.config_service.componentdef_tbl.Update(
+ self.cnxn,
+ {'path': 'DisplayManager', 'docstring': 'doc', 'deprecated': True},
+ id=component_id, commit=False)
+ self.cnxn.Commit()
+
+ def testUpdateComponentDef(self):
+ self.SetUpUpdateComponentDef(1)
+
+ self.mox.ReplayAll()
+ self.config_service.UpdateComponentDef(
+ self.cnxn, 789, 1, path='DisplayManager', docstring='doc',
+ deprecated=True, admin_ids=[], cc_ids=[], label_ids=[])
+ self.mox.VerifyAll()
+
+ def SetUpSoftDeleteComponentDef(self, component_id):
+ self.config_service.componentdef_tbl.Update(
+ self.cnxn, {'is_deleted': True}, commit=False, id=component_id)
+ self.cnxn.Commit()
+
+ def testSoftDeleteComponentDef(self):
+ self.SetUpSoftDeleteComponentDef(1)
+
+ self.mox.ReplayAll()
+ self.config_service.DeleteComponentDef(self.cnxn, 789, 1)
+ self.mox.VerifyAll()
+
+ ### Memcache management
+
+ def testInvalidateMemcache(self):
+ pass # TODO(jrobbins): write this
+
+ def testInvalidateMemcacheShards(self):
+ NOW = 1234567
+ memcache.set('789;1', NOW)
+ memcache.set('789;2', NOW - 1000)
+ memcache.set('789;3', NOW - 2000)
+ memcache.set('all;1', NOW)
+ memcache.set('all;2', NOW - 1000)
+ memcache.set('all;3', NOW - 2000)
+
+ # Delete some of them.
+ self.config_service._InvalidateMemcacheShards(
+ [(789, 1), (789, 2), (789,9)])
+
+ self.assertIsNone(memcache.get('789;1'))
+ self.assertIsNone(memcache.get('789;2'))
+ self.assertEqual(NOW - 2000, memcache.get('789;3'))
+ self.assertIsNone(memcache.get('all;1'))
+ self.assertIsNone(memcache.get('all;2'))
+ self.assertEqual(NOW - 2000, memcache.get('all;3'))
+
+ def testInvalidateMemcacheForEntireProject(self):
+ NOW = 1234567
+ memcache.set('789;1', NOW)
+ memcache.set('config:789', 'serialized config')
+ memcache.set('label_rows:789', 'serialized label rows')
+ memcache.set('status_rows:789', 'serialized status rows')
+ memcache.set('field_rows:789', 'serialized field rows')
+ memcache.set('890;1', NOW) # Other projects will not be affected.
+
+ self.config_service.InvalidateMemcacheForEntireProject(789)
+
+ self.assertIsNone(memcache.get('789;1'))
+ self.assertIsNone(memcache.get('config:789'))
+ self.assertIsNone(memcache.get('status_rows:789'))
+ self.assertIsNone(memcache.get('label_rows:789'))
+ self.assertIsNone(memcache.get('field_rows:789'))
+ self.assertEqual(NOW, memcache.get('890;1'))
+
+ def testUsersInvolvedInConfig_Empty(self):
+ templates = []
+ config = tracker_pb2.ProjectIssueConfig()
+ self.assertEqual(set(), self.config_service.UsersInvolvedInConfig(
+ config, templates))
+
+ def testUsersInvolvedInConfig_Default(self):
+ templates = [
+ tracker_bizobj.ConvertDictToTemplate(t)
+ for t in tracker_constants.DEFAULT_TEMPLATES]
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ self.assertEqual(set(), self.config_service.UsersInvolvedInConfig(
+ config, templates))
+
+ def testUsersInvolvedInConfig_Normal(self):
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ templates = [
+ tracker_bizobj.ConvertDictToTemplate(t)
+ for t in tracker_constants.DEFAULT_TEMPLATES]
+ templates[0].owner_id = 111
+ templates[0].admin_ids = [111, 222]
+ config.field_defs = [
+ tracker_pb2.FieldDef(admin_ids=[333], editor_ids=[444])
+ ]
+ actual = self.config_service.UsersInvolvedInConfig(config, templates)
+ self.assertEqual({111, 222, 333, 444}, actual)
diff --git a/services/test/features_svc_test.py b/services/test/features_svc_test.py
new file mode 100644
index 0000000..c80b819
--- /dev/null
+++ b/services/test/features_svc_test.py
@@ -0,0 +1,1431 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for features_svc module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import mox
+import time
+import unittest
+import mock
+
+from google.appengine.api import memcache
+from google.appengine.ext import testbed
+
+import settings
+
+from features import filterrules_helpers
+from features import features_constants
+from framework import exceptions
+from framework import framework_constants
+from framework import sql
+from proto import tracker_pb2
+from proto import features_pb2
+from services import chart_svc
+from services import features_svc
+from services import star_svc
+from services import user_svc
+from testing import fake
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+
+# NOTE: we are in the process of moving away from mox towards mock.
+# This file is a mix of both. All new tests or big test updates should make
+# use of the mock package.
+def MakeFeaturesService(cache_manager, my_mox):
+ features_service = features_svc.FeaturesService(cache_manager,
+ fake.ConfigService())
+ features_service.hotlist_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ features_service.hotlist2issue_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ features_service.hotlist2user_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ return features_service
+
+
+class HotlistTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.cache_manager = fake.CacheManager()
+ self.features_service = MakeFeaturesService(self.cache_manager, self.mox)
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testDeserializeHotlists(self):
+ hotlist_rows = [
+ (123, 'hot1', 'test hot 1', 'test hotlist', False, ''),
+ (234, 'hot2', 'test hot 2', 'test hotlist', False, '')]
+
+ ts = 20021111111111
+ issue_rows = [
+ (123, 567, 10, 111, ts, ''), (123, 678, 9, 111, ts, ''),
+ (234, 567, 0, 111, ts, '')]
+ role_rows = [
+ (123, 111, 'owner'), (123, 444, 'owner'),
+ (123, 222, 'editor'),
+ (123, 333, 'follower'),
+ (234, 111, 'owner')]
+ hotlist_dict = self.features_service.hotlist_2lc._DeserializeHotlists(
+ hotlist_rows, issue_rows, role_rows)
+
+ self.assertItemsEqual([123, 234], list(hotlist_dict.keys()))
+ self.assertEqual(123, hotlist_dict[123].hotlist_id)
+ self.assertEqual('hot1', hotlist_dict[123].name)
+ self.assertItemsEqual([111, 444], hotlist_dict[123].owner_ids)
+ self.assertItemsEqual([222], hotlist_dict[123].editor_ids)
+ self.assertItemsEqual([333], hotlist_dict[123].follower_ids)
+ self.assertEqual(234, hotlist_dict[234].hotlist_id)
+ self.assertItemsEqual([111], hotlist_dict[234].owner_ids)
+
+
+class HotlistIDTwoLevelCache(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.cache_manager = fake.CacheManager()
+ self.features_service = MakeFeaturesService(self.cache_manager, self.mox)
+ self.hotlist_id_2lc = self.features_service.hotlist_id_2lc
+
+ def tearDown(self):
+ memcache.flush_all()
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testGetAll(self):
+ cached_keys = [('name1', 111), ('name2', 222)]
+ self.hotlist_id_2lc.CacheItem(cached_keys[0], 121)
+ self.hotlist_id_2lc.CacheItem(cached_keys[1], 122)
+
+ # Set up DB query mocks.
+ # Test that a ('name1', 222) or ('name3', 333) hotlist
+ # does not get returned by GetAll even though these hotlists
+ # exist and are returned by the DB queries.
+ from_db_keys = [
+ ('name1', 333), ('name3', 222), ('name3', 555)]
+ self.features_service.hotlist2user_tbl.Select = mock.Mock(return_value=[
+ (123, 333), # name1 hotlist
+ (124, 222), # name3 hotlist
+ (125, 222), # name1 hotlist, should be ignored
+ (126, 333), # name3 hotlist, should be ignored
+ (127, 555), # wrongname hotlist, should be ignored
+ ])
+ self.features_service.hotlist_tbl.Select = mock.Mock(
+ return_value=[(123, 'Name1'), (124, 'Name3'),
+ (125, 'Name1'), (126, 'Name3')])
+
+ hit, misses = self.hotlist_id_2lc.GetAll(
+ self.cnxn, cached_keys + from_db_keys)
+
+ # Assertions
+ self.features_service.hotlist2user_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['hotlist_id', 'user_id'], user_id=[555, 333, 222],
+ role_name='owner')
+ hotlist_ids = [123, 124, 125, 126, 127]
+ self.features_service.hotlist_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['id', 'name'], id=hotlist_ids, is_deleted=False,
+ where=[('LOWER(name) IN (%s,%s)', ['name3', 'name1'])])
+
+ self.assertEqual(hit,{
+ ('name1', 111): 121,
+ ('name2', 222): 122,
+ ('name1', 333): 123,
+ ('name3', 222): 124})
+ self.assertEqual(from_db_keys[-1:], misses)
+
+
+class FeaturesServiceTest(unittest.TestCase):
+
+ def MakeMockTable(self):
+ return self.mox.CreateMock(sql.SQLTableManager)
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.cache_manager = fake.CacheManager()
+ self.config_service = fake.ConfigService()
+
+ self.features_service = features_svc.FeaturesService(self.cache_manager,
+ self.config_service)
+ self.issue_service = fake.IssueService()
+ self.chart_service = self.mox.CreateMock(chart_svc.ChartService)
+
+ for table_var in [
+ 'user2savedquery_tbl', 'quickedithistory_tbl',
+ 'quickeditmostrecent_tbl', 'savedquery_tbl',
+ 'savedqueryexecutesinproject_tbl', 'project2savedquery_tbl',
+ 'filterrule_tbl', 'hotlist_tbl', 'hotlist2issue_tbl',
+ 'hotlist2user_tbl']:
+ setattr(self.features_service, table_var, self.MakeMockTable())
+
+ def tearDown(self):
+ memcache.flush_all()
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ ### quickedit command history
+
+ def testGetRecentCommands(self):
+ self.features_service.quickedithistory_tbl.Select(
+ self.cnxn, cols=['slot_num', 'command', 'comment'],
+ user_id=1, project_id=12345).AndReturn(
+ [(1, 'status=New', 'Brand new issue')])
+ self.features_service.quickeditmostrecent_tbl.SelectValue(
+ self.cnxn, 'slot_num', default=1, user_id=1, project_id=12345
+ ).AndReturn(1)
+ self.mox.ReplayAll()
+ slots, recent_slot_num = self.features_service.GetRecentCommands(
+ self.cnxn, 1, 12345)
+ self.mox.VerifyAll()
+
+ self.assertEqual(1, recent_slot_num)
+ self.assertEqual(
+ len(tracker_constants.DEFAULT_RECENT_COMMANDS), len(slots))
+ self.assertEqual('status=New', slots[0][1])
+
+ def testStoreRecentCommand(self):
+ self.features_service.quickedithistory_tbl.InsertRow(
+ self.cnxn, replace=True, user_id=1, project_id=12345,
+ slot_num=1, command='status=New', comment='Brand new issue')
+ self.features_service.quickeditmostrecent_tbl.InsertRow(
+ self.cnxn, replace=True, user_id=1, project_id=12345,
+ slot_num=1)
+ self.mox.ReplayAll()
+ self.features_service.StoreRecentCommand(
+ self.cnxn, 1, 12345, 1, 'status=New', 'Brand new issue')
+ self.mox.VerifyAll()
+
+ def testExpungeQuickEditHistory(self):
+ self.features_service.quickeditmostrecent_tbl.Delete(
+ self.cnxn, project_id=12345)
+ self.features_service.quickedithistory_tbl.Delete(
+ self.cnxn, project_id=12345)
+ self.mox.ReplayAll()
+ self.features_service.ExpungeQuickEditHistory(
+ self.cnxn, 12345)
+ self.mox.VerifyAll()
+
+ def testExpungeQuickEditsByUsers(self):
+ user_ids = [333, 555, 777]
+ commit = False
+
+ self.features_service.quickeditmostrecent_tbl.Delete = mock.Mock()
+ self.features_service.quickedithistory_tbl.Delete = mock.Mock()
+
+ self.features_service.ExpungeQuickEditsByUsers(
+ self.cnxn, user_ids, limit=50)
+
+ self.features_service.quickeditmostrecent_tbl.Delete.\
+assert_called_once_with(self.cnxn, user_id=user_ids, commit=commit, limit=50)
+ self.features_service.quickedithistory_tbl.Delete.\
+assert_called_once_with(self.cnxn, user_id=user_ids, commit=commit, limit=50)
+
+ ### Saved User and Project Queries
+
+ def testGetSavedQuery_Valid(self):
+ self.features_service.savedquery_tbl.Select(
+ self.cnxn, cols=features_svc.SAVEDQUERY_COLS, id=[1]).AndReturn(
+ [(1, 'query1', 100, 'owner:me')])
+ self.features_service.savedqueryexecutesinproject_tbl.Select(
+ self.cnxn, cols=features_svc.SAVEDQUERYEXECUTESINPROJECT_COLS,
+ query_id=[1]).AndReturn([(1, 12345)])
+ self.mox.ReplayAll()
+ saved_query = self.features_service.GetSavedQuery(
+ self.cnxn, 1)
+ self.mox.VerifyAll()
+ self.assertEqual(1, saved_query.query_id)
+ self.assertEqual('query1', saved_query.name)
+ self.assertEqual(100, saved_query.base_query_id)
+ self.assertEqual('owner:me', saved_query.query)
+ self.assertEqual([12345], saved_query.executes_in_project_ids)
+
+ def testGetSavedQuery_Invalid(self):
+ self.features_service.savedquery_tbl.Select(
+ self.cnxn, cols=features_svc.SAVEDQUERY_COLS, id=[99]).AndReturn([])
+ self.features_service.savedqueryexecutesinproject_tbl.Select(
+ self.cnxn, cols=features_svc.SAVEDQUERYEXECUTESINPROJECT_COLS,
+ query_id=[99]).AndReturn([])
+ self.mox.ReplayAll()
+ saved_query = self.features_service.GetSavedQuery(
+ self.cnxn, 99)
+ self.mox.VerifyAll()
+ self.assertIsNone(saved_query)
+
+ def SetUpUsersSavedQueries(self, has_query_id=True):
+ query = tracker_bizobj.MakeSavedQuery(1, 'query1', 100, 'owner:me')
+ self.features_service.saved_query_cache.CacheItem(1, [query])
+
+ if has_query_id:
+ self.features_service.user2savedquery_tbl.Select(
+ self.cnxn,
+ cols=features_svc.SAVEDQUERY_COLS + ['user_id', 'subscription_mode'],
+ left_joins=[('SavedQuery ON query_id = id', [])],
+ order_by=[('rank', [])],
+ user_id=[2]).AndReturn(
+ [(2, 'query2', 100, 'status:New', 2, 'Sub_Mode')])
+ self.features_service.savedqueryexecutesinproject_tbl.Select(
+ self.cnxn,
+ cols=features_svc.SAVEDQUERYEXECUTESINPROJECT_COLS,
+ query_id=set([2])).AndReturn([(2, 12345)])
+ else:
+ self.features_service.user2savedquery_tbl.Select(
+ self.cnxn,
+ cols=features_svc.SAVEDQUERY_COLS + ['user_id', 'subscription_mode'],
+ left_joins=[('SavedQuery ON query_id = id', [])],
+ order_by=[('rank', [])],
+ user_id=[2]).AndReturn([])
+
+ def testGetUsersSavedQueriesDict(self):
+ self.SetUpUsersSavedQueries()
+ self.mox.ReplayAll()
+ results_dict = self.features_service._GetUsersSavedQueriesDict(
+ self.cnxn, [1, 2])
+ self.mox.VerifyAll()
+ self.assertIn(1, results_dict)
+ self.assertIn(2, results_dict)
+
+ def testGetUsersSavedQueriesDictWithoutSavedQueries(self):
+ self.SetUpUsersSavedQueries(False)
+ self.mox.ReplayAll()
+ results_dict = self.features_service._GetUsersSavedQueriesDict(
+ self.cnxn, [1, 2])
+ self.mox.VerifyAll()
+ self.assertIn(1, results_dict)
+ self.assertNotIn(2, results_dict)
+
+ def testGetSavedQueriesByUserID(self):
+ self.SetUpUsersSavedQueries()
+ self.mox.ReplayAll()
+ saved_queries = self.features_service.GetSavedQueriesByUserID(
+ self.cnxn, 2)
+ self.mox.VerifyAll()
+ self.assertEqual(1, len(saved_queries))
+ self.assertEqual(2, saved_queries[0].query_id)
+
+ def SetUpCannedQueriesForProjects(self, project_ids):
+ query = tracker_bizobj.MakeSavedQuery(
+ 2, 'project-query-2', 110, 'owner:goose@chaos.honk')
+ self.features_service.canned_query_cache.CacheItem(12346, [query])
+ self.features_service.canned_query_cache.CacheAll = mock.Mock()
+ self.features_service.project2savedquery_tbl.Select(
+ self.cnxn, cols=['project_id'] + features_svc.SAVEDQUERY_COLS,
+ left_joins=[('SavedQuery ON query_id = id', [])],
+ order_by=[('rank', [])], project_id=project_ids).AndReturn(
+ [(12345, 1, 'query1', 100, 'owner:me')])
+
+ def testGetCannedQueriesForProjects(self):
+ project_ids = [12345, 12346]
+ self.SetUpCannedQueriesForProjects(project_ids)
+ self.mox.ReplayAll()
+ results_dict = self.features_service.GetCannedQueriesForProjects(
+ self.cnxn, project_ids)
+ self.mox.VerifyAll()
+ self.assertIn(12345, results_dict)
+ self.assertIn(12346, results_dict)
+ self.features_service.canned_query_cache.CacheAll.assert_called_once_with(
+ results_dict)
+
+ def testGetCannedQueriesByProjectID(self):
+ project_id= 12345
+ self.SetUpCannedQueriesForProjects([project_id])
+ self.mox.ReplayAll()
+ result = self.features_service.GetCannedQueriesByProjectID(
+ self.cnxn, project_id)
+ self.mox.VerifyAll()
+ self.assertEqual(1, len(result))
+ self.assertEqual(1, result[0].query_id)
+
+ def SetUpUpdateSavedQueries(self, commit=True):
+ query1 = tracker_bizobj.MakeSavedQuery(1, 'query1', 100, 'owner:me')
+ query2 = tracker_bizobj.MakeSavedQuery(None, 'query2', 100, 'status:New')
+ saved_queries = [query1, query2]
+ savedquery_rows = [
+ (sq.query_id or None, sq.name, sq.base_query_id, sq.query)
+ for sq in saved_queries]
+ self.features_service.savedquery_tbl.Delete(
+ self.cnxn, id=[1], commit=commit)
+ self.features_service.savedquery_tbl.InsertRows(
+ self.cnxn, features_svc.SAVEDQUERY_COLS, savedquery_rows, commit=commit,
+ return_generated_ids=True).AndReturn([11, 12])
+ return saved_queries
+
+ def testUpdateSavedQueries(self):
+ saved_queries = self.SetUpUpdateSavedQueries()
+ self.mox.ReplayAll()
+ self.features_service._UpdateSavedQueries(
+ self.cnxn, saved_queries, True)
+ self.mox.VerifyAll()
+
+ def testUpdateCannedQueries(self):
+ self.features_service.project2savedquery_tbl.Delete(
+ self.cnxn, project_id=12345, commit=False)
+ canned_queries = self.SetUpUpdateSavedQueries(False)
+ project2savedquery_rows = [(12345, 0, 1), (12345, 1, 12)]
+ self.features_service.project2savedquery_tbl.InsertRows(
+ self.cnxn, features_svc.PROJECT2SAVEDQUERY_COLS,
+ project2savedquery_rows, commit=False)
+ self.features_service.canned_query_cache.Invalidate = mock.Mock()
+ self.cnxn.Commit()
+ self.mox.ReplayAll()
+ self.features_service.UpdateCannedQueries(
+ self.cnxn, 12345, canned_queries)
+ self.mox.VerifyAll()
+ self.features_service.canned_query_cache.Invalidate.assert_called_once_with(
+ self.cnxn, 12345)
+
+ def testUpdateUserSavedQueries(self):
+ saved_queries = self.SetUpUpdateSavedQueries(False)
+ self.features_service.savedqueryexecutesinproject_tbl.Delete(
+ self.cnxn, query_id=[1], commit=False)
+ self.features_service.user2savedquery_tbl.Delete(
+ self.cnxn, user_id=1, commit=False)
+ user2savedquery_rows = [
+ (1, 0, 1, 'noemail'), (1, 1, 12, 'noemail')]
+ self.features_service.user2savedquery_tbl.InsertRows(
+ self.cnxn, features_svc.USER2SAVEDQUERY_COLS,
+ user2savedquery_rows, commit=False)
+ self.features_service.savedqueryexecutesinproject_tbl.InsertRows(
+ self.cnxn, features_svc.SAVEDQUERYEXECUTESINPROJECT_COLS, [],
+ commit=False)
+ self.cnxn.Commit()
+ self.mox.ReplayAll()
+ self.features_service.UpdateUserSavedQueries(
+ self.cnxn, 1, saved_queries)
+ self.mox.VerifyAll()
+
+ ### Subscriptions
+
+ def testGetSubscriptionsInProjects(self):
+ sqeip_join_str = (
+ 'SavedQueryExecutesInProject ON '
+ 'SavedQueryExecutesInProject.query_id = User2SavedQuery.query_id')
+ user_join_str = (
+ 'User ON '
+ 'User.user_id = User2SavedQuery.user_id')
+ now = 1519418530
+ self.mox.StubOutWithMock(time, 'time')
+ time.time().MultipleTimes().AndReturn(now)
+ absence_threshold = now - settings.subscription_timeout_secs
+ where = [
+ ('(User.banned IS NULL OR User.banned = %s)', ['']),
+ ('User.last_visit_timestamp >= %s', [absence_threshold]),
+ ('(User.email_bounce_timestamp IS NULL OR '
+ 'User.email_bounce_timestamp = %s)', [0]),
+ ]
+ self.features_service.user2savedquery_tbl.Select(
+ self.cnxn, cols=['User2SavedQuery.user_id'], distinct=True,
+ joins=[(sqeip_join_str, []), (user_join_str, [])],
+ subscription_mode='immediate', project_id=12345,
+ where=where).AndReturn(
+ [(1, 'asd'), (2, 'efg')])
+ self.SetUpUsersSavedQueries()
+ self.mox.ReplayAll()
+ result = self.features_service.GetSubscriptionsInProjects(
+ self.cnxn, 12345)
+ self.mox.VerifyAll()
+ self.assertIn(1, result)
+ self.assertIn(2, result)
+
+ def testExpungeSavedQueriesExecuteInProject(self):
+ self.features_service.savedqueryexecutesinproject_tbl.Delete(
+ self.cnxn, project_id=12345)
+ self.features_service.project2savedquery_tbl.Select(
+ self.cnxn, cols=['query_id'], project_id=12345).AndReturn(
+ [(1, 'asd'), (2, 'efg')])
+ self.features_service.project2savedquery_tbl.Delete(
+ self.cnxn, project_id=12345)
+ self.features_service.savedquery_tbl.Delete(
+ self.cnxn, id=[1, 2])
+ self.mox.ReplayAll()
+ self.features_service.ExpungeSavedQueriesExecuteInProject(
+ self.cnxn, 12345)
+ self.mox.VerifyAll()
+
+ def testExpungeSavedQueriesByUsers(self):
+ user_ids = [222, 444, 666]
+ commit = False
+
+ sv_rows = [(8,), (9,)]
+ self.features_service.user2savedquery_tbl.Select = mock.Mock(
+ return_value=sv_rows)
+ self.features_service.user2savedquery_tbl.Delete = mock.Mock()
+ self.features_service.savedqueryexecutesinproject_tbl.Delete = mock.Mock()
+ self.features_service.savedquery_tbl.Delete = mock.Mock()
+
+ self.features_service.ExpungeSavedQueriesByUsers(
+ self.cnxn, user_ids, limit=50)
+
+ self.features_service.user2savedquery_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['query_id'], user_id=user_ids, limit=50)
+ self.features_service.user2savedquery_tbl.Delete.assert_called_once_with(
+ self.cnxn, query_id=[8, 9], commit=commit)
+ self.features_service.savedqueryexecutesinproject_tbl.\
+Delete.assert_called_once_with(
+ self.cnxn, query_id=[8, 9], commit=commit)
+ self.features_service.savedquery_tbl.Delete.assert_called_once_with(
+ self.cnxn, id=[8, 9], commit=commit)
+
+
+ ### Filter Rules
+
+ def testDeserializeFilterRules(self):
+ filterrule_rows = [
+ (12345, 0, 'predicate1', 'default_status:New'),
+ (12345, 1, 'predicate2', 'default_owner_id:1 add_cc_id:2'),
+ ]
+ result_dict = self.features_service._DeserializeFilterRules(
+ filterrule_rows)
+ self.assertIn(12345, result_dict)
+ self.assertEqual(2, len(result_dict[12345]))
+ self.assertEqual('New', result_dict[12345][0].default_status)
+ self.assertEqual(1, result_dict[12345][1].default_owner_id)
+ self.assertEqual([2], result_dict[12345][1].add_cc_ids)
+
+ def testDeserializeRuleConsequence_Multiple(self):
+ consequence = ('default_status:New default_owner_id:1 add_cc_id:2'
+ ' add_label:label-1 add_label:label.2'
+ ' add_notify:admin@example.com')
+ (default_status, default_owner_id, add_cc_ids, add_labels,
+ add_notify, warning, error
+ ) = self.features_service._DeserializeRuleConsequence(
+ consequence)
+ self.assertEqual('New', default_status)
+ self.assertEqual(1, default_owner_id)
+ self.assertEqual([2], add_cc_ids)
+ self.assertEqual(['label-1', 'label.2'], add_labels)
+ self.assertEqual(['admin@example.com'], add_notify)
+ self.assertEqual(None, warning)
+ self.assertEqual(None, error)
+
+ def testDeserializeRuleConsequence_Warning(self):
+ consequence = ('warning:Do not use status:New if there is an owner')
+ (_status, _owner_id, _cc_ids, _labels, _notify,
+ warning, _error) = self.features_service._DeserializeRuleConsequence(
+ consequence)
+ self.assertEqual(
+ 'Do not use status:New if there is an owner',
+ warning)
+
+ def testDeserializeRuleConsequence_Error(self):
+ consequence = ('error:Pri-0 issues require an owner')
+ (_status, _owner_id, _cc_ids, _labels, _notify,
+ _warning, error) = self.features_service._DeserializeRuleConsequence(
+ consequence)
+ self.assertEqual(
+ 'Pri-0 issues require an owner',
+ error)
+
+ def SetUpGetFilterRulesByProjectIDs(self):
+ filterrule_rows = [
+ (12345, 0, 'predicate1', 'default_status:New'),
+ (12345, 1, 'predicate2', 'default_owner_id:1 add_cc_id:2'),
+ ]
+
+ self.features_service.filterrule_tbl.Select(
+ self.cnxn, cols=features_svc.FILTERRULE_COLS,
+ project_id=[12345]).AndReturn(filterrule_rows)
+
+ def testGetFilterRulesByProjectIDs(self):
+ self.SetUpGetFilterRulesByProjectIDs()
+ self.mox.ReplayAll()
+ result = self.features_service._GetFilterRulesByProjectIDs(
+ self.cnxn, [12345])
+ self.mox.VerifyAll()
+ self.assertIn(12345, result)
+ self.assertEqual(2, len(result[12345]))
+
+ def testGetFilterRules(self):
+ self.SetUpGetFilterRulesByProjectIDs()
+ self.mox.ReplayAll()
+ result = self.features_service.GetFilterRules(
+ self.cnxn, 12345)
+ self.mox.VerifyAll()
+ self.assertEqual(2, len(result))
+
+ def testSerializeRuleConsequence(self):
+ rule = filterrules_helpers.MakeRule(
+ 'predicate', 'New', 1, [1, 2], ['label1', 'label2'], ['admin'])
+ result = self.features_service._SerializeRuleConsequence(rule)
+ self.assertEqual('add_label:label1 add_label:label2 default_status:New'
+ ' default_owner_id:1 add_cc_id:1 add_cc_id:2'
+ ' add_notify:admin', result)
+
+ def testUpdateFilterRules(self):
+ self.features_service.filterrule_tbl.Delete(self.cnxn, project_id=12345)
+ rows = [
+ (12345, 0, 'predicate1', 'add_label:label1 add_label:label2'
+ ' default_status:New default_owner_id:1'
+ ' add_cc_id:1 add_cc_id:2 add_notify:admin'),
+ (12345, 1, 'predicate2', 'add_label:label2 add_label:label3'
+ ' default_status:Fixed default_owner_id:2'
+ ' add_cc_id:1 add_cc_id:2 add_notify:admin2')
+ ]
+ self.features_service.filterrule_tbl.InsertRows(
+ self.cnxn, features_svc.FILTERRULE_COLS, rows)
+ rule1 = filterrules_helpers.MakeRule(
+ 'predicate1', 'New', 1, [1, 2], ['label1', 'label2'], ['admin'])
+ rule2 = filterrules_helpers.MakeRule(
+ 'predicate2', 'Fixed', 2, [1, 2], ['label2', 'label3'], ['admin2'])
+ self.mox.ReplayAll()
+ self.features_service.UpdateFilterRules(
+ self.cnxn, 12345, [rule1, rule2])
+ self.mox.VerifyAll()
+
+ def testExpungeFilterRules(self):
+ self.features_service.filterrule_tbl.Delete(self.cnxn, project_id=12345)
+ self.mox.ReplayAll()
+ self.features_service.ExpungeFilterRules(
+ self.cnxn, 12345)
+ self.mox.VerifyAll()
+
+ def testExpungeFilterRulesByUser(self):
+ emails = {'chicken@farm.test': 333, 'cow@fart.test': 222}
+ project_1_keep_rows = [
+ (1, 1, 'label:no-match-here', 'add_label:should-be-deleted-inserted')]
+ project_16_keep_rows =[
+ (16, 20, 'label:no-match-here', 'add_label:should-be-deleted-inserted'),
+ (16, 21, 'owner:rainbow@test.com', 'add_label:delete-and-insert')]
+ random_row = [
+ (19, 9, 'label:no-match-in-project', 'add_label:no-DELETE-INSERTROW')]
+ rows_to_delete = [
+ (1, 45, 'owner:cow@fart.test', 'add_label:happy-cows'),
+ (1, 46, 'owner:cow@fart.test', 'add_label:balloon'),
+ (16, 47, 'label:queue-eggs', 'add_notify:chicken@farm.test'),
+ (17, 48, 'owner:farmer@farm.test', 'add_cc_id:111 add_cc_id:222'),
+ (17, 48, 'label:queue-chickens', 'default_owner_id:333'),
+ ]
+ rows = (rows_to_delete + project_1_keep_rows + project_16_keep_rows +
+ random_row)
+ self.features_service.filterrule_tbl.Select = mock.Mock(return_value=rows)
+ self.features_service.filterrule_tbl.Delete = mock.Mock()
+
+ rules_dict = self.features_service.ExpungeFilterRulesByUser(
+ self.cnxn, emails)
+ expected_dict = {
+ 1: [tracker_pb2.FilterRule(
+ predicate=rows[0][2], add_labels=['happy-cows']),
+ tracker_pb2.FilterRule(
+ predicate=rows[1][2], add_labels=['balloon'])],
+ 16: [tracker_pb2.FilterRule(
+ predicate=rows[2][2], add_notify_addrs=['chicken@farm.test'])],
+ 17: [tracker_pb2.FilterRule(
+ predicate=rows[3][2], add_cc_ids=[111, 222])],
+ }
+ self.assertItemsEqual(rules_dict, expected_dict)
+
+ self.features_service.filterrule_tbl.Select.assert_called_once_with(
+ self.cnxn, features_svc.FILTERRULE_COLS)
+
+ calls = [mock.call(self.cnxn, project_id=project_id, rank=rank,
+ predicate=predicate, consequence=consequence,
+ commit=False)
+ for (project_id, rank, predicate, consequence) in rows_to_delete]
+ self.features_service.filterrule_tbl.Delete.assert_has_calls(
+ calls, any_order=True)
+
+ def testExpungeFilterRulesByUser_EmptyUsers(self):
+ self.features_service.filterrule_tbl.Select = mock.Mock()
+ self.features_service.filterrule_tbl.Delete = mock.Mock()
+
+ rules_dict = self.features_service.ExpungeFilterRulesByUser(self.cnxn, {})
+ self.assertEqual(rules_dict, {})
+ self.features_service.filterrule_tbl.Select.assert_not_called()
+ self.features_service.filterrule_tbl.Delete.assert_not_called()
+
+ def testExpungeFilterRulesByUser_NoMatch(self):
+ rows = [
+ (17, 48, 'owner:farmer@farm.test', 'add_cc_id:111 add_cc_id: 222'),
+ (19, 9, 'label:no-match-in-project', 'add_label:no-DELETE-INSERTROW'),
+ ]
+ self.features_service.filterrule_tbl.Select = mock.Mock(return_value=rows)
+ self.features_service.filterrule_tbl.Delete = mock.Mock()
+
+ emails = {'cow@fart.test': 222}
+ rules_dict = self.features_service.ExpungeFilterRulesByUser(
+ self.cnxn, emails)
+ self.assertItemsEqual(rules_dict, {})
+
+ self.features_service.filterrule_tbl.Select.assert_called_once_with(
+ self.cnxn, features_svc.FILTERRULE_COLS)
+ self.features_service.filterrule_tbl.Delete.assert_not_called()
+
+ ### Hotlists
+
+ def SetUpCreateHotlist(self):
+ # Check for the existing hotlist: there should be none.
+ # Two hotlists named 'hot1' exist but neither are owned by the user.
+ self.features_service.hotlist2user_tbl.Select(
+ self.cnxn, cols=['hotlist_id', 'user_id'],
+ user_id=[567], role_name='owner').AndReturn([])
+
+ self.features_service.hotlist_tbl.Select(
+ self.cnxn, cols=['id', 'name'], id=[], is_deleted=False,
+ where =[(('LOWER(name) IN (%s)'), ['hot1'])]).AndReturn([])
+
+ # Inserting the hotlist returns the id.
+ self.features_service.hotlist_tbl.InsertRow(
+ self.cnxn, name='hot1', summary='hot 1', description='test hotlist',
+ is_private=False,
+ default_col_spec=features_constants.DEFAULT_COL_SPEC).AndReturn(123)
+
+ # Insert the issues: there are none.
+ self.features_service.hotlist2issue_tbl.InsertRows(
+ self.cnxn, features_svc.HOTLIST2ISSUE_COLS,
+ [], commit=False)
+
+ # Insert the users: there is one owner and one editor.
+ self.features_service.hotlist2user_tbl.InsertRows(
+ self.cnxn, ['hotlist_id', 'user_id', 'role_name'],
+ [(123, 567, 'owner'), (123, 678, 'editor')])
+
+ def testCreateHotlist(self):
+ self.SetUpCreateHotlist()
+ self.mox.ReplayAll()
+ self.features_service.CreateHotlist(
+ self.cnxn, 'hot1', 'hot 1', 'test hotlist', [567], [678])
+ self.mox.VerifyAll()
+
+ def testCreateHotlist_InvalidName(self):
+ with self.assertRaises(exceptions.InputException):
+ self.features_service.CreateHotlist(
+ self.cnxn, '***Invalid name***', 'Misnamed Hotlist',
+ 'A Hotlist with an invalid name', [567], [678])
+
+ def testCreateHotlist_NoOwner(self):
+ with self.assertRaises(features_svc.UnownedHotlistException):
+ self.features_service.CreateHotlist(
+ self.cnxn, 'unowned-hotlist', 'Unowned Hotlist',
+ 'A Hotlist that is not owned', [], [])
+
+ def testCreateHotlist_HotlistAlreadyExists(self):
+ self.features_service.hotlist_id_2lc.CacheItem(('fake-hotlist', 567), 123)
+ with self.assertRaises(features_svc.HotlistAlreadyExists):
+ self.features_service.CreateHotlist(
+ self.cnxn, 'Fake-Hotlist', 'Misnamed Hotlist',
+ 'This name is already in use', [567], [678])
+
+ def testTransferHotlistOwnership(self):
+ hotlist_id = 123
+ new_owner_id = 222
+ hotlist = fake.Hotlist(hotlist_name='unique', hotlist_id=hotlist_id,
+ owner_ids=[111], editor_ids=[222, 333],
+ follower_ids=[444])
+ # LookupHotlistIDs, proposed new owner, owns no hotlist with the same name.
+ self.features_service.hotlist2user_tbl.Select = mock.Mock(
+ return_value=[(223, new_owner_id), (567, new_owner_id)])
+ self.features_service.hotlist_tbl.Select = mock.Mock(return_value=[])
+
+ # UpdateHotlistRoles
+ self.features_service.GetHotlist = mock.Mock(return_value=hotlist)
+ self.features_service.hotlist2user_tbl.Delete = mock.Mock()
+ self.features_service.hotlist2user_tbl.InsertRows = mock.Mock()
+
+ self.features_service.TransferHotlistOwnership(
+ self.cnxn, hotlist, new_owner_id, True)
+
+ self.features_service.hotlist2user_tbl.Delete.assert_called_once_with(
+ self.cnxn, hotlist_id=hotlist_id, commit=False)
+
+ self.features_service.GetHotlist.assert_called_once_with(
+ self.cnxn, hotlist_id, use_cache=False)
+ insert_rows = [(hotlist_id, new_owner_id, 'owner'),
+ (hotlist_id, 333, 'editor'),
+ (hotlist_id, 111, 'editor'),
+ (hotlist_id, 444, 'follower')]
+ self.features_service.hotlist2user_tbl.InsertRows.assert_called_once_with(
+ self.cnxn, features_svc.HOTLIST2USER_COLS, insert_rows, commit=False)
+
+ def testLookupHotlistIDs(self):
+ # Set up DB query mocks.
+ self.features_service.hotlist2user_tbl.Select = mock.Mock(return_value=[
+ (123, 222), (125, 333)])
+ self.features_service.hotlist_tbl.Select = mock.Mock(
+ return_value=[(123, 'q3-TODO'), (125, 'q4-TODO')])
+
+ self.features_service.hotlist_id_2lc.CacheItem(
+ ('q4-todo', 333), 124)
+
+ ret = self.features_service.LookupHotlistIDs(
+ self.cnxn, ['q3-todo', 'Q4-TODO'], [222, 333, 444])
+ self.assertEqual(ret, {('q3-todo', 222) : 123, ('q4-todo', 333): 124})
+ self.features_service.hotlist2user_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['hotlist_id', 'user_id'], user_id=[444, 333, 222],
+ role_name='owner')
+ self.features_service.hotlist_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['id', 'name'], id=[123, 125], is_deleted=False,
+ where=[
+ (('LOWER(name) IN (%s,%s)'), ['q3-todo', 'q4-todo'])])
+
+ def SetUpLookupUserHotlists(self):
+ self.features_service.hotlist2user_tbl.Select(
+ self.cnxn, cols=['user_id', 'hotlist_id'],
+ user_id=[111], left_joins=[('Hotlist ON hotlist_id = id', [])],
+ where=[('Hotlist.is_deleted = %s', [False])]).AndReturn([(111, 123)])
+
+ def testLookupUserHotlists(self):
+ self.SetUpLookupUserHotlists()
+ self.mox.ReplayAll()
+ ret = self.features_service.LookupUserHotlists(
+ self.cnxn, [111])
+ self.assertEqual(ret, {111: [123]})
+ self.mox.VerifyAll()
+
+ def SetUpLookupIssueHotlists(self):
+ self.features_service.hotlist2issue_tbl.Select(
+ self.cnxn, cols=['hotlist_id', 'issue_id'],
+ issue_id=[987], left_joins=[('Hotlist ON hotlist_id = id', [])],
+ where=[('Hotlist.is_deleted = %s', [False])]).AndReturn([(123, 987)])
+
+ def testLookupIssueHotlists(self):
+ self.SetUpLookupIssueHotlists()
+ self.mox.ReplayAll()
+ ret = self.features_service.LookupIssueHotlists(
+ self.cnxn, [987])
+ self.assertEqual(ret, {987: [123]})
+ self.mox.VerifyAll()
+
+ def SetUpGetHotlists(
+ self, hotlist_id, hotlist_rows=None, issue_rows=None, role_rows=None):
+ if not hotlist_rows:
+ hotlist_rows = [(hotlist_id, 'hotlist2', 'test hotlist 2',
+ 'test hotlist', False, '')]
+ if not issue_rows:
+ issue_rows=[]
+ if not role_rows:
+ role_rows=[]
+ self.features_service.hotlist_tbl.Select(
+ self.cnxn, cols=features_svc.HOTLIST_COLS,
+ id=[hotlist_id], is_deleted=False).AndReturn(hotlist_rows)
+ self.features_service.hotlist2user_tbl.Select(
+ self.cnxn, cols=['hotlist_id', 'user_id', 'role_name'],
+ hotlist_id=[hotlist_id]).AndReturn(role_rows)
+ self.features_service.hotlist2issue_tbl.Select(
+ self.cnxn, cols=features_svc.HOTLIST2ISSUE_COLS,
+ hotlist_id=[hotlist_id],
+ order_by=[('rank DESC', []), ('issue_id', [])]).AndReturn(issue_rows)
+
+ def SetUpUpdateHotlist(self, hotlist_id):
+ hotlist_rows = [
+ (hotlist_id, 'hotlist2', 'test hotlist 2', 'test hotlist', False, '')
+ ]
+ role_rows = [(hotlist_id, 111, 'owner')]
+
+ self.features_service.hotlist_tbl.Select = mock.Mock(
+ return_value=hotlist_rows)
+ self.features_service.hotlist2user_tbl.Select = mock.Mock(
+ return_value=role_rows)
+ self.features_service.hotlist2issue_tbl.Select = mock.Mock(return_value=[])
+
+ self.features_service.hotlist_tbl.Update = mock.Mock()
+ self.features_service.hotlist2user_tbl.Delete = mock.Mock()
+ self.features_service.hotlist2user_tbl.InsertRows = mock.Mock()
+
+ def testUpdateHotlist(self):
+ hotlist_id = 456
+ self.SetUpUpdateHotlist(hotlist_id)
+
+ self.features_service.UpdateHotlist(
+ self.cnxn,
+ hotlist_id,
+ summary='A better one-line summary',
+ owner_id=333,
+ add_editor_ids=[444, 555])
+ delta = {'summary': 'A better one-line summary'}
+ self.features_service.hotlist_tbl.Update.assert_called_once_with(
+ self.cnxn, delta, id=hotlist_id, commit=False)
+ self.features_service.hotlist2user_tbl.Delete.assert_called_once_with(
+ self.cnxn, hotlist_id=hotlist_id, role='owner', commit=False)
+ add_role_rows = [
+ (hotlist_id, 333, 'owner'), (hotlist_id, 444, 'editor'),
+ (hotlist_id, 555, 'editor')
+ ]
+ self.features_service.hotlist2user_tbl.InsertRows.assert_called_once_with(
+ self.cnxn, features_svc.HOTLIST2USER_COLS, add_role_rows, commit=False)
+
+ def testUpdateHotlist_NoRoleChanges(self):
+ hotlist_id = 456
+ self.SetUpUpdateHotlist(hotlist_id)
+
+ self.features_service.UpdateHotlist(self.cnxn, hotlist_id, name='chicken')
+ delta = {'name': 'chicken'}
+ self.features_service.hotlist_tbl.Update.assert_called_once_with(
+ self.cnxn, delta, id=hotlist_id, commit=False)
+ self.features_service.hotlist2user_tbl.Delete.assert_not_called()
+ self.features_service.hotlist2user_tbl.InsertRows.assert_not_called()
+
+ def testUpdateHotlist_NoOwnerChange(self):
+ hotlist_id = 456
+ self.SetUpUpdateHotlist(hotlist_id)
+
+ self.features_service.UpdateHotlist(
+ self.cnxn, hotlist_id, name='chicken', add_editor_ids=[
+ 333,
+ ])
+ delta = {'name': 'chicken'}
+ self.features_service.hotlist_tbl.Update.assert_called_once_with(
+ self.cnxn, delta, id=hotlist_id, commit=False)
+ self.features_service.hotlist2user_tbl.Delete.assert_not_called()
+ self.features_service.hotlist2user_tbl.InsertRows.assert_called_once_with(
+ self.cnxn,
+ features_svc.HOTLIST2USER_COLS, [
+ (hotlist_id, 333, 'editor'),
+ ],
+ commit=False)
+
+ def SetUpRemoveHotlistEditors(self):
+ hotlist = fake.Hotlist(
+ hotlist_name='hotlist',
+ hotlist_id=456,
+ owner_ids=[111],
+ editor_ids=[222, 333, 444])
+ self.features_service.GetHotlist = mock.Mock(return_value=hotlist)
+ self.features_service.hotlist2user_tbl.Delete = mock.Mock()
+ return hotlist
+
+ def testRemoveHotlistEditors(self):
+ """We can remove editors from a hotlist."""
+ hotlist = self.SetUpRemoveHotlistEditors()
+ remove_editor_ids = [222, 333]
+ self.features_service.RemoveHotlistEditors(
+ self.cnxn, hotlist.hotlist_id, remove_editor_ids=remove_editor_ids)
+ self.features_service.hotlist2user_tbl.Delete.assert_called_once_with(
+ self.cnxn, hotlist_id=hotlist.hotlist_id, user_id=remove_editor_ids)
+ self.assertEqual(hotlist.editor_ids, [444])
+
+ def testRemoveHotlistEditors_NoOp(self):
+ """A NoOp update does not trigger and sql table calls."""
+ hotlist = self.SetUpRemoveHotlistEditors()
+ with self.assertRaises(exceptions.InputException):
+ self.features_service.RemoveHotlistEditors(
+ self.cnxn, hotlist.hotlist_id, remove_editor_ids=[])
+
+ def SetUpUpdateHotlistItemsFields(self, hotlist_id, issue_ids):
+ hotlist_rows = [(hotlist_id, 'hotlist', '', '', True, '')]
+ insert_rows = [(345, 11, 112, 333, 2002, ''),
+ (345, 33, 332, 333, 2002, ''),
+ (345, 55, 552, 333, 2002, '')]
+ issue_rows = [(345, 11, 1, 333, 2002, ''), (345, 33, 3, 333, 2002, ''),
+ (345, 55, 3, 333, 2002, '')]
+ self.SetUpGetHotlists(
+ hotlist_id, hotlist_rows=hotlist_rows, issue_rows=issue_rows)
+ self.features_service.hotlist2issue_tbl.Delete(
+ self.cnxn, hotlist_id=hotlist_id,
+ issue_id=issue_ids, commit=False)
+ self.features_service.hotlist2issue_tbl.InsertRows(
+ self.cnxn, cols=features_svc.HOTLIST2ISSUE_COLS,
+ row_values=insert_rows, commit=True)
+
+ def testUpdateHotlistItemsFields_Ranks(self):
+ hotlist_item_fields = [
+ (11, 1, 333, 2002, ''), (33, 3, 333, 2002, ''),
+ (55, 3, 333, 2002, '')]
+ hotlist = fake.Hotlist(hotlist_name='hotlist', hotlist_id=345,
+ hotlist_item_fields=hotlist_item_fields)
+ self.features_service.hotlist_2lc.CacheItem(345, hotlist)
+ relations_to_change = {11: 112, 33: 332, 55: 552}
+ issue_ids = [11, 33, 55]
+ self.SetUpUpdateHotlistItemsFields(345, issue_ids)
+ self.mox.ReplayAll()
+ self.features_service.UpdateHotlistItemsFields(
+ self.cnxn, 345, new_ranks=relations_to_change)
+ self.mox.VerifyAll()
+
+ def testUpdateHotlistItemsFields_Notes(self):
+ pass
+
+ def testGetHotlists(self):
+ hotlist1 = fake.Hotlist(hotlist_name='hotlist1', hotlist_id=123)
+ self.features_service.hotlist_2lc.CacheItem(123, hotlist1)
+ self.SetUpGetHotlists(456)
+ self.mox.ReplayAll()
+ hotlist_dict = self.features_service.GetHotlists(
+ self.cnxn, [123, 456])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([123, 456], list(hotlist_dict.keys()))
+ self.assertEqual('hotlist1', hotlist_dict[123].name)
+ self.assertEqual('hotlist2', hotlist_dict[456].name)
+
+ def testGetHotlistsByID(self):
+ hotlist1 = fake.Hotlist(hotlist_name='hotlist1', hotlist_id=123)
+ self.features_service.hotlist_2lc.CacheItem(123, hotlist1)
+ # NOTE: The setup function must take a hotlist_id that is different
+ # from what was used in previous tests, otherwise the methods in the
+ # setup function will never get called.
+ self.SetUpGetHotlists(456)
+ self.mox.ReplayAll()
+ _, actual_missed = self.features_service.GetHotlistsByID(
+ self.cnxn, [123, 456])
+ self.mox.VerifyAll()
+ self.assertEqual(actual_missed, [])
+
+ def testGetHotlistsByUserID(self):
+ self.SetUpLookupUserHotlists()
+ self.SetUpGetHotlists(123)
+ self.mox.ReplayAll()
+ hotlists = self.features_service.GetHotlistsByUserID(self.cnxn, 111)
+ self.assertEqual(len(hotlists), 1)
+ self.assertEqual(hotlists[0].hotlist_id, 123)
+ self.mox.VerifyAll()
+
+ def testGetHotlistsByIssueID(self):
+ self.SetUpLookupIssueHotlists()
+ self.SetUpGetHotlists(123)
+ self.mox.ReplayAll()
+ hotlists = self.features_service.GetHotlistsByIssueID(self.cnxn, 987)
+ self.assertEqual(len(hotlists), 1)
+ self.assertEqual(hotlists[0].hotlist_id, 123)
+ self.mox.VerifyAll()
+
+ def SetUpUpdateHotlistRoles(
+ self, hotlist_id, owner_ids, editor_ids, follower_ids):
+
+ self.features_service.hotlist2user_tbl.Delete(
+ self.cnxn, hotlist_id=hotlist_id, commit=False)
+
+ insert_rows = [(hotlist_id, user_id, 'owner') for user_id in owner_ids]
+ insert_rows.extend(
+ [(hotlist_id, user_id, 'editor') for user_id in editor_ids])
+ insert_rows.extend(
+ [(hotlist_id, user_id, 'follower') for user_id in follower_ids])
+ self.features_service.hotlist2user_tbl.InsertRows(
+ self.cnxn, ['hotlist_id', 'user_id', 'role_name'],
+ insert_rows, commit=False)
+
+ self.cnxn.Commit()
+
+ def testUpdateHotlistRoles(self):
+ self.SetUpGetHotlists(456)
+ self.SetUpUpdateHotlistRoles(456, [111, 222], [333], [])
+ self.mox.ReplayAll()
+ self.features_service.UpdateHotlistRoles(
+ self.cnxn, 456, [111, 222], [333], [])
+ self.mox.VerifyAll()
+
+ def SetUpUpdateHotlistIssues(self, items):
+ hotlist = fake.Hotlist(hotlist_name='hotlist', hotlist_id=456)
+ hotlist.items = items
+ self.features_service.GetHotlist = mock.Mock(return_value=hotlist)
+ self.features_service.hotlist2issue_tbl.Delete = mock.Mock()
+ self.features_service.hotlist2issue_tbl.InsertRows = mock.Mock()
+ self.issue_service.GetIssues = mock.Mock()
+ return hotlist
+
+ def testUpdateHotlistIssues_ChangeIssues(self):
+ original_items = [
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78902, rank=11, adder_id=333, date_added=2345), # update
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78904, rank=0, adder_id=333, date_added=2345) # same
+ ]
+ hotlist = self.SetUpUpdateHotlistIssues(original_items)
+ updated_items = [
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78902, rank=13, adder_id=333, date_added=2345), # update
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78903, rank=23, adder_id=333, date_added=2345) # new
+ ]
+
+ self.features_service.UpdateHotlistIssues(
+ self.cnxn, hotlist.hotlist_id, updated_items, [], self.issue_service,
+ self.chart_service)
+
+ insert_rows = [
+ (hotlist.hotlist_id, 78902, 13, 333, 2345, ''),
+ (hotlist.hotlist_id, 78903, 23, 333, 2345, '')
+ ]
+ self.features_service.hotlist2issue_tbl.InsertRows.assert_called_once_with(
+ self.cnxn,
+ cols=features_svc.HOTLIST2ISSUE_COLS,
+ row_values=insert_rows,
+ commit=False)
+ self.features_service.hotlist2issue_tbl.Delete.assert_called_once_with(
+ self.cnxn,
+ hotlist_id=hotlist.hotlist_id,
+ issue_id=[78902, 78903],
+ commit=False)
+
+ # New hotlist itmes includes updated_items and unchanged items.
+ expected_all_items = [
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78904, rank=0, adder_id=333, date_added=2345),
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78902, rank=13, adder_id=333, date_added=2345),
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78903, rank=23, adder_id=333, date_added=2345)
+ ]
+ self.assertEqual(hotlist.items, expected_all_items)
+
+ # Assert we're storing the new snapshots of the affected issues.
+ self.issue_service.GetIssues.assert_called_once_with(
+ self.cnxn, [78902, 78903])
+
+ def testUpdateHotlistIssues_RemoveIssues(self):
+ original_items = [
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78901, rank=10, adder_id=222, date_added=2348), # remove
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78904, rank=0, adder_id=333, date_added=2345), # same
+ ]
+ hotlist = self.SetUpUpdateHotlistIssues(original_items)
+ remove_issue_ids = [78901]
+
+ self.features_service.UpdateHotlistIssues(
+ self.cnxn, hotlist.hotlist_id, [], remove_issue_ids, self.issue_service,
+ self.chart_service)
+
+ self.features_service.hotlist2issue_tbl.Delete.assert_called_once_with(
+ self.cnxn,
+ hotlist_id=hotlist.hotlist_id,
+ issue_id=remove_issue_ids,
+ commit=False)
+
+ # New hotlist itmes includes updated_items and unchanged items.
+ expected_all_items = [
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78904, rank=0, adder_id=333, date_added=2345)
+ ]
+ self.assertEqual(hotlist.items, expected_all_items)
+
+ # Assert we're storing the new snapshots of the affected issues.
+ self.issue_service.GetIssues.assert_called_once_with(self.cnxn, [78901])
+
+ def testUpdateHotlistIssues_RemoveAndChange(self):
+ original_items = [
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78901, rank=10, adder_id=222, date_added=2348), # remove
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78902, rank=11, adder_id=333, date_added=2345), # update
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78904, rank=0, adder_id=333, date_added=2345) # same
+ ]
+ hotlist = self.SetUpUpdateHotlistIssues(original_items)
+ # test 78902 gets added back with `updated_items`
+ remove_issue_ids = [78901, 78902]
+ updated_items = [
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78902, rank=13, adder_id=333, date_added=2345),
+ ]
+
+ self.features_service.UpdateHotlistIssues(
+ self.cnxn, hotlist.hotlist_id, updated_items, remove_issue_ids,
+ self.issue_service, self.chart_service)
+
+ delete_calls = [
+ mock.call(
+ self.cnxn,
+ hotlist_id=hotlist.hotlist_id,
+ issue_id=remove_issue_ids,
+ commit=False),
+ mock.call(
+ self.cnxn,
+ hotlist_id=hotlist.hotlist_id,
+ issue_id=[78902],
+ commit=False)
+ ]
+ self.assertEqual(
+ self.features_service.hotlist2issue_tbl.Delete.mock_calls, delete_calls)
+
+ insert_rows = [
+ (hotlist.hotlist_id, 78902, 13, 333, 2345, ''),
+ ]
+ self.features_service.hotlist2issue_tbl.InsertRows.assert_called_once_with(
+ self.cnxn,
+ cols=features_svc.HOTLIST2ISSUE_COLS,
+ row_values=insert_rows,
+ commit=False)
+
+ # New hotlist itmes includes updated_items and unchanged items.
+ expected_all_items = [
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78904, rank=0, adder_id=333, date_added=2345),
+ features_pb2.Hotlist.HotlistItem(
+ issue_id=78902, rank=13, adder_id=333, date_added=2345),
+ ]
+ self.assertEqual(hotlist.items, expected_all_items)
+
+ # Assert we're storing the new snapshots of the affected issues.
+ self.issue_service.GetIssues.assert_called_once_with(
+ self.cnxn, [78901, 78902])
+
+ def testUpdateHotlistIssues_NoChanges(self):
+ with self.assertRaises(exceptions.InputException):
+ self.features_service.UpdateHotlistIssues(
+ self.cnxn, 456, [], None, self.issue_service, self.chart_service)
+
+ def SetUpUpdateHotlistItems(self, cnxn, hotlist_id, remove, added_tuples):
+ self.features_service.hotlist2issue_tbl.Delete(
+ cnxn, hotlist_id=hotlist_id, issue_id=remove, commit=False)
+ rank = 1
+ added_tuples_with_rank = [(issue_id, rank+10*mult, user_id, ts, note) for
+ mult, (issue_id, user_id, ts, note) in
+ enumerate(added_tuples)]
+ insert_rows = [(hotlist_id, issue_id,
+ rank, user_id, date, note) for
+ (issue_id, rank, user_id, date, note) in
+ added_tuples_with_rank]
+ self.features_service.hotlist2issue_tbl.InsertRows(
+ cnxn, cols=features_svc.HOTLIST2ISSUE_COLS,
+ row_values=insert_rows, commit=False)
+
+ def testAddIssuesToHotlists(self):
+ added_tuples = [
+ (111, None, None, ''),
+ (222, None, None, ''),
+ (333, None, None, '')]
+ issues = [
+ tracker_pb2.Issue(issue_id=issue_id)
+ for issue_id, _, _, _ in added_tuples
+ ]
+ self.SetUpGetHotlists(456)
+ self.SetUpUpdateHotlistItems(
+ self.cnxn, 456, [], added_tuples)
+ self.SetUpGetHotlists(567)
+ self.SetUpUpdateHotlistItems(
+ self.cnxn, 567, [], added_tuples)
+
+ self.mox.StubOutWithMock(self.issue_service, 'GetIssues')
+ self.issue_service.GetIssues(self.cnxn,
+ [111, 222, 333]).AndReturn(issues)
+ self.chart_service.StoreIssueSnapshots(self.cnxn, issues,
+ commit=False)
+ self.mox.ReplayAll()
+ self.features_service.AddIssuesToHotlists(
+ self.cnxn, [456, 567], added_tuples, self.issue_service,
+ self.chart_service, commit=False)
+ self.mox.VerifyAll()
+
+ def testRemoveIssuesFromHotlists(self):
+ issue_rows = [
+ (456, 555, 1, None, None, ''),
+ (456, 666, 11, None, None, ''),
+ ]
+ issues = [tracker_pb2.Issue(issue_id=issue_rows[0][1])]
+ self.SetUpGetHotlists(456, issue_rows=issue_rows)
+ self.SetUpUpdateHotlistItems(
+ self. cnxn, 456, [555], [])
+ issue_rows = [
+ (789, 555, 1, None, None, ''),
+ (789, 666, 11, None, None, ''),
+ ]
+ self.SetUpGetHotlists(789, issue_rows=issue_rows)
+ self.SetUpUpdateHotlistItems(
+ self. cnxn, 789, [555], [])
+ self.mox.StubOutWithMock(self.issue_service, 'GetIssues')
+ self.issue_service.GetIssues(self.cnxn,
+ [555]).AndReturn(issues)
+ self.chart_service.StoreIssueSnapshots(self.cnxn, issues, commit=False)
+ self.mox.ReplayAll()
+ self.features_service.RemoveIssuesFromHotlists(
+ self.cnxn, [456, 789], [555], self.issue_service, self.chart_service,
+ commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateHotlistItems(self):
+ self.SetUpGetHotlists(456)
+ self.SetUpUpdateHotlistItems(
+ self. cnxn, 456, [], [
+ (111, None, None, ''),
+ (222, None, None, ''),
+ (333, None, None, '')])
+ self.mox.ReplayAll()
+ self.features_service.UpdateHotlistItems(
+ self.cnxn, 456, [],
+ [(111, None, None, ''),
+ (222, None, None, ''),
+ (333, None, None, '')], commit=False)
+ self.mox.VerifyAll()
+
+ def SetUpDeleteHotlist(self, cnxn, hotlist_id):
+ hotlist_rows = [(hotlist_id, 'hotlist', 'test hotlist',
+ 'test list', False, '')]
+ self.SetUpGetHotlists(678, hotlist_rows=hotlist_rows,
+ role_rows=[(hotlist_id, 111, 'owner', )])
+ self.features_service.hotlist2issue_tbl.Select(self.cnxn,
+ cols=['Issue.project_id'], hotlist_id=hotlist_id, distinct=True,
+ left_joins=[('Issue ON issue_id = id', [])]).AndReturn([(1,)])
+ self.features_service.hotlist_tbl.Update(cnxn, {'is_deleted': True},
+ commit=False, id=hotlist_id)
+
+ def testDeleteHotlist(self):
+ self.SetUpDeleteHotlist(self.cnxn, 678)
+ self.mox.ReplayAll()
+ self.features_service.DeleteHotlist(self.cnxn, 678, commit=False)
+ self.mox.VerifyAll()
+
+ def testExpungeHotlists(self):
+ hotliststar_tbl = mock.Mock()
+ star_service = star_svc.AbstractStarService(
+ self.cache_manager, hotliststar_tbl, 'hotlist_id', 'user_id', 'hotlist')
+ hotliststar_tbl.Delete = mock.Mock()
+ user_service = user_svc.UserService(self.cache_manager)
+ user_service.hotlistvisithistory_tbl.Delete = mock.Mock()
+ chart_service = chart_svc.ChartService(self.config_service)
+ self.cnxn.Execute = mock.Mock()
+
+ hotlist1 = fake.Hotlist(hotlist_name='unique', hotlist_id=678,
+ owner_ids=[111], editor_ids=[222, 333])
+ hotlist2 = fake.Hotlist(hotlist_name='unique2', hotlist_id=679,
+ owner_ids=[111])
+ hotlists_by_id = {hotlist1.hotlist_id: hotlist1,
+ hotlist2.hotlist_id: hotlist2}
+ self.features_service.GetHotlists = mock.Mock(return_value=hotlists_by_id)
+ self.features_service.hotlist2user_tbl.Delete = mock.Mock()
+ self.features_service.hotlist2issue_tbl.Delete = mock.Mock()
+ self.features_service.hotlist_tbl.Delete = mock.Mock()
+ # cache invalidation mocks
+ self.features_service.hotlist_2lc.InvalidateKeys = mock.Mock()
+ self.features_service.hotlist_id_2lc.InvalidateKeys = mock.Mock()
+ self.features_service.hotlist_user_to_ids.InvalidateKeys = mock.Mock()
+ self.config_service.InvalidateMemcacheForEntireProject = mock.Mock()
+
+ hotlists_project_id = 787
+ self.features_service.GetProjectIDsFromHotlist = mock.Mock(
+ return_value=[hotlists_project_id])
+
+ hotlist_ids = hotlists_by_id.keys()
+ commit = True # commit in ExpungeHotlists should be True by default.
+ self.features_service.ExpungeHotlists(
+ self.cnxn, hotlist_ids, star_service, user_service, chart_service)
+
+ star_calls = [
+ mock.call(
+ self.cnxn, commit=commit, limit=None, hotlist_id=hotlist_ids[0]),
+ mock.call(
+ self.cnxn, commit=commit, limit=None, hotlist_id=hotlist_ids[1])]
+ hotliststar_tbl.Delete.assert_has_calls(star_calls)
+
+ self.cnxn.Execute.assert_called_once_with(
+ 'DELETE FROM IssueSnapshot2Hotlist WHERE hotlist_id IN (%s,%s)',
+ [678, 679], commit=commit)
+ user_service.hotlistvisithistory_tbl.Delete.assert_called_once_with(
+ self.cnxn, commit=commit, hotlist_id=hotlist_ids)
+
+ self.features_service.hotlist2user_tbl.Delete.assert_called_once_with(
+ self.cnxn, hotlist_id=hotlist_ids, commit=commit)
+ self.features_service.hotlist2issue_tbl.Delete.assert_called_once_with(
+ self.cnxn, hotlist_id=hotlist_ids, commit=commit)
+ self.features_service.hotlist_tbl.Delete.assert_called_once_with(
+ self.cnxn, id=hotlist_ids, commit=commit)
+ # cache invalidation checks
+ self.features_service.hotlist_2lc.InvalidateKeys.assert_called_once_with(
+ self.cnxn, hotlist_ids)
+ invalidate_owner_calls = [
+ mock.call(self.cnxn, [(hotlist1.name, hotlist1.owner_ids[0])]),
+ mock.call(self.cnxn, [(hotlist2.name, hotlist2.owner_ids[0])])]
+ self.features_service.hotlist_id_2lc.InvalidateKeys.assert_has_calls(
+ invalidate_owner_calls)
+ self.features_service.hotlist_user_to_ids.InvalidateKeys.\
+assert_called_once_with(
+ self.cnxn, [333, 222, 111])
+ self.config_service.InvalidateMemcacheForEntireProject.\
+assert_called_once_with(hotlists_project_id)
+
+ def testExpungeUsersInHotlists(self):
+ hotliststar_tbl = mock.Mock()
+ star_service = star_svc.AbstractStarService(
+ self.cache_manager, hotliststar_tbl, 'hotlist_id', 'user_id', 'hotlist')
+ user_service = user_svc.UserService(self.cache_manager)
+ chart_service = chart_svc.ChartService(self.config_service)
+ user_ids = [111, 222]
+
+ # hotlist1 will get transferred to 333
+ hotlist1 = fake.Hotlist(hotlist_name='unique', hotlist_id=123,
+ owner_ids=[111], editor_ids=[222, 333])
+ # hotlist2 will get deleted
+ hotlist2 = fake.Hotlist(hotlist_name='name', hotlist_id=223,
+ owner_ids=[222], editor_ids=[111, 333])
+ delete_hotlists = [hotlist2.hotlist_id]
+ delete_hotlist_project_id = 788
+ self.features_service.GetProjectIDsFromHotlist = mock.Mock(
+ return_value=[delete_hotlist_project_id])
+ self.config_service.InvalidateMemcacheForEntireProject = mock.Mock()
+ hotlists_by_user_id = {
+ 111: [hotlist1.hotlist_id, hotlist2.hotlist_id],
+ 222: [hotlist1.hotlist_id, hotlist2.hotlist_id],
+ 333: [hotlist1.hotlist_id, hotlist2.hotlist_id]}
+ self.features_service.LookupUserHotlists = mock.Mock(
+ return_value=hotlists_by_user_id)
+ hotlists_by_id = {hotlist1.hotlist_id: hotlist1,
+ hotlist2.hotlist_id: hotlist2}
+ self.features_service.GetHotlistsByID = mock.Mock(
+ return_value=(hotlists_by_id, []))
+
+ # User 333 already has a hotlist named 'name'.
+ def side_effect(_cnxn, hotlist_names, owner_ids):
+ if 333 in owner_ids and 'name' in hotlist_names:
+ return {('name', 333): 567}
+ return {}
+ self.features_service.LookupHotlistIDs = mock.Mock(
+ side_effect=side_effect)
+ # Called to transfer hotlist ownership
+ self.features_service.UpdateHotlistRoles = mock.Mock()
+
+ # Called to expunge users and hotlists
+ self.features_service.hotlist2user_tbl.Delete = mock.Mock()
+ self.features_service.hotlist2issue_tbl.Update = mock.Mock()
+ user_service.hotlistvisithistory_tbl.Delete = mock.Mock()
+
+ # Called to expunge hotlists
+ hotlists_by_id = {hotlist1.hotlist_id: hotlist1,
+ hotlist2.hotlist_id: hotlist2}
+ self.features_service.GetHotlists = mock.Mock(
+ return_value=hotlists_by_id)
+ self.features_service.hotlist2issue_tbl.Delete = mock.Mock()
+ self.features_service.hotlist_tbl.Delete = mock.Mock()
+ hotliststar_tbl.Delete = mock.Mock()
+
+ self.features_service.ExpungeUsersInHotlists(
+ self.cnxn, user_ids, star_service, user_service, chart_service)
+
+ self.features_service.UpdateHotlistRoles.assert_called_once_with(
+ self.cnxn, hotlist1.hotlist_id, [333], [222], [], commit=False)
+
+ self.features_service.hotlist2user_tbl.Delete.assert_has_calls(
+ [mock.call(self.cnxn, user_id=user_ids, commit=False),
+ mock.call(self.cnxn, hotlist_id=delete_hotlists, commit=False)])
+ self.features_service.hotlist2issue_tbl.Update.assert_called_once_with(
+ self.cnxn, {'adder_id': framework_constants.DELETED_USER_ID},
+ adder_id=user_ids, commit=False)
+ user_service.hotlistvisithistory_tbl.Delete.assert_has_calls(
+ [mock.call(self.cnxn, user_id=user_ids, commit=False),
+ mock.call(self.cnxn, hotlist_id=delete_hotlists, commit=False)])
+
+ self.features_service.hotlist2issue_tbl.Delete.assert_called_once_with(
+ self.cnxn, hotlist_id=delete_hotlists, commit=False)
+ hotliststar_tbl.Delete.assert_called_once_with(
+ self.cnxn, commit=False, limit=None, hotlist_id=delete_hotlists[0])
+ self.features_service.hotlist_tbl.Delete.assert_called_once_with(
+ self.cnxn, id=delete_hotlists, commit=False)
+
+
+ def testGetProjectIDsFromHotlist(self):
+ self.features_service.hotlist2issue_tbl.Select(self.cnxn,
+ cols=['Issue.project_id'], hotlist_id=678, distinct=True,
+ left_joins=[('Issue ON issue_id = id', [])]).AndReturn(
+ [(789,), (787,), (788,)])
+
+ self.mox.ReplayAll()
+ project_ids = self.features_service.GetProjectIDsFromHotlist(self.cnxn, 678)
+ self.mox.VerifyAll()
+ self.assertEqual([789, 787, 788], project_ids)
diff --git a/services/test/fulltext_helpers_test.py b/services/test/fulltext_helpers_test.py
new file mode 100644
index 0000000..1e4f0c9
--- /dev/null
+++ b/services/test/fulltext_helpers_test.py
@@ -0,0 +1,247 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the fulltext_helpers module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import mox
+
+from google.appengine.api import search
+
+from proto import ast_pb2
+from proto import tracker_pb2
+from search import query2ast
+from services import fulltext_helpers
+
+
+TEXT_HAS = ast_pb2.QueryOp.TEXT_HAS
+NOT_TEXT_HAS = ast_pb2.QueryOp.NOT_TEXT_HAS
+GE = ast_pb2.QueryOp.GE
+
+
+class MockResult(object):
+
+ def __init__(self, doc_id):
+ self.doc_id = doc_id
+
+
+class MockSearchResponse(object):
+ """Mock object that can be iterated over in batches."""
+
+ def __init__(self, results, cursor):
+ """Constructor.
+
+ Args:
+ results: list of strings for document IDs.
+ cursor: search.Cursor object, if there are more results to
+ retrieve in another round-trip. Or, None if there are not.
+ """
+ self.results = [MockResult(r) for r in results]
+ self.cursor = cursor
+
+ def __iter__(self):
+ """The response itself is an iterator over the results."""
+ return self.results.__iter__()
+
+
+class FulltextHelpersTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.any_field_fd = tracker_pb2.FieldDef(
+ field_name='any_field', field_type=tracker_pb2.FieldTypes.STR_TYPE)
+ self.summary_fd = tracker_pb2.FieldDef(
+ field_name='summary', field_type=tracker_pb2.FieldTypes.STR_TYPE)
+ self.milestone_fd = tracker_pb2.FieldDef(
+ field_name='milestone', field_type=tracker_pb2.FieldTypes.STR_TYPE,
+ field_id=123)
+ self.fulltext_fields = ['summary']
+
+ self.mock_index = self.mox.CreateMockAnything()
+ self.mox.StubOutWithMock(search, 'Index')
+ self.query = None
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def RecordQuery(self, query):
+ self.query = query
+
+ def testBuildFTSQuery_EmptyQueryConjunction(self):
+ query_ast_conj = ast_pb2.Conjunction()
+ fulltext_query = fulltext_helpers.BuildFTSQuery(
+ query_ast_conj, self.fulltext_fields)
+ self.assertEqual(None, fulltext_query)
+
+ def testBuildFTSQuery_NoFullTextConditions(self):
+ estimated_hours_fd = tracker_pb2.FieldDef(
+ field_name='estimate', field_type=tracker_pb2.FieldTypes.INT_TYPE,
+ field_id=124)
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.MakeCond(TEXT_HAS, [estimated_hours_fd], [], [40])])
+ fulltext_query = fulltext_helpers.BuildFTSQuery(
+ query_ast_conj, self.fulltext_fields)
+ self.assertEqual(None, fulltext_query)
+
+ def testBuildFTSQuery_Normal(self):
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.MakeCond(TEXT_HAS, [self.summary_fd], ['needle'], []),
+ ast_pb2.MakeCond(TEXT_HAS, [self.milestone_fd], ['Q3', 'Q4'], [])])
+ fulltext_query = fulltext_helpers.BuildFTSQuery(
+ query_ast_conj, self.fulltext_fields)
+ self.assertEqual(
+ '(summary:"needle") (custom_123:"Q3" OR custom_123:"Q4")',
+ fulltext_query)
+
+ def testBuildFTSQuery_WithQuotes(self):
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.MakeCond(TEXT_HAS, [self.summary_fd], ['"needle haystack"'],
+ [])])
+ fulltext_query = fulltext_helpers.BuildFTSQuery(
+ query_ast_conj, self.fulltext_fields)
+ self.assertEqual('(summary:"needle haystack")', fulltext_query)
+
+ def testBuildFTSQuery_IngoreColonInText(self):
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.MakeCond(TEXT_HAS, [self.summary_fd], ['"needle:haystack"'],
+ [])])
+ fulltext_query = fulltext_helpers.BuildFTSQuery(
+ query_ast_conj, self.fulltext_fields)
+ self.assertEqual('(summary:"needle haystack")', fulltext_query)
+
+ def testBuildFTSQuery_InvalidQuery(self):
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.MakeCond(TEXT_HAS, [self.summary_fd], ['haystack"needle'], []),
+ ast_pb2.MakeCond(TEXT_HAS, [self.milestone_fd], ['Q3', 'Q4'], [])])
+ with self.assertRaises(AssertionError):
+ fulltext_helpers.BuildFTSQuery(
+ query_ast_conj, self.fulltext_fields)
+
+ def testBuildFTSQuery_SpecialPrefixQuery(self):
+ special_prefix = query2ast.NON_OP_PREFIXES[0]
+
+ # Test with summary field.
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.MakeCond(TEXT_HAS, [self.summary_fd],
+ ['%s//google.com' % special_prefix], []),
+ ast_pb2.MakeCond(TEXT_HAS, [self.milestone_fd], ['Q3', 'Q4'], [])])
+ fulltext_query = fulltext_helpers.BuildFTSQuery(
+ query_ast_conj, self.fulltext_fields)
+ self.assertEqual(
+ '(summary:"%s//google.com") (custom_123:"Q3" OR custom_123:"Q4")' % (
+ special_prefix),
+ fulltext_query)
+
+ # Test with any field.
+ any_fd = tracker_pb2.FieldDef(
+ field_name=ast_pb2.ANY_FIELD,
+ field_type=tracker_pb2.FieldTypes.STR_TYPE)
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.MakeCond(
+ TEXT_HAS, [any_fd], ['%s//google.com' % special_prefix], []),
+ ast_pb2.MakeCond(TEXT_HAS, [self.milestone_fd], ['Q3', 'Q4'], [])])
+ fulltext_query = fulltext_helpers.BuildFTSQuery(
+ query_ast_conj, self.fulltext_fields)
+ self.assertEqual(
+ '("%s//google.com") (custom_123:"Q3" OR custom_123:"Q4")' % (
+ special_prefix),
+ fulltext_query)
+
+ def testBuildFTSCondition_IgnoredOperator(self):
+ query_cond = ast_pb2.MakeCond(
+ GE, [self.summary_fd], ['needle'], [])
+ fulltext_query_clause = fulltext_helpers._BuildFTSCondition(
+ query_cond, self.fulltext_fields)
+ self.assertEqual('', fulltext_query_clause)
+
+ def testBuildFTSCondition_BuiltinField(self):
+ query_cond = ast_pb2.MakeCond(
+ TEXT_HAS, [self.summary_fd], ['needle'], [])
+ fulltext_query_clause = fulltext_helpers._BuildFTSCondition(
+ query_cond, self.fulltext_fields)
+ self.assertEqual('(summary:"needle")', fulltext_query_clause)
+
+ def testBuildFTSCondition_NonStringField(self):
+ est_days_fd = tracker_pb2.FieldDef(
+ field_name='EstDays', field_id=123,
+ field_type=tracker_pb2.FieldTypes.INT_TYPE)
+ query_cond = ast_pb2.MakeCond(
+ TEXT_HAS, [est_days_fd], ['needle'], [])
+ fulltext_query_clause = fulltext_helpers._BuildFTSCondition(
+ query_cond, self.fulltext_fields)
+ # Ignore in FTS, this search condition is done in SQL.
+ self.assertEqual('', fulltext_query_clause)
+
+ def testBuildFTSCondition_Negatation(self):
+ query_cond = ast_pb2.MakeCond(
+ NOT_TEXT_HAS, [self.summary_fd], ['needle'], [])
+ fulltext_query_clause = fulltext_helpers._BuildFTSCondition(
+ query_cond, self.fulltext_fields)
+ self.assertEqual('NOT (summary:"needle")', fulltext_query_clause)
+
+ def testBuildFTSCondition_QuickOR(self):
+ query_cond = ast_pb2.MakeCond(
+ TEXT_HAS, [self.summary_fd], ['needle', 'pin'], [])
+ fulltext_query_clause = fulltext_helpers._BuildFTSCondition(
+ query_cond, self.fulltext_fields)
+ self.assertEqual(
+ '(summary:"needle" OR summary:"pin")',
+ fulltext_query_clause)
+
+ def testBuildFTSCondition_NegatedQuickOR(self):
+ query_cond = ast_pb2.MakeCond(
+ NOT_TEXT_HAS, [self.summary_fd], ['needle', 'pin'], [])
+ fulltext_query_clause = fulltext_helpers._BuildFTSCondition(
+ query_cond, self.fulltext_fields)
+ self.assertEqual(
+ 'NOT (summary:"needle" OR summary:"pin")',
+ fulltext_query_clause)
+
+ def testBuildFTSCondition_AnyField(self):
+ query_cond = ast_pb2.MakeCond(
+ TEXT_HAS, [self.any_field_fd], ['needle'], [])
+ fulltext_query_clause = fulltext_helpers._BuildFTSCondition(
+ query_cond, self.fulltext_fields)
+ self.assertEqual('("needle")', fulltext_query_clause)
+
+ def testBuildFTSCondition_NegatedAnyField(self):
+ query_cond = ast_pb2.MakeCond(
+ NOT_TEXT_HAS, [self.any_field_fd], ['needle'], [])
+ fulltext_query_clause = fulltext_helpers._BuildFTSCondition(
+ query_cond, self.fulltext_fields)
+ self.assertEqual('NOT ("needle")', fulltext_query_clause)
+
+ def testBuildFTSCondition_CrossProjectWithMultipleFieldDescriptors(self):
+ other_milestone_fd = tracker_pb2.FieldDef(
+ field_name='milestone', field_type=tracker_pb2.FieldTypes.STR_TYPE,
+ field_id=456)
+ query_cond = ast_pb2.MakeCond(
+ TEXT_HAS, [self.milestone_fd, other_milestone_fd], ['needle'], [])
+ fulltext_query_clause = fulltext_helpers._BuildFTSCondition(
+ query_cond, self.fulltext_fields)
+ self.assertEqual(
+ '(custom_123:"needle" OR custom_456:"needle")', fulltext_query_clause)
+
+ def SetUpComprehensiveSearch(self):
+ search.Index(name='search index name').AndReturn(
+ self.mock_index)
+ self.mock_index.search(mox.IgnoreArg()).WithSideEffects(
+ self.RecordQuery).AndReturn(
+ MockSearchResponse(['123', '234'], search.Cursor()))
+ self.mock_index.search(mox.IgnoreArg()).WithSideEffects(
+ self.RecordQuery).AndReturn(MockSearchResponse(['345'], None))
+
+ def testComprehensiveSearch(self):
+ self.SetUpComprehensiveSearch()
+ self.mox.ReplayAll()
+ project_ids = fulltext_helpers.ComprehensiveSearch(
+ 'browser', 'search index name')
+ self.mox.VerifyAll()
+ self.assertItemsEqual([123, 234, 345], project_ids)
diff --git a/services/test/issue_svc_test.py b/services/test/issue_svc_test.py
new file mode 100644
index 0000000..b6fe682
--- /dev/null
+++ b/services/test/issue_svc_test.py
@@ -0,0 +1,2754 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for issue_svc module."""
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import logging
+import time
+import unittest
+from mock import patch, Mock, ANY
+
+import mox
+
+from google.appengine.api import search
+from google.appengine.ext import testbed
+
+import settings
+from framework import exceptions
+from framework import framework_constants
+from framework import sql
+from proto import tracker_pb2
+from services import caches
+from services import chart_svc
+from services import issue_svc
+from services import service_manager
+from services import spam_svc
+from services import tracker_fulltext
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+
+
+class MockIndex(object):
+
+ def delete(self, string_list):
+ pass
+
+
+def MakeIssueService(project_service, config_service, cache_manager,
+ chart_service, my_mox):
+ issue_service = issue_svc.IssueService(
+ project_service, config_service, cache_manager, chart_service)
+ for table_var in [
+ 'issue_tbl', 'issuesummary_tbl', 'issue2label_tbl',
+ 'issue2component_tbl', 'issue2cc_tbl', 'issue2notify_tbl',
+ 'issue2fieldvalue_tbl', 'issuerelation_tbl', 'danglingrelation_tbl',
+ 'issueformerlocations_tbl', 'comment_tbl', 'commentcontent_tbl',
+ 'issueupdate_tbl', 'attachment_tbl', 'reindexqueue_tbl',
+ 'localidcounter_tbl', 'issuephasedef_tbl', 'issue2approvalvalue_tbl',
+ 'issueapproval2approver_tbl', 'issueapproval2comment_tbl',
+ 'commentimporter_tbl']:
+ setattr(issue_service, table_var, my_mox.CreateMock(sql.SQLTableManager))
+
+ return issue_service
+
+
+class TestableIssueTwoLevelCache(issue_svc.IssueTwoLevelCache):
+
+ def __init__(self, issue_list):
+ cache_manager = fake.CacheManager()
+ super(TestableIssueTwoLevelCache, self).__init__(
+ cache_manager, None, None, None)
+ self.cache = caches.RamCache(cache_manager, 'issue')
+ self.memcache_prefix = 'issue:'
+ self.pb_class = tracker_pb2.Issue
+
+ self.issue_dict = {
+ issue.issue_id: issue
+ for issue in issue_list}
+
+ def FetchItems(self, cnxn, issue_ids, shard_id=None):
+ return {
+ issue_id: self.issue_dict[issue_id]
+ for issue_id in issue_ids
+ if issue_id in self.issue_dict}
+
+
+class IssueIDTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cnxn = 'fake connection'
+ self.project_service = fake.ProjectService()
+ self.config_service = fake.ConfigService()
+ self.cache_manager = fake.CacheManager()
+ self.chart_service = chart_svc.ChartService(self.config_service)
+ self.issue_service = MakeIssueService(
+ self.project_service, self.config_service, self.cache_manager,
+ self.chart_service, self.mox)
+ self.issue_id_2lc = self.issue_service.issue_id_2lc
+ self.spam_service = fake.SpamService()
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testDeserializeIssueIDs_Empty(self):
+ issue_id_dict = self.issue_id_2lc._DeserializeIssueIDs([])
+ self.assertEqual({}, issue_id_dict)
+
+ def testDeserializeIssueIDs_Normal(self):
+ rows = [(789, 1, 78901), (789, 2, 78902), (789, 3, 78903)]
+ issue_id_dict = self.issue_id_2lc._DeserializeIssueIDs(rows)
+ expected = {
+ (789, 1): 78901,
+ (789, 2): 78902,
+ (789, 3): 78903,
+ }
+ self.assertEqual(expected, issue_id_dict)
+
+ def SetUpFetchItems(self):
+ where = [
+ ('(Issue.project_id = %s AND Issue.local_id IN (%s,%s,%s))',
+ [789, 1, 2, 3])]
+ rows = [(789, 1, 78901), (789, 2, 78902), (789, 3, 78903)]
+ self.issue_service.issue_tbl.Select(
+ self.cnxn, cols=['project_id', 'local_id', 'id'],
+ where=where, or_where_conds=True).AndReturn(rows)
+
+ def testFetchItems(self):
+ project_local_ids_list = [(789, 1), (789, 2), (789, 3)]
+ issue_ids = [78901, 78902, 78903]
+ self.SetUpFetchItems()
+ self.mox.ReplayAll()
+ issue_dict = self.issue_id_2lc.FetchItems(
+ self.cnxn, project_local_ids_list)
+ self.mox.VerifyAll()
+ self.assertItemsEqual(project_local_ids_list, list(issue_dict.keys()))
+ self.assertItemsEqual(issue_ids, list(issue_dict.values()))
+
+ def testKeyToStr(self):
+ self.assertEqual('789,1', self.issue_id_2lc._KeyToStr((789, 1)))
+
+ def testStrToKey(self):
+ self.assertEqual((789, 1), self.issue_id_2lc._StrToKey('789,1'))
+
+
+class IssueTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cnxn = 'fake connection'
+ self.project_service = fake.ProjectService()
+ self.config_service = fake.ConfigService()
+ self.cache_manager = fake.CacheManager()
+ self.chart_service = chart_svc.ChartService(self.config_service)
+ self.issue_service = MakeIssueService(
+ self.project_service, self.config_service, self.cache_manager,
+ self.chart_service, self.mox)
+ self.issue_2lc = self.issue_service.issue_2lc
+
+ now = int(time.time())
+ self.project_service.TestAddProject('proj', project_id=789)
+ self.issue_rows = [
+ (78901, 789, 1, 1, 111, 222,
+ now, now, now, now, now, now,
+ 0, 0, 0, 1, 0, False)]
+ self.summary_rows = [(78901, 'sum')]
+ self.label_rows = [(78901, 1, 0)]
+ self.component_rows = []
+ self.cc_rows = [(78901, 333, 0)]
+ self.notify_rows = []
+ self.fieldvalue_rows = []
+ self.blocked_on_rows = (
+ (78901, 78902, 'blockedon', 20), (78903, 78901, 'blockedon', 10))
+ self.blocking_rows = ()
+ self.merged_rows = ()
+ self.relation_rows = (
+ self.blocked_on_rows + self.blocking_rows + self.merged_rows)
+ self.dangling_relation_rows = [
+ (78901, 'codesite', 5001, None, 'blocking'),
+ (78901, 'codesite', 5002, None, 'blockedon'),
+ (78901, None, None, 'b/1234567', 'blockedon')]
+ self.phase_rows = [(1, 'Canary', 1), (2, 'Stable', 11)]
+ self.approvalvalue_rows = [(22, 78901, 2, 'not_set', None, None),
+ (21, 78901, 1, 'needs_review', None, None),
+ (23, 78901, 1, 'not_set', None, None)]
+ self.av_approver_rows = [
+ (21, 111, 78901), (21, 222, 78901), (21, 333, 78901)]
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testUnpackApprovalValue(self):
+ row = next(
+ row for row in self.approvalvalue_rows if row[3] == 'needs_review')
+ av, issue_id = self.issue_2lc._UnpackApprovalValue(row)
+ self.assertEqual(av.status, tracker_pb2.ApprovalStatus.NEEDS_REVIEW)
+ self.assertIsNone(av.setter_id)
+ self.assertIsNone(av.set_on)
+ self.assertEqual(issue_id, 78901)
+ self.assertEqual(av.phase_id, 1)
+
+ def testUnpackApprovalValue_MissingStatus(self):
+ av, _issue_id = self.issue_2lc._UnpackApprovalValue(
+ (21, 78901, 1, '', None, None))
+ self.assertEqual(av.status, tracker_pb2.ApprovalStatus.NOT_SET)
+
+ def testUnpackPhase(self):
+ phase = self.issue_2lc._UnpackPhase(
+ self.phase_rows[0])
+ self.assertEqual(phase.name, 'Canary')
+ self.assertEqual(phase.phase_id, 1)
+ self.assertEqual(phase.rank, 1)
+
+ def testDeserializeIssues_Empty(self):
+ issue_dict = self.issue_2lc._DeserializeIssues(
+ self.cnxn, [], [], [], [], [], [], [], [], [], [], [], [])
+ self.assertEqual({}, issue_dict)
+
+ def testDeserializeIssues_Normal(self):
+ issue_dict = self.issue_2lc._DeserializeIssues(
+ self.cnxn, self.issue_rows, self.summary_rows, self.label_rows,
+ self.component_rows, self.cc_rows, self.notify_rows,
+ self.fieldvalue_rows, self.relation_rows, self.dangling_relation_rows,
+ self.phase_rows, self.approvalvalue_rows, self.av_approver_rows)
+ self.assertItemsEqual([78901], list(issue_dict.keys()))
+ issue = issue_dict[78901]
+ self.assertEqual(len(issue.phases), 2)
+ self.assertIsNotNone(tracker_bizobj.FindPhaseByID(1, issue.phases))
+ av_21 = tracker_bizobj.FindApprovalValueByID(
+ 21, issue.approval_values)
+ self.assertEqual(av_21.phase_id, 1)
+ self.assertItemsEqual(av_21.approver_ids, [111, 222, 333])
+ self.assertIsNotNone(tracker_bizobj.FindPhaseByID(2, issue.phases))
+ self.assertEqual(issue.phases,
+ [tracker_pb2.Phase(rank=1, phase_id=1, name='Canary'),
+ tracker_pb2.Phase(rank=11, phase_id=2, name='Stable')])
+ av_22 = tracker_bizobj.FindApprovalValueByID(
+ 22, issue.approval_values)
+ self.assertEqual(av_22.phase_id, 2)
+ self.assertEqual([
+ tracker_pb2.DanglingIssueRef(
+ project=row[1],
+ issue_id=row[2],
+ ext_issue_identifier=row[3])
+ for row in self.dangling_relation_rows
+ if row[4] == 'blockedon'
+ ], issue.dangling_blocked_on_refs)
+ self.assertEqual([
+ tracker_pb2.DanglingIssueRef(
+ project=row[1],
+ issue_id=row[2],
+ ext_issue_identifier=row[3])
+ for row in self.dangling_relation_rows
+ if row[4] == 'blocking'
+ ], issue.dangling_blocking_refs)
+
+ def testDeserializeIssues_UnexpectedLabel(self):
+ unexpected_label_rows = [
+ (78901, 999, 0)
+ ]
+ self.assertRaises(
+ AssertionError,
+ self.issue_2lc._DeserializeIssues,
+ self.cnxn, self.issue_rows, self.summary_rows, unexpected_label_rows,
+ self.component_rows, self.cc_rows, self.notify_rows,
+ self.fieldvalue_rows, self.relation_rows, self.dangling_relation_rows,
+ self.phase_rows, self.approvalvalue_rows, self.av_approver_rows)
+
+ def testDeserializeIssues_UnexpectedIssueRelation(self):
+ unexpected_relation_rows = [
+ (78990, 78999, 'blockedon', None)
+ ]
+ self.assertRaises(
+ AssertionError,
+ self.issue_2lc._DeserializeIssues,
+ self.cnxn, self.issue_rows, self.summary_rows, self.label_rows,
+ self.component_rows, self.cc_rows, self.notify_rows,
+ self.fieldvalue_rows, unexpected_relation_rows,
+ self.dangling_relation_rows, self.phase_rows, self.approvalvalue_rows,
+ self.av_approver_rows)
+
+ def testDeserializeIssues_ExternalMergedInto(self):
+ """_DeserializeIssues handles external mergedinto refs correctly."""
+ dangling_relation_rows = self.dangling_relation_rows + [
+ (78901, None, None, 'b/1234567', 'mergedinto')]
+ issue_dict = self.issue_2lc._DeserializeIssues(
+ self.cnxn, self.issue_rows, self.summary_rows, self.label_rows,
+ self.component_rows, self.cc_rows, self.notify_rows,
+ self.fieldvalue_rows, self.relation_rows, dangling_relation_rows,
+ self.phase_rows, self.approvalvalue_rows, self.av_approver_rows)
+ self.assertEqual('b/1234567', issue_dict[78901].merged_into_external)
+
+ def SetUpFetchItems(self, issue_ids, has_approvalvalues=True):
+ shard_id = None
+ self.issue_service.issue_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUE_COLS, id=issue_ids,
+ shard_id=shard_id).AndReturn(self.issue_rows)
+ self.issue_service.issuesummary_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUESUMMARY_COLS, shard_id=shard_id,
+ issue_id=issue_ids).AndReturn(self.summary_rows)
+ self.issue_service.issue2label_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUE2LABEL_COLS, shard_id=shard_id,
+ issue_id=issue_ids).AndReturn(self.label_rows)
+ self.issue_service.issue2component_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUE2COMPONENT_COLS, shard_id=shard_id,
+ issue_id=issue_ids).AndReturn(self.component_rows)
+ self.issue_service.issue2cc_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUE2CC_COLS, shard_id=shard_id,
+ issue_id=issue_ids).AndReturn(self.cc_rows)
+ self.issue_service.issue2notify_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUE2NOTIFY_COLS, shard_id=shard_id,
+ issue_id=issue_ids).AndReturn(self.notify_rows)
+ self.issue_service.issue2fieldvalue_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUE2FIELDVALUE_COLS, shard_id=shard_id,
+ issue_id=issue_ids).AndReturn(self.fieldvalue_rows)
+ if has_approvalvalues:
+ self.issue_service.issuephasedef_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUEPHASEDEF_COLS,
+ id=[1, 2]).AndReturn(self.phase_rows)
+ self.issue_service.issue2approvalvalue_tbl.Select(
+ self.cnxn,
+ cols=issue_svc.ISSUE2APPROVALVALUE_COLS,
+ issue_id=issue_ids).AndReturn(self.approvalvalue_rows)
+ else:
+ self.issue_service.issue2approvalvalue_tbl.Select(
+ self.cnxn,
+ cols=issue_svc.ISSUE2APPROVALVALUE_COLS,
+ issue_id=issue_ids).AndReturn([])
+ self.issue_service.issueapproval2approver_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUEAPPROVAL2APPROVER_COLS,
+ issue_id=issue_ids).AndReturn(self.av_approver_rows)
+ self.issue_service.issuerelation_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUERELATION_COLS,
+ issue_id=issue_ids, kind='blockedon',
+ order_by=[('issue_id', []), ('rank DESC', []),
+ ('dst_issue_id', [])]).AndReturn(self.blocked_on_rows)
+ self.issue_service.issuerelation_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUERELATION_COLS,
+ dst_issue_id=issue_ids, kind='blockedon',
+ order_by=[('issue_id', []), ('dst_issue_id', [])]
+ ).AndReturn(self.blocking_rows)
+ self.issue_service.issuerelation_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUERELATION_COLS,
+ where=[('(issue_id IN (%s) OR dst_issue_id IN (%s))',
+ issue_ids + issue_ids),
+ ('kind != %s', ['blockedon'])]).AndReturn(self.merged_rows)
+ self.issue_service.danglingrelation_tbl.Select(
+ self.cnxn, cols=issue_svc.DANGLINGRELATION_COLS, # Note: no shard
+ issue_id=issue_ids).AndReturn(self.dangling_relation_rows)
+
+ def testFetchItems(self):
+ issue_ids = [78901]
+ self.SetUpFetchItems(issue_ids)
+ self.mox.ReplayAll()
+ issue_dict = self.issue_2lc.FetchItems(self.cnxn, issue_ids)
+ self.mox.VerifyAll()
+ self.assertItemsEqual(issue_ids, list(issue_dict.keys()))
+ self.assertEqual(2, len(issue_dict[78901].phases))
+
+ def testFetchItemsNoApprovalValues(self):
+ issue_ids = [78901]
+ self.SetUpFetchItems(issue_ids, False)
+ self.mox.ReplayAll()
+ issue_dict = self.issue_2lc.FetchItems(self.cnxn, issue_ids)
+ self.mox.VerifyAll()
+ self.assertItemsEqual(issue_ids, list(issue_dict.keys()))
+ self.assertEqual([], issue_dict[78901].phases)
+
+
+class IssueServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.services = service_manager.Services()
+ self.services.user = fake.UserService()
+ self.reporter = self.services.user.TestAddUser('reporter@example.com', 111)
+ self.services.usergroup = fake.UserGroupService()
+ self.services.project = fake.ProjectService()
+ self.project = self.services.project.TestAddProject('proj', project_id=789)
+ self.services.config = fake.ConfigService()
+ self.services.features = fake.FeaturesService()
+ self.cache_manager = fake.CacheManager()
+ self.services.chart = chart_svc.ChartService(self.services.config)
+ self.services.issue = MakeIssueService(
+ self.services.project, self.services.config, self.cache_manager,
+ self.services.chart, self.mox)
+ self.services.spam = self.mox.CreateMock(spam_svc.SpamService)
+ self.now = int(time.time())
+ self.patcher = patch('services.tracker_fulltext.IndexIssues')
+ self.patcher.start()
+ self.mox.StubOutWithMock(self.services.chart, 'StoreIssueSnapshots')
+
+ def classifierResult(self, score, failed_open=False):
+ return {'confidence_is_spam': score,
+ 'failed_open': failed_open}
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+ self.patcher.stop()
+
+ ### Issue ID lookups
+
+ def testLookupIssueIDsFollowMoves(self):
+ moved_issue_id = 78901
+ moved_pair = (789, 1)
+ missing_pair = (1, 1)
+ cached_issue_id = 78902
+ cached_pair = (789, 2)
+ uncached_issue_id = 78903
+ uncached_pair = (789, 3)
+ uncached_issue_id_2 = 78904
+ uncached_pair_2 = (789, 4)
+ self.services.issue.issue_id_2lc.CacheItem(cached_pair, cached_issue_id)
+
+ # Simulate rows returned in reverse order (to verify the method still
+ # returns them in the specified order).
+ uncached_rows = [
+ (uncached_pair_2[0], uncached_pair_2[1], uncached_issue_id_2),
+ (uncached_pair[0], uncached_pair[1], uncached_issue_id)
+ ]
+ self.services.issue.issue_tbl.Select(
+ self.cnxn,
+ cols=['project_id', 'local_id', 'id'],
+ or_where_conds=True,
+ where=mox.IgnoreArg()).AndReturn(uncached_rows)
+ # Moved issue is found.
+ self.services.issue.issueformerlocations_tbl.SelectValue(
+ self.cnxn,
+ 'issue_id',
+ default=0,
+ project_id=moved_pair[0],
+ local_id=moved_pair[1]).AndReturn(moved_issue_id)
+
+ self.mox.ReplayAll()
+ found_ids, misses = self.services.issue.LookupIssueIDsFollowMoves(
+ self.cnxn,
+ [moved_pair, missing_pair, cached_pair, uncached_pair, uncached_pair_2])
+ self.mox.VerifyAll()
+
+ expected_found_ids = [
+ moved_issue_id, cached_issue_id, uncached_issue_id, uncached_issue_id_2
+ ]
+ self.assertListEqual(expected_found_ids, found_ids)
+ self.assertListEqual([missing_pair], misses)
+
+ def testLookupIssueIDs_Hit(self):
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ self.services.issue.issue_id_2lc.CacheItem((789, 2), 78902)
+ actual, _misses = self.services.issue.LookupIssueIDs(
+ self.cnxn, [(789, 1), (789, 2)])
+ self.assertEqual([78901, 78902], actual)
+
+ def testLookupIssueID(self):
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ actual = self.services.issue.LookupIssueID(self.cnxn, 789, 1)
+ self.assertEqual(78901, actual)
+
+ def testResolveIssueRefs(self):
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ self.services.issue.issue_id_2lc.CacheItem((789, 2), 78902)
+ prefetched_projects = {'proj': fake.Project('proj', project_id=789)}
+ refs = [('proj', 1), (None, 2)]
+ actual, misses = self.services.issue.ResolveIssueRefs(
+ self.cnxn, prefetched_projects, 'proj', refs)
+ self.assertEqual(misses, [])
+ self.assertEqual([78901, 78902], actual)
+
+ def testLookupIssueRefs_Empty(self):
+ actual = self.services.issue.LookupIssueRefs(self.cnxn, [])
+ self.assertEqual({}, actual)
+
+ def testLookupIssueRefs_Normal(self):
+ issue_1 = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901, project_name='proj')
+ self.services.issue.issue_2lc.CacheItem(78901, issue_1)
+ actual = self.services.issue.LookupIssueRefs(self.cnxn, [78901])
+ self.assertEqual(
+ {78901: ('proj', 1)},
+ actual)
+
+ ### Issue objects
+
+ def CheckCreateIssue(self, is_project_member):
+ settings.classifier_spam_thresh = 0.9
+ av_23 = tracker_pb2.ApprovalValue(
+ approval_id=23, phase_id=1, approver_ids=[111, 222],
+ status=tracker_pb2.ApprovalStatus.NEEDS_REVIEW)
+ av_24 = tracker_pb2.ApprovalValue(
+ approval_id=24, phase_id=1, approver_ids=[111])
+ approval_values = [av_23, av_24]
+ av_rows = [(23, 78901, 1, 'needs_review', None, None),
+ (24, 78901, 1, 'not_set', None, None)]
+ approver_rows = [(23, 111, 78901), (23, 222, 78901), (24, 111, 78901)]
+ ad_23 = tracker_pb2.ApprovalDef(
+ approval_id=23, approver_ids=[111], survey='Question?')
+ ad_24 = tracker_pb2.ApprovalDef(
+ approval_id=24, approver_ids=[111], survey='Question?')
+ config = self.services.config.GetProjectConfig(
+ self.cnxn, 789)
+ config.approval_defs.extend([ad_23, ad_24])
+ self.services.config.StoreConfig(self.cnxn, config)
+
+ self.SetUpAllocateNextLocalID(789, None, None)
+ self.SetUpInsertIssue(av_rows=av_rows, approver_rows=approver_rows)
+ self.SetUpInsertComment(7890101, is_description=True)
+ self.SetUpInsertComment(7890101, is_description=True, approval_id=23,
+ content='<b>Question?</b>')
+ self.SetUpInsertComment(7890101, is_description=True, approval_id=24,
+ content='<b>Question?</b>')
+ self.services.spam.ClassifyIssue(mox.IgnoreArg(),
+ mox.IgnoreArg(), self.reporter, is_project_member).AndReturn(
+ self.classifierResult(0.0))
+ self.services.spam.RecordClassifierIssueVerdict(self.cnxn,
+ mox.IsA(tracker_pb2.Issue), False, 1.0, False)
+ self.SetUpEnqueueIssuesForIndexing([78901])
+
+ self.mox.ReplayAll()
+ issue = fake.MakeTestIssue(
+ 789,
+ 1,
+ 'sum',
+ 'New',
+ 111,
+ reporter_id=111,
+ labels=['Type-Defect'],
+ opened_timestamp=self.now,
+ modified_timestamp=self.now,
+ approval_values=approval_values)
+ created_issue, _ = self.services.issue.CreateIssue(
+ self.cnxn, self.services, issue, 'content')
+ self.mox.VerifyAll()
+ self.assertEqual(1, created_issue.local_id)
+
+ def testCreateIssue_NonmemberSpamCheck(self):
+ """A non-member must pass a non-member spam check."""
+ self.CheckCreateIssue(False)
+
+ def testCreateIssue_DirectMemberSpamCheck(self):
+ """A direct member of a project gets a member spam check."""
+ self.project.committer_ids.append(self.reporter.user_id)
+ self.CheckCreateIssue(True)
+
+ def testCreateIssue_ComputedUsergroupSpamCheck(self):
+ """A member of a computed group in project gets a member spam check."""
+ group_id = self.services.usergroup.CreateGroup(
+ self.cnxn, self.services, 'everyone@example.com', 'ANYONE',
+ ext_group_type='COMPUTED')
+ self.project.committer_ids.append(group_id)
+ self.CheckCreateIssue(True)
+
+ def testCreateIssue_EmptyStringLabels(self):
+ settings.classifier_spam_thresh = 0.9
+ self.SetUpAllocateNextLocalID(789, None, None)
+ self.SetUpInsertIssue(label_rows=[])
+ self.SetUpInsertComment(7890101, is_description=True)
+ self.services.spam.ClassifyIssue(mox.IgnoreArg(),
+ mox.IgnoreArg(), self.reporter, False).AndReturn(
+ self.classifierResult(0.0))
+ self.services.spam.RecordClassifierIssueVerdict(self.cnxn,
+ mox.IsA(tracker_pb2.Issue), False, 1.0, False)
+ self.SetUpEnqueueIssuesForIndexing([78901])
+
+ self.mox.ReplayAll()
+ issue = fake.MakeTestIssue(
+ 789,
+ 1,
+ 'sum',
+ 'New',
+ 111,
+ reporter_id=111,
+ opened_timestamp=self.now,
+ modified_timestamp=self.now)
+ created_issue, _ = self.services.issue.CreateIssue(
+ self.cnxn, self.services, issue, 'content')
+ self.mox.VerifyAll()
+ self.assertEqual(1, created_issue.local_id)
+
+ def SetUpUpdateIssuesModified(self, iids, modified_timestamp=None):
+ self.services.issue.issue_tbl.Update(
+ self.cnxn, {'modified': modified_timestamp or self.now},
+ id=iids, commit=False)
+
+ def testCreateIssue_SpamPredictionFailed(self):
+ settings.classifier_spam_thresh = 0.9
+ self.SetUpAllocateNextLocalID(789, None, None)
+ self.SetUpInsertSpamIssue()
+ self.SetUpInsertComment(7890101, is_description=True)
+
+ self.services.spam.ClassifyIssue(mox.IsA(tracker_pb2.Issue),
+ mox.IsA(tracker_pb2.IssueComment), self.reporter, False).AndReturn(
+ self.classifierResult(1.0, True))
+ self.services.spam.RecordClassifierIssueVerdict(self.cnxn,
+ mox.IsA(tracker_pb2.Issue), True, 1.0, True)
+ self.SetUpUpdateIssuesApprovals([])
+ self.SetUpEnqueueIssuesForIndexing([78901])
+
+ self.mox.ReplayAll()
+ issue = fake.MakeTestIssue(
+ 789,
+ 1,
+ 'sum',
+ 'New',
+ 111,
+ reporter_id=111,
+ labels=['Type-Defect'],
+ opened_timestamp=self.now,
+ modified_timestamp=self.now)
+ created_issue, _ = self.services.issue.CreateIssue(
+ self.cnxn, self.services, issue, 'content')
+ self.mox.VerifyAll()
+ self.assertEqual(1, created_issue.local_id)
+
+ def testCreateIssue_Spam(self):
+ settings.classifier_spam_thresh = 0.9
+ self.SetUpAllocateNextLocalID(789, None, None)
+ self.SetUpInsertSpamIssue()
+ self.SetUpInsertComment(7890101, is_description=True)
+
+ self.services.spam.ClassifyIssue(mox.IsA(tracker_pb2.Issue),
+ mox.IsA(tracker_pb2.IssueComment), self.reporter, False).AndReturn(
+ self.classifierResult(1.0))
+ self.services.spam.RecordClassifierIssueVerdict(self.cnxn,
+ mox.IsA(tracker_pb2.Issue), True, 1.0, False)
+ self.SetUpUpdateIssuesApprovals([])
+ self.SetUpEnqueueIssuesForIndexing([78901])
+
+ self.mox.ReplayAll()
+ issue = fake.MakeTestIssue(
+ 789,
+ 1,
+ 'sum',
+ 'New',
+ 111,
+ reporter_id=111,
+ labels=['Type-Defect'],
+ opened_timestamp=self.now,
+ modified_timestamp=self.now)
+ created_issue, _ = self.services.issue.CreateIssue(
+ self.cnxn, self.services, issue, 'content')
+ self.mox.VerifyAll()
+ self.assertEqual(1, created_issue.local_id)
+
+ def testCreateIssue_FederatedReferences(self):
+ self.SetUpAllocateNextLocalID(789, None, None)
+ self.SetUpInsertIssue(dangling_relation_rows=[
+ (78901, None, None, 'b/1234', 'blockedon'),
+ (78901, None, None, 'b/5678', 'blockedon'),
+ (78901, None, None, 'b/9876', 'blocking'),
+ (78901, None, None, 'b/5432', 'blocking')])
+ self.SetUpInsertComment(7890101, is_description=True)
+ self.services.spam.ClassifyIssue(mox.IsA(tracker_pb2.Issue),
+ mox.IsA(tracker_pb2.IssueComment), self.reporter, False).AndReturn(
+ self.classifierResult(0.0))
+ self.services.spam.RecordClassifierIssueVerdict(self.cnxn,
+ mox.IsA(tracker_pb2.Issue), mox.IgnoreArg(), mox.IgnoreArg(),
+ mox.IgnoreArg())
+ self.SetUpEnqueueIssuesForIndexing([78901])
+
+ self.mox.ReplayAll()
+ issue = fake.MakeTestIssue(
+ 789,
+ 1,
+ 'sum',
+ 'New',
+ 111,
+ reporter_id=111,
+ labels=['Type-Defect'],
+ opened_timestamp=self.now,
+ modified_timestamp=self.now)
+ issue.dangling_blocked_on_refs = [
+ tracker_pb2.DanglingIssueRef(ext_issue_identifier=shortlink)
+ for shortlink in ['b/1234', 'b/5678']
+ ]
+ issue.dangling_blocking_refs = [
+ tracker_pb2.DanglingIssueRef(ext_issue_identifier=shortlink)
+ for shortlink in ['b/9876', 'b/5432']
+ ]
+ self.services.issue.CreateIssue(self.cnxn, self.services, issue, 'content')
+ self.mox.VerifyAll()
+
+ def testCreateIssue_Imported(self):
+ settings.classifier_spam_thresh = 0.9
+ self.SetUpAllocateNextLocalID(789, None, None)
+ self.SetUpInsertIssue(label_rows=[])
+ self.SetUpInsertComment(7890101, is_description=True)
+ self.services.issue.commentimporter_tbl.InsertRow(
+ self.cnxn, comment_id=7890101, importer_id=222)
+ self.services.spam.ClassifyIssue(mox.IgnoreArg(),
+ mox.IgnoreArg(), self.reporter, False).AndReturn(
+ self.classifierResult(0.0))
+ self.services.spam.RecordClassifierIssueVerdict(self.cnxn,
+ mox.IsA(tracker_pb2.Issue), False, 1.0, False)
+ self.SetUpEnqueueIssuesForIndexing([78901])
+ self.mox.ReplayAll()
+
+ issue = fake.MakeTestIssue(
+ 789,
+ 1,
+ 'sum',
+ 'New',
+ 111,
+ reporter_id=111,
+ opened_timestamp=self.now,
+ modified_timestamp=self.now)
+ created_issue, comment = self.services.issue.CreateIssue(
+ self.cnxn, self.services, issue, 'content', importer_id=222)
+
+ self.mox.VerifyAll()
+ self.assertEqual(1, created_issue.local_id)
+ self.assertEqual(111, comment.user_id)
+ self.assertEqual(222, comment.importer_id)
+ self.assertEqual(self.now, comment.timestamp)
+
+ def testGetAllIssuesInProject_NoIssues(self):
+ self.SetUpGetHighestLocalID(789, None, None)
+ self.mox.ReplayAll()
+ issues = self.services.issue.GetAllIssuesInProject(self.cnxn, 789)
+ self.mox.VerifyAll()
+ self.assertEqual([], issues)
+
+ def testGetAnyOnHandIssue(self):
+ issue_ids = [78901, 78902, 78903]
+ self.SetUpGetIssues()
+ issue = self.services.issue.GetAnyOnHandIssue(issue_ids)
+ self.assertEqual(78901, issue.issue_id)
+
+ def SetUpGetIssues(self):
+ issue_1 = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901)
+ issue_1.project_name = 'proj'
+ issue_2 = fake.MakeTestIssue(
+ project_id=789, local_id=2, owner_id=111, summary='sum',
+ status='Fixed', issue_id=78902)
+ issue_2.project_name = 'proj'
+ self.services.issue.issue_2lc.CacheItem(78901, issue_1)
+ self.services.issue.issue_2lc.CacheItem(78902, issue_2)
+ return issue_1, issue_2
+
+ def testGetIssuesDict(self):
+ issue_ids = [78901, 78902, 78903]
+ issue_1, issue_2 = self.SetUpGetIssues()
+ self.services.issue.issue_2lc = TestableIssueTwoLevelCache(
+ [issue_1, issue_2])
+ issues_dict, missed_iids = self.services.issue.GetIssuesDict(
+ self.cnxn, issue_ids)
+ self.assertEqual(
+ {78901: issue_1, 78902: issue_2},
+ issues_dict)
+ self.assertEqual([78903], missed_iids)
+
+ def testGetIssues(self):
+ issue_ids = [78901, 78902]
+ issue_1, issue_2 = self.SetUpGetIssues()
+ issues = self.services.issue.GetIssues(self.cnxn, issue_ids)
+ self.assertEqual([issue_1, issue_2], issues)
+
+ def testGetIssue(self):
+ issue_1, _issue_2 = self.SetUpGetIssues()
+ actual_issue = self.services.issue.GetIssue(self.cnxn, 78901)
+ self.assertEqual(issue_1, actual_issue)
+
+ def testGetIssuesByLocalIDs(self):
+ issue_1, issue_2 = self.SetUpGetIssues()
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ self.services.issue.issue_id_2lc.CacheItem((789, 2), 78902)
+ actual_issues = self.services.issue.GetIssuesByLocalIDs(
+ self.cnxn, 789, [1, 2])
+ self.assertEqual([issue_1, issue_2], actual_issues)
+
+ def testGetIssueByLocalID(self):
+ issue_1, _issue_2 = self.SetUpGetIssues()
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ actual_issues = self.services.issue.GetIssueByLocalID(self.cnxn, 789, 1)
+ self.assertEqual(issue_1, actual_issues)
+
+ def testGetOpenAndClosedIssues(self):
+ issue_1, issue_2 = self.SetUpGetIssues()
+ open_issues, closed_issues = self.services.issue.GetOpenAndClosedIssues(
+ self.cnxn, [78901, 78902])
+ self.assertEqual([issue_1], open_issues)
+ self.assertEqual([issue_2], closed_issues)
+
+ def SetUpGetCurrentLocationOfMovedIssue(self, project_id, local_id):
+ issue_id = project_id * 100 + local_id
+ self.services.issue.issueformerlocations_tbl.SelectValue(
+ self.cnxn, 'issue_id', default=0, project_id=project_id,
+ local_id=local_id).AndReturn(issue_id)
+ self.services.issue.issue_tbl.SelectRow(
+ self.cnxn, cols=['project_id', 'local_id'], id=issue_id).AndReturn(
+ (project_id + 1, local_id + 1))
+
+ def testGetCurrentLocationOfMovedIssue(self):
+ self.SetUpGetCurrentLocationOfMovedIssue(789, 1)
+ self.mox.ReplayAll()
+ new_project_id, new_local_id = (
+ self.services.issue.GetCurrentLocationOfMovedIssue(self.cnxn, 789, 1))
+ self.mox.VerifyAll()
+ self.assertEqual(789 + 1, new_project_id)
+ self.assertEqual(1 + 1, new_local_id)
+
+ def SetUpGetPreviousLocations(self, issue_id, location_rows):
+ self.services.issue.issueformerlocations_tbl.Select(
+ self.cnxn, cols=['project_id', 'local_id'],
+ issue_id=issue_id).AndReturn(location_rows)
+
+ def testGetPreviousLocations(self):
+ self.SetUpGetPreviousLocations(78901, [(781, 1), (782, 11), (789, 1)])
+ self.mox.ReplayAll()
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901)
+ locations = self.services.issue.GetPreviousLocations(self.cnxn, issue)
+ self.mox.VerifyAll()
+ self.assertEqual(locations, [(781, 1), (782, 11)])
+
+ def SetUpInsertIssue(
+ self, label_rows=None, av_rows=None, approver_rows=None,
+ dangling_relation_rows=None):
+ row = (789, 1, 1, 111, 111,
+ self.now, 0, self.now, self.now, self.now, self.now,
+ None, 0,
+ False, 0, 0, False)
+ self.services.issue.issue_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUE_COLS[1:], [row],
+ commit=False, return_generated_ids=True).AndReturn([78901])
+ self.cnxn.Commit()
+ self.services.issue.issue_tbl.Update(
+ self.cnxn, {'shard': 78901 % settings.num_logical_shards},
+ id=78901, commit=False)
+ self.SetUpUpdateIssuesSummary()
+ self.SetUpUpdateIssuesLabels(label_rows=label_rows)
+ self.SetUpUpdateIssuesFields()
+ self.SetUpUpdateIssuesComponents()
+ self.SetUpUpdateIssuesCc()
+ self.SetUpUpdateIssuesNotify()
+ self.SetUpUpdateIssuesRelation(
+ dangling_relation_rows=dangling_relation_rows)
+ self.SetUpUpdateIssuesApprovals(
+ av_rows=av_rows, approver_rows=approver_rows)
+ self.services.chart.StoreIssueSnapshots(self.cnxn, mox.IgnoreArg(),
+ commit=False)
+
+ def SetUpInsertSpamIssue(self):
+ row = (789, 1, 1, 111, 111,
+ self.now, 0, self.now, self.now, self.now, self.now,
+ None, 0, False, 0, 0, True)
+ self.services.issue.issue_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUE_COLS[1:], [row],
+ commit=False, return_generated_ids=True).AndReturn([78901])
+ self.cnxn.Commit()
+ self.services.issue.issue_tbl.Update(
+ self.cnxn, {'shard': 78901 % settings.num_logical_shards},
+ id=78901, commit=False)
+ self.SetUpUpdateIssuesSummary()
+ self.SetUpUpdateIssuesLabels()
+ self.SetUpUpdateIssuesFields()
+ self.SetUpUpdateIssuesComponents()
+ self.SetUpUpdateIssuesCc()
+ self.SetUpUpdateIssuesNotify()
+ self.SetUpUpdateIssuesRelation()
+ self.services.chart.StoreIssueSnapshots(self.cnxn, mox.IgnoreArg(),
+ commit=False)
+
+ def SetUpUpdateIssuesSummary(self):
+ self.services.issue.issuesummary_tbl.InsertRows(
+ self.cnxn, ['issue_id', 'summary'],
+ [(78901, 'sum')], replace=True, commit=False)
+
+ def SetUpUpdateIssuesLabels(self, label_rows=None):
+ if label_rows is None:
+ label_rows = [(78901, 1, False, 1)]
+ self.services.issue.issue2label_tbl.Delete(
+ self.cnxn, issue_id=[78901], commit=False)
+ self.services.issue.issue2label_tbl.InsertRows(
+ self.cnxn, ['issue_id', 'label_id', 'derived', 'issue_shard'],
+ label_rows, ignore=True, commit=False)
+
+ def SetUpUpdateIssuesFields(self, issue2fieldvalue_rows=None):
+ issue2fieldvalue_rows = issue2fieldvalue_rows or []
+ self.services.issue.issue2fieldvalue_tbl.Delete(
+ self.cnxn, issue_id=[78901], commit=False)
+ self.services.issue.issue2fieldvalue_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUE2FIELDVALUE_COLS + ['issue_shard'],
+ issue2fieldvalue_rows, commit=False)
+
+ def SetUpUpdateIssuesComponents(self, issue2component_rows=None):
+ issue2component_rows = issue2component_rows or []
+ self.services.issue.issue2component_tbl.Delete(
+ self.cnxn, issue_id=[78901], commit=False)
+ self.services.issue.issue2component_tbl.InsertRows(
+ self.cnxn, ['issue_id', 'component_id', 'derived', 'issue_shard'],
+ issue2component_rows, ignore=True, commit=False)
+
+ def SetUpUpdateIssuesCc(self, issue2cc_rows=None):
+ issue2cc_rows = issue2cc_rows or []
+ self.services.issue.issue2cc_tbl.Delete(
+ self.cnxn, issue_id=[78901], commit=False)
+ self.services.issue.issue2cc_tbl.InsertRows(
+ self.cnxn, ['issue_id', 'cc_id', 'derived', 'issue_shard'],
+ issue2cc_rows, ignore=True, commit=False)
+
+ def SetUpUpdateIssuesNotify(self, notify_rows=None):
+ notify_rows = notify_rows or []
+ self.services.issue.issue2notify_tbl.Delete(
+ self.cnxn, issue_id=[78901], commit=False)
+ self.services.issue.issue2notify_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUE2NOTIFY_COLS,
+ notify_rows, ignore=True, commit=False)
+
+ def SetUpUpdateIssuesRelation(
+ self, relation_rows=None, dangling_relation_rows=None):
+ relation_rows = relation_rows or []
+ dangling_relation_rows = dangling_relation_rows or []
+ self.services.issue.issuerelation_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUERELATION_COLS[:-1],
+ dst_issue_id=[78901], kind='blockedon').AndReturn([])
+ self.services.issue.issuerelation_tbl.Delete(
+ self.cnxn, issue_id=[78901], commit=False)
+ self.services.issue.issuerelation_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUERELATION_COLS, relation_rows,
+ ignore=True, commit=False)
+ self.services.issue.danglingrelation_tbl.Delete(
+ self.cnxn, issue_id=[78901], commit=False)
+ self.services.issue.danglingrelation_tbl.InsertRows(
+ self.cnxn, issue_svc.DANGLINGRELATION_COLS, dangling_relation_rows,
+ ignore=True, commit=False)
+
+ def SetUpUpdateIssuesApprovals(self, av_rows=None, approver_rows=None):
+ av_rows = av_rows or []
+ approver_rows = approver_rows or []
+ self.services.issue.issue2approvalvalue_tbl.Delete(
+ self.cnxn, issue_id=78901, commit=False)
+ self.services.issue.issue2approvalvalue_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUE2APPROVALVALUE_COLS, av_rows, commit=False)
+ self.services.issue.issueapproval2approver_tbl.Delete(
+ self.cnxn, issue_id=78901, commit=False)
+ self.services.issue.issueapproval2approver_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUEAPPROVAL2APPROVER_COLS, approver_rows,
+ commit=False)
+
+ def testInsertIssue(self):
+ self.SetUpInsertIssue()
+ self.mox.ReplayAll()
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, reporter_id=111,
+ summary='sum', status='New', labels=['Type-Defect'], issue_id=78901,
+ opened_timestamp=self.now, modified_timestamp=self.now)
+ actual_issue_id = self.services.issue.InsertIssue(self.cnxn, issue)
+ self.mox.VerifyAll()
+ self.assertEqual(78901, actual_issue_id)
+
+ def SetUpUpdateIssues(self, given_delta=None):
+ delta = given_delta or {
+ 'project_id': 789,
+ 'local_id': 1,
+ 'owner_id': 111,
+ 'status_id': 1,
+ 'opened': 123456789,
+ 'closed': 0,
+ 'modified': 123456789,
+ 'owner_modified': 123456789,
+ 'status_modified': 123456789,
+ 'component_modified': 123456789,
+ 'derived_owner_id': None,
+ 'derived_status_id': None,
+ 'deleted': False,
+ 'star_count': 12,
+ 'attachment_count': 0,
+ 'is_spam': False,
+ }
+ self.services.issue.issue_tbl.Update(
+ self.cnxn, delta, id=78901, commit=False)
+ if not given_delta:
+ self.SetUpUpdateIssuesLabels()
+ self.SetUpUpdateIssuesCc()
+ self.SetUpUpdateIssuesFields()
+ self.SetUpUpdateIssuesComponents()
+ self.SetUpUpdateIssuesNotify()
+ self.SetUpUpdateIssuesSummary()
+ self.SetUpUpdateIssuesRelation()
+ self.services.chart.StoreIssueSnapshots(self.cnxn, mox.IgnoreArg(),
+ commit=False)
+
+ if given_delta:
+ self.services.chart.StoreIssueSnapshots(self.cnxn, mox.IgnoreArg(),
+ commit=False)
+
+ self.cnxn.Commit()
+
+ def testUpdateIssues_Empty(self):
+ # Note: no setup because DB should not be called.
+ self.mox.ReplayAll()
+ self.services.issue.UpdateIssues(self.cnxn, [])
+ self.mox.VerifyAll()
+
+ def testUpdateIssues_Normal(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', labels=['Type-Defect'], issue_id=78901,
+ opened_timestamp=123456789, modified_timestamp=123456789,
+ star_count=12)
+ issue.assume_stale = False
+ self.SetUpUpdateIssues()
+ self.mox.ReplayAll()
+ self.services.issue.UpdateIssues(self.cnxn, [issue])
+ self.mox.VerifyAll()
+
+ def testUpdateIssue_Normal(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', labels=['Type-Defect'], issue_id=78901,
+ opened_timestamp=123456789, modified_timestamp=123456789,
+ star_count=12)
+ issue.assume_stale = False
+ self.SetUpUpdateIssues()
+ self.mox.ReplayAll()
+ self.services.issue.UpdateIssue(self.cnxn, issue)
+ self.mox.VerifyAll()
+
+ def testUpdateIssue_Stale(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', labels=['Type-Defect'], issue_id=78901,
+ opened_timestamp=123456789, modified_timestamp=123456789,
+ star_count=12)
+ # Do not set issue.assume_stale = False
+ # Do not call self.SetUpUpdateIssues() because nothing should be updated.
+ self.mox.ReplayAll()
+ self.assertRaises(
+ AssertionError, self.services.issue.UpdateIssue, self.cnxn, issue)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesSummary(self):
+ issue = fake.MakeTestIssue(
+ local_id=1, issue_id=78901, owner_id=111, summary='sum', status='New',
+ project_id=789)
+ issue.assume_stale = False
+ self.SetUpUpdateIssuesSummary()
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssuesSummary(self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesLabels(self):
+ issue = fake.MakeTestIssue(
+ local_id=1, issue_id=78901, owner_id=111, summary='sum', status='New',
+ labels=['Type-Defect'], project_id=789)
+ self.SetUpUpdateIssuesLabels()
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssuesLabels(
+ self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesFields_Empty(self):
+ issue = fake.MakeTestIssue(
+ local_id=1, issue_id=78901, owner_id=111, summary='sum', status='New',
+ project_id=789)
+ self.SetUpUpdateIssuesFields()
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssuesFields(self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesFields_Some(self):
+ issue = fake.MakeTestIssue(
+ local_id=1, issue_id=78901, owner_id=111, summary='sum', status='New',
+ project_id=789)
+ issue_shard = issue.issue_id % settings.num_logical_shards
+ fv1 = tracker_bizobj.MakeFieldValue(345, 679, '', 0, None, None, False)
+ issue.field_values.append(fv1)
+ fv2 = tracker_bizobj.MakeFieldValue(346, 0, 'Blue', 0, None, None, True)
+ issue.field_values.append(fv2)
+ fv3 = tracker_bizobj.MakeFieldValue(347, 0, '', 0, 1234567890, None, True)
+ issue.field_values.append(fv3)
+ fv4 = tracker_bizobj.MakeFieldValue(
+ 348, 0, '', 0, None, 'www.google.com', True, phase_id=14)
+ issue.field_values.append(fv4)
+ self.SetUpUpdateIssuesFields(issue2fieldvalue_rows=[
+ (issue.issue_id, fv1.field_id, fv1.int_value, fv1.str_value,
+ None, fv1.date_value, fv1.url_value, fv1.derived, None,
+ issue_shard),
+ (issue.issue_id, fv2.field_id, fv2.int_value, fv2.str_value,
+ None, fv2.date_value, fv2.url_value, fv2.derived, None,
+ issue_shard),
+ (issue.issue_id, fv3.field_id, fv3.int_value, fv3.str_value,
+ None, fv3.date_value, fv3.url_value, fv3.derived, None,
+ issue_shard),
+ (issue.issue_id, fv4.field_id, fv4.int_value, fv4.str_value,
+ None, fv4.date_value, fv4.url_value, fv4.derived, 14,
+ issue_shard),
+ ])
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssuesFields(self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesComponents_Empty(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901)
+ self.SetUpUpdateIssuesComponents()
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssuesComponents(
+ self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesCc_Empty(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901)
+ self.SetUpUpdateIssuesCc()
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssuesCc(self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesCc_Some(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901)
+ issue.cc_ids = [222, 333]
+ issue.derived_cc_ids = [888]
+ issue_shard = issue.issue_id % settings.num_logical_shards
+ self.SetUpUpdateIssuesCc(issue2cc_rows=[
+ (issue.issue_id, 222, False, issue_shard),
+ (issue.issue_id, 333, False, issue_shard),
+ (issue.issue_id, 888, True, issue_shard),
+ ])
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssuesCc(self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesNotify_Empty(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901)
+ self.SetUpUpdateIssuesNotify()
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssuesNotify(self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesRelation_Empty(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901)
+ self.SetUpUpdateIssuesRelation()
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssuesRelation(self.cnxn, [issue], commit=False)
+ self.mox.VerifyAll()
+
+ def testUpdateIssuesRelation_MergedIntoExternal(self):
+ self.services.issue.issuerelation_tbl.Select = Mock(return_value=[])
+ self.services.issue.issuerelation_tbl.Delete = Mock()
+ self.services.issue.issuerelation_tbl.InsertRows = Mock()
+ self.services.issue.danglingrelation_tbl.Delete = Mock()
+ self.services.issue.danglingrelation_tbl.InsertRows = Mock()
+
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901, merged_into_external='b/5678')
+
+ self.services.issue._UpdateIssuesRelation(self.cnxn, [issue])
+
+ self.services.issue.danglingrelation_tbl.Delete.assert_called_once_with(
+ self.cnxn, commit=False, issue_id=[78901])
+ self.services.issue.danglingrelation_tbl.InsertRows\
+ .assert_called_once_with(
+ self.cnxn, ['issue_id', 'dst_issue_project', 'dst_issue_local_id',
+ 'ext_issue_identifier', 'kind'],
+ [(78901, None, None, 'b/5678', 'mergedinto')],
+ ignore=True, commit=True)
+
+ @patch('time.time')
+ def testUpdateIssueStructure(self, mockTime):
+ mockTime.return_value = self.now
+ reporter_id = 111
+ comment_content = 'This issue is being converted'
+ # Set up config
+ config = self.services.config.GetProjectConfig(
+ self.cnxn, 789)
+ config.approval_defs = [
+ tracker_pb2.ApprovalDef(
+ approval_id=3, survey='Question3', approver_ids=[222]),
+ tracker_pb2.ApprovalDef(
+ approval_id=4, survey='Question4', approver_ids=[444]),
+ tracker_pb2.ApprovalDef(
+ approval_id=7, survey='Question7', approver_ids=[222]),
+ ]
+ config.field_defs = [
+ tracker_pb2.FieldDef(
+ field_id=3, project_id=789, field_name='Cow'),
+ tracker_pb2.FieldDef(
+ field_id=4, project_id=789, field_name='Chicken'),
+ tracker_pb2.FieldDef(
+ field_id=6, project_id=789, field_name='Llama'),
+ tracker_pb2.FieldDef(
+ field_id=7, project_id=789, field_name='Roo'),
+ tracker_pb2.FieldDef(
+ field_id=8, project_id=789, field_name='Salmon'),
+ tracker_pb2.FieldDef(
+ field_id=9, project_id=789, field_name='Tuna', is_phase_field=True),
+ tracker_pb2.FieldDef(
+ field_id=10, project_id=789, field_name='Clown', is_phase_field=True),
+ tracker_pb2.FieldDef(
+ field_id=11, project_id=789, field_name='Dory', is_phase_field=True),
+ ]
+
+ # Set up issue
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum', status='Open',
+ issue_id=78901, project_name='proj')
+ issue.approval_values = [
+ tracker_pb2.ApprovalValue(
+ approval_id=3,
+ phase_id=4,
+ status=tracker_pb2.ApprovalStatus.APPROVED,
+ approver_ids=[111], # trumps approval_def approver_ids
+ ),
+ tracker_pb2.ApprovalValue(
+ approval_id=4,
+ phase_id=5,
+ approver_ids=[111]), # trumps approval_def approver_ids
+ tracker_pb2.ApprovalValue(approval_id=6)]
+ issue.phases = [
+ tracker_pb2.Phase(name='Expired', phase_id=4),
+ tracker_pb2.Phase(name='canarY', phase_id=3),
+ tracker_pb2.Phase(name='Stable', phase_id=2)]
+ issue.field_values = [
+ tracker_bizobj.MakeFieldValue(8, None, 'Pink', None, None, None, False),
+ tracker_bizobj.MakeFieldValue(
+ 9, None, 'Silver', None, None, None, False, phase_id=3),
+ tracker_bizobj.MakeFieldValue(
+ 10, None, 'Orange', None, None, None, False, phase_id=4),
+ tracker_bizobj.MakeFieldValue(
+ 11, None, 'Flat', None, None, None, False, phase_id=2),
+ ]
+
+ # Set up template
+ template = testing_helpers.DefaultTemplates()[0]
+ template.approval_values = [
+ tracker_pb2.ApprovalValue(
+ approval_id=3,
+ phase_id=6), # Different phase. Nothing else affected.
+ # No phase. Nothing else affected.
+ tracker_pb2.ApprovalValue(approval_id=4),
+ # New approval not already found in issue.
+ tracker_pb2.ApprovalValue(
+ approval_id=7,
+ phase_id=5),
+ ] # No approval 6
+ # TODO(jojwang): monorail:4693, rename 'Stable-Full' after all
+ # 'stable-full' gates have been renamed to 'stable'.
+ template.phases = [tracker_pb2.Phase(name='Canary', phase_id=5),
+ tracker_pb2.Phase(name='Stable-Full', phase_id=6)]
+
+ self.SetUpInsertComment(
+ 7890101, is_description=True, approval_id=3,
+ content=config.approval_defs[0].survey, commit=False)
+ self.SetUpInsertComment(
+ 7890101, is_description=True, approval_id=4,
+ content=config.approval_defs[1].survey, commit=False)
+ self.SetUpInsertComment(
+ 7890101, is_description=True, approval_id=7,
+ content=config.approval_defs[2].survey, commit=False)
+ amendment_row = (
+ 78901, 7890101, 'custom', None, '-Llama Roo', None, None, 'Approvals')
+ self.SetUpInsertComment(
+ 7890101, content=comment_content, amendment_rows=[amendment_row],
+ commit=False)
+ av_rows = [
+ (3, 78901, 6, 'approved', None, None),
+ (4, 78901, None, 'not_set', None, None),
+ (7, 78901, 5, 'not_set', None, None),
+ ]
+ approver_rows = [(3, 111, 78901), (4, 111, 78901), (7, 222, 78901)]
+ self.SetUpUpdateIssuesApprovals(
+ av_rows=av_rows, approver_rows=approver_rows)
+ issue_shard = issue.issue_id % settings.num_logical_shards
+ issue2fieldvalue_rows = [
+ (78901, 8, None, 'Pink', None, None, None, False, None, issue_shard),
+ (78901, 9, None, 'Silver', None, None, None, False, 5, issue_shard),
+ (78901, 11, None, 'Flat', None, None, None, False, 6, issue_shard),
+ ]
+ self.SetUpUpdateIssuesFields(issue2fieldvalue_rows=issue2fieldvalue_rows)
+
+ self.mox.ReplayAll()
+ comment = self.services.issue.UpdateIssueStructure(
+ self.cnxn, config, issue, template, reporter_id,
+ comment_content=comment_content, commit=False, invalidate=False)
+ self.mox.VerifyAll()
+
+ expected_avs = [
+ tracker_pb2.ApprovalValue(
+ approval_id=3,
+ phase_id=6,
+ status=tracker_pb2.ApprovalStatus.APPROVED,
+ approver_ids=[111],
+ ),
+ tracker_pb2.ApprovalValue(
+ approval_id=4,
+ status=tracker_pb2.ApprovalStatus.NOT_SET,
+ approver_ids=[111]),
+ tracker_pb2.ApprovalValue(
+ approval_id=7,
+ status=tracker_pb2.ApprovalStatus.NOT_SET,
+ phase_id=5,
+ approver_ids=[222]),
+ ]
+ self.assertEqual(issue.approval_values, expected_avs)
+ self.assertEqual(issue.phases, template.phases)
+ amendment = tracker_bizobj.MakeApprovalStructureAmendment(
+ ['Roo', 'Cow', 'Chicken'], ['Cow', 'Chicken', 'Llama'])
+ expected_comment = self.services.issue._MakeIssueComment(
+ 789, reporter_id, content=comment_content, amendments=[amendment])
+ expected_comment.issue_id = 78901
+ expected_comment.id = 7890101
+ self.assertEqual(expected_comment, comment)
+
+ def testDeltaUpdateIssue(self):
+ pass # TODO(jrobbins): write more tests
+
+ def testDeltaUpdateIssue_NoOp(self):
+ """If the user didn't provide any content, we don't make an IssueComment."""
+ commenter_id = 222
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901, project_name='proj')
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ delta = tracker_pb2.IssueDelta()
+
+ amendments, comment_pb = self.services.issue.DeltaUpdateIssue(
+ self.cnxn, self.services, commenter_id, issue.project_id, config,
+ issue, delta, comment='', index_now=False, timestamp=self.now)
+ self.assertEqual([], amendments)
+ self.assertIsNone(comment_pb)
+
+ def testDeltaUpdateIssue_MergedInto(self):
+ commenter_id = 222
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901, project_name='proj')
+ target_issue = fake.MakeTestIssue(
+ project_id=789, local_id=2, owner_id=111, summary='sum sum',
+ status='Live', issue_id=78902, project_name='proj')
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ self.mox.StubOutWithMock(self.services.issue, 'GetIssue')
+ self.mox.StubOutWithMock(self.services.issue, 'UpdateIssue')
+ self.mox.StubOutWithMock(self.services.issue, 'CreateIssueComment')
+ self.mox.StubOutWithMock(self.services.issue, '_UpdateIssuesModified')
+
+ self.services.issue.GetIssue(
+ self.cnxn, 0).AndRaise(exceptions.NoSuchIssueException)
+ self.services.issue.GetIssue(
+ self.cnxn, target_issue.issue_id).AndReturn(target_issue)
+ self.services.issue.UpdateIssue(
+ self.cnxn, issue, commit=False, invalidate=False)
+ amendments = [
+ tracker_bizobj.MakeMergedIntoAmendment(
+ [('proj', 2)], [None], default_project_name='proj')]
+ self.services.issue.CreateIssueComment(
+ self.cnxn, issue, commenter_id, 'comment text', attachments=None,
+ amendments=amendments, commit=False, is_description=False,
+ kept_attachments=None, importer_id=None, timestamp=ANY,
+ inbound_message=None)
+ self.services.issue._UpdateIssuesModified(
+ self.cnxn, {issue.issue_id, target_issue.issue_id},
+ modified_timestamp=self.now, invalidate=True)
+ self.SetUpEnqueueIssuesForIndexing([78901])
+
+ self.mox.ReplayAll()
+ delta = tracker_pb2.IssueDelta(merged_into=target_issue.issue_id)
+ self.services.issue.DeltaUpdateIssue(
+ self.cnxn, self.services, commenter_id, issue.project_id, config,
+ issue, delta, comment='comment text',
+ index_now=False, timestamp=self.now)
+ self.mox.VerifyAll()
+
+ def testDeltaUpdateIssue_BlockedOn(self):
+ commenter_id = 222
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901, project_name='proj')
+ blockedon_issue = fake.MakeTestIssue(
+ project_id=789, local_id=2, owner_id=111, summary='sum sum',
+ status='Live', issue_id=78902, project_name='proj')
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ self.mox.StubOutWithMock(self.services.issue, 'GetIssue')
+ self.mox.StubOutWithMock(self.services.issue, 'GetIssues')
+ self.mox.StubOutWithMock(self.services.issue, 'LookupIssueRefs')
+ self.mox.StubOutWithMock(self.services.issue, 'UpdateIssue')
+ self.mox.StubOutWithMock(self.services.issue, 'CreateIssueComment')
+ self.mox.StubOutWithMock(self.services.issue, '_UpdateIssuesModified')
+ self.mox.StubOutWithMock(self.services.issue, "SortBlockedOn")
+
+ # Calls in ApplyIssueDelta
+ # Call to find added blockedon issues.
+ issue_refs = {blockedon_issue.issue_id: (
+ blockedon_issue.project_name, blockedon_issue.local_id)}
+ self.services.issue.LookupIssueRefs(
+ self.cnxn, [blockedon_issue.issue_id]).AndReturn(issue_refs)
+
+ # Call to find removed blockedon issues.
+ self.services.issue.LookupIssueRefs(self.cnxn, []).AndReturn({})
+ # Call to sort blockedon issues.
+ self.services.issue.SortBlockedOn(
+ self.cnxn, issue, [blockedon_issue.issue_id]).AndReturn(([78902], [0]))
+
+ self.services.issue.UpdateIssue(
+ self.cnxn, issue, commit=False, invalidate=False)
+ amendments = [
+ tracker_bizobj.MakeBlockedOnAmendment(
+ [('proj', 2)], [], default_project_name='proj')]
+ self.services.issue.CreateIssueComment(
+ self.cnxn, issue, commenter_id, 'comment text', attachments=None,
+ amendments=amendments, commit=False, is_description=False,
+ kept_attachments=None, importer_id=None, timestamp=ANY,
+ inbound_message=None)
+ # Call to find added blockedon issues.
+ self.services.issue.GetIssues(
+ self.cnxn, [blockedon_issue.issue_id]).AndReturn([blockedon_issue])
+ self.services.issue.CreateIssueComment(
+ self.cnxn, blockedon_issue, commenter_id, content='',
+ amendments=[tracker_bizobj.MakeBlockingAmendment(
+ [(issue.project_name, issue.local_id)], [],
+ default_project_name='proj')],
+ importer_id=None, timestamp=ANY)
+ # Call to find removed blockedon issues.
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+ # Call to find added blocking issues.
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+ # Call to find removed blocking issues.
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+
+ self.services.issue._UpdateIssuesModified(
+ self.cnxn, {issue.issue_id, blockedon_issue.issue_id},
+ modified_timestamp=self.now, invalidate=True)
+ self.SetUpEnqueueIssuesForIndexing([78901])
+
+ self.mox.ReplayAll()
+ delta = tracker_pb2.IssueDelta(blocked_on_add=[blockedon_issue.issue_id])
+ self.services.issue.DeltaUpdateIssue(
+ self.cnxn, self.services, commenter_id, issue.project_id, config,
+ issue, delta, comment='comment text',
+ index_now=False, timestamp=self.now)
+ self.mox.VerifyAll()
+
+ def testDeltaUpdateIssue_Blocking(self):
+ commenter_id = 222
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901, project_name='proj')
+ blocking_issue = fake.MakeTestIssue(
+ project_id=789, local_id=2, owner_id=111, summary='sum sum',
+ status='Live', issue_id=78902, project_name='proj')
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+
+ self.mox.StubOutWithMock(self.services.issue, 'GetIssue')
+ self.mox.StubOutWithMock(self.services.issue, 'GetIssues')
+ self.mox.StubOutWithMock(self.services.issue, 'LookupIssueRefs')
+ self.mox.StubOutWithMock(self.services.issue, 'UpdateIssue')
+ self.mox.StubOutWithMock(self.services.issue, 'CreateIssueComment')
+ self.mox.StubOutWithMock(self.services.issue, '_UpdateIssuesModified')
+ self.mox.StubOutWithMock(self.services.issue, "SortBlockedOn")
+
+ # Calls in ApplyIssueDelta
+ # Call to find added blocking issues.
+ issue_refs = {blocking_issue: (
+ blocking_issue.project_name, blocking_issue.local_id)}
+ self.services.issue.LookupIssueRefs(
+ self.cnxn, [blocking_issue.issue_id]).AndReturn(issue_refs)
+ # Call to find removed blocking issues.
+ self.services.issue.LookupIssueRefs(self.cnxn, []).AndReturn({})
+
+ self.services.issue.UpdateIssue(
+ self.cnxn, issue, commit=False, invalidate=False)
+ amendments = [
+ tracker_bizobj.MakeBlockingAmendment(
+ [('proj', 2)], [], default_project_name='proj')]
+ self.services.issue.CreateIssueComment(
+ self.cnxn, issue, commenter_id, 'comment text', attachments=None,
+ amendments=amendments, commit=False, is_description=False,
+ kept_attachments=None, importer_id=None, timestamp=ANY,
+ inbound_message=None)
+ # Call to find added blockedon issues.
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+ # Call to find removed blockedon issues.
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+ # Call to find added blocking issues.
+ self.services.issue.GetIssues(
+ self.cnxn, [blocking_issue.issue_id]).AndReturn([blocking_issue])
+ self.services.issue.CreateIssueComment(
+ self.cnxn, blocking_issue, commenter_id, content='',
+ amendments=[tracker_bizobj.MakeBlockedOnAmendment(
+ [(issue.project_name, issue.local_id)], [],
+ default_project_name='proj')],
+ importer_id=None, timestamp=ANY)
+ # Call to find removed blocking issues.
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+ self.services.issue._UpdateIssuesModified(
+ self.cnxn, {issue.issue_id, blocking_issue.issue_id},
+ modified_timestamp=self.now, invalidate=True)
+ self.SetUpEnqueueIssuesForIndexing([78901])
+
+ self.mox.ReplayAll()
+ delta = tracker_pb2.IssueDelta(blocking_add=[blocking_issue.issue_id])
+ self.services.issue.DeltaUpdateIssue(
+ self.cnxn, self.services, commenter_id, issue.project_id, config,
+ issue, delta, comment='comment text',
+ index_now=False, timestamp=self.now)
+ self.mox.VerifyAll()
+
+ def testDeltaUpdateIssue_Imported(self):
+ """If importer_id is specified, store it."""
+ commenter_id = 222
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901, project_name='proj')
+ issue.assume_stale = False
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(789)
+ delta = tracker_pb2.IssueDelta()
+
+ self.mox.StubOutWithMock(self.services.issue, 'GetIssue')
+ self.mox.StubOutWithMock(self.services.issue, 'GetIssues')
+ self.mox.StubOutWithMock(self.services.issue, 'UpdateIssue')
+ self.mox.StubOutWithMock(self.services.issue, 'CreateIssueComment')
+ self.mox.StubOutWithMock(self.services.issue, '_UpdateIssuesModified')
+ self.mox.StubOutWithMock(self.services.issue, "SortBlockedOn")
+ self.services.issue.UpdateIssue(
+ self.cnxn, issue, commit=False, invalidate=False)
+ # Call to find added blockedon issues.
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+ # Call to find removed blockedon issues.
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+ self.services.issue.CreateIssueComment(
+ self.cnxn, issue, commenter_id, 'a comment', attachments=None,
+ amendments=[], commit=False, is_description=False,
+ kept_attachments=None, importer_id=333, timestamp=ANY,
+ inbound_message=None).AndReturn(
+ tracker_pb2.IssueComment(content='a comment', importer_id=333))
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+ self.services.issue.GetIssues(self.cnxn, []).AndReturn([])
+ self.services.issue._UpdateIssuesModified(
+ self.cnxn, {issue.issue_id},
+ modified_timestamp=self.now, invalidate=True)
+ self.SetUpEnqueueIssuesForIndexing([78901])
+ self.mox.ReplayAll()
+
+ amendments, comment_pb = self.services.issue.DeltaUpdateIssue(
+ self.cnxn, self.services, commenter_id, issue.project_id, config,
+ issue, delta, comment='a comment', index_now=False, timestamp=self.now,
+ importer_id=333)
+
+ self.mox.VerifyAll()
+ self.assertEqual([], amendments)
+ self.assertEqual('a comment', comment_pb.content)
+ self.assertEqual(333, comment_pb.importer_id)
+
+ def SetUpMoveIssues_NewProject(self):
+ self.services.issue.issueformerlocations_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUEFORMERLOCATIONS_COLS, project_id=789,
+ issue_id=[78901]).AndReturn([])
+ self.SetUpAllocateNextLocalID(789, None, None)
+ self.SetUpUpdateIssues()
+ self.services.issue.comment_tbl.Update(
+ self.cnxn, {'project_id': 789}, issue_id=[78901], commit=False)
+
+ old_location_rows = [(78901, 711, 2)]
+ self.services.issue.issueformerlocations_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUEFORMERLOCATIONS_COLS, old_location_rows,
+ ignore=True, commit=False)
+ self.cnxn.Commit()
+
+ def testMoveIssues_NewProject(self):
+ """Move project 711 issue 2 to become project 789 issue 1."""
+ dest_project = fake.Project(project_id=789)
+ issue = fake.MakeTestIssue(
+ project_id=711, local_id=2, owner_id=111, summary='sum',
+ status='Live', labels=['Type-Defect'], issue_id=78901,
+ opened_timestamp=123456789, modified_timestamp=123456789,
+ star_count=12)
+ issue.assume_stale = False
+ self.SetUpMoveIssues_NewProject()
+ self.mox.ReplayAll()
+ self.services.issue.MoveIssues(
+ self.cnxn, dest_project, [issue], self.services.user)
+ self.mox.VerifyAll()
+
+ # TODO(jrobbins): case where issue is moved back into former project
+
+ def testExpungeFormerLocations(self):
+ self.services.issue.issueformerlocations_tbl.Delete(
+ self.cnxn, project_id=789)
+
+ self.mox.ReplayAll()
+ self.services.issue.ExpungeFormerLocations(self.cnxn, 789)
+ self.mox.VerifyAll()
+
+ def testExpungeIssues(self):
+ issue_ids = [1, 2]
+
+ self.mox.StubOutWithMock(search, 'Index')
+ search.Index(name=settings.search_index_name_format % 1).AndReturn(
+ MockIndex())
+ search.Index(name=settings.search_index_name_format % 2).AndReturn(
+ MockIndex())
+
+ self.services.issue.issuesummary_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.issue2label_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.issue2component_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.issue2cc_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.issue2notify_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.issueupdate_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.attachment_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.comment_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.issuerelation_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.issuerelation_tbl.Delete(self.cnxn, dst_issue_id=[1, 2])
+ self.services.issue.danglingrelation_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.issueformerlocations_tbl.Delete(
+ self.cnxn, issue_id=[1, 2])
+ self.services.issue.reindexqueue_tbl.Delete(self.cnxn, issue_id=[1, 2])
+ self.services.issue.issue_tbl.Delete(self.cnxn, id=[1, 2])
+
+ self.mox.ReplayAll()
+ self.services.issue.ExpungeIssues(self.cnxn, issue_ids)
+ self.mox.VerifyAll()
+
+ def testSoftDeleteIssue(self):
+ project = fake.Project(project_id=789)
+ issue_1, issue_2 = self.SetUpGetIssues()
+ self.services.issue.issue_2lc = TestableIssueTwoLevelCache(
+ [issue_1, issue_2])
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ delta = {'deleted': True}
+ self.services.issue.issue_tbl.Update(
+ self.cnxn, delta, id=78901, commit=False)
+
+ self.services.chart.StoreIssueSnapshots(self.cnxn, mox.IgnoreArg(),
+ commit=False)
+
+ self.cnxn.Commit()
+ self.mox.ReplayAll()
+ self.services.issue.SoftDeleteIssue(
+ self.cnxn, project.project_id, 1, True, self.services.user)
+ self.mox.VerifyAll()
+ self.assertTrue(issue_1.deleted)
+
+ def SetUpDeleteComponentReferences(self, component_id):
+ self.services.issue.issue2component_tbl.Delete(
+ self.cnxn, component_id=component_id)
+
+ def testDeleteComponentReferences(self):
+ self.SetUpDeleteComponentReferences(123)
+ self.mox.ReplayAll()
+ self.services.issue.DeleteComponentReferences(self.cnxn, 123)
+ self.mox.VerifyAll()
+
+ ### Local ID generation
+
+ def SetUpInitializeLocalID(self, project_id):
+ self.services.issue.localidcounter_tbl.InsertRow(
+ self.cnxn, project_id=project_id, used_local_id=0, used_spam_id=0)
+
+ def testInitializeLocalID(self):
+ self.SetUpInitializeLocalID(789)
+ self.mox.ReplayAll()
+ self.services.issue.InitializeLocalID(self.cnxn, 789)
+ self.mox.VerifyAll()
+
+ def SetUpAllocateNextLocalID(
+ self, project_id, highest_in_use, highest_former):
+ highest_either = max(highest_in_use or 0, highest_former or 0)
+ self.services.issue.localidcounter_tbl.IncrementCounterValue(
+ self.cnxn, 'used_local_id', project_id=project_id).AndReturn(
+ highest_either + 1)
+
+ def testAllocateNextLocalID_NewProject(self):
+ self.SetUpAllocateNextLocalID(789, None, None)
+ self.mox.ReplayAll()
+ next_local_id = self.services.issue.AllocateNextLocalID(self.cnxn, 789)
+ self.mox.VerifyAll()
+ self.assertEqual(1, next_local_id)
+
+ def testAllocateNextLocalID_HighestInUse(self):
+ self.SetUpAllocateNextLocalID(789, 14, None)
+ self.mox.ReplayAll()
+ next_local_id = self.services.issue.AllocateNextLocalID(self.cnxn, 789)
+ self.mox.VerifyAll()
+ self.assertEqual(15, next_local_id)
+
+ def testAllocateNextLocalID_HighestWasMoved(self):
+ self.SetUpAllocateNextLocalID(789, 23, 66)
+ self.mox.ReplayAll()
+ next_local_id = self.services.issue.AllocateNextLocalID(self.cnxn, 789)
+ self.mox.VerifyAll()
+ self.assertEqual(67, next_local_id)
+
+ def SetUpGetHighestLocalID(self, project_id, highest_in_use, highest_former):
+ self.services.issue.issue_tbl.SelectValue(
+ self.cnxn, 'MAX(local_id)', project_id=project_id).AndReturn(
+ highest_in_use)
+ self.services.issue.issueformerlocations_tbl.SelectValue(
+ self.cnxn, 'MAX(local_id)', project_id=project_id).AndReturn(
+ highest_former)
+
+ def testGetHighestLocalID_OnlyActiveLocalIDs(self):
+ self.SetUpGetHighestLocalID(789, 14, None)
+ self.mox.ReplayAll()
+ highest_id = self.services.issue.GetHighestLocalID(self.cnxn, 789)
+ self.mox.VerifyAll()
+ self.assertEqual(14, highest_id)
+
+ def testGetHighestLocalID_OnlyFormerIDs(self):
+ self.SetUpGetHighestLocalID(789, None, 97)
+ self.mox.ReplayAll()
+ highest_id = self.services.issue.GetHighestLocalID(self.cnxn, 789)
+ self.mox.VerifyAll()
+ self.assertEqual(97, highest_id)
+
+ def testGetHighestLocalID_BothActiveAndFormer(self):
+ self.SetUpGetHighestLocalID(789, 345, 97)
+ self.mox.ReplayAll()
+ highest_id = self.services.issue.GetHighestLocalID(self.cnxn, 789)
+ self.mox.VerifyAll()
+ self.assertEqual(345, highest_id)
+
+ def testGetAllLocalIDsInProject(self):
+ self.SetUpGetHighestLocalID(789, 14, None)
+ self.mox.ReplayAll()
+ local_id_range = self.services.issue.GetAllLocalIDsInProject(self.cnxn, 789)
+ self.mox.VerifyAll()
+ self.assertEqual(list(range(1, 15)), local_id_range)
+
+ ### Comments
+
+ def testConsolidateAmendments_Empty(self):
+ amendments = []
+ actual = self.services.issue._ConsolidateAmendments(amendments)
+ self.assertEqual([], actual)
+
+ def testConsolidateAmendments_NoOp(self):
+ amendments = [
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('SUMMARY'),
+ oldvalue='old sum', newvalue='new sum'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('STATUS'),
+ oldvalue='New', newvalue='Accepted')]
+ actual = self.services.issue._ConsolidateAmendments(amendments)
+ self.assertEqual(amendments, actual)
+
+ def testConsolidateAmendments_StandardFields(self):
+ amendments = [
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('STATUS'),
+ oldvalue='New'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('STATUS'),
+ newvalue='Accepted'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('SUMMARY'),
+ oldvalue='old sum'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('SUMMARY'),
+ newvalue='new sum')]
+ actual = self.services.issue._ConsolidateAmendments(amendments)
+
+ expected = [
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('SUMMARY'),
+ oldvalue='old sum', newvalue='new sum'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('STATUS'),
+ oldvalue='New', newvalue='Accepted')]
+ self.assertEqual(expected, actual)
+
+ def testConsolidateAmendments_BlockerRelations(self):
+ amendments = [
+ tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID('BLOCKEDON'), newvalue='78901'),
+ tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID('BLOCKEDON'), newvalue='-b/3 b/1 b/2'),
+ tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID('BLOCKING'), newvalue='78902'),
+ tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID('BLOCKING'), newvalue='-b/33 b/11 b/22')
+ ]
+
+ actual = self.services.issue._ConsolidateAmendments(amendments)
+
+ expected = [
+ tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID('BLOCKEDON'),
+ newvalue='78901 -b/3 b/1 b/2'),
+ tracker_pb2.Amendment(
+ field=tracker_pb2.FieldID('BLOCKING'),
+ newvalue='78902 -b/33 b/11 b/22')
+ ]
+ self.assertEqual(expected, actual)
+
+ def testConsolidateAmendments_CustomFields(self):
+ amendments = [
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('CUSTOM'),
+ custom_field_name='a', oldvalue='old a'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('CUSTOM'),
+ custom_field_name='b', oldvalue='old b')]
+ actual = self.services.issue._ConsolidateAmendments(amendments)
+ self.assertEqual(amendments, actual)
+
+ def testConsolidateAmendments_SortAmmendments(self):
+ amendments = [
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('STATUS'),
+ oldvalue='New', newvalue='Accepted'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('SUMMARY'),
+ oldvalue='old sum', newvalue='new sum'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('LABELS'),
+ oldvalue='Type-Defect', newvalue='-Type-Defect Type-Enhancement'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('CC'),
+ oldvalue='a@google.com', newvalue='b@google.com')]
+ expected = [
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('SUMMARY'),
+ oldvalue='old sum', newvalue='new sum'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('STATUS'),
+ oldvalue='New', newvalue='Accepted'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('CC'),
+ oldvalue='a@google.com', newvalue='b@google.com'),
+ tracker_pb2.Amendment(field=tracker_pb2.FieldID('LABELS'),
+ oldvalue='Type-Defect', newvalue='-Type-Defect Type-Enhancement')]
+ actual = self.services.issue._ConsolidateAmendments(amendments)
+ self.assertEqual(expected, actual)
+
+ def testDeserializeComments_Empty(self):
+ comments = self.services.issue._DeserializeComments([], [], [], [], [], [])
+ self.assertEqual([], comments)
+
+ def SetUpCommentRows(self):
+ comment_rows = [
+ (7890101, 78901, self.now, 789, 111,
+ None, False, False, 'unused_commentcontent_id'),
+ (7890102, 78901, self.now, 789, 111,
+ None, False, False, 'unused_commentcontent_id')]
+ commentcontent_rows = [(7890101, 'content', 'msg'),
+ (7890102, 'content2', 'msg')]
+ amendment_rows = [
+ (1, 78901, 7890101, 'cc', 'old', 'new val', 222, None, None)]
+ attachment_rows = []
+ approval_rows = [(23, 7890102)]
+ importer_rows = []
+ return (comment_rows, commentcontent_rows, amendment_rows,
+ attachment_rows, approval_rows, importer_rows)
+
+ def testDeserializeComments_Normal(self):
+ (comment_rows, commentcontent_rows, amendment_rows,
+ attachment_rows, approval_rows, importer_rows) = self.SetUpCommentRows()
+ commentcontent_rows = [(7890101, 'content', 'msg')]
+ comments = self.services.issue._DeserializeComments(
+ comment_rows, commentcontent_rows, amendment_rows, attachment_rows,
+ approval_rows, importer_rows)
+ self.assertEqual(2, len(comments))
+
+ def testDeserializeComments_Imported(self):
+ (comment_rows, commentcontent_rows, amendment_rows,
+ attachment_rows, approval_rows, _) = self.SetUpCommentRows()
+ importer_rows = [(7890101, 222)]
+ commentcontent_rows = [(7890101, 'content', 'msg')]
+ comments = self.services.issue._DeserializeComments(
+ comment_rows, commentcontent_rows, amendment_rows, attachment_rows,
+ approval_rows, importer_rows)
+ self.assertEqual(2, len(comments))
+ self.assertEqual(222, comments[0].importer_id)
+
+ def MockTheRestOfGetCommentsByID(self, comment_ids):
+ self.services.issue.commentcontent_tbl.Select = Mock(
+ return_value=[
+ (cid + 5000, 'content', None) for cid in comment_ids])
+ self.services.issue.issueupdate_tbl.Select = Mock(
+ return_value=[])
+ self.services.issue.attachment_tbl.Select = Mock(
+ return_value=[])
+ self.services.issue.issueapproval2comment_tbl.Select = Mock(
+ return_value=[])
+ self.services.issue.commentimporter_tbl.Select = Mock(
+ return_value=[])
+
+ def testGetCommentsByID_Normal(self):
+ """We can load comments by comment_ids."""
+ comment_ids = [101001, 101002, 101003]
+ self.services.issue.comment_tbl.Select = Mock(
+ return_value=[
+ (cid, cid - cid % 100, self.now, 789, 111,
+ None, False, False, cid + 5000)
+ for cid in comment_ids])
+ self.MockTheRestOfGetCommentsByID(comment_ids)
+
+ comments = self.services.issue.GetCommentsByID(
+ self.cnxn, comment_ids, [0, 1, 2])
+
+ self.services.issue.comment_tbl.Select.assert_called_with(
+ self.cnxn, cols=issue_svc.COMMENT_COLS,
+ id=comment_ids, shard_id=ANY)
+
+ self.assertEqual(3, len(comments))
+
+ def testGetCommentsByID_CacheReplicationLag(self):
+ self._testGetCommentsByID_ReplicationLag(True)
+
+ def testGetCommentsByID_NoCacheReplicationLag(self):
+ self._testGetCommentsByID_ReplicationLag(False)
+
+ def _testGetCommentsByID_ReplicationLag(self, use_cache):
+ """If not all comments are on the replica, we try the primary DB."""
+ comment_ids = [101001, 101002, 101003]
+ replica_comment_ids = comment_ids[:-1]
+
+ return_value_1 = [
+ (cid, cid - cid % 100, self.now, 789, 111,
+ None, False, False, cid + 5000)
+ for cid in replica_comment_ids]
+ return_value_2 = [
+ (cid, cid - cid % 100, self.now, 789, 111,
+ None, False, False, cid + 5000)
+ for cid in comment_ids]
+ return_values = [return_value_1, return_value_2]
+ self.services.issue.comment_tbl.Select = Mock(
+ side_effect=lambda *_args, **_kwargs: return_values.pop(0))
+
+ self.MockTheRestOfGetCommentsByID(comment_ids)
+
+ comments = self.services.issue.GetCommentsByID(
+ self.cnxn, comment_ids, [0, 1, 2], use_cache=use_cache)
+
+ self.services.issue.comment_tbl.Select.assert_called_with(
+ self.cnxn, cols=issue_svc.COMMENT_COLS,
+ id=comment_ids, shard_id=ANY)
+ self.services.issue.comment_tbl.Select.assert_called_with(
+ self.cnxn, cols=issue_svc.COMMENT_COLS,
+ id=comment_ids, shard_id=ANY)
+ self.assertEqual(3, len(comments))
+
+ def SetUpGetComments(self, issue_ids):
+ # Assumes one comment per issue.
+ cids = [issue_id + 1000 for issue_id in issue_ids]
+ self.services.issue.comment_tbl.Select(
+ self.cnxn, cols=issue_svc.COMMENT_COLS,
+ where=None, issue_id=issue_ids, order_by=[('created', [])],
+ shard_id=mox.IsA(int)).AndReturn([
+ (issue_id + 1000, issue_id, self.now, 789, 111,
+ None, False, False, issue_id + 5000)
+ for issue_id in issue_ids])
+ self.services.issue.commentcontent_tbl.Select(
+ self.cnxn, cols=issue_svc.COMMENTCONTENT_COLS,
+ id=[issue_id + 5000 for issue_id in issue_ids],
+ shard_id=mox.IsA(int)).AndReturn([
+ (issue_id + 5000, 'content', None) for issue_id in issue_ids])
+ self.services.issue.issueapproval2comment_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUEAPPROVAL2COMMENT_COLS,
+ comment_id=cids).AndReturn([
+ (23, cid) for cid in cids])
+
+ # Assume no amendments or attachment for now.
+ self.services.issue.issueupdate_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUEUPDATE_COLS,
+ comment_id=cids, shard_id=mox.IsA(int)).AndReturn([])
+ attachment_rows = []
+ if issue_ids:
+ attachment_rows = [
+ (1234, issue_ids[0], cids[0], 'a_filename', 1024, 'text/plain',
+ False, None)]
+
+ self.services.issue.attachment_tbl.Select(
+ self.cnxn, cols=issue_svc.ATTACHMENT_COLS,
+ comment_id=cids, shard_id=mox.IsA(int)).AndReturn(attachment_rows)
+
+ self.services.issue.commentimporter_tbl.Select(
+ self.cnxn, cols=issue_svc.COMMENTIMPORTER_COLS,
+ comment_id=cids, shard_id=mox.IsA(int)).AndReturn([])
+
+ def testGetComments_Empty(self):
+ self.SetUpGetComments([])
+ self.mox.ReplayAll()
+ comments = self.services.issue.GetComments(
+ self.cnxn, issue_id=[])
+ self.mox.VerifyAll()
+ self.assertEqual(0, len(comments))
+
+ def testGetComments_Normal(self):
+ self.SetUpGetComments([100001, 100002])
+ self.mox.ReplayAll()
+ comments = self.services.issue.GetComments(
+ self.cnxn, issue_id=[100001, 100002])
+ self.mox.VerifyAll()
+ self.assertEqual(2, len(comments))
+ self.assertEqual('content', comments[0].content)
+ self.assertEqual('content', comments[1].content)
+ self.assertEqual(23, comments[0].approval_id)
+ self.assertEqual(23, comments[1].approval_id)
+
+ def SetUpGetComment_Found(self, comment_id):
+ # Assumes one comment per issue.
+ commentcontent_id = comment_id * 10
+ self.services.issue.comment_tbl.Select(
+ self.cnxn, cols=issue_svc.COMMENT_COLS,
+ where=None, id=comment_id, order_by=[('created', [])],
+ shard_id=mox.IsA(int)).AndReturn([
+ (comment_id, int(comment_id // 100), self.now, 789, 111,
+ None, False, True, commentcontent_id)])
+ self.services.issue.commentcontent_tbl.Select(
+ self.cnxn, cols=issue_svc.COMMENTCONTENT_COLS,
+ id=[commentcontent_id], shard_id=mox.IsA(int)).AndReturn([
+ (commentcontent_id, 'content', None)])
+ self.services.issue.issueapproval2comment_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUEAPPROVAL2COMMENT_COLS,
+ comment_id=[comment_id]).AndReturn([(23, comment_id)])
+ # Assume no amendments or attachment for now.
+ self.services.issue.issueupdate_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUEUPDATE_COLS,
+ comment_id=[comment_id], shard_id=mox.IsA(int)).AndReturn([])
+ self.services.issue.attachment_tbl.Select(
+ self.cnxn, cols=issue_svc.ATTACHMENT_COLS,
+ comment_id=[comment_id], shard_id=mox.IsA(int)).AndReturn([])
+ self.services.issue.commentimporter_tbl.Select(
+ self.cnxn, cols=issue_svc.COMMENTIMPORTER_COLS,
+ comment_id=[comment_id], shard_id=mox.IsA(int)).AndReturn([])
+
+ def testGetComment_Found(self):
+ self.SetUpGetComment_Found(7890101)
+ self.mox.ReplayAll()
+ comment = self.services.issue.GetComment(self.cnxn, 7890101)
+ self.mox.VerifyAll()
+ self.assertEqual('content', comment.content)
+ self.assertEqual(23, comment.approval_id)
+
+ def SetUpGetComment_Missing(self, comment_id):
+ # Assumes one comment per issue.
+ self.services.issue.comment_tbl.Select(
+ self.cnxn, cols=issue_svc.COMMENT_COLS,
+ where=None, id=comment_id, order_by=[('created', [])],
+ shard_id=mox.IsA(int)).AndReturn([])
+ self.services.issue.commentcontent_tbl.Select(
+ self.cnxn, cols=issue_svc.COMMENTCONTENT_COLS,
+ id=[], shard_id=mox.IsA(int)).AndReturn([])
+ self.services.issue.issueapproval2comment_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUEAPPROVAL2COMMENT_COLS,
+ comment_id=[]).AndReturn([])
+ # Assume no amendments or attachment for now.
+ self.services.issue.issueupdate_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUEUPDATE_COLS,
+ comment_id=[], shard_id=mox.IsA(int)).AndReturn([])
+ self.services.issue.attachment_tbl.Select(
+ self.cnxn, cols=issue_svc.ATTACHMENT_COLS, comment_id=[],
+ shard_id=mox.IsA(int)).AndReturn([])
+ self.services.issue.commentimporter_tbl.Select(
+ self.cnxn, cols=issue_svc.COMMENTIMPORTER_COLS,
+ comment_id=[], shard_id=mox.IsA(int)).AndReturn([])
+
+ def testGetComment_Missing(self):
+ self.SetUpGetComment_Missing(7890101)
+ self.mox.ReplayAll()
+ self.assertRaises(
+ exceptions.NoSuchCommentException,
+ self.services.issue.GetComment, self.cnxn, 7890101)
+ self.mox.VerifyAll()
+
+ def testGetCommentsForIssue(self):
+ issue = fake.MakeTestIssue(789, 1, 'Summary', 'New', 111)
+ self.SetUpGetComments([issue.issue_id])
+ self.mox.ReplayAll()
+ self.services.issue.GetCommentsForIssue(self.cnxn, issue.issue_id)
+ self.mox.VerifyAll()
+
+ def testGetCommentsForIssues(self):
+ self.SetUpGetComments([100001, 100002])
+ self.mox.ReplayAll()
+ self.services.issue.GetCommentsForIssues(
+ self.cnxn, issue_ids=[100001, 100002])
+ self.mox.VerifyAll()
+
+ def SetUpInsertComment(
+ self, comment_id, is_spam=False, is_description=False, approval_id=None,
+ content=None, amendment_rows=None, commit=True):
+ content = content or 'content'
+ commentcontent_id = comment_id * 10
+ self.services.issue.commentcontent_tbl.InsertRow(
+ self.cnxn, content=content,
+ inbound_message=None, commit=False).AndReturn(commentcontent_id)
+ self.services.issue.comment_tbl.InsertRow(
+ self.cnxn, issue_id=78901, created=self.now, project_id=789,
+ commenter_id=111, deleted_by=None, is_spam=is_spam,
+ is_description=is_description, commentcontent_id=commentcontent_id,
+ commit=False).AndReturn(comment_id)
+
+ amendment_rows = amendment_rows or []
+ self.services.issue.issueupdate_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUEUPDATE_COLS[1:], amendment_rows,
+ commit=False)
+
+ attachment_rows = []
+ self.services.issue.attachment_tbl.InsertRows(
+ self.cnxn, issue_svc.ATTACHMENT_COLS[1:], attachment_rows,
+ commit=False)
+
+ if approval_id:
+ self.services.issue.issueapproval2comment_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUEAPPROVAL2COMMENT_COLS,
+ [(approval_id, comment_id)], commit=False)
+
+ if commit:
+ self.cnxn.Commit()
+
+ def testInsertComment(self):
+ self.SetUpInsertComment(7890101, approval_id=23)
+ self.mox.ReplayAll()
+ comment = tracker_pb2.IssueComment(
+ issue_id=78901, timestamp=self.now, project_id=789, user_id=111,
+ content='content', approval_id=23)
+ self.services.issue.InsertComment(self.cnxn, comment, commit=True)
+ self.mox.VerifyAll()
+ self.assertEqual(7890101, comment.id)
+
+ def SetUpUpdateComment(self, comment_id, delta=None):
+ delta = delta or {
+ 'commenter_id': 111,
+ 'deleted_by': 222,
+ 'is_spam': False,
+ }
+ self.services.issue.comment_tbl.Update(
+ self.cnxn, delta, id=comment_id)
+
+ def testUpdateComment(self):
+ self.SetUpUpdateComment(7890101)
+ self.mox.ReplayAll()
+ comment = tracker_pb2.IssueComment(
+ id=7890101, issue_id=78901, timestamp=self.now, project_id=789,
+ user_id=111, content='new content', deleted_by=222,
+ is_spam=False)
+ self.services.issue._UpdateComment(self.cnxn, comment)
+ self.mox.VerifyAll()
+
+ def testMakeIssueComment(self):
+ comment = self.services.issue._MakeIssueComment(
+ 789, 111, 'content', timestamp=self.now, approval_id=23,
+ importer_id=222)
+ self.assertEqual('content', comment.content)
+ self.assertEqual([], comment.amendments)
+ self.assertEqual([], comment.attachments)
+ self.assertEqual(comment.approval_id, 23)
+ self.assertEqual(222, comment.importer_id)
+
+ def testMakeIssueComment_NonAscii(self):
+ _ = self.services.issue._MakeIssueComment(
+ 789, 111, 'content', timestamp=self.now,
+ inbound_message=u'sent by написа')
+
+ def testCreateIssueComment_Normal(self):
+ issue_1, _issue_2 = self.SetUpGetIssues()
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ self.SetUpInsertComment(7890101, approval_id=24)
+ self.mox.ReplayAll()
+ comment = self.services.issue.CreateIssueComment(
+ self.cnxn, issue_1, 111, 'content', timestamp=self.now, approval_id=24)
+ self.mox.VerifyAll()
+ self.assertEqual('content', comment.content)
+
+ def testCreateIssueComment_EditDescription(self):
+ issue_1, _issue_2 = self.SetUpGetIssues()
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ self.services.issue.attachment_tbl.Select(
+ self.cnxn, cols=issue_svc.ATTACHMENT_COLS, id=[123])
+ self.SetUpInsertComment(7890101, is_description=True)
+ self.mox.ReplayAll()
+
+ comment = self.services.issue.CreateIssueComment(
+ self.cnxn, issue_1, 111, 'content', is_description=True,
+ kept_attachments=[123], timestamp=self.now)
+ self.mox.VerifyAll()
+ self.assertEqual('content', comment.content)
+
+ def testCreateIssueComment_Spam(self):
+ issue_1, _issue_2 = self.SetUpGetIssues()
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ self.SetUpInsertComment(7890101, is_spam=True)
+ self.mox.ReplayAll()
+ comment = self.services.issue.CreateIssueComment(
+ self.cnxn, issue_1, 111, 'content', timestamp=self.now, is_spam=True)
+ self.mox.VerifyAll()
+ self.assertEqual('content', comment.content)
+ self.assertTrue(comment.is_spam)
+
+ def testSoftDeleteComment(self):
+ """Deleting a comment with an attachment marks it and updates count."""
+ issue_1, issue_2 = self.SetUpGetIssues()
+ self.services.issue.issue_2lc = TestableIssueTwoLevelCache(
+ [issue_1, issue_2])
+ issue_1.attachment_count = 1
+ issue_1.assume_stale = False
+ comment = tracker_pb2.IssueComment(id=7890101)
+ comment.attachments = [tracker_pb2.Attachment()]
+ self.services.issue.issue_id_2lc.CacheItem((789, 1), 78901)
+ self.SetUpUpdateComment(
+ comment.id, delta={'deleted_by': 222, 'is_spam': False})
+ self.SetUpUpdateIssues(given_delta={'attachment_count': 0})
+ self.SetUpEnqueueIssuesForIndexing([78901])
+ self.mox.ReplayAll()
+ self.services.issue.SoftDeleteComment(
+ self.cnxn, issue_1, comment, 222, self.services.user)
+ self.mox.VerifyAll()
+
+ ### Approvals
+
+ def testGetIssueApproval(self):
+ av_24 = tracker_pb2.ApprovalValue(approval_id=24)
+ av_25 = tracker_pb2.ApprovalValue(approval_id=25)
+ issue_1 = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901, approval_values=[av_24, av_25])
+ issue_1.project_name = 'proj'
+ self.services.issue.issue_2lc.CacheItem(78901, issue_1)
+
+ issue, actual_approval_value = self.services.issue.GetIssueApproval(
+ self.cnxn, issue_1.issue_id, av_24.approval_id)
+
+ self.assertEqual(av_24, actual_approval_value)
+ self.assertEqual(issue, issue_1)
+
+ def testGetIssueApproval_NoSuchApproval(self):
+ issue_1 = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901)
+ issue_1.project_name = 'proj'
+ self.services.issue.issue_2lc.CacheItem(78901, issue_1)
+ self.assertRaises(
+ exceptions.NoSuchIssueApprovalException,
+ self.services.issue.GetIssueApproval,
+ self.cnxn, issue_1.issue_id, 24)
+
+ def testDeltaUpdateIssueApproval(self):
+ config = self.services.config.GetProjectConfig(
+ self.cnxn, 789)
+ config.field_defs = [
+ tracker_pb2.FieldDef(
+ field_id=1, project_id=789, field_name='EstDays',
+ field_type=tracker_pb2.FieldTypes.INT_TYPE,
+ applicable_type=''),
+ tracker_pb2.FieldDef(
+ field_id=2, project_id=789, field_name='Tag',
+ field_type=tracker_pb2.FieldTypes.STR_TYPE,
+ applicable_type=''),
+ ]
+ self.services.config.StoreConfig(self.cnxn, config)
+
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, summary='summary', status='New',
+ owner_id=999, issue_id=78901, labels=['noodle-puppies'])
+ av = tracker_pb2.ApprovalValue(approval_id=23)
+ final_av = tracker_pb2.ApprovalValue(
+ approval_id=23, setter_id=111, set_on=1234,
+ status=tracker_pb2.ApprovalStatus.REVIEW_REQUESTED,
+ approver_ids=[222, 444])
+ labels_add = ['snakes-are']
+ label_id = 1001
+ labels_remove = ['noodle-puppies']
+ amendments = [
+ tracker_bizobj.MakeApprovalStatusAmendment(
+ tracker_pb2.ApprovalStatus.REVIEW_REQUESTED),
+ tracker_bizobj.MakeApprovalApproversAmendment([222, 444], []),
+ tracker_bizobj.MakeFieldAmendment(1, config, [4], []),
+ tracker_bizobj.MakeFieldClearedAmendment(2, config),
+ tracker_bizobj.MakeLabelsAmendment(labels_add, labels_remove)
+ ]
+ approval_delta = tracker_pb2.ApprovalDelta(
+ status=tracker_pb2.ApprovalStatus.REVIEW_REQUESTED,
+ approver_ids_add=[222, 444], set_on=1234,
+ subfield_vals_add=[
+ tracker_bizobj.MakeFieldValue(1, 4, None, None, None, None, False)
+ ],
+ labels_add=labels_add,
+ labels_remove=labels_remove,
+ subfields_clear=[2]
+ )
+
+ self.services.issue.issue2approvalvalue_tbl.Update = Mock()
+ self.services.issue.issueapproval2approver_tbl.Delete = Mock()
+ self.services.issue.issueapproval2approver_tbl.InsertRows = Mock()
+ self.services.issue.issue2fieldvalue_tbl.Delete = Mock()
+ self.services.issue.issue2fieldvalue_tbl.InsertRows = Mock()
+ self.services.issue.issue2label_tbl.Delete = Mock()
+ self.services.issue.issue2label_tbl.InsertRows = Mock()
+ self.services.issue.CreateIssueComment = Mock()
+ self.services.config.LookupLabelID = Mock(return_value=label_id)
+ shard = issue.issue_id % settings.num_logical_shards
+ fv_rows = [(78901, 1, 4, None, None, None, None, False, None, shard)]
+ label_rows = [(78901, label_id, False, shard)]
+
+ self.services.issue.DeltaUpdateIssueApproval(
+ self.cnxn, 111, config, issue, av, approval_delta, 'some comment',
+ attachments=[], commit=False, kept_attachments=[1, 2, 3])
+
+ self.assertEqual(av, final_av)
+
+ self.services.issue.issue2approvalvalue_tbl.Update.assert_called_once_with(
+ self.cnxn,
+ {'status': 'review_requested', 'setter_id': 111, 'set_on': 1234},
+ approval_id=23, issue_id=78901, commit=False)
+ self.services.issue.issueapproval2approver_tbl.\
+ Delete.assert_called_once_with(
+ self.cnxn, issue_id=78901, approval_id=23, commit=False)
+ self.services.issue.issueapproval2approver_tbl.\
+ InsertRows.assert_called_once_with(
+ self.cnxn, issue_svc.ISSUEAPPROVAL2APPROVER_COLS,
+ [(23, 222, 78901), (23, 444, 78901)], commit=False)
+ self.services.issue.issue2fieldvalue_tbl.\
+ Delete.assert_called_once_with(
+ self.cnxn, issue_id=[78901], commit=False)
+ self.services.issue.issue2fieldvalue_tbl.\
+ InsertRows.assert_called_once_with(
+ self.cnxn, issue_svc.ISSUE2FIELDVALUE_COLS + ['issue_shard'],
+ fv_rows, commit=False)
+ self.services.issue.issue2label_tbl.\
+ Delete.assert_called_once_with(
+ self.cnxn, issue_id=[78901], commit=False)
+ self.services.issue.issue2label_tbl.\
+ InsertRows.assert_called_once_with(
+ self.cnxn, issue_svc.ISSUE2LABEL_COLS + ['issue_shard'],
+ label_rows, ignore=True, commit=False)
+ self.services.issue.CreateIssueComment.assert_called_once_with(
+ self.cnxn, issue, 111, 'some comment', amendments=amendments,
+ approval_id=23, is_description=False, attachments=[], commit=False,
+ kept_attachments=[1, 2, 3])
+
+ def testDeltaUpdateIssueApproval_IsDescription(self):
+ config = self.services.config.GetProjectConfig(
+ self.cnxn, 789)
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, summary='summary', status='New',
+ owner_id=999, issue_id=78901)
+ av = tracker_pb2.ApprovalValue(approval_id=23)
+ approval_delta = tracker_pb2.ApprovalDelta()
+
+ self.services.issue.CreateIssueComment = Mock()
+
+ self.services.issue.DeltaUpdateIssueApproval(
+ self.cnxn, 111, config, issue, av, approval_delta, 'better response',
+ is_description=True, commit=False)
+
+ self.services.issue.CreateIssueComment.assert_called_once_with(
+ self.cnxn, issue, 111, 'better response', amendments=[],
+ approval_id=23, is_description=True, attachments=None, commit=False,
+ kept_attachments=None)
+
+ def testUpdateIssueApprovalStatus(self):
+ av = tracker_pb2.ApprovalValue(approval_id=23, setter_id=111, set_on=1234)
+
+ self.services.issue.issue2approvalvalue_tbl.Update(
+ self.cnxn, {'status': 'not_set', 'setter_id': 111, 'set_on': 1234},
+ approval_id=23, issue_id=78901, commit=False)
+
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssueApprovalStatus(
+ self.cnxn, 78901, av.approval_id, av.status,
+ av.setter_id, av.set_on)
+ self.mox.VerifyAll()
+
+ def testUpdateIssueApprovalApprovers(self):
+ self.services.issue.issueapproval2approver_tbl.Delete(
+ self.cnxn, issue_id=78901, approval_id=23, commit=False)
+ self.services.issue.issueapproval2approver_tbl.InsertRows(
+ self.cnxn, issue_svc.ISSUEAPPROVAL2APPROVER_COLS,
+ [(23, 111, 78901), (23, 222, 78901), (23, 444, 78901)], commit=False)
+
+ self.mox.ReplayAll()
+ self.services.issue._UpdateIssueApprovalApprovers(
+ self.cnxn, 78901, 23, [111, 222, 444])
+ self.mox.VerifyAll()
+
+ ### Attachments
+
+ def testGetAttachmentAndContext(self):
+ # TODO(jrobbins): re-implemnent to use Google Cloud Storage.
+ pass
+
+ def SetUpUpdateAttachment(self, comment_id, attachment_id, delta):
+ self.services.issue.attachment_tbl.Update(
+ self.cnxn, delta, id=attachment_id)
+ self.services.issue.comment_2lc.InvalidateKeys(
+ self.cnxn, [comment_id])
+
+
+ def testUpdateAttachment(self):
+ delta = {
+ 'filename': 'a_filename',
+ 'filesize': 1024,
+ 'mimetype': 'text/plain',
+ 'deleted': False,
+ }
+ self.SetUpUpdateAttachment(5678, 1234, delta)
+ self.mox.ReplayAll()
+ attach = tracker_pb2.Attachment(
+ attachment_id=1234, filename='a_filename', filesize=1024,
+ mimetype='text/plain')
+ comment = tracker_pb2.IssueComment(id=5678)
+ self.services.issue._UpdateAttachment(self.cnxn, comment, attach)
+ self.mox.VerifyAll()
+
+ def testStoreAttachmentBlob(self):
+ # TODO(jrobbins): re-implemnent to use Google Cloud Storage.
+ pass
+
+ def testSoftDeleteAttachment(self):
+ issue = fake.MakeTestIssue(789, 1, 'sum', 'New', 111, issue_id=78901)
+ issue.assume_stale = False
+ issue.attachment_count = 1
+
+ comment = tracker_pb2.IssueComment(
+ project_id=789, content='soon to be deleted', user_id=111,
+ issue_id=issue.issue_id)
+ attachment = tracker_pb2.Attachment(
+ attachment_id=1234)
+ comment.attachments.append(attachment)
+
+ self.SetUpUpdateAttachment(179901, 1234, {'deleted': True})
+ self.SetUpUpdateIssues(given_delta={'attachment_count': 0})
+ self.SetUpEnqueueIssuesForIndexing([78901])
+
+ self.mox.ReplayAll()
+ self.services.issue.SoftDeleteAttachment(
+ self.cnxn, issue, comment, 1234, self.services.user)
+ self.mox.VerifyAll()
+
+ ### Reindex queue
+
+ def SetUpEnqueueIssuesForIndexing(self, issue_ids):
+ reindex_rows = [(issue_id,) for issue_id in issue_ids]
+ self.services.issue.reindexqueue_tbl.InsertRows(
+ self.cnxn, ['issue_id'], reindex_rows, ignore=True, commit=True)
+
+ def testEnqueueIssuesForIndexing(self):
+ self.SetUpEnqueueIssuesForIndexing([78901])
+ self.mox.ReplayAll()
+ self.services.issue.EnqueueIssuesForIndexing(self.cnxn, [78901])
+ self.mox.VerifyAll()
+
+ def SetUpReindexIssues(self, issue_ids):
+ self.services.issue.reindexqueue_tbl.Select(
+ self.cnxn, order_by=[('created', [])],
+ limit=50).AndReturn([(issue_id,) for issue_id in issue_ids])
+
+ if issue_ids:
+ _issue_1, _issue_2 = self.SetUpGetIssues()
+ self.services.issue.reindexqueue_tbl.Delete(
+ self.cnxn, issue_id=issue_ids)
+
+ def testReindexIssues_QueueEmpty(self):
+ self.SetUpReindexIssues([])
+ self.mox.ReplayAll()
+ self.services.issue.ReindexIssues(self.cnxn, 50, self.services.user)
+ self.mox.VerifyAll()
+
+ def testReindexIssues_QueueHasTwoIssues(self):
+ self.SetUpReindexIssues([78901, 78902])
+ self.mox.ReplayAll()
+ self.services.issue.ReindexIssues(self.cnxn, 50, self.services.user)
+ self.mox.VerifyAll()
+
+ ### Search functions
+
+ def SetUpRunIssueQuery(
+ self, rows, limit=settings.search_limit_per_shard):
+ self.services.issue.issue_tbl.Select(
+ self.cnxn, shard_id=1, distinct=True, cols=['Issue.id'],
+ left_joins=[], where=[('Issue.deleted = %s', [False])], order_by=[],
+ limit=limit).AndReturn(rows)
+
+ def testRunIssueQuery_NoResults(self):
+ self.SetUpRunIssueQuery([])
+ self.mox.ReplayAll()
+ result_iids, capped = self.services.issue.RunIssueQuery(
+ self.cnxn, [], [], [], shard_id=1)
+ self.mox.VerifyAll()
+ self.assertEqual([], result_iids)
+ self.assertFalse(capped)
+
+ def testRunIssueQuery_Normal(self):
+ self.SetUpRunIssueQuery([(1,), (11,), (21,)])
+ self.mox.ReplayAll()
+ result_iids, capped = self.services.issue.RunIssueQuery(
+ self.cnxn, [], [], [], shard_id=1)
+ self.mox.VerifyAll()
+ self.assertEqual([1, 11, 21], result_iids)
+ self.assertFalse(capped)
+
+ def testRunIssueQuery_Capped(self):
+ try:
+ orig = settings.search_limit_per_shard
+ settings.search_limit_per_shard = 3
+ self.SetUpRunIssueQuery([(1,), (11,), (21,)], limit=3)
+ self.mox.ReplayAll()
+ result_iids, capped = self.services.issue.RunIssueQuery(
+ self.cnxn, [], [], [], shard_id=1)
+ self.mox.VerifyAll()
+ self.assertEqual([1, 11, 21], result_iids)
+ self.assertTrue(capped)
+ finally:
+ settings.search_limit_per_shard = orig
+
+ def SetUpGetIIDsByLabelIDs(self):
+ self.services.issue.issue_tbl.Select(
+ self.cnxn, shard_id=1, cols=['id'],
+ left_joins=[('Issue2Label ON Issue.id = Issue2Label.issue_id', [])],
+ label_id=[123, 456], project_id=789,
+ where=[('shard = %s', [1])]
+ ).AndReturn([(1,), (2,), (3,)])
+
+ def testGetIIDsByLabelIDs(self):
+ self.SetUpGetIIDsByLabelIDs()
+ self.mox.ReplayAll()
+ iids = self.services.issue.GetIIDsByLabelIDs(self.cnxn, [123, 456], 789, 1)
+ self.mox.VerifyAll()
+ self.assertEqual([1, 2, 3], iids)
+
+ def testGetIIDsByLabelIDsWithEmptyLabelIds(self):
+ self.mox.ReplayAll()
+ iids = self.services.issue.GetIIDsByLabelIDs(self.cnxn, [], 789, 1)
+ self.mox.VerifyAll()
+ self.assertEqual([], iids)
+
+ def SetUpGetIIDsByParticipant(self):
+ self.services.issue.issue_tbl.Select(
+ self.cnxn, shard_id=1, cols=['id'],
+ reporter_id=[111, 888],
+ where=[('shard = %s', [1]), ('Issue.project_id IN (%s)', [789])]
+ ).AndReturn([(1,)])
+ self.services.issue.issue_tbl.Select(
+ self.cnxn, shard_id=1, cols=['id'],
+ owner_id=[111, 888],
+ where=[('shard = %s', [1]), ('Issue.project_id IN (%s)', [789])]
+ ).AndReturn([(2,)])
+ self.services.issue.issue_tbl.Select(
+ self.cnxn, shard_id=1, cols=['id'],
+ derived_owner_id=[111, 888],
+ where=[('shard = %s', [1]), ('Issue.project_id IN (%s)', [789])]
+ ).AndReturn([(3,)])
+ self.services.issue.issue_tbl.Select(
+ self.cnxn, shard_id=1, cols=['id'],
+ left_joins=[('Issue2Cc ON Issue2Cc.issue_id = Issue.id', [])],
+ cc_id=[111, 888],
+ where=[('shard = %s', [1]), ('Issue.project_id IN (%s)', [789]),
+ ('cc_id IS NOT NULL', [])]
+ ).AndReturn([(4,)])
+ self.services.issue.issue_tbl.Select(
+ self.cnxn, shard_id=1, cols=['Issue.id'],
+ left_joins=[
+ ('Issue2FieldValue ON Issue.id = Issue2FieldValue.issue_id', []),
+ ('FieldDef ON Issue2FieldValue.field_id = FieldDef.id', [])],
+ user_id=[111, 888], grants_perm='View',
+ where=[('shard = %s', [1]), ('Issue.project_id IN (%s)', [789]),
+ ('user_id IS NOT NULL', [])]
+ ).AndReturn([(5,)])
+
+ def testGetIIDsByParticipant(self):
+ self.SetUpGetIIDsByParticipant()
+ self.mox.ReplayAll()
+ iids = self.services.issue.GetIIDsByParticipant(
+ self.cnxn, [111, 888], [789], 1)
+ self.mox.VerifyAll()
+ self.assertEqual([1, 2, 3, 4, 5], iids)
+
+ ### Issue Dependency reranking
+
+ def testSortBlockedOn(self):
+ issue = self.SetUpSortBlockedOn()
+ self.mox.ReplayAll()
+ ret = self.services.issue.SortBlockedOn(
+ self.cnxn, issue, issue.blocked_on_iids)
+ self.mox.VerifyAll()
+ self.assertEqual(ret, ([78902, 78903], [20, 10]))
+
+ def SetUpSortBlockedOn(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, owner_id=111, summary='sum',
+ status='Live', issue_id=78901)
+ issue.project_name = 'proj'
+ issue.blocked_on_iids = [78902, 78903]
+ issue.blocked_on_ranks = [20, 10]
+ self.services.issue.issue_2lc.CacheItem(78901, issue)
+ blocked_on_rows = (
+ (78901, 78902, 'blockedon', 20), (78901, 78903, 'blockedon', 10))
+ self.services.issue.issuerelation_tbl.Select(
+ self.cnxn, cols=issue_svc.ISSUERELATION_COLS,
+ issue_id=issue.issue_id, dst_issue_id=issue.blocked_on_iids,
+ kind='blockedon',
+ order_by=[('rank DESC', []), ('dst_issue_id', [])]).AndReturn(
+ blocked_on_rows)
+ return issue
+
+ def testApplyIssueRerank(self):
+ blocker_ids = [78902, 78903]
+ relations_to_change = list(zip(blocker_ids, [20, 10]))
+ self.services.issue.issuerelation_tbl.Delete(
+ self.cnxn, issue_id=78901, dst_issue_id=blocker_ids, commit=False)
+ insert_rows = [(78901, blocker_id, 'blockedon', rank)
+ for blocker_id, rank in relations_to_change]
+ self.services.issue.issuerelation_tbl.InsertRows(
+ self.cnxn, cols=issue_svc.ISSUERELATION_COLS, row_values=insert_rows,
+ commit=True)
+
+ self.mox.StubOutWithMock(self.services.issue, "InvalidateIIDs")
+
+ self.services.issue.InvalidateIIDs(self.cnxn, [78901])
+ self.mox.ReplayAll()
+ self.services.issue.ApplyIssueRerank(self.cnxn, 78901, relations_to_change)
+ self.mox.VerifyAll()
+
+ def testExpungeUsersInIssues(self):
+ comment_id_rows = [(12, 78901, 112), (13, 78902, 113)]
+ comment_ids = [12, 13]
+ content_ids = [112, 113]
+ self.services.issue.comment_tbl.Select = Mock(
+ return_value=comment_id_rows)
+ self.services.issue.commentcontent_tbl.Update = Mock()
+ self.services.issue.comment_tbl.Update = Mock()
+
+ fv_issue_id_rows = [(78902,), (78903,), (78904,)]
+ self.services.issue.issue2fieldvalue_tbl.Select = Mock(
+ return_value=fv_issue_id_rows)
+ self.services.issue.issue2fieldvalue_tbl.Delete = Mock()
+ self.services.issue.issueapproval2approver_tbl.Delete = Mock()
+ self.services.issue.issue2approvalvalue_tbl.Update = Mock()
+
+ self.services.issue.issueupdate_tbl.Update = Mock()
+
+ self.services.issue.issue2notify_tbl.Delete = Mock()
+
+ cc_issue_id_rows = [(78904,), (78905,), (78906,)]
+ self.services.issue.issue2cc_tbl.Select = Mock(
+ return_value=cc_issue_id_rows)
+ self.services.issue.issue2cc_tbl.Delete = Mock()
+ owner_issue_id_rows = [(78907,), (78908,), (78909,)]
+ derived_owner_issue_id_rows = [(78910,), (78911,), (78912,)]
+ reporter_issue_id_rows = [(78912,), (78913,)]
+ self.services.issue.issue_tbl.Select = Mock(
+ side_effect=[owner_issue_id_rows, derived_owner_issue_id_rows,
+ reporter_issue_id_rows])
+ self.services.issue.issue_tbl.Update = Mock()
+
+ self.services.issue.issuesnapshot_tbl.Update = Mock()
+ self.services.issue.issuesnapshot2cc_tbl.Delete = Mock()
+
+ emails = ['cow@farm.com', 'pig@farm.com', 'chicken@farm.com']
+ user_ids = [222, 888, 444]
+ user_ids_by_email = {
+ email: user_id for user_id, email in zip(user_ids, emails)}
+ commit = False
+ limit = 50
+
+ affected_user_ids = self.services.issue.ExpungeUsersInIssues(
+ self.cnxn, user_ids_by_email, limit=limit)
+ self.assertItemsEqual(
+ affected_user_ids,
+ [78901, 78902, 78903, 78904, 78905, 78906, 78907, 78908, 78909,
+ 78910, 78911, 78912, 78913])
+
+ self.services.issue.comment_tbl.Select.assert_called_once()
+ _cnxn, kwargs = self.services.issue.comment_tbl.Select.call_args
+ self.assertEqual(
+ kwargs['cols'], ['Comment.id', 'Comment.issue_id', 'commentcontent_id'])
+ self.assertItemsEqual(kwargs['commenter_id'], user_ids)
+ self.assertEqual(kwargs['limit'], limit)
+
+ # since user_ids are passed to ExpungeUsersInIssues via a dictionary,
+ # we cannot know the order of the user_ids list that the method
+ # ends up using. To be able to use assert_called_with()
+ # rather than extract call_args, we are saving the order of user_ids
+ # used by the method after confirming that it has the correct items.
+ user_ids = kwargs['commenter_id']
+
+ self.services.issue.commentcontent_tbl.Update.assert_called_once_with(
+ self.cnxn, {'inbound_message': None}, id=content_ids, commit=commit)
+ self.assertEqual(
+ len(self.services.issue.comment_tbl.Update.call_args_list), 2)
+ self.services.issue.comment_tbl.Update.assert_any_call(
+ self.cnxn, {'commenter_id': framework_constants.DELETED_USER_ID},
+ id=comment_ids, commit=False)
+ self.services.issue.comment_tbl.Update.assert_any_call(
+ self.cnxn, {'deleted_by': framework_constants.DELETED_USER_ID},
+ deleted_by=user_ids, commit=False, limit=limit)
+
+ # field values
+ self.services.issue.issue2fieldvalue_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['issue_id'], user_id=user_ids, limit=limit)
+ self.services.issue.issue2fieldvalue_tbl.Delete.assert_called_once_with(
+ self.cnxn, user_id=user_ids, limit=limit, commit=commit)
+
+ # approval values
+ self.services.issue.issueapproval2approver_tbl.\
+Delete.assert_called_once_with(
+ self.cnxn, approver_id=user_ids, commit=commit, limit=limit)
+ self.services.issue.issue2approvalvalue_tbl.Update.assert_called_once_with(
+ self.cnxn, {'setter_id': framework_constants.DELETED_USER_ID},
+ setter_id=user_ids, commit=commit, limit=limit)
+
+ # issue ccs
+ self.services.issue.issue2cc_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['issue_id'], cc_id=user_ids, limit=limit)
+ self.services.issue.issue2cc_tbl.Delete.assert_called_once_with(
+ self.cnxn, cc_id=user_ids, limit=limit, commit=commit)
+
+ # issue owners
+ self.services.issue.issue_tbl.Select.assert_any_call(
+ self.cnxn, cols=['id'], owner_id=user_ids, limit=limit)
+ self.services.issue.issue_tbl.Update.assert_any_call(
+ self.cnxn, {'owner_id': None},
+ id=[row[0] for row in owner_issue_id_rows], commit=commit)
+ self.services.issue.issue_tbl.Select.assert_any_call(
+ self.cnxn, cols=['id'], derived_owner_id=user_ids, limit=limit)
+ self.services.issue.issue_tbl.Update.assert_any_call(
+ self.cnxn, {'derived_owner_id': None},
+ id=[row[0] for row in derived_owner_issue_id_rows], commit=commit)
+
+ # issue reporter
+ self.services.issue.issue_tbl.Select.assert_any_call(
+ self.cnxn, cols=['id'], reporter_id=user_ids, limit=limit)
+ self.services.issue.issue_tbl.Update.assert_any_call(
+ self.cnxn, {'reporter_id': framework_constants.DELETED_USER_ID},
+ id=[row[0] for row in reporter_issue_id_rows], commit=commit)
+
+ self.assertEqual(
+ 3, len(self.services.issue.issue_tbl.Update.call_args_list))
+
+ # issue updates
+ self.services.issue.issueupdate_tbl.Update.assert_any_call(
+ self.cnxn, {'added_user_id': framework_constants.DELETED_USER_ID},
+ added_user_id=user_ids, commit=commit)
+ self.services.issue.issueupdate_tbl.Update.assert_any_call(
+ self.cnxn, {'removed_user_id': framework_constants.DELETED_USER_ID},
+ removed_user_id=user_ids, commit=commit)
+ self.assertEqual(
+ 2, len(self.services.issue.issueupdate_tbl.Update.call_args_list))
+
+ # issue notify
+ call_args_list = self.services.issue.issue2notify_tbl.Delete.call_args_list
+ self.assertEqual(1, len(call_args_list))
+ _cnxn, kwargs = call_args_list[0]
+ self.assertItemsEqual(kwargs['email'], emails)
+ self.assertEqual(kwargs['commit'], commit)
+
+ # issue snapshots
+ self.services.issue.issuesnapshot_tbl.Update.assert_any_call(
+ self.cnxn, {'owner_id': framework_constants.DELETED_USER_ID},
+ owner_id=user_ids, commit=commit, limit=limit)
+ self.services.issue.issuesnapshot_tbl.Update.assert_any_call(
+ self.cnxn, {'reporter_id': framework_constants.DELETED_USER_ID},
+ reporter_id=user_ids, commit=commit, limit=limit)
+ self.assertEqual(
+ 2, len(self.services.issue.issuesnapshot_tbl.Update.call_args_list))
+
+ self.services.issue.issuesnapshot2cc_tbl.Delete.assert_called_once_with(
+ self.cnxn, cc_id=user_ids, commit=commit, limit=limit)
diff --git a/services/test/ml_helpers_test.py b/services/test/ml_helpers_test.py
new file mode 100644
index 0000000..45a29cc
--- /dev/null
+++ b/services/test/ml_helpers_test.py
@@ -0,0 +1,120 @@
+# coding=utf-8
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+
+import io
+import unittest
+
+from services import ml_helpers
+
+
+NUM_WORD_HASHES = 5
+
+TOP_WORDS = {'cat': 0, 'dog': 1, 'bunny': 2, 'chinchilla': 3, 'hamster': 4}
+NUM_COMPONENT_FEATURES = len(TOP_WORDS)
+
+
+class MLHelpersTest(unittest.TestCase):
+
+ def testSpamHashFeatures(self):
+ hashes = ml_helpers._SpamHashFeatures(tuple(), NUM_WORD_HASHES)
+ self.assertEqual([0, 0, 0, 0, 0], hashes)
+
+ hashes = ml_helpers._SpamHashFeatures(('', ''), NUM_WORD_HASHES)
+ self.assertEqual([1.0, 0, 0, 0, 0], hashes)
+
+ hashes = ml_helpers._SpamHashFeatures(('abc', 'abc def'), NUM_WORD_HASHES)
+ self.assertEqual([0, 0, 2 / 3, 0, 1 / 3], hashes)
+
+ def testComponentFeatures(self):
+
+ features = ml_helpers._ComponentFeatures(['cat dog is not bunny'
+ ' chinchilla hamster'],
+ NUM_COMPONENT_FEATURES,
+ TOP_WORDS)
+ self.assertEqual([1, 1, 1, 1, 1], features)
+
+ features = ml_helpers._ComponentFeatures(['none of these are features'],
+ NUM_COMPONENT_FEATURES,
+ TOP_WORDS)
+ self.assertEqual([0, 0, 0, 0, 0], features)
+
+ features = ml_helpers._ComponentFeatures(['do hamsters look like a'
+ ' chinchilla'],
+ NUM_COMPONENT_FEATURES,
+ TOP_WORDS)
+ self.assertEqual([0, 0, 0, 1, 0], features)
+
+ features = ml_helpers._ComponentFeatures([''],
+ NUM_COMPONENT_FEATURES,
+ TOP_WORDS)
+ self.assertEqual([0, 0, 0, 0, 0], features)
+
+ def testGenerateFeaturesRaw(self):
+
+ features = ml_helpers.GenerateFeaturesRaw(
+ ['abc', 'abc def http://www.google.com http://www.google.com'],
+ NUM_WORD_HASHES)
+ self.assertEqual(
+ [1 / 2.75, 0.0, 1 / 5.5, 0.0, 1 / 2.2], features['word_hashes'])
+
+ features = ml_helpers.GenerateFeaturesRaw(['abc', 'abc def'],
+ NUM_WORD_HASHES)
+ self.assertEqual([0.0, 0.0, 2 / 3, 0.0, 1 / 3], features['word_hashes'])
+
+ features = ml_helpers.GenerateFeaturesRaw(['do hamsters look like a'
+ ' chinchilla'],
+ NUM_COMPONENT_FEATURES,
+ TOP_WORDS)
+ self.assertEqual([0, 0, 0, 1, 0], features['word_features'])
+
+ # BMP Unicode
+ features = ml_helpers.GenerateFeaturesRaw(
+ [u'abc’', u'abc ’ def'], NUM_WORD_HASHES)
+ self.assertEqual([0.0, 0.0, 0.25, 0.25, 0.5], features['word_hashes'])
+
+ # Non-BMP Unicode
+ features = ml_helpers.GenerateFeaturesRaw([u'abc國', u'abc 國 def'],
+ NUM_WORD_HASHES)
+ self.assertEqual([0.0, 0.0, 0.25, 0.25, 0.5], features['word_hashes'])
+
+ # A non-unicode bytestring containing unicode characters
+ features = ml_helpers.GenerateFeaturesRaw(['abc…', 'abc … def'],
+ NUM_WORD_HASHES)
+ self.assertEqual([0.25, 0.0, 0.25, 0.25, 0.25], features['word_hashes'])
+
+ # Empty input
+ features = ml_helpers.GenerateFeaturesRaw(['', ''], NUM_WORD_HASHES)
+ self.assertEqual([1.0, 0.0, 0.0, 0.0, 0.0], features['word_hashes'])
+
+ def test_from_file(self):
+ csv_file = io.StringIO(
+ u'''
+ "spam","the subject 1","the contents 1","spammer@gmail.com"
+ "ham","the subject 2"
+ "spam","the subject 3","the contents 2","spammer2@gmail.com"
+ '''.strip())
+ samples, skipped = ml_helpers.spam_from_file(csv_file)
+ self.assertEqual(len(samples), 2)
+ self.assertEqual(skipped, 1)
+ self.assertEqual(len(samples[1]), 3, 'Strips email')
+ self.assertEqual(samples[1][2], 'the contents 2')
+
+ def test_transform_csv_to_features(self):
+ training_data = [
+ ['spam', 'subject 1', 'contents 1'],
+ ['ham', 'subject 2', 'contents 2'],
+ ['spam', 'subject 3', 'contents 3'],
+ ]
+ X, y = ml_helpers.transform_spam_csv_to_features(training_data)
+
+ self.assertIsInstance(X, list)
+ self.assertIsInstance(X[0], dict)
+ self.assertIsInstance(y, list)
+
+ self.assertEqual(len(X), 3)
+ self.assertEqual(len(y), 3)
+
+ self.assertEqual(len(X[0]['word_hashes']), 500)
+ self.assertEqual(y, [1, 0, 1])
diff --git a/services/test/project_svc_test.py b/services/test/project_svc_test.py
new file mode 100644
index 0000000..2eb7a2b
--- /dev/null
+++ b/services/test/project_svc_test.py
@@ -0,0 +1,631 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the project_svc module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import time
+import unittest
+
+import mox
+import mock
+
+from google.appengine.ext import testbed
+
+from framework import framework_constants
+from framework import sql
+from proto import project_pb2
+from proto import user_pb2
+from services import config_svc
+from services import project_svc
+from testing import fake
+
+NOW = 12345678
+
+
+def MakeProjectService(cache_manager, my_mox):
+ project_service = project_svc.ProjectService(cache_manager)
+ project_service.project_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ project_service.user2project_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ project_service.extraperm_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ project_service.membernotes_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ project_service.usergroupprojects_tbl = my_mox.CreateMock(
+ sql.SQLTableManager)
+ project_service.acexclusion_tbl = my_mox.CreateMock(
+ sql.SQLTableManager)
+ return project_service
+
+
+class ProjectTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.cache_manager = fake.CacheManager()
+ self.project_service = MakeProjectService(self.cache_manager, self.mox)
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testDeserializeProjects(self):
+ project_rows = [
+ (
+ 123, 'proj1', 'test proj 1', 'test project', 'live', 'anyone', '',
+ '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, None, None, None, None, None, None, False),
+ (
+ 234, 'proj2', 'test proj 2', 'test project', 'live', 'anyone', '',
+ '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, None, None, None, None, None, None, True)
+ ]
+ role_rows = [
+ (123, 111, 'owner'), (123, 444, 'owner'),
+ (123, 222, 'committer'),
+ (123, 333, 'contributor'),
+ (234, 111, 'owner')]
+ extraperm_rows = []
+
+ project_dict = self.project_service.project_2lc._DeserializeProjects(
+ project_rows, role_rows, extraperm_rows)
+
+ self.assertItemsEqual([123, 234], list(project_dict.keys()))
+ self.assertEqual(123, project_dict[123].project_id)
+ self.assertEqual('proj1', project_dict[123].project_name)
+ self.assertEqual(NOW, project_dict[123].recent_activity)
+ self.assertItemsEqual([111, 444], project_dict[123].owner_ids)
+ self.assertItemsEqual([222], project_dict[123].committer_ids)
+ self.assertItemsEqual([333], project_dict[123].contributor_ids)
+ self.assertEqual(234, project_dict[234].project_id)
+ self.assertItemsEqual([111], project_dict[234].owner_ids)
+ self.assertEqual(False, project_dict[123].issue_notify_always_detailed)
+ self.assertEqual(True, project_dict[234].issue_notify_always_detailed)
+
+
+class UserToProjectIdTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.cnxn = fake.MonorailConnection()
+ self.cache_manager = fake.CacheManager()
+ self.mox = mox.Mox()
+ self.project_service = MakeProjectService(self.cache_manager, self.mox)
+ self.user_to_project_2lc = self.project_service.user_to_project_2lc
+
+ # Set up DB query mocks.
+ self.cached_user_ids = [100, 101]
+ self.from_db_user_ids = [102, 103]
+ test_table = [
+ (900, self.cached_user_ids[0]), # Project 900, User 100
+ (900, self.cached_user_ids[1]), # Project 900, User 101
+ (901, self.cached_user_ids[0]), # Project 901, User 101
+ (902, self.from_db_user_ids[0]), # Project 902, User 102
+ (902, self.from_db_user_ids[1]), # Project 902, User 103
+ (903, self.from_db_user_ids[0]), # Project 903, User 102
+ ]
+ self.project_service.user2project_tbl.Select = mock.Mock(
+ return_value=test_table)
+
+ def tearDown(self):
+ # memcache.flush_all()
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testGetAll(self):
+ # Cache user 100 and 101.
+ self.user_to_project_2lc.CacheItem(self.cached_user_ids[0], set([900, 901]))
+ self.user_to_project_2lc.CacheItem(self.cached_user_ids[1], set([900]))
+ # Test that other project_ids and user_ids get returned by DB queries.
+ first_hit, first_misses = self.user_to_project_2lc.GetAll(
+ self.cnxn, self.cached_user_ids + self.from_db_user_ids)
+
+ self.project_service.user2project_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['project_id', 'user_id'])
+
+ self.assertEqual(
+ first_hit, {
+ 100: set([900, 901]),
+ 101: set([900]),
+ 102: set([902, 903]),
+ 103: set([902]),
+ })
+ self.assertEqual([], first_misses)
+
+ def testGetAllRateLimit(self):
+ test_now = time.time()
+ # Initial request that queries table.
+ self.user_to_project_2lc._GetCurrentTime = mock.Mock(
+ return_value=test_now + 60)
+ self.user_to_project_2lc.GetAll(
+ self.cnxn, self.cached_user_ids + self.from_db_user_ids)
+
+ # Request a user with no projects right after the last request.
+ self.user_to_project_2lc._GetCurrentTime = mock.Mock(
+ return_value=test_now + 61)
+ second_hit, second_misses = self.user_to_project_2lc.GetAll(
+ self.cnxn, [104])
+
+ # Request one more user without project that should make a DB request
+ # because the required rate limit time has passed.
+ self.user_to_project_2lc._GetCurrentTime = mock.Mock(
+ return_value=test_now + 121)
+ third_hit, third_misses = self.user_to_project_2lc.GetAll(self.cnxn, [105])
+
+ # Queried only twice because the second request was rate limited.
+ self.assertEqual(self.project_service.user2project_tbl.Select.call_count, 2)
+
+ # Rate limited response will not return the full table.
+ self.assertEqual(second_hit, {
+ 104: set([]),
+ })
+ self.assertEqual([], second_misses)
+ self.assertEqual(
+ third_hit, {
+ 100: set([900, 901]),
+ 101: set([900]),
+ 102: set([902, 903]),
+ 103: set([902]),
+ 105: set([]),
+ })
+ self.assertEqual([], third_misses)
+
+
+class ProjectServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.cache_manager = fake.CacheManager()
+ self.config_service = self.mox.CreateMock(config_svc.ConfigService)
+ self.project_service = MakeProjectService(self.cache_manager, self.mox)
+
+ self.proj1 = fake.Project(project_name='proj1', project_id=123)
+ self.proj2 = fake.Project(project_name='proj2', project_id=234)
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def SetUpCreateProject(self):
+ # Check for existing project: there should be none.
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_name', 'project_id'],
+ project_name=['proj1']).AndReturn([])
+
+ # Inserting the project gives the project ID.
+ self.project_service.project_tbl.InsertRow(
+ self.cnxn, project_name='proj1',
+ summary='Test project summary', description='Test project description',
+ home_page=None, docs_url=None, source_url=None,
+ logo_file_name=None, logo_gcs_id=None,
+ state='LIVE', access='ANYONE').AndReturn(123)
+
+ # Insert the users. There are none.
+ self.project_service.user2project_tbl.InsertRows(
+ self.cnxn, ['project_id', 'user_id', 'role_name'], [])
+
+ def testCreateProject(self):
+ self.SetUpCreateProject()
+ self.mox.ReplayAll()
+ self.project_service.CreateProject(
+ self.cnxn, 'proj1', owner_ids=[], committer_ids=[], contributor_ids=[],
+ summary='Test project summary', description='Test project description')
+ self.mox.VerifyAll()
+
+ def SetUpLookupProjectIDs(self):
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_name', 'project_id'],
+ project_name=['proj2']).AndReturn([('proj2', 234)])
+
+ def testLookupProjectIDs(self):
+ self.SetUpLookupProjectIDs()
+ self.project_service.project_names_to_ids.CacheItem('proj1', 123)
+ self.mox.ReplayAll()
+ id_dict = self.project_service.LookupProjectIDs(
+ self.cnxn, ['proj1', 'proj2'])
+ self.mox.VerifyAll()
+ self.assertEqual({'proj1': 123, 'proj2': 234}, id_dict)
+
+ def testLookupProjectNames(self):
+ self.SetUpGetProjects() # Same as testGetProjects()
+ self.project_service.project_2lc.CacheItem(123, self.proj1)
+ self.mox.ReplayAll()
+ name_dict = self.project_service.LookupProjectNames(
+ self.cnxn, [123, 234])
+ self.mox.VerifyAll()
+ self.assertEqual({123: 'proj1', 234: 'proj2'}, name_dict)
+
+ def SetUpGetProjects(self, roles=None, extra_perms=None):
+ project_rows = [
+ (
+ 234, 'proj2', 'test proj 2', 'test project', 'live', 'anyone', '',
+ '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, None, None, None, None, None, None, False)
+ ]
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=project_svc.PROJECT_COLS,
+ project_id=[234]).AndReturn(project_rows)
+ self.project_service.user2project_tbl.Select(
+ self.cnxn, cols=['project_id', 'user_id', 'role_name'],
+ project_id=[234]).AndReturn(roles or [])
+ self.project_service.extraperm_tbl.Select(
+ self.cnxn, cols=project_svc.EXTRAPERM_COLS,
+ project_id=[234]).AndReturn(extra_perms or [])
+
+ def testGetProjects(self):
+ self.project_service.project_2lc.CacheItem(123, self.proj1)
+ self.SetUpGetProjects()
+ self.mox.ReplayAll()
+ project_dict = self.project_service.GetProjects(
+ self.cnxn, [123, 234])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([123, 234], list(project_dict.keys()))
+ self.assertEqual('proj1', project_dict[123].project_name)
+ self.assertEqual('proj2', project_dict[234].project_name)
+
+ def testGetProjects_ExtraPerms(self):
+ self.SetUpGetProjects(extra_perms=[(234, 222, 'BarPerm'),
+ (234, 111, 'FooPerm')])
+ self.mox.ReplayAll()
+ project_dict = self.project_service.GetProjects(self.cnxn, [234])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([234], list(project_dict.keys()))
+ self.assertEqual(
+ [project_pb2.Project.ExtraPerms(
+ member_id=111, perms=['FooPerm']),
+ project_pb2.Project.ExtraPerms(
+ member_id=222, perms=['BarPerm'])],
+ project_dict[234].extra_perms)
+
+
+ def testGetVisibleLiveProjects_AnyoneAccessWithUser(self):
+ project_rows = [
+ (
+ 234, 'proj2', 'test proj 2', 'test project', 'live', 'anyone', '',
+ '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, None, None, None, False)
+ ]
+
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'],
+ state=project_pb2.ProjectState.LIVE).AndReturn(project_rows)
+ self.SetUpGetProjects()
+ self.mox.ReplayAll()
+ user_a = user_pb2.User(email='a@example.com')
+ project_ids = self.project_service.GetVisibleLiveProjects(
+ self.cnxn, user_a, set([111]))
+
+ self.mox.VerifyAll()
+ self.assertItemsEqual([234], project_ids)
+
+ def testGetVisibleLiveProjects_AnyoneAccessWithAnon(self):
+ project_rows = [
+ (
+ 234, 'proj2', 'test proj 2', 'test project', 'live', 'anyone', '',
+ '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, None, None, None, None, None, None, False)
+ ]
+
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'],
+ state=project_pb2.ProjectState.LIVE).AndReturn(project_rows)
+ self.SetUpGetProjects()
+ self.mox.ReplayAll()
+ project_ids = self.project_service.GetVisibleLiveProjects(
+ self.cnxn, None, None)
+
+ self.mox.VerifyAll()
+ self.assertItemsEqual([234], project_ids)
+
+ def testGetVisibleLiveProjects_RestrictedAccessWithMember(self):
+ project_rows = [
+ (
+ 234, 'proj2', 'test proj 2', 'test project', 'live', 'members_only',
+ '', '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, False, None, None, None, None, None, None, False)
+ ]
+ self.proj2.access = project_pb2.ProjectAccess.MEMBERS_ONLY
+ self.proj2.contributor_ids.append(111)
+ self.project_service.project_2lc.CacheItem(234, self.proj2)
+
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'],
+ state=project_pb2.ProjectState.LIVE).AndReturn(project_rows)
+ self.mox.ReplayAll()
+ user_a = user_pb2.User(email='a@example.com')
+ project_ids = self.project_service.GetVisibleLiveProjects(
+ self.cnxn, user_a, set([111]))
+
+ self.mox.VerifyAll()
+ self.assertItemsEqual([234], project_ids)
+
+ def testGetVisibleLiveProjects_RestrictedAccessWithNonMember(self):
+ project_rows = [
+ (
+ 234, 'proj2', 'test proj 2', 'test project', 'live', 'members_only',
+ '', '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, None, None, None, None, None, None, False)
+ ]
+ self.proj2.access = project_pb2.ProjectAccess.MEMBERS_ONLY
+ self.project_service.project_2lc.CacheItem(234, self.proj2)
+
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'],
+ state=project_pb2.ProjectState.LIVE).AndReturn(project_rows)
+ self.mox.ReplayAll()
+ user_a = user_pb2.User(email='a@example.com')
+ project_ids = self.project_service.GetVisibleLiveProjects(
+ self.cnxn, user_a, set([111]))
+
+ self.mox.VerifyAll()
+ self.assertItemsEqual([], project_ids)
+
+ def testGetVisibleLiveProjects_RestrictedAccessWithAnon(self):
+ project_rows = [
+ (
+ 234, 'proj2', 'test proj 2', 'test project', 'live', 'members_only',
+ '', '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, None, None, None, None, None, None, False)
+ ]
+ self.proj2.access = project_pb2.ProjectAccess.MEMBERS_ONLY
+ self.project_service.project_2lc.CacheItem(234, self.proj2)
+
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'],
+ state=project_pb2.ProjectState.LIVE).AndReturn(project_rows)
+ self.mox.ReplayAll()
+ project_ids = self.project_service.GetVisibleLiveProjects(
+ self.cnxn, None, None)
+
+ self.mox.VerifyAll()
+ self.assertItemsEqual([], project_ids)
+
+ def testGetVisibleLiveProjects_RestrictedAccessWithSiteAdmin(self):
+ project_rows = [
+ (
+ 234, 'proj2', 'test proj 2', 'test project', 'live', 'members_only',
+ '', '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, None, None, None, None, None, None, False)
+ ]
+ self.proj2.access = project_pb2.ProjectAccess.MEMBERS_ONLY
+ self.project_service.project_2lc.CacheItem(234, self.proj2)
+
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'],
+ state=project_pb2.ProjectState.LIVE).AndReturn(project_rows)
+ self.mox.ReplayAll()
+ user_a = user_pb2.User(email='a@example.com')
+ user_a.is_site_admin = True
+ project_ids = self.project_service.GetVisibleLiveProjects(
+ self.cnxn, user_a, set([111]))
+
+ self.mox.VerifyAll()
+ self.assertItemsEqual([234], project_ids)
+
+ def testGetVisibleLiveProjects_ArchivedProject(self):
+ project_rows = [
+ (
+ 234, 'proj2', 'test proj 2', 'test project', 'archived', 'anyone',
+ '', '', None, '', 0, 50 * 1024 * 1024, NOW, NOW, None, True, False,
+ False, None, None, None, None, None, None, False)
+ ]
+ self.proj2.state = project_pb2.ProjectState.ARCHIVED
+ self.project_service.project_2lc.CacheItem(234, self.proj2)
+
+ self.project_service.project_tbl.Select(
+ self.cnxn, cols=['project_id'],
+ state=project_pb2.ProjectState.LIVE).AndReturn(project_rows)
+ self.mox.ReplayAll()
+ user_a = user_pb2.User(email='a@example.com')
+ project_ids = self.project_service.GetVisibleLiveProjects(
+ self.cnxn, user_a, set([111]))
+
+ self.mox.VerifyAll()
+ self.assertItemsEqual([], project_ids)
+
+ def testGetProjectsByName(self):
+ self.project_service.project_names_to_ids.CacheItem('proj1', 123)
+ self.project_service.project_2lc.CacheItem(123, self.proj1)
+ self.SetUpLookupProjectIDs()
+ self.SetUpGetProjects()
+ self.mox.ReplayAll()
+ project_dict = self.project_service.GetProjectsByName(
+ self.cnxn, ['proj1', 'proj2'])
+ self.mox.VerifyAll()
+ self.assertItemsEqual(['proj1', 'proj2'], list(project_dict.keys()))
+ self.assertEqual(123, project_dict['proj1'].project_id)
+ self.assertEqual(234, project_dict['proj2'].project_id)
+
+ def SetUpExpungeProject(self):
+ self.project_service.user2project_tbl.Delete(
+ self.cnxn, project_id=234)
+ self.project_service.usergroupprojects_tbl.Delete(
+ self.cnxn, project_id=234)
+ self.project_service.extraperm_tbl.Delete(
+ self.cnxn, project_id=234)
+ self.project_service.membernotes_tbl.Delete(
+ self.cnxn, project_id=234)
+ self.project_service.acexclusion_tbl.Delete(
+ self.cnxn, project_id=234)
+ self.project_service.project_tbl.Delete(
+ self.cnxn, project_id=234)
+
+ def testExpungeProject(self):
+ self.SetUpExpungeProject()
+ self.mox.ReplayAll()
+ self.project_service.ExpungeProject(self.cnxn, 234)
+ self.mox.VerifyAll()
+
+ def SetUpUpdateProject(self, project_id, delta):
+ self.project_service.project_tbl.SelectValue(
+ self.cnxn, 'project_name', project_id=project_id).AndReturn('projN')
+ self.project_service.project_tbl.Update(
+ self.cnxn, delta, project_id=project_id, commit=False)
+ self.cnxn.Commit()
+
+ def testUpdateProject(self):
+ delta = {'summary': 'An even better one-line summary'}
+ self.SetUpUpdateProject(234, delta)
+ self.mox.ReplayAll()
+ self.project_service.UpdateProject(
+ self.cnxn, 234, summary='An even better one-line summary')
+ self.mox.VerifyAll()
+
+ def testUpdateProject_NotifyAlwaysDetailed(self):
+ delta = {'issue_notify_always_detailed': True}
+ self.SetUpUpdateProject(234, delta)
+ self.mox.ReplayAll()
+ self.project_service.UpdateProject(
+ self.cnxn, 234, issue_notify_always_detailed=True)
+ self.mox.VerifyAll()
+
+ def SetUpUpdateProjectRoles(
+ self, project_id, owner_ids, committer_ids, contributor_ids):
+ self.project_service.project_tbl.SelectValue(
+ self.cnxn, 'project_name', project_id=project_id).AndReturn('projN')
+ self.project_service.project_tbl.Update(
+ self.cnxn, {'cached_content_timestamp': NOW}, project_id=project_id,
+ commit=False)
+
+ self.project_service.user2project_tbl.Delete(
+ self.cnxn, project_id=project_id, role_name='owner', commit=False)
+ self.project_service.user2project_tbl.Delete(
+ self.cnxn, project_id=project_id, role_name='committer', commit=False)
+ self.project_service.user2project_tbl.Delete(
+ self.cnxn, project_id=project_id, role_name='contributor',
+ commit=False)
+
+ self.project_service.user2project_tbl.InsertRows(
+ self.cnxn, ['project_id', 'user_id', 'role_name'],
+ [(project_id, user_id, 'owner') for user_id in owner_ids],
+ commit=False)
+ self.project_service.user2project_tbl.InsertRows(
+ self.cnxn, ['project_id', 'user_id', 'role_name'],
+ [(project_id, user_id, 'committer') for user_id in committer_ids],
+ commit=False)
+ self.project_service.user2project_tbl.InsertRows(
+ self.cnxn, ['project_id', 'user_id', 'role_name'],
+ [(project_id, user_id, 'contributor') for user_id in contributor_ids],
+ commit=False)
+
+ self.cnxn.Commit()
+
+ def testUpdateProjectRoles(self):
+ self.SetUpUpdateProjectRoles(234, [111, 222], [333], [])
+ self.mox.ReplayAll()
+ self.project_service.UpdateProjectRoles(
+ self.cnxn, 234, [111, 222], [333], [], now=NOW)
+ self.mox.VerifyAll()
+
+ def SetUpMarkProjectDeletable(self):
+ delta = {
+ 'project_name': 'DELETABLE_123',
+ 'state': 'deletable',
+ }
+ self.project_service.project_tbl.Update(self.cnxn, delta, project_id=123)
+ self.config_service.InvalidateMemcacheForEntireProject(123)
+
+ def testMarkProjectDeletable(self):
+ self.SetUpMarkProjectDeletable()
+ self.mox.ReplayAll()
+ self.project_service.MarkProjectDeletable(
+ self.cnxn, 123, self.config_service)
+ self.mox.VerifyAll()
+
+ def testUpdateRecentActivity_SignificantlyLaterActivity(self):
+ activity_time = NOW + framework_constants.SECS_PER_HOUR * 3
+ delta = {'recent_activity_timestamp': activity_time}
+ self.SetUpGetProjects()
+ self.SetUpUpdateProject(234, delta)
+ self.mox.ReplayAll()
+ self.project_service.UpdateRecentActivity(self.cnxn, 234, now=activity_time)
+ self.mox.VerifyAll()
+
+ def testUpdateRecentActivity_NotSignificant(self):
+ activity_time = NOW + 123
+ self.SetUpGetProjects()
+ # ProjectUpdate is not called.
+ self.mox.ReplayAll()
+ self.project_service.UpdateRecentActivity(self.cnxn, 234, now=activity_time)
+ self.mox.VerifyAll()
+
+ def SetUpGetUserRolesInAllProjects(self):
+ rows = [
+ (123, 'committer'),
+ (234, 'owner'),
+ ]
+ self.project_service.user2project_tbl.Select(
+ self.cnxn, cols=['project_id', 'role_name'],
+ user_id={111, 888}).AndReturn(rows)
+
+ def testGetUserRolesInAllProjects(self):
+ self.SetUpGetUserRolesInAllProjects()
+ self.mox.ReplayAll()
+ actual = self.project_service.GetUserRolesInAllProjects(
+ self.cnxn, {111, 888})
+ owned_project_ids, membered_project_ids, contrib_project_ids = actual
+ self.mox.VerifyAll()
+ self.assertItemsEqual([234], owned_project_ids)
+ self.assertItemsEqual([123], membered_project_ids)
+ self.assertItemsEqual([], contrib_project_ids)
+
+ def testGetUserRolesInAllProjectsWithoutEffectiveIds(self):
+ self.mox.ReplayAll()
+ actual = self.project_service.GetUserRolesInAllProjects(self.cnxn, {})
+ owned_project_ids, membered_project_ids, contrib_project_ids = actual
+ self.mox.VerifyAll()
+ self.assertItemsEqual([], owned_project_ids)
+ self.assertItemsEqual([], membered_project_ids)
+ self.assertItemsEqual([], contrib_project_ids)
+
+ def SetUpUpdateExtraPerms(self):
+ self.project_service.extraperm_tbl.Delete(
+ self.cnxn, project_id=234, user_id=111, commit=False)
+ self.project_service.extraperm_tbl.InsertRows(
+ self.cnxn, project_svc.EXTRAPERM_COLS,
+ [(234, 111, 'SecurityTeam')], commit=False)
+ self.project_service.project_tbl.Update(
+ self.cnxn, {'cached_content_timestamp': NOW},
+ project_id=234, commit=False)
+ self.cnxn.Commit()
+
+ def testUpdateExtraPerms(self):
+ self.SetUpGetProjects(roles=[(234, 111, 'owner')])
+ self.SetUpUpdateExtraPerms()
+ self.mox.ReplayAll()
+ self.project_service.UpdateExtraPerms(
+ self.cnxn, 234, 111, ['SecurityTeam'], now=NOW)
+ self.mox.VerifyAll()
+
+ def testExpungeUsersInProjects(self):
+ self.project_service.extraperm_tbl.Delete = mock.Mock()
+ self.project_service.acexclusion_tbl.Delete = mock.Mock()
+ self.project_service.membernotes_tbl.Delete = mock.Mock()
+ self.project_service.user2project_tbl.Delete = mock.Mock()
+
+ user_ids = [111, 222]
+ limit= 16
+ self.project_service.ExpungeUsersInProjects(
+ self.cnxn, user_ids, limit=limit)
+
+ call = [mock.call(self.cnxn, user_id=user_ids, limit=limit, commit=False)]
+ self.project_service.extraperm_tbl.Delete.assert_has_calls(call)
+ self.project_service.acexclusion_tbl.Delete.assert_has_calls(call)
+ self.project_service.membernotes_tbl.Delete.assert_has_calls(call)
+ self.project_service.user2project_tbl.Delete.assert_has_calls(call)
diff --git a/services/test/service_manager_test.py b/services/test/service_manager_test.py
new file mode 100644
index 0000000..33c8706
--- /dev/null
+++ b/services/test/service_manager_test.py
@@ -0,0 +1,44 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the service_manager module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+from features import autolink
+from services import cachemanager_svc
+from services import config_svc
+from services import features_svc
+from services import issue_svc
+from services import service_manager
+from services import project_svc
+from services import star_svc
+from services import user_svc
+from services import usergroup_svc
+
+
+class ServiceManagerTest(unittest.TestCase):
+
+ def testSetUpServices(self):
+ svcs = service_manager.set_up_services()
+ self.assertIsInstance(svcs, service_manager.Services)
+ self.assertIsInstance(svcs.autolink, autolink.Autolink)
+ self.assertIsInstance(svcs.cache_manager, cachemanager_svc.CacheManager)
+ self.assertIsInstance(svcs.user, user_svc.UserService)
+ self.assertIsInstance(svcs.user_star, star_svc.UserStarService)
+ self.assertIsInstance(svcs.project_star, star_svc.ProjectStarService)
+ self.assertIsInstance(svcs.issue_star, star_svc.IssueStarService)
+ self.assertIsInstance(svcs.project, project_svc.ProjectService)
+ self.assertIsInstance(svcs.usergroup, usergroup_svc.UserGroupService)
+ self.assertIsInstance(svcs.config, config_svc.ConfigService)
+ self.assertIsInstance(svcs.issue, issue_svc.IssueService)
+ self.assertIsInstance(svcs.features, features_svc.FeaturesService)
+
+ # Calling it again should give the same object
+ svcs2 = service_manager.set_up_services()
+ self.assertTrue(svcs is svcs2)
diff --git a/services/test/spam_svc_test.py b/services/test/spam_svc_test.py
new file mode 100644
index 0000000..3aeba13
--- /dev/null
+++ b/services/test/spam_svc_test.py
@@ -0,0 +1,433 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the spam service."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+
+import mox
+
+from google.appengine.ext import testbed
+
+import settings
+from framework import sql
+from framework import framework_constants
+from proto import user_pb2
+from proto import tracker_pb2
+from services import spam_svc
+from testing import fake
+from mock import Mock
+
+
+def assert_unreached():
+ raise Exception('This code should not have been called.') # pragma: no cover
+
+
+class SpamServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+
+ self.mox = mox.Mox()
+ self.mock_report_tbl = self.mox.CreateMock(sql.SQLTableManager)
+ self.mock_verdict_tbl = self.mox.CreateMock(sql.SQLTableManager)
+ self.mock_issue_tbl = self.mox.CreateMock(sql.SQLTableManager)
+ self.cnxn = self.mox.CreateMock(sql.MonorailConnection)
+ self.issue_service = fake.IssueService()
+ self.spam_service = spam_svc.SpamService()
+ self.spam_service.report_tbl = self.mock_report_tbl
+ self.spam_service.verdict_tbl = self.mock_verdict_tbl
+ self.spam_service.issue_tbl = self.mock_issue_tbl
+
+ self.spam_service.report_tbl.Delete = Mock()
+ self.spam_service.verdict_tbl.Delete = Mock()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def testLookupIssuesFlaggers(self):
+ self.mock_report_tbl.Select(
+ self.cnxn, cols=['issue_id', 'user_id', 'comment_id'],
+ issue_id=[234, 567, 890]).AndReturn([
+ [234, 111, None],
+ [234, 222, 1],
+ [567, 333, None]])
+ self.mox.ReplayAll()
+
+ reporters = (
+ self.spam_service.LookupIssuesFlaggers(self.cnxn, [234, 567, 890]))
+ self.mox.VerifyAll()
+ self.assertEqual({
+ 234: ([111], {1: [222]}),
+ 567: ([333], {}),
+ }, reporters)
+
+ def testLookupIssueFlaggers(self):
+ self.mock_report_tbl.Select(
+ self.cnxn, cols=['issue_id', 'user_id', 'comment_id'],
+ issue_id=[234]).AndReturn(
+ [[234, 111, None], [234, 222, 1]])
+ self.mox.ReplayAll()
+
+ issue_reporters, comment_reporters = (
+ self.spam_service.LookupIssueFlaggers(self.cnxn, 234))
+ self.mox.VerifyAll()
+ self.assertItemsEqual([111], issue_reporters)
+ self.assertEqual({1: [222]}, comment_reporters)
+
+ def testFlagIssues_overThresh(self):
+ issue = fake.MakeTestIssue(
+ project_id=789,
+ local_id=1,
+ reporter_id=111,
+ owner_id=456,
+ summary='sum',
+ status='Live',
+ issue_id=78901,
+ project_name='proj')
+ issue.assume_stale = False # We will store this issue.
+
+ self.mock_report_tbl.InsertRows(self.cnxn,
+ ['issue_id', 'reported_user_id', 'user_id'],
+ [(78901, 111, 111)], ignore=True)
+
+ self.mock_report_tbl.Select(self.cnxn,
+ cols=['issue_id', 'COUNT(*)'], group_by=['issue_id'],
+ issue_id=[78901]).AndReturn([(78901, settings.spam_flag_thresh)])
+ self.mock_verdict_tbl.Select(
+ self.cnxn, cols=['issue_id', 'reason', 'MAX(created)'],
+ group_by=['issue_id'], issue_id=[78901], comment_id=None).AndReturn([])
+ self.mock_verdict_tbl.InsertRows(
+ self.cnxn, ['issue_id', 'is_spam', 'reason', 'project_id'],
+ [(78901, True, 'threshold', 789)], ignore=True)
+
+ self.mox.ReplayAll()
+ self.spam_service.FlagIssues(
+ self.cnxn, self.issue_service, [issue], 111, True)
+ self.mox.VerifyAll()
+ self.assertIn(issue, self.issue_service.updated_issues)
+
+ self.assertEqual(
+ 1,
+ self.spam_service.issue_actions.get(
+ fields={
+ 'type': 'flag',
+ 'reporter_id': str(111),
+ 'issue': 'proj:1'
+ }))
+
+ def testFlagIssues_underThresh(self):
+ issue = fake.MakeTestIssue(
+ project_id=789,
+ local_id=1,
+ reporter_id=111,
+ owner_id=456,
+ summary='sum',
+ status='Live',
+ issue_id=78901,
+ project_name='proj')
+
+ self.mock_report_tbl.InsertRows(self.cnxn,
+ ['issue_id', 'reported_user_id', 'user_id'],
+ [(78901, 111, 111)], ignore=True)
+
+ self.mock_report_tbl.Select(self.cnxn,
+ cols=['issue_id', 'COUNT(*)'], group_by=['issue_id'],
+ issue_id=[78901]).AndReturn([(78901, settings.spam_flag_thresh - 1)])
+
+ self.mock_verdict_tbl.Select(
+ self.cnxn, cols=['issue_id', 'reason', 'MAX(created)'],
+ group_by=['issue_id'], issue_id=[78901], comment_id=None).AndReturn([])
+
+ self.mox.ReplayAll()
+ self.spam_service.FlagIssues(
+ self.cnxn, self.issue_service, [issue], 111, True)
+ self.mox.VerifyAll()
+
+ self.assertNotIn(issue, self.issue_service.updated_issues)
+ self.assertIsNone(
+ self.spam_service.issue_actions.get(
+ fields={
+ 'type': 'flag',
+ 'reporter_id': str(111),
+ 'issue': 'proj:1'
+ }))
+
+ def testUnflagIssue_overThresh(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, reporter_id=111, owner_id=456,
+ summary='sum', status='Live', issue_id=78901, is_spam=True)
+ self.mock_report_tbl.Delete(self.cnxn, issue_id=[issue.issue_id],
+ comment_id=None, user_id=111)
+ self.mock_report_tbl.Select(self.cnxn,
+ cols=['issue_id', 'COUNT(*)'], group_by=['issue_id'],
+ issue_id=[78901]).AndReturn([(78901, settings.spam_flag_thresh)])
+
+ self.mock_verdict_tbl.Select(
+ self.cnxn, cols=['issue_id', 'reason', 'MAX(created)'],
+ group_by=['issue_id'], issue_id=[78901], comment_id=None).AndReturn([])
+
+ self.mox.ReplayAll()
+ self.spam_service.FlagIssues(
+ self.cnxn, self.issue_service, [issue], 111, False)
+ self.mox.VerifyAll()
+
+ self.assertNotIn(issue, self.issue_service.updated_issues)
+ self.assertEqual(True, issue.is_spam)
+
+ def testUnflagIssue_underThresh(self):
+ """A non-member un-flagging an issue as spam should not be able
+ to overturn the verdict to ham. This is different from previous
+ behavior. See https://crbug.com/monorail/2232 for details."""
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, reporter_id=111, owner_id=456,
+ summary='sum', status='Live', issue_id=78901, is_spam=True)
+ issue.assume_stale = False # We will store this issue.
+ self.mock_report_tbl.Delete(self.cnxn, issue_id=[issue.issue_id],
+ comment_id=None, user_id=111)
+ self.mock_report_tbl.Select(self.cnxn,
+ cols=['issue_id', 'COUNT(*)'], group_by=['issue_id'],
+ issue_id=[78901]).AndReturn([(78901, settings.spam_flag_thresh - 1)])
+
+ self.mock_verdict_tbl.Select(
+ self.cnxn, cols=['issue_id', 'reason', 'MAX(created)'],
+ group_by=['issue_id'], issue_id=[78901], comment_id=None).AndReturn([])
+
+ self.mox.ReplayAll()
+ self.spam_service.FlagIssues(
+ self.cnxn, self.issue_service, [issue], 111, False)
+ self.mox.VerifyAll()
+
+ self.assertNotIn(issue, self.issue_service.updated_issues)
+ self.assertEqual(True, issue.is_spam)
+
+ def testUnflagIssue_underThreshNoManualOverride(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, reporter_id=111, owner_id=456,
+ summary='sum', status='Live', issue_id=78901, is_spam=True)
+ self.mock_report_tbl.Delete(self.cnxn, issue_id=[issue.issue_id],
+ comment_id=None, user_id=111)
+ self.mock_report_tbl.Select(self.cnxn,
+ cols=['issue_id', 'COUNT(*)'], group_by=['issue_id'],
+ issue_id=[78901]).AndReturn([(78901, settings.spam_flag_thresh - 1)])
+
+ self.mock_verdict_tbl.Select(
+ self.cnxn, cols=['issue_id', 'reason', 'MAX(created)'],
+ group_by=['issue_id'], comment_id=None,
+ issue_id=[78901]).AndReturn([(78901, 'manual', '')])
+
+ self.mox.ReplayAll()
+ self.spam_service.FlagIssues(
+ self.cnxn, self.issue_service, [issue], 111, False)
+ self.mox.VerifyAll()
+
+ self.assertNotIn(issue, self.issue_service.updated_issues)
+ self.assertEqual(True, issue.is_spam)
+
+ def testGetIssueClassifierQueue_noVerdicts(self):
+ self.mock_verdict_tbl.Select(self.cnxn,
+ cols=['issue_id', 'is_spam', 'reason', 'classifier_confidence',
+ 'created'],
+ where=[
+ ('project_id = %s', [789]),
+ ('classifier_confidence <= %s',
+ [settings.classifier_moderation_thresh]),
+ ('overruled = %s', [False]),
+ ('issue_id IS NOT NULL', []),
+ ],
+ order_by=[
+ ('classifier_confidence ASC', []),
+ ('created ASC', [])
+ ],
+ group_by=['issue_id'],
+ offset=0,
+ limit=10,
+ ).AndReturn([])
+
+ self.mock_verdict_tbl.SelectValue(self.cnxn,
+ col='COUNT(*)',
+ where=[
+ ('project_id = %s', [789]),
+ ('classifier_confidence <= %s',
+ [settings.classifier_moderation_thresh]),
+ ('overruled = %s', [False]),
+ ('issue_id IS NOT NULL', []),
+ ]).AndReturn(0)
+
+ self.mox.ReplayAll()
+ res, count = self.spam_service.GetIssueClassifierQueue(
+ self.cnxn, self.issue_service, 789)
+ self.mox.VerifyAll()
+
+ self.assertEqual([], res)
+ self.assertEqual(0, count)
+
+ def testGetIssueClassifierQueue_someVerdicts(self):
+ self.mock_verdict_tbl.Select(self.cnxn,
+ cols=['issue_id', 'is_spam', 'reason', 'classifier_confidence',
+ 'created'],
+ where=[
+ ('project_id = %s', [789]),
+ ('classifier_confidence <= %s',
+ [settings.classifier_moderation_thresh]),
+ ('overruled = %s', [False]),
+ ('issue_id IS NOT NULL', []),
+ ],
+ order_by=[
+ ('classifier_confidence ASC', []),
+ ('created ASC', [])
+ ],
+ group_by=['issue_id'],
+ offset=0,
+ limit=10,
+ ).AndReturn([[78901, 0, "classifier", 0.9, "2015-12-10 11:06:24"]])
+
+ self.mock_verdict_tbl.SelectValue(self.cnxn,
+ col='COUNT(*)',
+ where=[
+ ('project_id = %s', [789]),
+ ('classifier_confidence <= %s',
+ [settings.classifier_moderation_thresh]),
+ ('overruled = %s', [False]),
+ ('issue_id IS NOT NULL', []),
+ ]).AndReturn(10)
+
+ self.mox.ReplayAll()
+ res, count = self.spam_service.GetIssueClassifierQueue(
+ self.cnxn, self.issue_service, 789)
+ self.mox.VerifyAll()
+ self.assertEqual(1, len(res))
+ self.assertEqual(10, count)
+ self.assertEqual(78901, res[0].issue_id)
+ self.assertEqual(False, res[0].is_spam)
+ self.assertEqual("classifier", res[0].reason)
+ self.assertEqual(0.9, res[0].classifier_confidence)
+ self.assertEqual("2015-12-10 11:06:24", res[0].verdict_time)
+
+ def testIsExempt_RegularUser(self):
+ author = user_pb2.MakeUser(111, email='test@example.com')
+ self.assertFalse(self.spam_service._IsExempt(author, False))
+ author = user_pb2.MakeUser(111, email='test@chromium.org.example.com')
+ self.assertFalse(self.spam_service._IsExempt(author, False))
+
+ def testIsExempt_ProjectMember(self):
+ author = user_pb2.MakeUser(111, email='test@example.com')
+ self.assertTrue(self.spam_service._IsExempt(author, True))
+
+ def testIsExempt_AllowlistedDomain(self):
+ author = user_pb2.MakeUser(111, email='test@google.com')
+ self.assertTrue(self.spam_service._IsExempt(author, False))
+
+ def testClassifyIssue_spam(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, reporter_id=111, owner_id=456,
+ summary='sum', status='Live', issue_id=78901, is_spam=True)
+ self.spam_service._predict = lambda body: 1.0
+
+ # Prevent missing service inits to fail the test.
+ self.spam_service.ml_engine = True
+
+ comment_pb = tracker_pb2.IssueComment()
+ comment_pb.content = "this is spam"
+ reporter = user_pb2.MakeUser(111, email='test@test.com')
+ res = self.spam_service.ClassifyIssue(issue, comment_pb, reporter, False)
+ self.assertEqual(1.0, res['confidence_is_spam'])
+
+ reporter.email = 'test@chromium.org.spam.com'
+ res = self.spam_service.ClassifyIssue(issue, comment_pb, reporter, False)
+ self.assertEqual(1.0, res['confidence_is_spam'])
+
+ reporter.email = 'test.google.com@test.com'
+ res = self.spam_service.ClassifyIssue(issue, comment_pb, reporter, False)
+ self.assertEqual(1.0, res['confidence_is_spam'])
+
+ def testClassifyIssue_Allowlisted(self):
+ issue = fake.MakeTestIssue(
+ project_id=789, local_id=1, reporter_id=111, owner_id=456,
+ summary='sum', status='Live', issue_id=78901, is_spam=True)
+ self.spam_service._predict = assert_unreached
+
+ # Prevent missing service inits to fail the test.
+ self.spam_service.ml_engine = True
+
+ comment_pb = tracker_pb2.IssueComment()
+ comment_pb.content = "this is spam"
+ reporter = user_pb2.MakeUser(111, email='test@google.com')
+ res = self.spam_service.ClassifyIssue(issue, comment_pb, reporter, False)
+ self.assertEqual(0.0, res['confidence_is_spam'])
+ reporter.email = 'test@chromium.org'
+ res = self.spam_service.ClassifyIssue(issue, comment_pb, reporter, False)
+ self.assertEqual(0.0, res['confidence_is_spam'])
+
+ def testClassifyComment_spam(self):
+ self.spam_service._predict = lambda body: 1.0
+
+ # Prevent missing service inits to fail the test.
+ self.spam_service.ml_engine = True
+
+ commenter = user_pb2.MakeUser(111, email='test@test.com')
+ res = self.spam_service.ClassifyComment('this is spam', commenter, False)
+ self.assertEqual(1.0, res['confidence_is_spam'])
+
+ commenter.email = 'test@chromium.org.spam.com'
+ res = self.spam_service.ClassifyComment('this is spam', commenter, False)
+ self.assertEqual(1.0, res['confidence_is_spam'])
+
+ commenter.email = 'test.google.com@test.com'
+ res = self.spam_service.ClassifyComment('this is spam', commenter, False)
+ self.assertEqual(1.0, res['confidence_is_spam'])
+
+ def testClassifyComment_Allowlisted(self):
+ self.spam_service._predict = assert_unreached
+
+ # Prevent missing service inits to fail the test.
+ self.spam_service.ml_engine = True
+
+ commenter = user_pb2.MakeUser(111, email='test@google.com')
+ res = self.spam_service.ClassifyComment('this is spam', commenter, False)
+ self.assertEqual(0.0, res['confidence_is_spam'])
+
+ commenter.email = 'test@chromium.org'
+ res = self.spam_service.ClassifyComment('this is spam', commenter, False)
+ self.assertEqual(0.0, res['confidence_is_spam'])
+
+ def test_ham_classification(self):
+ actual = self.spam_service.ham_classification()
+ self.assertEqual(actual['confidence_is_spam'], 0.0)
+ self.assertEqual(actual['failed_open'], False)
+
+ def testExpungeUsersInSpam(self):
+ user_ids = [3, 4, 5]
+ self.spam_service.ExpungeUsersInSpam(self.cnxn, user_ids=user_ids)
+
+ self.spam_service.report_tbl.Delete.assert_has_calls(
+ [
+ mock.call(self.cnxn, reported_user_id=user_ids, commit=False),
+ mock.call(self.cnxn, user_id=user_ids, commit=False)
+ ])
+ self.spam_service.verdict_tbl.Delete.assert_called_once_with(
+ self.cnxn, user_id=user_ids, commit=False)
+
+ def testLookupIssueVerdicts(self):
+ self.spam_service.verdict_tbl.Select = Mock(return_value=[
+ [5, 10], [4, 11], [6, 12],
+ ])
+ actual = self.spam_service.LookupIssueVerdicts(self.cnxn, [4, 5, 6])
+
+ self.spam_service.verdict_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['issue_id', 'reason', 'MAX(created)'],
+ issue_id=[4, 5, 6], comment_id=None, group_by=['issue_id'])
+ self.assertEqual(actual, {
+ 5: 10,
+ 4: 11,
+ 6: 12,
+ })
diff --git a/services/test/star_svc_test.py b/services/test/star_svc_test.py
new file mode 100644
index 0000000..03a0d23
--- /dev/null
+++ b/services/test/star_svc_test.py
@@ -0,0 +1,225 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the star service."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import mox
+import mock
+
+from google.appengine.ext import testbed
+
+import settings
+from mock import Mock
+from framework import sql
+from proto import user_pb2
+from services import star_svc
+from testing import fake
+
+
+class AbstractStarServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.mock_tbl = self.mox.CreateMock(sql.SQLTableManager)
+ self.cnxn = 'fake connection'
+ self.cache_manager = fake.CacheManager()
+ self.star_service = star_svc.AbstractStarService(
+ self.cache_manager, self.mock_tbl, 'item_id', 'user_id', 'project')
+ self.mock_tbl.Delete = Mock()
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def SetUpExpungeStars(self):
+ self.mock_tbl.Delete(self.cnxn, item_id=123, commit=True)
+
+ def testExpungeStars(self):
+ self.SetUpExpungeStars()
+ self.mox.ReplayAll()
+ self.star_service.ExpungeStars(self.cnxn, 123)
+ self.mox.VerifyAll()
+
+ def testExpungeStars_Limit(self):
+ self.star_service.ExpungeStars(self.cnxn, 123, limit=50)
+ self.mock_tbl.Delete.assert_called_once_with(
+ self.cnxn, commit=True, limit=50, item_id=123)
+
+ def testExpungeStarsByUsers(self):
+ user_ids = [2, 3, 4]
+ self.star_service.ExpungeStarsByUsers(self.cnxn, user_ids, limit=40)
+ self.mock_tbl.Delete.assert_called_once_with(
+ self.cnxn, user_id=user_ids, commit=False, limit=40)
+
+ def SetUpLookupItemsStarrers(self):
+ self.mock_tbl.Select(
+ self.cnxn, cols=['item_id', 'user_id'],
+ item_id=[234]).AndReturn([(234, 111), (234, 222)])
+
+ def testLookupItemsStarrers(self):
+ self.star_service.starrer_cache.CacheItem(123, [111, 333])
+ self.SetUpLookupItemsStarrers()
+ self.mox.ReplayAll()
+ starrer_list_dict = self.star_service.LookupItemsStarrers(
+ self.cnxn, [123, 234])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([123, 234], list(starrer_list_dict.keys()))
+ self.assertItemsEqual([111, 333], starrer_list_dict[123])
+ self.assertItemsEqual([111, 222], starrer_list_dict[234])
+ self.assertItemsEqual([111, 333],
+ self.star_service.starrer_cache.GetItem(123))
+ self.assertItemsEqual([111, 222],
+ self.star_service.starrer_cache.GetItem(234))
+
+ def SetUpLookupStarredItemIDs(self):
+ self.mock_tbl.Select(
+ self.cnxn, cols=['item_id'], user_id=111).AndReturn(
+ [(123,), (234,)])
+
+ def testLookupStarredItemIDs(self):
+ self.SetUpLookupStarredItemIDs()
+ self.mox.ReplayAll()
+ item_ids = self.star_service.LookupStarredItemIDs(self.cnxn, 111)
+ self.mox.VerifyAll()
+ self.assertItemsEqual([123, 234], item_ids)
+ self.assertItemsEqual([123, 234],
+ self.star_service.star_cache.GetItem(111))
+
+ def testIsItemStarredBy(self):
+ self.SetUpLookupStarredItemIDs()
+ self.mox.ReplayAll()
+ self.assertTrue(self.star_service.IsItemStarredBy(self.cnxn, 123, 111))
+ self.assertTrue(self.star_service.IsItemStarredBy(self.cnxn, 234, 111))
+ self.assertFalse(
+ self.star_service.IsItemStarredBy(self.cnxn, 435, 111))
+ self.mox.VerifyAll()
+
+ def SetUpCountItemStars(self):
+ self.mock_tbl.Select(
+ self.cnxn, cols=['item_id', 'COUNT(user_id)'], item_id=[234],
+ group_by=['item_id']).AndReturn([(234, 2)])
+
+ def testCountItemStars(self):
+ self.star_service.star_count_cache.CacheItem(123, 3)
+ self.SetUpCountItemStars()
+ self.mox.ReplayAll()
+ self.assertEqual(3, self.star_service.CountItemStars(self.cnxn, 123))
+ self.assertEqual(2, self.star_service.CountItemStars(self.cnxn, 234))
+ self.mox.VerifyAll()
+
+ def testCountItemsStars(self):
+ self.star_service.star_count_cache.CacheItem(123, 3)
+ self.SetUpCountItemStars()
+ self.mox.ReplayAll()
+ count_dict = self.star_service.CountItemsStars(
+ self.cnxn, [123, 234])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([123, 234], list(count_dict.keys()))
+ self.assertEqual(3, count_dict[123])
+ self.assertEqual(2, count_dict[234])
+
+ def SetUpSetStar_Add(self):
+ self.mock_tbl.InsertRows(
+ self.cnxn, ['item_id', 'user_id'], [(123, 111)], ignore=True,
+ commit=True)
+
+ def testSetStar_Add(self):
+ self.SetUpSetStar_Add()
+ self.mox.ReplayAll()
+ self.star_service.SetStar(self.cnxn, 123, 111, True)
+ self.mox.VerifyAll()
+ self.assertFalse(self.star_service.star_cache.HasItem(123))
+ self.assertFalse(self.star_service.starrer_cache.HasItem(123))
+ self.assertFalse(self.star_service.star_count_cache.HasItem(123))
+
+ def SetUpSetStar_Remove(self):
+ self.mock_tbl.Delete(self.cnxn, item_id=123, user_id=[111])
+
+ def testSetStar_Remove(self):
+ self.SetUpSetStar_Remove()
+ self.mox.ReplayAll()
+ self.star_service.SetStar(self.cnxn, 123, 111, False)
+ self.mox.VerifyAll()
+ self.assertFalse(self.star_service.star_cache.HasItem(123))
+ self.assertFalse(self.star_service.starrer_cache.HasItem(123))
+ self.assertFalse(self.star_service.star_count_cache.HasItem(123))
+
+ def SetUpSetStarsBatch_Add(self):
+ self.mock_tbl.InsertRows(
+ self.cnxn, ['item_id', 'user_id'], [(123, 111), (123, 222)],
+ ignore=True, commit=True)
+
+ def testSetStarsBatch_Add(self):
+ self.SetUpSetStarsBatch_Add()
+ self.mox.ReplayAll()
+ self.star_service.SetStarsBatch(self.cnxn, 123, [111, 222], True)
+ self.mox.VerifyAll()
+ self.assertFalse(self.star_service.star_cache.HasItem(123))
+ self.assertFalse(self.star_service.starrer_cache.HasItem(123))
+ self.assertFalse(self.star_service.star_count_cache.HasItem(123))
+
+ def SetUpSetStarsBatch_Remove(self):
+ self.mock_tbl.Delete(self.cnxn, item_id=123, user_id=[111, 222])
+
+ def testSetStarsBatch_Remove(self):
+ self.SetUpSetStarsBatch_Remove()
+ self.mox.ReplayAll()
+ self.star_service.SetStarsBatch(self.cnxn, 123, [111, 222], False)
+ self.mox.VerifyAll()
+ self.assertFalse(self.star_service.star_cache.HasItem(123))
+ self.assertFalse(self.star_service.starrer_cache.HasItem(123))
+ self.assertFalse(self.star_service.star_count_cache.HasItem(123))
+
+
+class IssueStarServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mock_tbl = mock.Mock()
+ self.mock_tbl.Delete = mock.Mock()
+ self.mock_tbl.InsertRows = mock.Mock()
+
+ self.cache_manager = fake.CacheManager()
+ with mock.patch(
+ 'framework.sql.SQLTableManager', return_value=self.mock_tbl):
+ self.issue_star = star_svc.IssueStarService(
+ self.cache_manager)
+
+ self.cnxn = 'fake connection'
+
+ def testSetStarsBatch_SkipIssueUpdate_Remove(self):
+ self.issue_star.SetStarsBatch_SkipIssueUpdate(
+ self.cnxn, 78901, [111, 222], False)
+ self.mock_tbl.Delete.assert_called_once_with(
+ self.cnxn, issue_id=78901, user_id=[111, 222], commit=True)
+
+ def testSetStarsBatch_SkipIssueUpdate_Remove_NoCommit(self):
+ self.issue_star.SetStarsBatch_SkipIssueUpdate(
+ self.cnxn, 78901, [111, 222], False, commit=False)
+ self.mock_tbl.Delete.assert_called_once_with(
+ self.cnxn, issue_id=78901, user_id=[111, 222], commit=False)
+
+ def testSetStarsBatch_SkipIssueUpdate_Add(self):
+ self.issue_star.SetStarsBatch_SkipIssueUpdate(
+ self.cnxn, 78901, [111, 222], True)
+ self.mock_tbl.InsertRows.assert_called_once_with(
+ self.cnxn, ['issue_id', 'user_id'], [(78901, 111), (78901, 222)],
+ ignore=True, commit=True)
+
+ def testSetStarsBatch_SkipIssueUpdate_Add_NoCommit(self):
+ self.issue_star.SetStarsBatch_SkipIssueUpdate(
+ self.cnxn, 78901, [111, 222], True, commit=False)
+ self.mock_tbl.InsertRows.assert_called_once_with(
+ self.cnxn, ['issue_id', 'user_id'], [(78901, 111), (78901, 222)],
+ ignore=True, commit=False)
diff --git a/services/test/template_svc_test.py b/services/test/template_svc_test.py
new file mode 100644
index 0000000..964722d
--- /dev/null
+++ b/services/test/template_svc_test.py
@@ -0,0 +1,471 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for services.template_svc module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import mock
+import unittest
+
+from mock import Mock, patch
+
+from proto import tracker_pb2
+from services import template_svc
+from testing import fake
+from testing import testing_helpers
+from tracker import tracker_bizobj
+from tracker import tracker_constants
+
+
+class TemplateSetTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.ts2lc = template_svc.TemplateSetTwoLevelCache(
+ cache_manager=fake.CacheManager(),
+ template_service=Mock(spec=template_svc.TemplateService))
+ self.ts2lc.template_service.template_tbl = Mock()
+
+ def testFetchItems_Empty(self):
+ self.ts2lc.template_service.template_tbl.Select .return_value = []
+ actual = self.ts2lc.FetchItems(cnxn=None, keys=[1, 2])
+ self.assertEqual({1: [], 2: []}, actual)
+
+ def testFetchItems_Normal(self):
+ # pylint: disable=unused-argument
+ def mockSelect(cnxn, cols, project_id, order_by):
+ assert project_id in (1, 2)
+ if project_id == 1:
+ return [
+ (8, 1, 'template-8', 'content', 'summary', False, 111, 'status',
+ False, False, False),
+ (9, 1, 'template-9', 'content', 'summary', False, 111, 'status',
+ True, False, False)]
+ else:
+ return [
+ (7, 2, 'template-7', 'content', 'summary', False, 111, 'status',
+ False, False, False)]
+
+ self.ts2lc.template_service.template_tbl.Select.side_effect = mockSelect
+ actual = self.ts2lc.FetchItems(cnxn=None, keys=[1, 2])
+ expected = {
+ 1: [(8, 'template-8', False), (9, 'template-9', True)],
+ 2: [(7, 'template-7', False)],
+ }
+ self.assertEqual(expected, actual)
+
+
+class TemplateDefTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.template_def_2lc = template_svc.TemplateDefTwoLevelCache(
+ cache_manager=fake.CacheManager(),
+ template_service=Mock(spec=template_svc.TemplateService))
+ self.template_def_2lc.template_service.template_tbl = Mock()
+ self.template_def_2lc.template_service.template2label_tbl = Mock()
+ self.template_def_2lc.template_service.template2component_tbl = Mock()
+ self.template_def_2lc.template_service.template2admin_tbl = Mock()
+ self.template_def_2lc.template_service.template2fieldvalue_tbl = Mock()
+ self.template_def_2lc.template_service.issuephasedef_tbl = Mock()
+ self.template_def_2lc.template_service.template2approvalvalue_tbl = Mock()
+
+ def testFetchItems_Empty(self):
+ self.template_def_2lc.template_service.template_tbl.Select\
+ .return_value = []
+ self.template_def_2lc.template_service.template2label_tbl.Select\
+ .return_value = []
+ self.template_def_2lc.template_service.template2component_tbl.Select\
+ .return_value = []
+ self.template_def_2lc.template_service.template2admin_tbl.Select\
+ .return_value = []
+ self.template_def_2lc.template_service.template2fieldvalue_tbl.Select\
+ .return_value = []
+ self.template_def_2lc.template_service.template2approvalvalue_tbl.Select\
+ .return_value = []
+
+ actual = self.template_def_2lc.FetchItems(cnxn=None, keys=[1, 2])
+ self.assertEqual({}, actual)
+
+ def testFetchItems_Normal(self):
+ template_9_row = (9, 1, 'template-9', 'content', 'summary',
+ False, 111, 'status',
+ False, False, False)
+ template_8_row = (8, 1, 'template-8', 'content', 'summary',
+ False, 111, 'status',
+ False, False, False)
+ template_7_row = (7, 2, 'template-7', 'content', 'summary',
+ False, 111, 'status',
+ False, False, False)
+
+ self.template_def_2lc.template_service.template_tbl.Select\
+ .return_value = [template_7_row, template_8_row,
+ template_9_row]
+ self.template_def_2lc.template_service.template2label_tbl.Select\
+ .return_value = [(9, 'label-1'), (7, 'label-2')]
+ self.template_def_2lc.template_service.template2component_tbl.Select\
+ .return_value = [(9, 13), (7, 14)]
+ self.template_def_2lc.template_service.template2admin_tbl.Select\
+ .return_value = [(9, 111), (7, 222)]
+
+ fv1_row = (15, None, 'fv-1', None, None, None, False)
+ fv2_row = (16, None, 'fv-2', None, None, None, False)
+ fv1 = tracker_bizobj.MakeFieldValue(*fv1_row)
+ fv2 = tracker_bizobj.MakeFieldValue(*fv2_row)
+ self.template_def_2lc.template_service.template2fieldvalue_tbl.Select\
+ .return_value = [((9,) + fv1_row[:-1]), ((7,) + fv2_row[:-1])]
+
+ av1_row = (17, 9, 19, 'na')
+ av2_row = (18, 7, 20, 'not_set')
+ av1 = tracker_pb2.ApprovalValue(approval_id=17, phase_id=19,
+ status=tracker_pb2.ApprovalStatus('NA'))
+ av2 = tracker_pb2.ApprovalValue(approval_id=18, phase_id=20,
+ status=tracker_pb2.ApprovalStatus(
+ 'NOT_SET'))
+ phase1_row = (19, 'phase-1', 1)
+ phase2_row = (20, 'phase-2', 2)
+ phase1 = tracker_pb2.Phase(phase_id=19, name='phase-1', rank=1)
+ phase2 = tracker_pb2.Phase(phase_id=20, name='phase-2', rank=2)
+
+ self.template_def_2lc.template_service.template2approvalvalue_tbl.Select\
+ .return_value = [av1_row, av2_row]
+ self.template_def_2lc.template_service.issuephasedef_tbl.Select\
+ .return_value = [phase1_row, phase2_row]
+
+ actual = self.template_def_2lc.FetchItems(cnxn=None, keys=[7, 8, 9])
+ self.assertEqual(3, len(list(actual.keys())))
+ self.assertTrue(isinstance(actual[7], tracker_pb2.TemplateDef))
+ self.assertTrue(isinstance(actual[8], tracker_pb2.TemplateDef))
+ self.assertTrue(isinstance(actual[9], tracker_pb2.TemplateDef))
+
+ self.assertEqual(7, actual[7].template_id)
+ self.assertEqual(8, actual[8].template_id)
+ self.assertEqual(9, actual[9].template_id)
+
+ self.assertEqual(['label-2'], actual[7].labels)
+ self.assertEqual([], actual[8].labels)
+ self.assertEqual(['label-1'], actual[9].labels)
+
+ self.assertEqual([14], actual[7].component_ids)
+ self.assertEqual([], actual[8].component_ids)
+ self.assertEqual([13], actual[9].component_ids)
+
+ self.assertEqual([222], actual[7].admin_ids)
+ self.assertEqual([], actual[8].admin_ids)
+ self.assertEqual([111], actual[9].admin_ids)
+
+ self.assertEqual([fv2], actual[7].field_values)
+ self.assertEqual([], actual[8].field_values)
+ self.assertEqual([fv1], actual[9].field_values)
+
+ self.assertEqual([phase2], actual[7].phases)
+ self.assertEqual([], actual[8].phases)
+ self.assertEqual([phase1], actual[9].phases)
+
+ self.assertEqual([av2], actual[7].approval_values)
+ self.assertEqual([], actual[8].approval_values)
+ self.assertEqual([av1], actual[9].approval_values)
+
+
+class TemplateServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.cnxn = Mock()
+ self.template_service = template_svc.TemplateService(fake.CacheManager())
+ self.template_service.template_set_2lc = Mock()
+ self.template_service.template_def_2lc = Mock()
+
+ def testCreateDefaultProjectTemplates_Normal(self):
+ self.template_service.CreateIssueTemplateDef = Mock()
+ self.template_service.CreateDefaultProjectTemplates(self.cnxn, 789)
+
+ expected_calls = [
+ mock.call(self.cnxn, 789, tpl['name'], tpl['content'], tpl['summary'],
+ tpl['summary_must_be_edited'], tpl['status'],
+ tpl.get('members_only', False), True, False, None, tpl['labels'],
+ [], [], [], [])
+ for tpl in tracker_constants.DEFAULT_TEMPLATES]
+ self.template_service.CreateIssueTemplateDef.assert_has_calls(
+ expected_calls, any_order=True)
+
+ def testGetTemplateByName_Normal(self):
+ """GetTemplateByName returns a template that exists."""
+ result_dict = {789: [(1, 'one', 0)]}
+ template = tracker_pb2.TemplateDef(name='one')
+ self.template_service.template_set_2lc.GetAll.return_value = (
+ result_dict, None)
+ self.template_service.template_def_2lc.GetAll.return_value = (
+ {1: template}, None)
+ actual = self.template_service.GetTemplateByName(self.cnxn, 'one', 789)
+ self.assertEqual(actual.template_id, template.template_id)
+
+ def testGetTemplateByName_NotFound(self):
+ """When GetTemplateByName is given the name of a template that does not
+ exist."""
+ result_dict = {789: [(1, 'one', 0)]}
+ template = tracker_pb2.TemplateDef(name='one')
+ self.template_service.template_set_2lc.GetAll.return_value = (
+ result_dict, None)
+ self.template_service.template_def_2lc.GetAll.return_value = (
+ {1: template}, None)
+ actual = self.template_service.GetTemplateByName(self.cnxn, 'two', 789)
+ self.assertEqual(actual, None)
+
+ def testGetTemplateById_Normal(self):
+ """GetTemplateById_Normal returns a template that exists."""
+ template = tracker_pb2.TemplateDef(template_id=1, name='one')
+ self.template_service.template_def_2lc.GetAll.return_value = (
+ {1: template}, None)
+ actual = self.template_service.GetTemplateById(self.cnxn, 1)
+ self.assertEqual(actual.template_id, template.template_id)
+
+ def testGetTemplateById_NotFound(self):
+ """When GetTemplateById is given the ID of a template that does not
+ exist."""
+ self.template_service.template_def_2lc.GetAll.return_value = (
+ {}, None)
+ actual = self.template_service.GetTemplateById(self.cnxn, 1)
+ self.assertEqual(actual, None)
+
+ def testGetTemplatesById_Normal(self):
+ """GetTemplatesById_Normal returns a template that exists."""
+ template = tracker_pb2.TemplateDef(template_id=1, name='one')
+ self.template_service.template_def_2lc.GetAll.return_value = (
+ {1: template}, None)
+ actual = self.template_service.GetTemplatesById(self.cnxn, 1)
+ self.assertEqual(actual[0].template_id, template.template_id)
+
+ def testGetTemplatesById_NotFound(self):
+ """When GetTemplatesById is given the ID of a template that does not
+ exist."""
+ self.template_service.template_def_2lc.GetAll.return_value = (
+ {}, None)
+ actual = self.template_service.GetTemplatesById(self.cnxn, 1)
+ self.assertEqual(actual, [])
+
+ def testGetProjectTemplates_Normal(self):
+ template_set = [(1, 'one', 0), (2, 'two', 1)]
+ result_dict = {789: template_set}
+ self.template_service.template_set_2lc.GetAll.return_value = (
+ result_dict, None)
+ self.template_service.template_def_2lc.GetAll.return_value = (
+ {1: tracker_pb2.TemplateDef()}, None)
+
+ self.assertEqual([tracker_pb2.TemplateDef()],
+ self.template_service.GetProjectTemplates(self.cnxn, 789))
+ self.template_service.template_set_2lc.GetAll.assert_called_once_with(
+ self.cnxn, [789])
+
+ def testExpungeProjectTemplates(self):
+ template_id_rows = [(1,), (2,)]
+ self.template_service.template_tbl.Select = Mock(
+ return_value=template_id_rows)
+ self.template_service.template2label_tbl.Delete = Mock()
+ self.template_service.template2component_tbl.Delete = Mock()
+ self.template_service.template_tbl.Delete = Mock()
+
+ self.template_service.ExpungeProjectTemplates(self.cnxn, 789)
+
+ self.template_service.template_tbl.Select\
+ .assert_called_once_with(self.cnxn, project_id=789, cols=['id'])
+ self.template_service.template2label_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=[1, 2])
+ self.template_service.template2component_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=[1, 2])
+ self.template_service.template_tbl.Delete\
+ .assert_called_once_with(self.cnxn, project_id=789)
+
+
+class CreateIssueTemplateDefTest(TemplateServiceTest):
+
+ def setUp(self):
+ super(CreateIssueTemplateDefTest, self).setUp()
+
+ self.template_service.template_tbl.InsertRow = Mock(return_value=1)
+ self.template_service.template2label_tbl.InsertRows = Mock()
+ self.template_service.template2component_tbl.InsertRows = Mock()
+ self.template_service.template2admin_tbl.InsertRows = Mock()
+ self.template_service.template2fieldvalue_tbl.InsertRows = Mock()
+ self.template_service.issuephasedef_tbl.InsertRow = Mock(return_value=81)
+ self.template_service.template2approvalvalue_tbl.InsertRows = Mock()
+ self.template_service.template_set_2lc._StrToKey = Mock(return_value=789)
+
+ def testCreateIssueTemplateDef(self):
+ fv = tracker_bizobj.MakeFieldValue(
+ 1, None, 'somestring', None, None, None, False)
+ av_23 = tracker_pb2.ApprovalValue(
+ approval_id=23, phase_id=11,
+ status=tracker_pb2.ApprovalStatus.NEEDS_REVIEW)
+ av_24 = tracker_pb2.ApprovalValue(approval_id=24, phase_id=11)
+ approval_values = [av_23, av_24]
+ phases = [tracker_pb2.Phase(
+ name='Canary', rank=11, phase_id=11)]
+
+ actual_template_id = self.template_service.CreateIssueTemplateDef(
+ self.cnxn, 789, 'template', 'content', 'summary', True, 'Available',
+ True, True, True, owner_id=111, labels=['label'], component_ids=[3],
+ admin_ids=[222], field_values=[fv], phases=phases,
+ approval_values=approval_values)
+
+ self.assertEqual(1, actual_template_id)
+
+ self.template_service.template_tbl.InsertRow\
+ .assert_called_once_with(self.cnxn, project_id=789, name='template',
+ content='content', summary='summary', summary_must_be_edited=True,
+ owner_id=111, status='Available', members_only=True,
+ owner_defaults_to_member=True, component_required=True,
+ commit=False)
+ self.template_service.template2label_tbl.InsertRows\
+ .assert_called_once_with(self.cnxn, template_svc.TEMPLATE2LABEL_COLS,
+ [(1, 'label')], commit=False)
+ self.template_service.template2component_tbl.InsertRows\
+ .assert_called_once_with(self.cnxn,
+ template_svc.TEMPLATE2COMPONENT_COLS,
+ [(1, 3)], commit=False)
+ self.template_service.template2admin_tbl.InsertRows\
+ .assert_called_once_with(self.cnxn, template_svc.TEMPLATE2ADMIN_COLS,
+ [(1, 222)], commit=False)
+ self.template_service.template2fieldvalue_tbl.InsertRows\
+ .assert_called_once_with(self.cnxn,
+ template_svc.TEMPLATE2FIELDVALUE_COLS,
+ [(1, 1, None, 'somestring', None, None, None)], commit=False)
+ self.template_service.issuephasedef_tbl.InsertRow\
+ .assert_called_once_with(self.cnxn, name='Canary',
+ rank=11, commit=False)
+ self.template_service.template2approvalvalue_tbl.InsertRows\
+ .assert_called_once_with(self.cnxn,
+ template_svc.TEMPLATE2APPROVALVALUE_COLS,
+ [(23, 1, 81, 'needs_review'), (24, 1, 81, 'not_set')], commit=False)
+ self.cnxn.Commit.assert_called_once_with()
+ self.template_service.template_set_2lc.InvalidateKeys\
+ .assert_called_once_with(self.cnxn, [789])
+
+
+class UpdateIssueTemplateDefTest(TemplateServiceTest):
+
+ def setUp(self):
+ super(UpdateIssueTemplateDefTest, self).setUp()
+
+ self.template_service.template_tbl.Update = Mock()
+ self.template_service.template2label_tbl.Delete = Mock()
+ self.template_service.template2label_tbl.InsertRows = Mock()
+ self.template_service.template2admin_tbl.Delete = Mock()
+ self.template_service.template2admin_tbl.InsertRows = Mock()
+ self.template_service.template2approvalvalue_tbl.Delete = Mock()
+ self.template_service.issuephasedef_tbl.InsertRow = Mock(return_value=1)
+ self.template_service.template2approvalvalue_tbl.InsertRows = Mock()
+ self.template_service.template_set_2lc._StrToKey = Mock(return_value=789)
+
+ def testUpdateIssueTemplateDef(self):
+ av_20 = tracker_pb2.ApprovalValue(approval_id=20, phase_id=11)
+ av_21 = tracker_pb2.ApprovalValue(approval_id=21, phase_id=11)
+ approval_values = [av_20, av_21]
+ phases = [tracker_pb2.Phase(
+ name='Canary', phase_id=11, rank=11)]
+ self.template_service.UpdateIssueTemplateDef(
+ self.cnxn, 789, 1, content='content', summary='summary',
+ component_required=True, labels=[], admin_ids=[111],
+ phases=phases, approval_values=approval_values)
+
+ new_values = dict(
+ content='content', summary='summary', component_required=True)
+ self.template_service.template_tbl.Update\
+ .assert_called_once_with(self.cnxn, new_values, id=1, commit=False)
+ self.template_service.template2label_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=1, commit=False)
+ self.template_service.template2label_tbl.InsertRows\
+ .assert_called_once_with(self.cnxn, template_svc.TEMPLATE2LABEL_COLS,
+ [], commit=False)
+ self.template_service.template2admin_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=1, commit=False)
+ self.template_service.template2admin_tbl.InsertRows\
+ .assert_called_once_with(self.cnxn, template_svc.TEMPLATE2ADMIN_COLS,
+ [(1, 111)], commit=False)
+ self.template_service.template2approvalvalue_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=1, commit=False)
+ self.template_service.issuephasedef_tbl.InsertRow\
+ .assert_called_once_with(self.cnxn, name='Canary',
+ rank=11, commit=False)
+ self.template_service.template2approvalvalue_tbl.InsertRows\
+ .assert_called_once_with(self.cnxn,
+ template_svc.TEMPLATE2APPROVALVALUE_COLS,
+ [(20, 1, 1, 'not_set'), (21, 1, 1, 'not_set')], commit=False)
+ self.cnxn.Commit.assert_called_once_with()
+ self.template_service.template_set_2lc.InvalidateKeys\
+ .assert_called_once_with(self.cnxn, [789])
+ self.template_service.template_def_2lc.InvalidateKeys\
+ .assert_called_once_with(self.cnxn, [1])
+
+
+class DeleteTemplateTest(TemplateServiceTest):
+
+ def testDeleteIssueTemplateDef(self):
+ self.template_service.template2label_tbl.Delete = Mock()
+ self.template_service.template2component_tbl.Delete = Mock()
+ self.template_service.template2admin_tbl.Delete = Mock()
+ self.template_service.template2fieldvalue_tbl.Delete = Mock()
+ self.template_service.template2approvalvalue_tbl.Delete = Mock()
+ self.template_service.template_tbl.Delete = Mock()
+ self.template_service.template_set_2lc._StrToKey = Mock(return_value=789)
+
+ self.template_service.DeleteIssueTemplateDef(self.cnxn, 789, 1)
+
+ self.template_service.template2label_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=1, commit=False)
+ self.template_service.template2component_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=1, commit=False)
+ self.template_service.template2admin_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=1, commit=False)
+ self.template_service.template2fieldvalue_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=1, commit=False)
+ self.template_service.template2approvalvalue_tbl.Delete\
+ .assert_called_once_with(self.cnxn, template_id=1, commit=False)
+ self.template_service.template_tbl.Delete\
+ .assert_called_once_with(self.cnxn, id=1, commit=False)
+ self.cnxn.Commit.assert_called_once_with()
+ self.template_service.template_set_2lc.InvalidateKeys\
+ .assert_called_once_with(self.cnxn, [789])
+ self.template_service.template_def_2lc.InvalidateKeys\
+ .assert_called_once_with(self.cnxn, [1])
+
+
+class ExpungeUsersInTemplatesTest(TemplateServiceTest):
+
+ def setUp(self):
+ super(ExpungeUsersInTemplatesTest, self).setUp()
+
+ self.template_service.template2admin_tbl.Delete = Mock()
+ self.template_service.template2fieldvalue_tbl.Delete = Mock()
+ self.template_service.template_tbl.Update = Mock()
+
+ def testExpungeUsersInTemplates(self):
+ user_ids = [111, 222]
+ self.template_service.ExpungeUsersInTemplates(self.cnxn, user_ids, limit=60)
+
+ self.template_service.template2admin_tbl.Delete.assert_called_once_with(
+ self.cnxn, admin_id=user_ids, commit=False, limit=60)
+ self.template_service.template2fieldvalue_tbl\
+ .Delete.assert_called_once_with(
+ self.cnxn, user_id=user_ids, commit=False, limit=60)
+ self.template_service.template_tbl.Update.assert_called_once_with(
+ self.cnxn, {'owner_id': None}, owner_id=user_ids, commit=False)
+
+
+class UnpackTemplateTest(unittest.TestCase):
+
+ def testEmpty(self):
+ with self.assertRaises(ValueError):
+ template_svc.UnpackTemplate(())
+
+ def testNormal(self):
+ row = (1, 2, 'name', 'content', 'summary', False, 3, 'status', False,
+ False, False)
+ self.assertEqual(
+ tracker_pb2.TemplateDef(template_id=1, name='name',
+ content='content', summary='summary', summary_must_be_edited=False,
+ owner_id=3, status='status', members_only=False,
+ owner_defaults_to_member=False,
+ component_required=False),
+ template_svc.UnpackTemplate(row))
diff --git a/services/test/tracker_fulltext_test.py b/services/test/tracker_fulltext_test.py
new file mode 100644
index 0000000..db8a7a7
--- /dev/null
+++ b/services/test/tracker_fulltext_test.py
@@ -0,0 +1,283 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Unit tests for tracker_fulltext module."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import mox
+
+from google.appengine.api import search
+
+import settings
+from framework import framework_views
+from proto import ast_pb2
+from proto import tracker_pb2
+from services import fulltext_helpers
+from services import tracker_fulltext
+from testing import fake
+from tracker import tracker_bizobj
+
+
+class TrackerFulltextTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.mock_index = self.mox.CreateMockAnything()
+ self.mox.StubOutWithMock(search, 'Index')
+ self.docs = None
+ self.cnxn = 'fake connection'
+ self.user_service = fake.UserService()
+ self.user_service.TestAddUser('test@example.com', 111)
+ self.issue_service = fake.IssueService()
+ self.config_service = fake.ConfigService()
+
+ self.issue = fake.MakeTestIssue(
+ 123, 1, 'test summary', 'New', 111)
+ self.issue_service.TestAddIssue(self.issue)
+ self.comment = tracker_pb2.IssueComment(
+ project_id=789, issue_id=self.issue.issue_id, user_id=111,
+ content='comment content',
+ attachments=[
+ tracker_pb2.Attachment(filename='hello.c'),
+ tracker_pb2.Attachment(filename='hello.h')])
+ self.issue_service.TestAddComment(self.comment, 1)
+ self.users_by_id = framework_views.MakeAllUserViews(
+ self.cnxn, self.user_service, [111])
+
+ def tearDown(self):
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def RecordDocs(self, docs):
+ self.docs = docs
+
+ def SetUpIndexIssues(self):
+ search.Index(name=settings.search_index_name_format % 1).AndReturn(
+ self.mock_index)
+ self.mock_index.put(mox.IgnoreArg()).WithSideEffects(self.RecordDocs)
+
+ def testIndexIssues(self):
+ self.SetUpIndexIssues()
+ self.mox.ReplayAll()
+ tracker_fulltext.IndexIssues(
+ self.cnxn, [self.issue], self.user_service, self.issue_service,
+ self.config_service)
+ self.mox.VerifyAll()
+ self.assertEqual(1, len(self.docs))
+ issue_doc = self.docs[0]
+ self.assertEqual(123, issue_doc.fields[0].value)
+ self.assertEqual('test summary', issue_doc.fields[1].value)
+
+ def SetUpCreateIssueSearchDocuments(self):
+ self.mox.StubOutWithMock(tracker_fulltext, '_IndexDocsInShard')
+ tracker_fulltext._IndexDocsInShard(1, mox.IgnoreArg()).WithSideEffects(
+ lambda shard_id, docs: self.RecordDocs(docs))
+
+ def testCreateIssueSearchDocuments_Normal(self):
+ self.SetUpCreateIssueSearchDocuments()
+ self.mox.ReplayAll()
+ config_dict = {123: tracker_bizobj.MakeDefaultProjectIssueConfig(123)}
+ tracker_fulltext._CreateIssueSearchDocuments(
+ [self.issue], {self.issue.issue_id: [self.comment]}, self.users_by_id,
+ config_dict)
+ self.mox.VerifyAll()
+ self.assertEqual(1, len(self.docs))
+ issue_doc = self.docs[0]
+ self.assertEqual(5, len(issue_doc.fields))
+ self.assertEqual(123, issue_doc.fields[0].value)
+ self.assertEqual('test summary', issue_doc.fields[1].value)
+ self.assertEqual('test@example.com comment content hello.c hello.h',
+ issue_doc.fields[3].value)
+ self.assertEqual('', issue_doc.fields[4].value)
+
+ def testCreateIssueSearchDocuments_NoIndexableComments(self):
+ """Sometimes all comments on a issue are spam or deleted."""
+ self.SetUpCreateIssueSearchDocuments()
+ self.mox.ReplayAll()
+ config_dict = {123: tracker_bizobj.MakeDefaultProjectIssueConfig(123)}
+ self.comment.deleted_by = 111
+ tracker_fulltext._CreateIssueSearchDocuments(
+ [self.issue], {self.issue.issue_id: [self.comment]}, self.users_by_id,
+ config_dict)
+ self.mox.VerifyAll()
+ self.assertEqual(1, len(self.docs))
+ issue_doc = self.docs[0]
+ self.assertEqual(5, len(issue_doc.fields))
+ self.assertEqual(123, issue_doc.fields[0].value)
+ self.assertEqual('test summary', issue_doc.fields[1].value)
+ self.assertEqual('', issue_doc.fields[3].value)
+ self.assertEqual('', issue_doc.fields[4].value)
+
+ def testCreateIssueSearchDocuments_CustomFields(self):
+ self.SetUpCreateIssueSearchDocuments()
+ self.mox.ReplayAll()
+ config = tracker_bizobj.MakeDefaultProjectIssueConfig(123)
+ config_dict = {123: tracker_bizobj.MakeDefaultProjectIssueConfig(123)}
+ int_field = tracker_bizobj.MakeFieldDef(
+ 1, 123, 'CustomInt', tracker_pb2.FieldTypes.INT_TYPE, None, False,
+ False, False, None, None, None, None, False, None, None, None,
+ 'no_action', 'A custom int field', False)
+ int_field_value = tracker_bizobj.MakeFieldValue(
+ 1, 42, None, None, False, None, None)
+ str_field = tracker_bizobj.MakeFieldDef(
+ 2, 123, 'CustomStr', tracker_pb2.FieldTypes.STR_TYPE, None, False,
+ False, False, None, None, None, None, False, None, None, None,
+ 'no_action', 'A custom string field', False)
+ str_field_value = tracker_bizobj.MakeFieldValue(
+ 2, None, u'\xf0\x9f\x92\x96\xef\xb8\x8f', None, None, None, False)
+ # TODO(jrobbins): user-type field 3
+ date_field = tracker_bizobj.MakeFieldDef(
+ 4, 123, 'CustomDate', tracker_pb2.FieldTypes.DATE_TYPE, None, False,
+ False, False, None, None, None, None, False, None, None, None,
+ 'no_action', 'A custom date field', False)
+ date_field_value = tracker_bizobj.MakeFieldValue(
+ 4, None, None, None, 1234567890, None, False)
+ config.field_defs.extend([int_field, str_field, date_field])
+ self.issue.field_values.extend([
+ int_field_value, str_field_value, date_field_value])
+
+ tracker_fulltext._CreateIssueSearchDocuments(
+ [self.issue], {self.issue.issue_id: [self.comment]}, self.users_by_id,
+ config_dict)
+ self.mox.VerifyAll()
+ self.assertEqual(1, len(self.docs))
+ issue_doc = self.docs[0]
+ metadata = issue_doc.fields[2]
+ self.assertEqual(
+ u'New test@example.com [] 42 \xf0\x9f\x92\x96\xef\xb8\x8f 2009-02-13 ',
+ metadata.value)
+
+ def testExtractCommentText(self):
+ extracted_text = tracker_fulltext._ExtractCommentText(
+ self.comment, self.users_by_id)
+ self.assertEqual(
+ 'test@example.com comment content hello.c hello.h',
+ extracted_text)
+
+ def testIndexableComments_NumberOfComments(self):
+ """We consider at most 100 initial comments and 500 most recent comments."""
+ comments = [self.comment]
+ indexable = tracker_fulltext._IndexableComments(comments, self.users_by_id)
+ self.assertEqual(1, len(indexable))
+
+ comments = [self.comment] * 100
+ indexable = tracker_fulltext._IndexableComments(comments, self.users_by_id)
+ self.assertEqual(100, len(indexable))
+
+ comments = [self.comment] * 101
+ indexable = tracker_fulltext._IndexableComments(comments, self.users_by_id)
+ self.assertEqual(101, len(indexable))
+
+ comments = [self.comment] * 600
+ indexable = tracker_fulltext._IndexableComments(comments, self.users_by_id)
+ self.assertEqual(600, len(indexable))
+
+ comments = [self.comment] * 601
+ indexable = tracker_fulltext._IndexableComments(comments, self.users_by_id)
+ self.assertEqual(600, len(indexable))
+ self.assertNotIn(100, indexable)
+
+ def testIndexableComments_NumberOfChars(self):
+ """We consider comments that can fit into the search index document."""
+ self.comment.content = 'x' * 1000
+ comments = [self.comment] * 100
+
+ indexable = tracker_fulltext._IndexableComments(
+ comments, self.users_by_id, remaining_chars=100000)
+ self.assertEqual(100, len(indexable))
+
+ indexable = tracker_fulltext._IndexableComments(
+ comments, self.users_by_id, remaining_chars=50000)
+ self.assertEqual(50, len(indexable))
+ indexable = tracker_fulltext._IndexableComments(
+ comments, self.users_by_id, remaining_chars=50999)
+ self.assertEqual(50, len(indexable))
+
+ indexable = tracker_fulltext._IndexableComments(
+ comments, self.users_by_id, remaining_chars=999)
+ self.assertEqual(0, len(indexable))
+
+ indexable = tracker_fulltext._IndexableComments(
+ comments, self.users_by_id, remaining_chars=0)
+ self.assertEqual(0, len(indexable))
+
+ indexable = tracker_fulltext._IndexableComments(
+ comments, self.users_by_id, remaining_chars=-1)
+ self.assertEqual(0, len(indexable))
+
+ def SetUpUnindexIssues(self):
+ search.Index(name=settings.search_index_name_format % 1).AndReturn(
+ self.mock_index)
+ self.mock_index.delete(['1'])
+
+ def testUnindexIssues(self):
+ self.SetUpUnindexIssues()
+ self.mox.ReplayAll()
+ tracker_fulltext.UnindexIssues([1])
+ self.mox.VerifyAll()
+
+ def SetUpSearchIssueFullText(self):
+ self.mox.StubOutWithMock(fulltext_helpers, 'ComprehensiveSearch')
+ fulltext_helpers.ComprehensiveSearch(
+ '(project_id:789) (summary:"test")',
+ settings.search_index_name_format % 1).AndReturn([123, 234])
+
+ def testSearchIssueFullText_Normal(self):
+ self.SetUpSearchIssueFullText()
+ self.mox.ReplayAll()
+ summary_fd = tracker_pb2.FieldDef(
+ field_name='summary', field_type=tracker_pb2.FieldTypes.STR_TYPE)
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.Condition(
+ op=ast_pb2.QueryOp.TEXT_HAS, field_defs=[summary_fd],
+ str_values=['test'])])
+ issue_ids, capped = tracker_fulltext.SearchIssueFullText(
+ [789], query_ast_conj, 1)
+ self.mox.VerifyAll()
+ self.assertItemsEqual([123, 234], issue_ids)
+ self.assertFalse(capped)
+
+ def testSearchIssueFullText_CrossProject(self):
+ self.mox.StubOutWithMock(fulltext_helpers, 'ComprehensiveSearch')
+ fulltext_helpers.ComprehensiveSearch(
+ '(project_id:789 OR project_id:678) (summary:"test")',
+ settings.search_index_name_format % 1).AndReturn([123, 234])
+ self.mox.ReplayAll()
+
+ summary_fd = tracker_pb2.FieldDef(
+ field_name='summary', field_type=tracker_pb2.FieldTypes.STR_TYPE)
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.Condition(
+ op=ast_pb2.QueryOp.TEXT_HAS, field_defs=[summary_fd],
+ str_values=['test'])])
+ issue_ids, capped = tracker_fulltext.SearchIssueFullText(
+ [789, 678], query_ast_conj, 1)
+ self.mox.VerifyAll()
+ self.assertItemsEqual([123, 234], issue_ids)
+ self.assertFalse(capped)
+
+ def testSearchIssueFullText_Capped(self):
+ try:
+ orig = settings.fulltext_limit_per_shard
+ settings.fulltext_limit_per_shard = 1
+ self.SetUpSearchIssueFullText()
+ self.mox.ReplayAll()
+ summary_fd = tracker_pb2.FieldDef(
+ field_name='summary', field_type=tracker_pb2.FieldTypes.STR_TYPE)
+ query_ast_conj = ast_pb2.Conjunction(conds=[
+ ast_pb2.Condition(
+ op=ast_pb2.QueryOp.TEXT_HAS, field_defs=[summary_fd],
+ str_values=['test'])])
+ issue_ids, capped = tracker_fulltext.SearchIssueFullText(
+ [789], query_ast_conj, 1)
+ self.mox.VerifyAll()
+ self.assertItemsEqual([123, 234], issue_ids)
+ self.assertTrue(capped)
+ finally:
+ settings.fulltext_limit_per_shard = orig
diff --git a/services/test/user_svc_test.py b/services/test/user_svc_test.py
new file mode 100644
index 0000000..4a8eb16
--- /dev/null
+++ b/services/test/user_svc_test.py
@@ -0,0 +1,600 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the user service."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import unittest
+
+import mock
+import mox
+import time
+
+from google.appengine.ext import testbed
+
+from framework import exceptions
+from framework import framework_constants
+from framework import sql
+from proto import user_pb2
+from services import user_svc
+from testing import fake
+
+
+def SetUpGetUsers(user_service, cnxn):
+ """Set up expected calls to SQL tables."""
+ user_service.user_tbl.Select(
+ cnxn, cols=user_svc.USER_COLS, user_id=[333]).AndReturn(
+ [(333, 'c@example.com', False, False, False, False, True,
+ False, 'Spammer',
+ 'stay_same_issue', False, False, True, 0, 0, None)])
+ user_service.linkedaccount_tbl.Select(
+ cnxn, cols=user_svc.LINKEDACCOUNT_COLS, parent_id=[333], child_id=[333],
+ or_where_conds=True).AndReturn([])
+
+
+def MakeUserService(cache_manager, my_mox):
+ user_service = user_svc.UserService(cache_manager)
+ user_service.user_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ user_service.hotlistvisithistory_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ user_service.linkedaccount_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ # Account linking invites are done with patch().
+ return user_service
+
+
+class UserTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = fake.MonorailConnection()
+ self.cache_manager = fake.CacheManager()
+ self.user_service = MakeUserService(self.cache_manager, self.mox)
+
+ def tearDown(self):
+ self.testbed.deactivate()
+
+ def testDeserializeUsersByID(self):
+ user_rows = [
+ (111, 'a@example.com', False, False, False, False, True, False, '',
+ 'stay_same_issue', False, False, True, 0, 0, None),
+ (222, 'b@example.com', False, False, False, False, True, False, '',
+ 'next_in_list', False, False, True, 0, 0, None),
+ ]
+ linkedaccount_rows = []
+ user_dict = self.user_service.user_2lc._DeserializeUsersByID(
+ user_rows, linkedaccount_rows)
+ self.assertEqual(2, len(user_dict))
+ self.assertEqual('a@example.com', user_dict[111].email)
+ self.assertFalse(user_dict[111].is_site_admin)
+ self.assertEqual('', user_dict[111].banned)
+ self.assertFalse(user_dict[111].notify_issue_change)
+ self.assertEqual('b@example.com', user_dict[222].email)
+ self.assertIsNone(user_dict[111].linked_parent_id)
+ self.assertEqual([], user_dict[111].linked_child_ids)
+ self.assertIsNone(user_dict[222].linked_parent_id)
+ self.assertEqual([], user_dict[222].linked_child_ids)
+
+ def testDeserializeUsersByID_LinkedAccounts(self):
+ user_rows = [
+ (111, 'a@example.com', False, False, False, False, True, False, '',
+ 'stay_same_issue', False, False, True, 0, 0, None),
+ ]
+ linkedaccount_rows = [(111, 222), (111, 333), (444, 111)]
+ user_dict = self.user_service.user_2lc._DeserializeUsersByID(
+ user_rows, linkedaccount_rows)
+ self.assertEqual(1, len(user_dict))
+ user_pb = user_dict[111]
+ self.assertEqual('a@example.com', user_pb.email)
+ self.assertEqual(444, user_pb.linked_parent_id)
+ self.assertEqual([222, 333], user_pb.linked_child_ids)
+
+ def testFetchItems(self):
+ SetUpGetUsers(self.user_service, self.cnxn)
+ self.mox.ReplayAll()
+ user_dict = self.user_service.user_2lc.FetchItems(self.cnxn, [333])
+ self.mox.VerifyAll()
+ self.assertEqual([333], list(user_dict.keys()))
+ self.assertEqual('c@example.com', user_dict[333].email)
+ self.assertFalse(user_dict[333].is_site_admin)
+ self.assertEqual('Spammer', user_dict[333].banned)
+
+
+class UserServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = fake.MonorailConnection()
+ self.cache_manager = fake.CacheManager()
+ self.user_service = MakeUserService(self.cache_manager, self.mox)
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def SetUpCreateUsers(self):
+ self.user_service.user_tbl.InsertRows(
+ self.cnxn,
+ ['user_id', 'email', 'obscure_email'],
+ [(3035911623, 'a@example.com', True),
+ (2996997680, 'b@example.com', True)]
+ ).AndReturn(None)
+
+ def testCreateUsers(self):
+ self.SetUpCreateUsers()
+ self.mox.ReplayAll()
+ self.user_service._CreateUsers(
+ self.cnxn, ['a@example.com', 'b@example.com'])
+ self.mox.VerifyAll()
+
+ def SetUpLookupUserEmails(self):
+ self.user_service.user_tbl.Select(
+ self.cnxn, cols=['user_id', 'email'], user_id=[222]).AndReturn(
+ [(222, 'b@example.com')])
+
+ def testLookupUserEmails(self):
+ self.SetUpLookupUserEmails()
+ self.user_service.email_cache.CacheItem(
+ 111, 'a@example.com')
+ self.mox.ReplayAll()
+ emails_dict = self.user_service.LookupUserEmails(
+ self.cnxn, [111, 222])
+ self.mox.VerifyAll()
+ self.assertEqual(
+ {111: 'a@example.com', 222: 'b@example.com'},
+ emails_dict)
+
+ def SetUpLookupUserEmails_Missed(self):
+ self.user_service.user_tbl.Select(
+ self.cnxn, cols=['user_id', 'email'], user_id=[222]).AndReturn([])
+ self.user_service.email_cache.CacheItem(
+ 111, 'a@example.com')
+
+ def testLookupUserEmails_Missed(self):
+ self.SetUpLookupUserEmails_Missed()
+ self.mox.ReplayAll()
+ with self.assertRaises(exceptions.NoSuchUserException):
+ self.user_service.LookupUserEmails(self.cnxn, [111, 222])
+ self.mox.VerifyAll()
+
+ def testLookUpUserEmails_IgnoreMissed(self):
+ self.SetUpLookupUserEmails_Missed()
+ self.mox.ReplayAll()
+ emails_dict = self.user_service.LookupUserEmails(
+ self.cnxn, [111, 222], ignore_missed=True)
+ self.mox.VerifyAll()
+ self.assertEqual({111: 'a@example.com'}, emails_dict)
+
+ def testLookupUserEmail(self):
+ self.SetUpLookupUserEmails() # Same as testLookupUserEmails()
+ self.mox.ReplayAll()
+ email_addr = self.user_service.LookupUserEmail(self.cnxn, 222)
+ self.mox.VerifyAll()
+ self.assertEqual('b@example.com', email_addr)
+
+ def SetUpLookupUserIDs(self):
+ self.user_service.user_tbl.Select(
+ self.cnxn, cols=['email', 'user_id'],
+ email=['b@example.com']).AndReturn([('b@example.com', 222)])
+
+ def testLookupUserIDs(self):
+ self.SetUpLookupUserIDs()
+ self.user_service.user_id_cache.CacheItem(
+ 'a@example.com', 111)
+ self.mox.ReplayAll()
+ user_id_dict = self.user_service.LookupUserIDs(
+ self.cnxn, ['a@example.com', 'b@example.com'])
+ self.mox.VerifyAll()
+ self.assertEqual(
+ {'a@example.com': 111, 'b@example.com': 222},
+ user_id_dict)
+
+ def testLookupUserIDs_InvalidEmail(self):
+ self.user_service.user_tbl.Select(
+ self.cnxn, cols=['email', 'user_id'], email=['abc']).AndReturn([])
+ self.mox.ReplayAll()
+ user_id_dict = self.user_service.LookupUserIDs(
+ self.cnxn, ['abc'], autocreate=True)
+ self.mox.VerifyAll()
+ self.assertEqual({}, user_id_dict)
+
+ def testLookupUserIDs_NoUserValue(self):
+ self.user_service.user_tbl.Select = mock.Mock(
+ return_value=[('b@example.com', 222)])
+ user_id_dict = self.user_service.LookupUserIDs(
+ self.cnxn, [framework_constants.NO_VALUES, '', 'b@example.com'])
+ self.assertEqual({'b@example.com': 222}, user_id_dict)
+ self.user_service.user_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['email', 'user_id'], email=['b@example.com'])
+
+ def testLookupUserID(self):
+ self.SetUpLookupUserIDs() # Same as testLookupUserIDs()
+ self.user_service.user_id_cache.CacheItem('a@example.com', 111)
+ self.mox.ReplayAll()
+ user_id = self.user_service.LookupUserID(self.cnxn, 'b@example.com')
+ self.mox.VerifyAll()
+ self.assertEqual(222, user_id)
+
+ def SetUpGetUsersByIDs(self):
+ self.user_service.user_tbl.Select(
+ self.cnxn, cols=user_svc.USER_COLS, user_id=[333, 444]).AndReturn(
+ [
+ (
+ 333, 'c@example.com', False, False, False, False, True,
+ False, 'Spammer', 'stay_same_issue', False, False, True, 0,
+ 0, None)
+ ])
+ self.user_service.linkedaccount_tbl.Select(
+ self.cnxn,
+ cols=user_svc.LINKEDACCOUNT_COLS,
+ parent_id=[333, 444],
+ child_id=[333, 444],
+ or_where_conds=True).AndReturn([])
+
+
+ def testGetUsersByIDs(self):
+ self.SetUpGetUsersByIDs()
+ user_a = user_pb2.User(email='a@example.com')
+ self.user_service.user_2lc.CacheItem(111, user_a)
+ self.mox.ReplayAll()
+ # 444 user does not exist.
+ user_dict = self.user_service.GetUsersByIDs(self.cnxn, [111, 333, 444])
+ self.mox.VerifyAll()
+ self.assertEqual(3, len(user_dict))
+ self.assertEqual('a@example.com', user_dict[111].email)
+ self.assertFalse(user_dict[111].is_site_admin)
+ self.assertFalse(user_dict[111].banned)
+ self.assertTrue(user_dict[111].notify_issue_change)
+ self.assertEqual('c@example.com', user_dict[333].email)
+ self.assertEqual(user_dict[444], user_pb2.MakeUser(444))
+
+ def testGetUsersByIDs_SkipMissed(self):
+ self.SetUpGetUsersByIDs()
+ user_a = user_pb2.User(email='a@example.com')
+ self.user_service.user_2lc.CacheItem(111, user_a)
+ self.mox.ReplayAll()
+ # 444 user does not exist
+ user_dict = self.user_service.GetUsersByIDs(
+ self.cnxn, [111, 333, 444], skip_missed=True)
+ self.mox.VerifyAll()
+ self.assertEqual(2, len(user_dict))
+ self.assertEqual('a@example.com', user_dict[111].email)
+ self.assertFalse(user_dict[111].is_site_admin)
+ self.assertFalse(user_dict[111].banned)
+ self.assertTrue(user_dict[111].notify_issue_change)
+ self.assertEqual('c@example.com', user_dict[333].email)
+
+ def testGetUser(self):
+ SetUpGetUsers(self.user_service, self.cnxn)
+ user_a = user_pb2.User(email='a@example.com')
+ self.user_service.user_2lc.CacheItem(111, user_a)
+ self.mox.ReplayAll()
+ user = self.user_service.GetUser(self.cnxn, 333)
+ self.mox.VerifyAll()
+ self.assertEqual('c@example.com', user.email)
+
+ def SetUpUpdateUser(self):
+ delta = {
+ 'keep_people_perms_open': False,
+ 'preview_on_hover': True,
+ 'notify_issue_change': True,
+ 'after_issue_update': 'STAY_SAME_ISSUE',
+ 'notify_starred_issue_change': True,
+ 'notify_starred_ping': False,
+ 'is_site_admin': False,
+ 'banned': 'Turned spammer',
+ 'obscure_email': True,
+ 'email_compact_subject': False,
+ 'email_view_widget': True,
+ 'last_visit_timestamp': 0,
+ 'email_bounce_timestamp': 0,
+ 'vacation_message': None,
+ }
+ self.user_service.user_tbl.Update(
+ self.cnxn, delta, user_id=111, commit=False)
+
+ def testUpdateUser(self):
+ self.SetUpUpdateUser()
+ user_a = user_pb2.User(
+ email='a@example.com', banned='Turned spammer')
+ self.mox.ReplayAll()
+ self.user_service.UpdateUser(self.cnxn, 111, user_a)
+ self.mox.VerifyAll()
+ self.assertFalse(self.user_service.user_2lc.HasItem(111))
+
+ def SetUpGetRecentlyVisitedHotlists(self):
+ self.user_service.hotlistvisithistory_tbl.Select(
+ self.cnxn, cols=['hotlist_id'], user_id=[111],
+ order_by=[('viewed DESC', [])], limit=10).AndReturn(
+ ((123,), (234,)))
+
+ def testGetRecentlyVisitedHotlists(self):
+ self.SetUpGetRecentlyVisitedHotlists()
+ self.mox.ReplayAll()
+ recent_hotlist_rows = self.user_service.GetRecentlyVisitedHotlists(
+ self.cnxn, 111)
+ self.mox.VerifyAll()
+ self.assertEqual(recent_hotlist_rows, [123, 234])
+
+ def SetUpAddVisitedHotlist(self, ts):
+ self.user_service.hotlistvisithistory_tbl.Delete(
+ self.cnxn, hotlist_id=123, user_id=111, commit=False)
+ self.user_service.hotlistvisithistory_tbl.InsertRows(
+ self.cnxn, user_svc.HOTLISTVISITHISTORY_COLS,
+ [(123, 111, ts)],
+ commit=False)
+
+ @mock.patch('time.time')
+ def testAddVisitedHotlist(self, mockTime):
+ ts = 122333
+ mockTime.return_value = ts
+ self.SetUpAddVisitedHotlist(ts)
+ self.mox.ReplayAll()
+ self.user_service.AddVisitedHotlist(self.cnxn, 111, 123, commit=False)
+ self.mox.VerifyAll()
+
+ def testExpungeHotlistsFromHistory(self):
+ self.user_service.hotlistvisithistory_tbl.Delete = mock.Mock()
+ hotlist_ids = [123, 223]
+ self.user_service.ExpungeHotlistsFromHistory(
+ self.cnxn, hotlist_ids, commit=False)
+ self.user_service.hotlistvisithistory_tbl.Delete.assert_called_once_with(
+ self.cnxn, hotlist_id=hotlist_ids, commit=False)
+
+ def testExpungeUsersHotlistsHistory(self):
+ self.user_service.hotlistvisithistory_tbl.Delete = mock.Mock()
+ user_ids = [111, 222]
+ self.user_service.ExpungeUsersHotlistsHistory(
+ self.cnxn, user_ids, commit=False)
+ self.user_service.hotlistvisithistory_tbl.Delete.assert_called_once_with(
+ self.cnxn, user_id=user_ids, commit=False)
+
+ def SetUpTrimUserVisitedHotlists(self, user_ids, ts):
+ self.user_service.hotlistvisithistory_tbl.Select(
+ self.cnxn, cols=['user_id'], group_by=['user_id'],
+ having=[('COUNT(*) > %s', [10])], limit=1000).AndReturn((
+ (111,), (222,), (333,)))
+ for user_id in user_ids:
+ self.user_service.hotlistvisithistory_tbl.Select(
+ self.cnxn, cols=['viewed'], user_id=user_id,
+ order_by=[('viewed DESC', [])]).AndReturn([
+ (ts,), (ts,), (ts,), (ts,), (ts,), (ts,),
+ (ts,), (ts,), (ts,), (ts,), (ts+1,)])
+ self.user_service.hotlistvisithistory_tbl.Delete(
+ self.cnxn, user_id=user_id, where=[('viewed < %s', [ts])],
+ commit=False)
+
+ @mock.patch('time.time')
+ def testTrimUserVisitedHotlists(self, mockTime):
+ ts = 122333
+ mockTime.return_value = ts
+ self.SetUpTrimUserVisitedHotlists([111, 222, 333], ts)
+ self.mox.ReplayAll()
+ self.user_service.TrimUserVisitedHotlists(self.cnxn, commit=False)
+ self.mox.VerifyAll()
+
+ def testGetPendingLinkedInvites_Anon(self):
+ """An Anon user never has invites to link accounts."""
+ as_parent, as_child = self.user_service.GetPendingLinkedInvites(
+ self.cnxn, 0)
+ self.assertEqual([], as_parent)
+ self.assertEqual([], as_child)
+
+ def testGetPendingLinkedInvites_None(self):
+ """A user who has no link invites gets empty lists."""
+ self.user_service.linkedaccountinvite_tbl = mock.Mock()
+ self.user_service.linkedaccountinvite_tbl.Select.return_value = []
+ as_parent, as_child = self.user_service.GetPendingLinkedInvites(
+ self.cnxn, 111)
+ self.assertEqual([], as_parent)
+ self.assertEqual([], as_child)
+
+ def testGetPendingLinkedInvites_Some(self):
+ """A user who has link invites can get them."""
+ self.user_service.linkedaccountinvite_tbl = mock.Mock()
+ self.user_service.linkedaccountinvite_tbl.Select.return_value = [
+ (111, 222), (111, 333), (888, 999), (333, 111)]
+ as_parent, as_child = self.user_service.GetPendingLinkedInvites(
+ self.cnxn, 111)
+ self.assertEqual([222, 333], as_parent)
+ self.assertEqual([333], as_child)
+
+ def testAssertNotAlreadyLinked_NotLinked(self):
+ """No exception is raised when accounts are not already linked."""
+ self.user_service.linkedaccount_tbl = mock.Mock()
+ self.user_service.linkedaccount_tbl.Select.return_value = []
+ self.user_service._AssertNotAlreadyLinked(self.cnxn, 111, 222)
+
+ def testAssertNotAlreadyLinked_AlreadyLinked(self):
+ """Reject attempt to link any account that is already linked."""
+ self.user_service.linkedaccount_tbl = mock.Mock()
+ self.user_service.linkedaccount_tbl.Select.return_value = [
+ (111, 222)]
+ with self.assertRaises(exceptions.InputException):
+ self.user_service._AssertNotAlreadyLinked(self.cnxn, 111, 333)
+
+ def testInviteLinkedParent_Anon(self):
+ """Anon cannot invite anyone to link accounts."""
+ with self.assertRaises(exceptions.InputException):
+ self.user_service.InviteLinkedParent(self.cnxn, 0, 0)
+ with self.assertRaises(exceptions.InputException):
+ self.user_service.InviteLinkedParent(self.cnxn, 111, 0)
+ with self.assertRaises(exceptions.InputException):
+ self.user_service.InviteLinkedParent(self.cnxn, 0, 111)
+
+ def testInviteLinkedParent_Normal(self):
+ """One account can invite another to link."""
+ self.user_service.linkedaccount_tbl = mock.Mock()
+ self.user_service.linkedaccount_tbl.Select.return_value = []
+ self.user_service.linkedaccountinvite_tbl = mock.Mock()
+ self.user_service.InviteLinkedParent(
+ self.cnxn, 111, 222)
+ self.user_service.linkedaccountinvite_tbl.InsertRow.assert_called_once_with(
+ self.cnxn, parent_id=111, child_id=222)
+
+ def testAcceptLinkedChild_Anon(self):
+ """Reject attempts for anon to accept any invite."""
+ with self.assertRaises(exceptions.InputException):
+ self.user_service.AcceptLinkedChild(self.cnxn, 0, 333)
+ with self.assertRaises(exceptions.InputException):
+ self.user_service.AcceptLinkedChild(self.cnxn, 333, 0)
+
+ def testAcceptLinkedChild_Missing(self):
+ """Reject attempts to link without a matching invite."""
+ self.user_service.linkedaccountinvite_tbl = mock.Mock()
+ self.user_service.linkedaccountinvite_tbl.Select.return_value = []
+ self.user_service.linkedaccount_tbl = mock.Mock()
+ self.user_service.linkedaccount_tbl.Select.return_value = []
+ with self.assertRaises(exceptions.InputException) as cm:
+ self.user_service.AcceptLinkedChild(self.cnxn, 111, 333)
+ self.assertEqual('No such invite', cm.exception.message)
+
+ def testAcceptLinkedChild_Normal(self):
+ """Create linkage between accounts and remove invite."""
+ self.user_service.linkedaccountinvite_tbl = mock.Mock()
+ self.user_service.linkedaccountinvite_tbl.Select.return_value = [
+ (111, 222), (333, 444)]
+ self.user_service.linkedaccount_tbl = mock.Mock()
+ self.user_service.linkedaccount_tbl.Select.return_value = []
+
+ self.user_service.AcceptLinkedChild(self.cnxn, 111, 222)
+ self.user_service.linkedaccount_tbl.InsertRow.assert_called_once_with(
+ self.cnxn, parent_id=111, child_id=222)
+ self.user_service.linkedaccountinvite_tbl.Delete.assert_called_once_with(
+ self.cnxn, parent_id=111, child_id=222)
+
+ def testUnlinkAccounts_MissingIDs(self):
+ """Reject an attempt to unlink anon."""
+ with self.assertRaises(exceptions.InputException):
+ self.user_service.UnlinkAccounts(self.cnxn, 0, 0)
+ with self.assertRaises(exceptions.InputException):
+ self.user_service.UnlinkAccounts(self.cnxn, 0, 111)
+ with self.assertRaises(exceptions.InputException):
+ self.user_service.UnlinkAccounts(self.cnxn, 111, 0)
+
+ def testUnlinkAccounts_Normal(self):
+ """We can unlink accounts."""
+ self.user_service.linkedaccount_tbl = mock.Mock()
+ self.user_service.UnlinkAccounts(self.cnxn, 111, 222)
+ self.user_service.linkedaccount_tbl.Delete.assert_called_once_with(
+ self.cnxn, parent_id=111, child_id=222)
+
+ def testUpdateUserSettings(self):
+ self.SetUpUpdateUser()
+ user_a = user_pb2.User(email='a@example.com')
+ self.mox.ReplayAll()
+ self.user_service.UpdateUserSettings(
+ self.cnxn, 111, user_a, is_banned=True,
+ banned_reason='Turned spammer')
+ self.mox.VerifyAll()
+
+ def testGetUsersPrefs(self):
+ self.user_service.userprefs_tbl = mock.Mock()
+ self.user_service.userprefs_tbl.Select.return_value = [
+ (111, 'code_font', 'true'),
+ (111, 'keep_perms_open', 'true'),
+ # Note: user 222 has not set any prefs.
+ (333, 'code_font', 'false')]
+
+ prefs_dict = self.user_service.GetUsersPrefs(self.cnxn, [111, 222, 333])
+
+ expected = {
+ 111: user_pb2.UserPrefs(
+ user_id=111,
+ prefs=[user_pb2.UserPrefValue(name='code_font', value='true'),
+ user_pb2.UserPrefValue(name='keep_perms_open', value='true')]),
+ 222: user_pb2.UserPrefs(user_id=222),
+ 333: user_pb2.UserPrefs(
+ user_id=333,
+ prefs=[user_pb2.UserPrefValue(name='code_font', value='false')]),
+ }
+ self.assertEqual(expected, prefs_dict)
+
+ def testGetUserPrefs(self):
+ self.user_service.userprefs_tbl = mock.Mock()
+ self.user_service.userprefs_tbl.Select.return_value = [
+ (111, 'code_font', 'true'),
+ (111, 'keep_perms_open', 'true'),
+ # Note: user 222 has not set any prefs.
+ (333, 'code_font', 'false')]
+
+ userprefs = self.user_service.GetUserPrefs(self.cnxn, 111)
+ expected = user_pb2.UserPrefs(
+ user_id=111,
+ prefs=[user_pb2.UserPrefValue(name='code_font', value='true'),
+ user_pb2.UserPrefValue(name='keep_perms_open', value='true')])
+ self.assertEqual(expected, userprefs)
+
+ userprefs = self.user_service.GetUserPrefs(self.cnxn, 222)
+ expected = user_pb2.UserPrefs(user_id=222)
+ self.assertEqual(expected, userprefs)
+
+ def testSetUserPrefs(self):
+ self.user_service.userprefs_tbl = mock.Mock()
+ pref_values = [user_pb2.UserPrefValue(name='code_font', value='true'),
+ user_pb2.UserPrefValue(name='keep_perms_open', value='true')]
+ self.user_service.SetUserPrefs(self.cnxn, 111, pref_values)
+ self.user_service.userprefs_tbl.InsertRows.assert_called_once_with(
+ self.cnxn, user_svc.USERPREFS_COLS,
+ [(111, 'code_font', 'true'),
+ (111, 'keep_perms_open', 'true')],
+ replace=True)
+
+ def testExpungeUsers(self):
+ self.user_service.linkedaccount_tbl.Delete = mock.Mock()
+ self.user_service.linkedaccountinvite_tbl.Delete = mock.Mock()
+ self.user_service.userprefs_tbl.Delete = mock.Mock()
+ self.user_service.user_tbl.Delete = mock.Mock()
+
+ user_ids = [222, 444]
+ self.user_service.ExpungeUsers(self.cnxn, user_ids)
+
+ linked_account_calls = [
+ mock.call(self.cnxn, parent_id=user_ids, commit=False),
+ mock.call(self.cnxn, child_id=user_ids, commit=False)]
+ self.user_service.linkedaccount_tbl.Delete.has_calls(linked_account_calls)
+ self.user_service.linkedaccountinvite_tbl.Delete.has_calls(
+ linked_account_calls)
+ user_calls = [mock.call(self.cnxn, user_id=user_ids, commit=False)]
+ self.user_service.userprefs_tbl.Delete.has_calls(user_calls)
+ self.user_service.user_tbl.Delete.has_calls(user_calls)
+
+ def testTotalUsersCount(self):
+ self.user_service.user_tbl.SelectValue = mock.Mock(return_value=10)
+ self.assertEqual(self.user_service.TotalUsersCount(self.cnxn), 9)
+ self.user_service.user_tbl.SelectValue.assert_called_once_with(
+ self.cnxn, col='COUNT(*)')
+
+ def testGetAllUserEmailsBatch(self):
+ rows = [('cow@test.com',), ('pig@test.com',), ('fox@test.com',)]
+ self.user_service.user_tbl.Select = mock.Mock(return_value=rows)
+ emails = self.user_service.GetAllUserEmailsBatch(self.cnxn)
+ self.user_service.user_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['email'], limit=1000, offset=0,
+ where=[('user_id != %s', [framework_constants.DELETED_USER_ID])],
+ order_by=[('user_id ASC', [])])
+ self.assertItemsEqual(
+ emails, ['cow@test.com', 'pig@test.com', 'fox@test.com'])
+
+ def testGetAllUserEmailsBatch_CustomLimit(self):
+ rows = [('cow@test.com',), ('pig@test.com',), ('fox@test.com',)]
+ self.user_service.user_tbl.Select = mock.Mock(return_value=rows)
+ emails = self.user_service.GetAllUserEmailsBatch(
+ self.cnxn, limit=30, offset=60)
+ self.user_service.user_tbl.Select.assert_called_once_with(
+ self.cnxn, cols=['email'], limit=30, offset=60,
+ where=[('user_id != %s', [framework_constants.DELETED_USER_ID])],
+ order_by=[('user_id ASC', [])])
+ self.assertItemsEqual(
+ emails, ['cow@test.com', 'pig@test.com', 'fox@test.com'])
diff --git a/services/test/usergroup_svc_test.py b/services/test/usergroup_svc_test.py
new file mode 100644
index 0000000..5bfd899
--- /dev/null
+++ b/services/test/usergroup_svc_test.py
@@ -0,0 +1,562 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Tests for the usergroup service."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import mock
+import unittest
+
+import mox
+
+from google.appengine.ext import testbed
+
+from framework import exceptions
+from framework import permissions
+from framework import sql
+from proto import usergroup_pb2
+from services import service_manager
+from services import usergroup_svc
+from testing import fake
+
+
+def MakeUserGroupService(cache_manager, my_mox):
+ usergroup_service = usergroup_svc.UserGroupService(cache_manager)
+ usergroup_service.usergroup_tbl = my_mox.CreateMock(sql.SQLTableManager)
+ usergroup_service.usergroupsettings_tbl = my_mox.CreateMock(
+ sql.SQLTableManager)
+ usergroup_service.usergroupprojects_tbl = my_mox.CreateMock(
+ sql.SQLTableManager)
+ return usergroup_service
+
+
+class MembershipTwoLevelCacheTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.cache_manager = fake.CacheManager()
+ self.usergroup_service = MakeUserGroupService(self.cache_manager, self.mox)
+
+ def testDeserializeMemberships(self):
+ memberships_rows = [(111, 777), (111, 888), (222, 888)]
+ actual = self.usergroup_service.memberships_2lc._DeserializeMemberships(
+ memberships_rows)
+ self.assertItemsEqual([111, 222], list(actual.keys()))
+ self.assertItemsEqual([777, 888], actual[111])
+ self.assertItemsEqual([888], actual[222])
+
+
+class UserGroupServiceTest(unittest.TestCase):
+
+ def setUp(self):
+ self.testbed = testbed.Testbed()
+ self.testbed.activate()
+ self.testbed.init_memcache_stub()
+
+ self.mox = mox.Mox()
+ self.cnxn = 'fake connection'
+ self.cache_manager = fake.CacheManager()
+ self.usergroup_service = MakeUserGroupService(self.cache_manager, self.mox)
+ self.services = service_manager.Services(
+ user=fake.UserService(),
+ usergroup=self.usergroup_service,
+ project=fake.ProjectService())
+
+ def tearDown(self):
+ self.testbed.deactivate()
+ self.mox.UnsetStubs()
+ self.mox.ResetAll()
+
+ def SetUpCreateGroup(
+ self, group_id, visiblity, external_group_type=None):
+ self.SetUpUpdateSettings(group_id, visiblity, external_group_type)
+
+ def testCreateGroup_Normal(self):
+ self.services.user.TestAddUser('group@example.com', 888)
+ self.SetUpCreateGroup(888, 'anyone')
+ self.mox.ReplayAll()
+ actual_group_id = self.usergroup_service.CreateGroup(
+ self.cnxn, self.services, 'group@example.com', 'anyone')
+ self.mox.VerifyAll()
+ self.assertEqual(888, actual_group_id)
+
+ def testCreateGroup_Import(self):
+ self.services.user.TestAddUser('troopers', 888)
+ self.SetUpCreateGroup(888, 'owners', 'mdb')
+ self.mox.ReplayAll()
+ actual_group_id = self.usergroup_service.CreateGroup(
+ self.cnxn, self.services, 'troopers', 'owners', 'mdb')
+ self.mox.VerifyAll()
+ self.assertEqual(888, actual_group_id)
+
+ def SetUpDetermineWhichUserIDsAreGroups(self, ids_to_query, mock_group_ids):
+ self.usergroup_service.usergroupsettings_tbl.Select(
+ self.cnxn, cols=['group_id'], group_id=ids_to_query).AndReturn(
+ (gid,) for gid in mock_group_ids)
+
+ def testDetermineWhichUserIDsAreGroups_NoGroups(self):
+ self.SetUpDetermineWhichUserIDsAreGroups([], [])
+ self.mox.ReplayAll()
+ actual_group_ids = self.usergroup_service.DetermineWhichUserIDsAreGroups(
+ self.cnxn, [])
+ self.mox.VerifyAll()
+ self.assertEqual([], actual_group_ids)
+
+ def testDetermineWhichUserIDsAreGroups_SomeGroups(self):
+ user_ids = [111, 222, 333]
+ group_ids = [888, 999]
+ self.SetUpDetermineWhichUserIDsAreGroups(user_ids + group_ids, group_ids)
+ self.mox.ReplayAll()
+ actual_group_ids = self.usergroup_service.DetermineWhichUserIDsAreGroups(
+ self.cnxn, user_ids + group_ids)
+ self.mox.VerifyAll()
+ self.assertEqual(group_ids, actual_group_ids)
+
+ def testLookupUserGroupID_Found(self):
+ mock_select = mock.MagicMock()
+ self.services.usergroup.usergroupsettings_tbl.Select = mock_select
+ mock_select.return_value = [('group@example.com', 888)]
+
+ actual = self.services.usergroup.LookupUserGroupID(
+ self.cnxn, 'group@example.com')
+
+ self.assertEqual(888, actual)
+ mock_select.assert_called_once_with(
+ self.cnxn, cols=['email', 'group_id'],
+ left_joins=[('User ON UserGroupSettings.group_id = User.user_id', [])],
+ email='group@example.com',
+ where=[('group_id IS NOT NULL', [])])
+
+ def testLookupUserGroupID_NotFound(self):
+ mock_select = mock.MagicMock()
+ self.services.usergroup.usergroupsettings_tbl.Select = mock_select
+ mock_select.return_value = []
+
+ actual = self.services.usergroup.LookupUserGroupID(
+ self.cnxn, 'user@example.com')
+
+ self.assertIsNone(actual)
+ mock_select.assert_called_once_with(
+ self.cnxn, cols=['email', 'group_id'],
+ left_joins=[('User ON UserGroupSettings.group_id = User.user_id', [])],
+ email='user@example.com',
+ where=[('group_id IS NOT NULL', [])])
+
+ def SetUpLookupAllMemberships(self, user_ids, mock_membership_rows):
+ self.usergroup_service.usergroup_tbl.Select(
+ self.cnxn, cols=['user_id', 'group_id'], distinct=True,
+ user_id=user_ids).AndReturn(mock_membership_rows)
+
+ def testLookupAllMemberships(self):
+ self.usergroup_service.group_dag.initialized = True
+ self.usergroup_service.memberships_2lc.CacheItem(111, {888, 999})
+ self.SetUpLookupAllMemberships([222], [(222, 777), (222, 999)])
+ self.usergroup_service.usergroupsettings_tbl.Select(
+ self.cnxn, cols=['group_id']).AndReturn([])
+ self.usergroup_service.usergroup_tbl.Select(
+ self.cnxn, cols=['user_id', 'group_id'], distinct=True,
+ user_id=[]).AndReturn([])
+ self.mox.ReplayAll()
+ actual_membership_dict = self.usergroup_service.LookupAllMemberships(
+ self.cnxn, [111, 222])
+ self.mox.VerifyAll()
+ self.assertEqual(
+ {111: {888, 999}, 222: {777, 999}},
+ actual_membership_dict)
+
+ def SetUpRemoveMembers(self, group_id, member_ids):
+ self.usergroup_service.usergroup_tbl.Delete(
+ self.cnxn, group_id=group_id, user_id=member_ids)
+
+ def testRemoveMembers(self):
+ self.usergroup_service.group_dag.initialized = True
+ self.SetUpRemoveMembers(888, [111, 222])
+ self.SetUpLookupAllMembers([111, 222], [], {}, {})
+ self.mox.ReplayAll()
+ self.usergroup_service.RemoveMembers(self.cnxn, 888, [111, 222])
+ self.mox.VerifyAll()
+
+ def testUpdateMembers(self):
+ self.usergroup_service.group_dag.initialized = True
+ self.usergroup_service.usergroup_tbl.Delete(
+ self.cnxn, group_id=888, user_id=[111, 222])
+ self.usergroup_service.usergroup_tbl.InsertRows(
+ self.cnxn, ['user_id', 'group_id', 'role'],
+ [(111, 888, 'member'), (222, 888, 'member')])
+ self.SetUpLookupAllMembers([111, 222], [], {}, {})
+ self.mox.ReplayAll()
+ self.usergroup_service.UpdateMembers(
+ self.cnxn, 888, [111, 222], 'member')
+ self.mox.VerifyAll()
+
+ def testUpdateMembers_CircleDetection(self):
+ # Two groups: 888 and 999 while 999 is a member of 888.
+ self.SetUpDAG([(888,), (999,)], [(999, 888)])
+ self.mox.ReplayAll()
+ self.assertRaises(
+ exceptions.CircularGroupException,
+ self.usergroup_service.UpdateMembers, self.cnxn, 999, [888], 'member')
+ self.mox.VerifyAll()
+
+ def SetUpLookupAllMembers(
+ self, group_ids, direct_member_rows,
+ descedants_dict, indirect_member_rows_dict):
+ self.usergroup_service.usergroup_tbl.Select(
+ self.cnxn, cols=['user_id', 'group_id', 'role'], distinct=True,
+ group_id=group_ids).AndReturn(direct_member_rows)
+ for gid in group_ids:
+ if descedants_dict.get(gid, []):
+ self.usergroup_service.usergroup_tbl.Select(
+ self.cnxn, cols=['user_id'], distinct=True,
+ group_id=descedants_dict.get(gid, [])).AndReturn(
+ indirect_member_rows_dict.get(gid, []))
+
+ def testLookupAllMembers(self):
+ self.usergroup_service.group_dag.initialized = True
+ self.usergroup_service.group_dag.user_group_children = (
+ collections.defaultdict(list))
+ self.usergroup_service.group_dag.user_group_children[777] = [888]
+ self.usergroup_service.group_dag.user_group_children[888] = [999]
+ self.SetUpLookupAllMembers(
+ [777],
+ [(888, 777, 'member'), (111, 888, 'member'), (999, 888, 'member'),
+ (222, 999, 'member')],
+ {777: [888, 999]},
+ {777: [(111,), (222,), (999,)]})
+
+ self.mox.ReplayAll()
+ members_dict, owners_dict = self.usergroup_service.LookupAllMembers(
+ self.cnxn, [777])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([111, 222, 888, 999], members_dict[777])
+ self.assertItemsEqual([], owners_dict[777])
+
+ def testExpandAnyGroupEmailRecipients(self):
+ self.usergroup_service.group_dag.initialized = True
+ self.SetUpDetermineWhichUserIDsAreGroups(
+ [111, 777, 888, 999], [777, 888, 999])
+ self.SetUpGetGroupSettings(
+ [777, 888, 999],
+ [(777, 'anyone', None, 0, 1, 0),
+ (888, 'anyone', None, 0, 0, 1),
+ (999, 'anyone', None, 0, 1, 1)],
+ )
+ self.SetUpLookupAllMembers(
+ [777, 888, 999],
+ [(222, 777, 'member'), (333, 888, 'member'), (444, 999, 'member')],
+ {}, {})
+ self.mox.ReplayAll()
+ direct, indirect = self.usergroup_service.ExpandAnyGroupEmailRecipients(
+ self.cnxn, [111, 777, 888, 999])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([111, 888, 999], direct)
+ self.assertItemsEqual([222, 444], indirect)
+
+ def SetUpLookupMembers(self, group_member_dict):
+ mock_membership_rows = []
+ group_ids = []
+ for gid, members in group_member_dict.items():
+ group_ids.append(gid)
+ mock_membership_rows.extend([(uid, gid, 'member') for uid in members])
+ group_ids.sort()
+ self.usergroup_service.usergroup_tbl.Select(
+ self.cnxn, cols=['user_id','group_id', 'role'], distinct=True,
+ group_id=group_ids).AndReturn(mock_membership_rows)
+
+ def testLookupMembers_NoneRequested(self):
+ self.mox.ReplayAll()
+ member_ids, _ = self.usergroup_service.LookupMembers(self.cnxn, [])
+ self.mox.VerifyAll()
+ self.assertItemsEqual({}, member_ids)
+
+ def testLookupMembers_Nonexistent(self):
+ """If some requested groups don't exist, they are ignored."""
+ self.SetUpLookupMembers({777: []})
+ self.mox.ReplayAll()
+ member_ids, _ = self.usergroup_service.LookupMembers(self.cnxn, [777])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([], member_ids[777])
+
+ def testLookupMembers_AllEmpty(self):
+ """Requesting all empty groups results in no members."""
+ self.SetUpLookupMembers({888: [], 999: []})
+ self.mox.ReplayAll()
+ member_ids, _ = self.usergroup_service.LookupMembers(self.cnxn, [888, 999])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([], member_ids[888])
+
+ def testLookupMembers_OneGroup(self):
+ self.SetUpLookupMembers({888: [111, 222]})
+ self.mox.ReplayAll()
+ member_ids, _ = self.usergroup_service.LookupMembers(self.cnxn, [888])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([111, 222], member_ids[888])
+
+ def testLookupMembers_GroupsAndNonGroups(self):
+ """We ignore any non-groups passed in."""
+ self.SetUpLookupMembers({111: [], 333: [], 888: [111, 222]})
+ self.mox.ReplayAll()
+ member_ids, _ = self.usergroup_service.LookupMembers(
+ self.cnxn, [111, 333, 888])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([111, 222], member_ids[888])
+
+ def testLookupMembers_OverlappingGroups(self):
+ """We get the union of IDs. Imagine 888 = {111} and 999 = {111, 222}."""
+ self.SetUpLookupMembers({888: [111], 999: [111, 222]})
+ self.mox.ReplayAll()
+ member_ids, _ = self.usergroup_service.LookupMembers(self.cnxn, [888, 999])
+ self.mox.VerifyAll()
+ self.assertItemsEqual([111, 222], member_ids[999])
+ self.assertItemsEqual([111], member_ids[888])
+
+ def testLookupVisibleMembers_LimitedVisiblity(self):
+ """We get only the member IDs in groups that the user is allowed to see."""
+ self.usergroup_service.group_dag.initialized = True
+ self.SetUpGetGroupSettings(
+ [888, 999],
+ [(888, 'anyone', None, 0, 1, 0), (999, 'members', None, 0, 1, 0)])
+ self.SetUpLookupMembers({888: [111], 999: [111]})
+ self.SetUpLookupAllMembers(
+ [888, 999], [(111, 888, 'member'), (111, 999, 'member')], {}, {})
+ self.mox.ReplayAll()
+ member_ids, _ = self.usergroup_service.LookupVisibleMembers(
+ self.cnxn, [888, 999], permissions.USER_PERMISSIONSET, set(),
+ self.services)
+ self.mox.VerifyAll()
+ self.assertItemsEqual([111], member_ids[888])
+ self.assertNotIn(999, member_ids)
+
+ def SetUpGetAllUserGroupsInfo(self, mock_settings_rows, mock_count_rows,
+ mock_friends=None):
+ mock_friends = mock_friends or []
+ self.usergroup_service.usergroupsettings_tbl.Select(
+ self.cnxn, cols=['email', 'group_id', 'who_can_view_members',
+ 'external_group_type', 'last_sync_time',
+ 'notify_members', 'notify_group'],
+ left_joins=[('User ON UserGroupSettings.group_id = User.user_id', [])]
+ ).AndReturn(mock_settings_rows)
+ self.usergroup_service.usergroup_tbl.Select(
+ self.cnxn, cols=['group_id', 'COUNT(*)'],
+ group_by=['group_id']).AndReturn(mock_count_rows)
+
+ group_ids = [g[1] for g in mock_settings_rows]
+ self.usergroup_service.usergroupprojects_tbl.Select(
+ self.cnxn, cols=usergroup_svc.USERGROUPPROJECTS_COLS,
+ group_id=group_ids).AndReturn(mock_friends)
+
+ def testGetAllUserGroupsInfo(self):
+ self.SetUpGetAllUserGroupsInfo(
+ [('group@example.com', 888, 'anyone', None, 0, 1, 0)],
+ [(888, 12)])
+ self.mox.ReplayAll()
+ actual_infos = self.usergroup_service.GetAllUserGroupsInfo(self.cnxn)
+ self.mox.VerifyAll()
+ self.assertEqual(1, len(actual_infos))
+ addr, count, group_settings, group_id = actual_infos[0]
+ self.assertEqual('group@example.com', addr)
+ self.assertEqual(12, count)
+ self.assertEqual(usergroup_pb2.MemberVisibility.ANYONE,
+ group_settings.who_can_view_members)
+ self.assertEqual(888, group_id)
+
+ def SetUpGetGroupSettings(self, group_ids, mock_result_rows,
+ mock_friends=None):
+ mock_friends = mock_friends or []
+ self.usergroup_service.usergroupsettings_tbl.Select(
+ self.cnxn, cols=usergroup_svc.USERGROUPSETTINGS_COLS,
+ group_id=group_ids).AndReturn(mock_result_rows)
+ self.usergroup_service.usergroupprojects_tbl.Select(
+ self.cnxn, cols=usergroup_svc.USERGROUPPROJECTS_COLS,
+ group_id=group_ids).AndReturn(mock_friends)
+
+ def testGetGroupSettings_NoGroupsRequested(self):
+ self.SetUpGetGroupSettings([], [])
+ self.mox.ReplayAll()
+ actual_settings_dict = self.usergroup_service.GetAllGroupSettings(
+ self.cnxn, [])
+ self.mox.VerifyAll()
+ self.assertEqual({}, actual_settings_dict)
+
+ def testGetGroupSettings_NoGroupsFound(self):
+ self.SetUpGetGroupSettings([777], [])
+ self.mox.ReplayAll()
+ actual_settings_dict = self.usergroup_service.GetAllGroupSettings(
+ self.cnxn, [777])
+ self.mox.VerifyAll()
+ self.assertEqual({}, actual_settings_dict)
+
+ def testGetGroupSettings_SomeGroups(self):
+ self.SetUpGetGroupSettings(
+ [777, 888, 999],
+ [(888, 'anyone', None, 0, 1, 0), (999, 'members', None, 0, 1, 0)])
+ self.mox.ReplayAll()
+ actual_settings_dict = self.usergroup_service.GetAllGroupSettings(
+ self.cnxn, [777, 888, 999])
+ self.mox.VerifyAll()
+ self.assertEqual(
+ {888: usergroup_pb2.MakeSettings('anyone'),
+ 999: usergroup_pb2.MakeSettings('members')},
+ actual_settings_dict)
+
+ def testGetGroupSettings_NoSuchGroup(self):
+ self.SetUpGetGroupSettings([777], [])
+ self.mox.ReplayAll()
+ actual_settings = self.usergroup_service.GetGroupSettings(self.cnxn, 777)
+ self.mox.VerifyAll()
+ self.assertEqual(None, actual_settings)
+
+ def testGetGroupSettings_Found(self):
+ self.SetUpGetGroupSettings([888], [(888, 'anyone', None, 0, 1, 0)])
+ self.mox.ReplayAll()
+ actual_settings = self.usergroup_service.GetGroupSettings(self.cnxn, 888)
+ self.mox.VerifyAll()
+ self.assertEqual(
+ usergroup_pb2.MemberVisibility.ANYONE,
+ actual_settings.who_can_view_members)
+
+ def testGetGroupSettings_Import(self):
+ self.SetUpGetGroupSettings(
+ [888], [(888, 'owners', 'mdb', 0, 1, 0)])
+ self.mox.ReplayAll()
+ actual_settings = self.usergroup_service.GetGroupSettings(self.cnxn, 888)
+ self.mox.VerifyAll()
+ self.assertEqual(
+ usergroup_pb2.MemberVisibility.OWNERS,
+ actual_settings.who_can_view_members)
+ self.assertEqual(
+ usergroup_pb2.GroupType.MDB,
+ actual_settings.ext_group_type)
+
+ def SetUpUpdateSettings(self, group_id, visiblity, external_group_type=None,
+ last_sync_time=0, friend_projects=None,
+ notify_members=True, notify_group=False):
+ friend_projects = friend_projects or []
+ self.usergroup_service.usergroupsettings_tbl.InsertRow(
+ self.cnxn, group_id=group_id, who_can_view_members=visiblity,
+ external_group_type=external_group_type,
+ last_sync_time=last_sync_time, notify_members=notify_members,
+ notify_group=notify_group, replace=True)
+ self.usergroup_service.usergroupprojects_tbl.Delete(
+ self.cnxn, group_id=group_id)
+ if friend_projects:
+ rows = [(group_id, p_id) for p_id in friend_projects]
+ self.usergroup_service.usergroupprojects_tbl.InsertRows(
+ self.cnxn, ['group_id', 'project_id'], rows)
+
+ def testUpdateSettings_Normal(self):
+ self.SetUpUpdateSettings(888, 'anyone')
+ self.mox.ReplayAll()
+ self.usergroup_service.UpdateSettings(
+ self.cnxn, 888, usergroup_pb2.MakeSettings('anyone'))
+ self.mox.VerifyAll()
+
+ def testUpdateSettings_Import(self):
+ self.SetUpUpdateSettings(888, 'owners', 'mdb')
+ self.mox.ReplayAll()
+ self.usergroup_service.UpdateSettings(
+ self.cnxn, 888,
+ usergroup_pb2.MakeSettings('owners', 'mdb'))
+ self.mox.VerifyAll()
+
+ def testUpdateSettings_WithFriends(self):
+ self.SetUpUpdateSettings(888, 'anyone', friend_projects=[789])
+ self.mox.ReplayAll()
+ self.usergroup_service.UpdateSettings(
+ self.cnxn, 888,
+ usergroup_pb2.MakeSettings('anyone', friend_projects=[789]))
+ self.mox.VerifyAll()
+
+ def testExpungeUsersInGroups(self):
+ self.usergroup_service.usergroupprojects_tbl.Delete = mock.Mock()
+ self.usergroup_service.usergroupsettings_tbl.Delete = mock.Mock()
+ self.usergroup_service.usergroup_tbl.Delete = mock.Mock()
+
+ ids = [222, 333, 444]
+ self.usergroup_service.ExpungeUsersInGroups(self.cnxn, ids)
+
+ self.usergroup_service.usergroupprojects_tbl.Delete.assert_called_once_with(
+ self.cnxn, group_id=ids, commit=False)
+ self.usergroup_service.usergroupsettings_tbl.Delete.assert_called_once_with(
+ self.cnxn, group_id=ids, commit=False)
+ self.usergroup_service.usergroup_tbl.Delete.assert_has_calls(
+ [mock.call(self.cnxn, group_id=ids, commit=False),
+ mock.call(self.cnxn, user_id=ids, commit=False)])
+
+ def SetUpDAG(self, group_id_rows, usergroup_rows):
+ self.usergroup_service.usergroupsettings_tbl.Select(
+ self.cnxn, cols=['group_id']).AndReturn(group_id_rows)
+ self.usergroup_service.usergroup_tbl.Select(
+ self.cnxn, cols=['user_id', 'group_id'], distinct=True,
+ user_id=[r[0] for r in group_id_rows]).AndReturn(usergroup_rows)
+
+ def testDAG_Build(self):
+ # Old entries should go away after rebuilding
+ self.usergroup_service.group_dag.user_group_parents = (
+ collections.defaultdict(list))
+ self.usergroup_service.group_dag.user_group_parents[111] = [222]
+ # Two groups: 888 and 999 while 999 is a member of 888.
+ self.SetUpDAG([(888,), (999,)], [(999, 888)])
+ self.mox.ReplayAll()
+ self.usergroup_service.group_dag.Build(self.cnxn)
+ self.mox.VerifyAll()
+ self.assertIn(888, self.usergroup_service.group_dag.user_group_children)
+ self.assertIn(999, self.usergroup_service.group_dag.user_group_parents)
+ self.assertNotIn(111, self.usergroup_service.group_dag.user_group_parents)
+
+ def testDAG_GetAllAncestors(self):
+ # Three groups: 777, 888 and 999.
+ # 999 is a direct member of 888, and 888 is a direct member of 777.
+ self.SetUpDAG([(777,), (888,), (999,)], [(999, 888), (888, 777)])
+ self.mox.ReplayAll()
+ ancestors = self.usergroup_service.group_dag.GetAllAncestors(
+ self.cnxn, 999)
+ self.mox.VerifyAll()
+ ancestors.sort()
+ self.assertEqual([777, 888], ancestors)
+
+ def testDAG_GetAllAncestorsDiamond(self):
+ # Four groups: 666, 777, 888 and 999.
+ # 999 is a direct member of both 888 and 777,
+ # 888 is a direct member of 666, and 777 is also a direct member of 666.
+ self.SetUpDAG([(666, ), (777,), (888,), (999,)],
+ [(999, 888), (999, 777), (888, 666), (777, 666)])
+ self.mox.ReplayAll()
+ ancestors = self.usergroup_service.group_dag.GetAllAncestors(
+ self.cnxn, 999)
+ self.mox.VerifyAll()
+ ancestors.sort()
+ self.assertEqual([666, 777, 888], ancestors)
+
+ def testDAG_GetAllDescendants(self):
+ # Four groups: 666, 777, 888 and 999.
+ # 999 is a direct member of both 888 and 777,
+ # 888 is a direct member of 666, and 777 is also a direct member of 666.
+ self.SetUpDAG([(666, ), (777,), (888,), (999,)],
+ [(999, 888), (999, 777), (888, 666), (777, 666)])
+ self.mox.ReplayAll()
+ descendants = self.usergroup_service.group_dag.GetAllDescendants(
+ self.cnxn, 666)
+ self.mox.VerifyAll()
+ descendants.sort()
+ self.assertEqual([777, 888, 999], descendants)
+
+ def testDAG_IsChild(self):
+ # Four groups: 666, 777, 888 and 999.
+ # 999 is a direct member of both 888 and 777,
+ # 888 is a direct member of 666, and 777 is also a direct member of 666.
+ self.SetUpDAG([(666, ), (777,), (888,), (999,)],
+ [(999, 888), (999, 777), (888, 666), (777, 666)])
+ self.mox.ReplayAll()
+ result1 = self.usergroup_service.group_dag.IsChild(
+ self.cnxn, 777, 666)
+ result2 = self.usergroup_service.group_dag.IsChild(
+ self.cnxn, 777, 888)
+ self.mox.VerifyAll()
+ self.assertTrue(result1)
+ self.assertFalse(result2)
diff --git a/services/tracker_fulltext.py b/services/tracker_fulltext.py
new file mode 100644
index 0000000..ecbfc44
--- /dev/null
+++ b/services/tracker_fulltext.py
@@ -0,0 +1,320 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of functions that provide fulltext search for issues."""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import logging
+import time
+
+from six import string_types
+
+from google.appengine.api import search
+
+import settings
+from framework import framework_constants
+from framework import framework_helpers
+from framework import framework_views
+from services import fulltext_helpers
+from tracker import tracker_bizobj
+
+
+# When updating and re-indexing all issues in a project, work in batches
+# of this size to manage memory usage and avoid rpc timeouts.
+_INDEX_BATCH_SIZE = 40
+
+
+# The user can search for text that occurs specifically in these
+# parts of an issue.
+ISSUE_FULLTEXT_FIELDS = ['summary', 'description', 'comment']
+# Note: issue documents also contain a "metadata" field, but we do not
+# expose that to users. Issue metadata can be searched in a structured way
+# by giving a specific field name such as "owner:" or "status:". The metadata
+# search field exists only for fulltext queries that do not specify any field.
+
+
+def IndexIssues(cnxn, issues, user_service, issue_service, config_service):
+ """(Re)index all the given issues.
+
+ Args:
+ cnxn: connection to SQL database.
+ issues: list of Issue PBs to index.
+ user_service: interface to user data storage.
+ issue_service: interface to issue data storage.
+ config_service: interface to configuration data storage.
+ """
+ issues = list(issues)
+ config_dict = config_service.GetProjectConfigs(
+ cnxn, {issue.project_id for issue in issues})
+ for start in range(0, len(issues), _INDEX_BATCH_SIZE):
+ logging.info('indexing issues: %d remaining', len(issues) - start)
+ _IndexIssueBatch(
+ cnxn, issues[start:start + _INDEX_BATCH_SIZE], user_service,
+ issue_service, config_dict)
+
+
+def _IndexIssueBatch(cnxn, issues, user_service, issue_service, config_dict):
+ """Internal method to (re)index the given batch of issues.
+
+ Args:
+ cnxn: connection to SQL database.
+ issues: list of Issue PBs to index.
+ user_service: interface to user data storage.
+ issue_service: interface to issue data storage.
+ config_dict: dict {project_id: config} for all the projects that
+ the given issues are in.
+ """
+ user_ids = tracker_bizobj.UsersInvolvedInIssues(issues)
+ comments_dict = issue_service.GetCommentsForIssues(
+ cnxn, [issue.issue_id for issue in issues])
+ for comments in comments_dict.values():
+ user_ids.update([ic.user_id for ic in comments])
+
+ users_by_id = framework_views.MakeAllUserViews(
+ cnxn, user_service, user_ids)
+ _CreateIssueSearchDocuments(issues, comments_dict, users_by_id, config_dict)
+
+
+def _CreateIssueSearchDocuments(
+ issues, comments_dict, users_by_id, config_dict):
+ """Make the GAE search index documents for the given issue batch.
+
+ Args:
+ issues: list of issues to index.
+ comments_dict: prefetched dictionary of comments on those issues.
+ users_by_id: dictionary {user_id: UserView} so that the email
+ addresses of users who left comments can be found via search.
+ config_dict: dict {project_id: config} for all the projects that
+ the given issues are in.
+ """
+ documents_by_shard = collections.defaultdict(list)
+ for issue in issues:
+ summary = issue.summary
+ # TODO(jrobbins): allow search specifically on explicit vs derived
+ # fields.
+ owner_id = tracker_bizobj.GetOwnerId(issue)
+ owner_email = users_by_id[owner_id].email
+ config = config_dict[issue.project_id]
+ component_paths = []
+ for component_id in issue.component_ids:
+ cd = tracker_bizobj.FindComponentDefByID(component_id, config)
+ if cd:
+ component_paths.append(cd.path)
+
+ field_values = [tracker_bizobj.GetFieldValue(fv, users_by_id)
+ for fv in issue.field_values]
+ # Convert to string only the values that are not strings already.
+ # This is done because the default encoding in appengine seems to be 'ascii'
+ # and string values might contain unicode characters, so str will fail to
+ # encode them.
+ field_values = [value if isinstance(value, string_types) else str(value)
+ for value in field_values]
+
+ metadata = '%s %s %s %s %s %s' % (
+ tracker_bizobj.GetStatus(issue),
+ owner_email,
+ [users_by_id[cc_id].email for cc_id in
+ tracker_bizobj.GetCcIds(issue)],
+ ' '.join(component_paths),
+ ' '.join(field_values),
+ ' '.join(tracker_bizobj.GetLabels(issue)))
+ custom_fields = _BuildCustomFTSFields(issue)
+
+ comments = comments_dict.get(issue.issue_id, [])
+ room_for_comments = (framework_constants.MAX_FTS_FIELD_SIZE -
+ len(summary) -
+ len(metadata) -
+ sum(len(cf.value) for cf in custom_fields))
+ comments = _IndexableComments(
+ comments, users_by_id, remaining_chars=room_for_comments)
+ logging.info('len(comments) is %r', len(comments))
+ if comments:
+ description = _ExtractCommentText(comments[0], users_by_id)
+ description = description[:framework_constants.MAX_FTS_FIELD_SIZE]
+ all_comments = ' '. join(
+ _ExtractCommentText(c, users_by_id) for c in comments[1:])
+ all_comments = all_comments[:framework_constants.MAX_FTS_FIELD_SIZE]
+ else:
+ description = ''
+ all_comments = ''
+ logging.info(
+ 'Issue %s:%r has zero indexable comments',
+ issue.project_name, issue.local_id)
+
+ logging.info('Building document for %s:%d',
+ issue.project_name, issue.local_id)
+ logging.info('len(summary) = %d', len(summary))
+ logging.info('len(metadata) = %d', len(metadata))
+ logging.info('len(description) = %d', len(description))
+ logging.info('len(comment) = %d', len(all_comments))
+ for cf in custom_fields:
+ logging.info('len(%s) = %d', cf.name, len(cf.value))
+
+ doc = search.Document(
+ doc_id=str(issue.issue_id),
+ fields=[
+ search.NumberField(name='project_id', value=issue.project_id),
+ search.TextField(name='summary', value=summary),
+ search.TextField(name='metadata', value=metadata),
+ search.TextField(name='description', value=description),
+ search.TextField(name='comment', value=all_comments),
+ ] + custom_fields)
+
+ shard_id = issue.issue_id % settings.num_logical_shards
+ documents_by_shard[shard_id].append(doc)
+
+ start_time = time.time()
+ promises = []
+ for shard_id, documents in documents_by_shard.items():
+ if documents:
+ promises.append(framework_helpers.Promise(
+ _IndexDocsInShard, shard_id, documents))
+
+ for promise in promises:
+ promise.WaitAndGetValue()
+
+ logging.info('Finished %d indexing in shards in %d ms',
+ len(documents_by_shard), int((time.time() - start_time) * 1000))
+
+
+def _IndexableComments(comments, users_by_id, remaining_chars=None):
+ """We only index the comments that are not deleted or banned.
+
+ Args:
+ comments: list of Comment PBs for one issue.
+ users_by_id: Dict of (user_id -> UserView) for all users.
+ remaining_chars: number of characters available for comment text
+ without hitting the GAE search index max document size.
+
+ Returns:
+ A list of comments filtered to not have any deleted comments or
+ comments from banned users. If the issue has a huge number of
+ comments, only a certain number of the first and last comments
+ are actually indexed.
+ """
+ if remaining_chars is None:
+ remaining_chars = framework_constants.MAX_FTS_FIELD_SIZE
+ allowed_comments = []
+ for comment in comments:
+ user_view = users_by_id.get(comment.user_id)
+ if not (comment.deleted_by or (user_view and user_view.banned)):
+ if comment.is_description and allowed_comments:
+ # index the latest description, but not older descriptions
+ allowed_comments[0] = comment
+ else:
+ allowed_comments.append(comment)
+
+ reasonable_size = (framework_constants.INITIAL_COMMENTS_TO_INDEX +
+ framework_constants.FINAL_COMMENTS_TO_INDEX)
+ if len(allowed_comments) <= reasonable_size:
+ candidates = allowed_comments
+ else:
+ candidates = ( # Prioritize the description and recent comments.
+ allowed_comments[0:1] +
+ allowed_comments[-framework_constants.FINAL_COMMENTS_TO_INDEX:] +
+ allowed_comments[1:framework_constants.INITIAL_COMMENTS_TO_INDEX])
+
+ total_length = 0
+ result = []
+ for comment in candidates:
+ total_length += len(comment.content)
+ if total_length > remaining_chars:
+ break
+ result.append(comment)
+
+ return result
+
+
+def _IndexDocsInShard(shard_id, documents):
+ search_index = search.Index(
+ name=settings.search_index_name_format % shard_id)
+ search_index.put(documents)
+ logging.info('FTS indexed %d docs in shard %d', len(documents), shard_id)
+ # TODO(jrobbins): catch OverQuotaError and add the issues to the
+ # ReindexQueue table instead.
+
+
+def _ExtractCommentText(comment, users_by_id):
+ """Return a string with all the searchable text of the given Comment PB."""
+ commenter_email = users_by_id[comment.user_id].email
+ return '%s %s %s' % (
+ commenter_email,
+ comment.content,
+ ' '.join(attach.filename
+ for attach in comment.attachments
+ if not attach.deleted))
+
+
+def _BuildCustomFTSFields(issue):
+ """Return a list of FTS Fields to index string-valued custom fields."""
+ fts_fields = []
+ for fv in issue.field_values:
+ if fv.str_value:
+ # TODO(jrobbins): also indicate which were derived vs. explicit.
+ # TODO(jrobbins): also toss in the email addresses of any users in
+ # user-valued custom fields, ints for int-valued fields, etc.
+ fts_field = search.TextField(
+ name='custom_%d' % fv.field_id, value=fv.str_value)
+ fts_fields.append(fts_field)
+
+ return fts_fields
+
+
+def UnindexIssues(issue_ids):
+ """Remove many issues from the sharded search indexes."""
+ iids_by_shard = {}
+ for issue_id in issue_ids:
+ shard_id = issue_id % settings.num_logical_shards
+ iids_by_shard.setdefault(shard_id, [])
+ iids_by_shard[shard_id].append(issue_id)
+
+ for shard_id, iids_in_shard in iids_by_shard.items():
+ try:
+ logging.info(
+ 'unindexing %r issue_ids in %r', len(iids_in_shard), shard_id)
+ search_index = search.Index(
+ name=settings.search_index_name_format % shard_id)
+ search_index.delete([str(iid) for iid in iids_in_shard])
+ except search.Error:
+ logging.exception('FTS deletion failed')
+
+
+def SearchIssueFullText(project_ids, query_ast_conj, shard_id):
+ """Do full-text search in GAE FTS.
+
+ Args:
+ project_ids: list of project ID numbers to consider.
+ query_ast_conj: One conjuctive clause from the AST parsed
+ from the user's query.
+ shard_id: int shard ID for the shard to consider.
+
+ Returns:
+ (issue_ids, capped) where issue_ids is a list of issue issue_ids that match
+ the full-text query. And, capped is True if the results were capped due to
+ an implementation limitation. Or, return (None, False) if the given AST
+ conjunction contains no full-text conditions.
+ """
+ fulltext_query = fulltext_helpers.BuildFTSQuery(
+ query_ast_conj, ISSUE_FULLTEXT_FIELDS)
+ if fulltext_query is None:
+ return None, False
+
+ if project_ids:
+ project_clause = ' OR '.join(
+ 'project_id:%d' % pid for pid in project_ids)
+ fulltext_query = '(%s) %s' % (project_clause, fulltext_query)
+
+ # TODO(jrobbins): it would be good to also include some other
+ # structured search terms to narrow down the set of index
+ # documents considered. E.g., most queries are only over the
+ # open issues.
+ logging.info('FTS query is %r', fulltext_query)
+ issue_ids = fulltext_helpers.ComprehensiveSearch(
+ fulltext_query, settings.search_index_name_format % shard_id)
+ capped = len(issue_ids) >= settings.fulltext_limit_per_shard
+ return issue_ids, capped
diff --git a/services/user_svc.py b/services/user_svc.py
new file mode 100644
index 0000000..28ad465
--- /dev/null
+++ b/services/user_svc.py
@@ -0,0 +1,729 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""A set of functions that provide persistence for users.
+
+Business objects are described in user_pb2.py.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import time
+
+import settings
+from framework import exceptions
+from framework import framework_bizobj
+from framework import framework_constants
+from framework import framework_helpers
+from framework import sql
+from framework import validate
+from proto import user_pb2
+from services import caches
+
+
+USER_TABLE_NAME = 'User'
+USERPREFS_TABLE_NAME = 'UserPrefs'
+HOTLISTVISITHISTORY_TABLE_NAME = 'HotlistVisitHistory'
+LINKEDACCOUNT_TABLE_NAME = 'LinkedAccount'
+LINKEDACCOUNTINVITE_TABLE_NAME = 'LinkedAccountInvite'
+
+USER_COLS = [
+ 'user_id', 'email', 'is_site_admin', 'notify_issue_change',
+ 'notify_starred_issue_change', 'email_compact_subject', 'email_view_widget',
+ 'notify_starred_ping',
+ 'banned', 'after_issue_update', 'keep_people_perms_open',
+ 'preview_on_hover', 'obscure_email',
+ 'last_visit_timestamp', 'email_bounce_timestamp', 'vacation_message']
+USERPREFS_COLS = ['user_id', 'name', 'value']
+HOTLISTVISITHISTORY_COLS = ['hotlist_id', 'user_id', 'viewed']
+LINKEDACCOUNT_COLS = ['parent_id', 'child_id']
+LINKEDACCOUNTINVITE_COLS = ['parent_id', 'child_id']
+
+
+class UserTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for User PBs."""
+
+ def __init__(self, cache_manager, user_service):
+ super(UserTwoLevelCache, self).__init__(
+ cache_manager, 'user', 'user:', user_pb2.User,
+ max_size=settings.user_cache_max_size)
+ self.user_service = user_service
+
+ def _DeserializeUsersByID(self, user_rows, linkedaccount_rows):
+ """Convert database row tuples into User PBs.
+
+ Args:
+ user_rows: rows from the User DB table.
+ linkedaccount_rows: rows from the LinkedAccount DB table.
+
+ Returns:
+ A dict {user_id: user_pb} for all the users referenced in user_rows.
+ """
+ result_dict = {}
+
+ # Make one User PB for each row in user_rows.
+ for row in user_rows:
+ (user_id, email, is_site_admin,
+ notify_issue_change, notify_starred_issue_change,
+ email_compact_subject, email_view_widget, notify_starred_ping, banned,
+ after_issue_update, keep_people_perms_open, preview_on_hover,
+ obscure_email, last_visit_timestamp,
+ email_bounce_timestamp, vacation_message) = row
+ user = user_pb2.MakeUser(
+ user_id, email=email, obscure_email=obscure_email)
+ user.is_site_admin = bool(is_site_admin)
+ user.notify_issue_change = bool(notify_issue_change)
+ user.notify_starred_issue_change = bool(notify_starred_issue_change)
+ user.email_compact_subject = bool(email_compact_subject)
+ user.email_view_widget = bool(email_view_widget)
+ user.notify_starred_ping = bool(notify_starred_ping)
+ if banned:
+ user.banned = banned
+ if after_issue_update:
+ user.after_issue_update = user_pb2.IssueUpdateNav(
+ after_issue_update.upper())
+ user.keep_people_perms_open = bool(keep_people_perms_open)
+ user.preview_on_hover = bool(preview_on_hover)
+ user.last_visit_timestamp = last_visit_timestamp or 0
+ user.email_bounce_timestamp = email_bounce_timestamp or 0
+ if vacation_message:
+ user.vacation_message = vacation_message
+ result_dict[user_id] = user
+
+ # Put in any linked accounts.
+ for parent_id, child_id in linkedaccount_rows:
+ if parent_id in result_dict:
+ result_dict[parent_id].linked_child_ids.append(child_id)
+ if child_id in result_dict:
+ result_dict[child_id].linked_parent_id = parent_id
+
+ return result_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, retrieve User objects from the database.
+
+ Args:
+ cnxn: connection to SQL database.
+ keys: list of user IDs to retrieve.
+
+ Returns:
+ A dict {user_id: user_pb} for each user that satisfies the conditions.
+ """
+ user_rows = self.user_service.user_tbl.Select(
+ cnxn, cols=USER_COLS, user_id=keys)
+ linkedaccount_rows = self.user_service.linkedaccount_tbl.Select(
+ cnxn, cols=LINKEDACCOUNT_COLS, parent_id=keys, child_id=keys,
+ or_where_conds=True)
+ return self._DeserializeUsersByID(user_rows, linkedaccount_rows)
+
+
+class UserPrefsTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for UserPrefs PBs."""
+
+ def __init__(self, cache_manager, user_service):
+ super(UserPrefsTwoLevelCache, self).__init__(
+ cache_manager, 'user', 'userprefs:', user_pb2.UserPrefs,
+ max_size=settings.user_cache_max_size)
+ self.user_service = user_service
+
+ def _DeserializeUserPrefsByID(self, userprefs_rows):
+ """Convert database row tuples into UserPrefs PBs.
+
+ Args:
+ userprefs_rows: rows from the UserPrefs DB table.
+
+ Returns:
+ A dict {user_id: userprefs} for all the users in userprefs_rows.
+ """
+ result_dict = {}
+
+ # Make one UserPrefs PB for each row in userprefs_rows.
+ for row in userprefs_rows:
+ (user_id, name, value) = row
+ if user_id not in result_dict:
+ userprefs = user_pb2.UserPrefs(user_id=user_id)
+ result_dict[user_id] = userprefs
+ else:
+ userprefs = result_dict[user_id]
+ userprefs.prefs.append(user_pb2.UserPrefValue(name=name, value=value))
+
+ return result_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, retrieve UserPrefs objects from the database.
+
+ Args:
+ cnxn: connection to SQL database.
+ keys: list of user IDs to retrieve.
+
+ Returns:
+ A dict {user_id: userprefs} for each user.
+ """
+ userprefs_rows = self.user_service.userprefs_tbl.Select(
+ cnxn, cols=USERPREFS_COLS, user_id=keys)
+ return self._DeserializeUserPrefsByID(userprefs_rows)
+
+
+class UserService(object):
+ """The persistence layer for all user data."""
+
+ def __init__(self, cache_manager):
+ """Constructor.
+
+ Args:
+ cache_manager: local cache with distributed invalidation.
+ """
+ self.user_tbl = sql.SQLTableManager(USER_TABLE_NAME)
+ self.userprefs_tbl = sql.SQLTableManager(USERPREFS_TABLE_NAME)
+ self.hotlistvisithistory_tbl = sql.SQLTableManager(
+ HOTLISTVISITHISTORY_TABLE_NAME)
+ self.linkedaccount_tbl = sql.SQLTableManager(LINKEDACCOUNT_TABLE_NAME)
+ self.linkedaccountinvite_tbl = sql.SQLTableManager(
+ LINKEDACCOUNTINVITE_TABLE_NAME)
+
+ # Like a dictionary {user_id: email}
+ self.email_cache = caches.RamCache(cache_manager, 'user', max_size=50000)
+
+ # Like a dictionary {email: user_id}.
+ # This will never invaidate, and it doesn't need to.
+ self.user_id_cache = caches.RamCache(cache_manager, 'user', max_size=50000)
+
+ # Like a dictionary {user_id: user_pb}
+ self.user_2lc = UserTwoLevelCache(cache_manager, self)
+
+ # Like a dictionary {user_id: userprefs}
+ self.userprefs_2lc = UserPrefsTwoLevelCache(cache_manager, self)
+
+ ### Creating users
+
+ def _CreateUsers(self, cnxn, emails):
+ """Create many users in the database."""
+ emails = [email.lower() for email in emails]
+ ids = [framework_helpers.MurmurHash3_x86_32(email) for email in emails]
+ row_values = [
+ (user_id, email, not framework_bizobj.IsPriviledgedDomainUser(email))
+ for (user_id, email) in zip(ids, emails)]
+ self.user_tbl.InsertRows(
+ cnxn, ['user_id', 'email', 'obscure_email'], row_values)
+ self.user_2lc.InvalidateKeys(cnxn, ids)
+
+ ### Lookup of user ID and email address
+
+ def LookupUserEmails(self, cnxn, user_ids, ignore_missed=False):
+ """Return a dict of email addresses for the given user IDs.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_ids: list of int user IDs to look up.
+ ignore_missed: if True, does not throw NoSuchUserException, when there
+ are users not found for some user_ids.
+
+ Returns:
+ A dict {user_id: email_addr} for all the requested IDs.
+
+ Raises:
+ exceptions.NoSuchUserException: if any requested user cannot be found
+ and ignore_missed is False.
+ """
+ self.email_cache.CacheItem(framework_constants.NO_USER_SPECIFIED, '')
+ emails_dict, missed_ids = self.email_cache.GetAll(user_ids)
+ if missed_ids:
+ logging.info('got %d user emails from cache', len(emails_dict))
+ rows = self.user_tbl.Select(
+ cnxn, cols=['user_id', 'email'], user_id=missed_ids)
+ retrieved_dict = dict(rows)
+ logging.info('looked up users %r', retrieved_dict)
+ self.email_cache.CacheAll(retrieved_dict)
+ emails_dict.update(retrieved_dict)
+
+ # Check if there are any that we could not find. ID 0 means "no user".
+ nonexist_ids = [user_id for user_id in user_ids
+ if user_id and user_id not in emails_dict]
+ if nonexist_ids:
+ if ignore_missed:
+ logging.info('No email addresses found for users %r' % nonexist_ids)
+ else:
+ raise exceptions.NoSuchUserException(
+ 'No email addresses found for users %r' % nonexist_ids)
+
+ return emails_dict
+
+ def LookupUserEmail(self, cnxn, user_id):
+ """Get the email address of the given user.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_id: int user ID of the user whose email address is needed.
+
+ Returns:
+ String email address of that user or None if user_id is invalid.
+
+ Raises:
+ exceptions.NoSuchUserException: if no email address was found for that
+ user.
+ """
+ if not user_id:
+ return None
+ emails_dict = self.LookupUserEmails(cnxn, [user_id])
+ return emails_dict[user_id]
+
+ def LookupExistingUserIDs(self, cnxn, emails):
+ """Return a dict of user IDs for the given emails for users that exist.
+
+ Args:
+ cnxn: connection to SQL database.
+ emails: list of string email addresses.
+
+ Returns:
+ A dict {email_addr: user_id} for the requested emails.
+ """
+ # Look up these users in the RAM cache
+ user_id_dict, missed_emails = self.user_id_cache.GetAll(emails)
+
+ # Hit the DB to lookup any user IDs that were not cached.
+ if missed_emails:
+ rows = self.user_tbl.Select(
+ cnxn, cols=['email', 'user_id'], email=missed_emails)
+ retrieved_dict = dict(rows)
+ # Cache all the user IDs that we retrieved to make later requests faster.
+ self.user_id_cache.CacheAll(retrieved_dict)
+ user_id_dict.update(retrieved_dict)
+
+ return user_id_dict
+
+ def LookupUserIDs(self, cnxn, emails, autocreate=False,
+ allowgroups=False):
+ """Return a dict of user IDs for the given emails.
+
+ Args:
+ cnxn: connection to SQL database.
+ emails: list of string email addresses.
+ autocreate: set to True to create users that were not found.
+ allowgroups: set to True to allow non-email user name for group
+ creation.
+
+ Returns:
+ A dict {email_addr: user_id} for the requested emails.
+
+ Raises:
+ exceptions.NoSuchUserException: if some users were not found and
+ autocreate is False.
+ """
+ # Skip any addresses that look like "--" or are empty,
+ # because that means "no user".
+ # Also, make sure all email addresses are lower case.
+ needed_emails = [email.lower() for email in emails
+ if email
+ and not framework_constants.NO_VALUE_RE.match(email)]
+
+ # Look up these users in the RAM cache
+ user_id_dict = self.LookupExistingUserIDs(cnxn, needed_emails)
+ if len(needed_emails) == len(user_id_dict):
+ return user_id_dict
+
+ # If any were not found in the DB, create them or raise an exception.
+ nonexist_emails = [email for email in needed_emails
+ if email not in user_id_dict]
+ logging.info('nonexist_emails: %r, autocreate is %r',
+ nonexist_emails, autocreate)
+ if not autocreate:
+ raise exceptions.NoSuchUserException('%r' % nonexist_emails)
+
+ if not allowgroups:
+ # Only create accounts for valid email addresses.
+ nonexist_emails = [email for email in nonexist_emails
+ if validate.IsValidEmail(email)]
+ if not nonexist_emails:
+ return user_id_dict
+
+ self._CreateUsers(cnxn, nonexist_emails)
+ created_rows = self.user_tbl.Select(
+ cnxn, cols=['email', 'user_id'], email=nonexist_emails)
+ created_dict = dict(created_rows)
+ # Cache all the user IDs that we retrieved to make later requests faster.
+ self.user_id_cache.CacheAll(created_dict)
+ user_id_dict.update(created_dict)
+
+ logging.info('looked up User IDs %r', user_id_dict)
+ return user_id_dict
+
+ def LookupUserID(self, cnxn, email, autocreate=False, allowgroups=False):
+ """Get one user ID for the given email address.
+
+ Args:
+ cnxn: connection to SQL database.
+ email: string email address of the user to look up.
+ autocreate: set to True to create users that were not found.
+ allowgroups: set to True to allow non-email user name for group
+ creation.
+
+ Returns:
+ The int user ID of the specified user.
+
+ Raises:
+ exceptions.NoSuchUserException if the user was not found and autocreate
+ is False.
+ """
+ email = email.lower()
+ email_dict = self.LookupUserIDs(
+ cnxn, [email], autocreate=autocreate, allowgroups=allowgroups)
+ if email not in email_dict:
+ raise exceptions.NoSuchUserException('%r not found' % email)
+ return email_dict[email]
+
+ ### Retrieval of user objects: with preferences and cues
+
+ def GetUsersByIDs(self, cnxn, user_ids, use_cache=True, skip_missed=False):
+ """Return a dictionary of retrieved User PBs.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_ids: list of user IDs to fetch.
+ use_cache: set to False to ignore cache and force DB lookup.
+ skip_missed: set to True if default User objects for missed_ids should
+ not be created.
+
+ Returns:
+ A dict {user_id: user_pb} for each specified user ID. For any user ID
+ that is not fount in the DB, a default User PB is created on-the-fly.
+ """
+ # Check the RAM cache and memcache, as appropriate.
+ result_dict, missed_ids = self.user_2lc.GetAll(
+ cnxn, user_ids, use_cache=use_cache)
+
+ # TODO(crbug/monorail/7367): Never create default values for missed_ids
+ # once we remove all code paths that hit this. See bug for more info.
+ # Any new code that calls this method, should not rely on this
+ # functionality.
+ if missed_ids and not skip_missed:
+ # Provide default values for any user ID that was not found.
+ result_dict.update(
+ (user_id, user_pb2.MakeUser(user_id)) for user_id in missed_ids)
+
+ return result_dict
+
+ def GetUser(self, cnxn, user_id):
+ """Load the specified user from the user details table."""
+ return self.GetUsersByIDs(cnxn, [user_id])[user_id]
+
+ ### Updating user objects
+
+ def UpdateUser(self, cnxn, user_id, user):
+ """Store a user PB in the database.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_id: int user ID of the user to update.
+ user: User PB to store.
+
+ Returns:
+ Nothing.
+ """
+ if not user_id:
+ raise exceptions.NoSuchUserException('Cannot update anonymous user')
+
+ delta = {
+ 'is_site_admin': user.is_site_admin,
+ 'notify_issue_change': user.notify_issue_change,
+ 'notify_starred_issue_change': user.notify_starred_issue_change,
+ 'email_compact_subject': user.email_compact_subject,
+ 'email_view_widget': user.email_view_widget,
+ 'notify_starred_ping': user.notify_starred_ping,
+ 'banned': user.banned,
+ 'after_issue_update': str(user.after_issue_update or 'UP_TO_LIST'),
+ 'keep_people_perms_open': user.keep_people_perms_open,
+ 'preview_on_hover': user.preview_on_hover,
+ 'obscure_email': user.obscure_email,
+ 'last_visit_timestamp': user.last_visit_timestamp,
+ 'email_bounce_timestamp': user.email_bounce_timestamp,
+ 'vacation_message': user.vacation_message,
+ }
+ # Start sending UPDATE statements, but don't COMMIT until the end.
+ self.user_tbl.Update(cnxn, delta, user_id=user_id, commit=False)
+
+ cnxn.Commit()
+ self.user_2lc.InvalidateKeys(cnxn, [user_id])
+
+ def UpdateUserBan(
+ self, cnxn, user_id, user,
+ is_banned=None, banned_reason=None):
+ if is_banned is not None:
+ if is_banned:
+ user.banned = banned_reason or 'No reason given'
+ else:
+ user.reset('banned')
+
+ # Write the user settings to the database.
+ self.UpdateUser(cnxn, user_id, user)
+
+ def GetRecentlyVisitedHotlists(self, cnxn, user_id):
+ recent_hotlist_rows = self.hotlistvisithistory_tbl.Select(
+ cnxn, cols=['hotlist_id'], user_id=[user_id],
+ order_by=[('viewed DESC', [])], limit=10)
+ return [row[0] for row in recent_hotlist_rows]
+
+ def AddVisitedHotlist(self, cnxn, user_id, hotlist_id, commit=True):
+ self.hotlistvisithistory_tbl.Delete(
+ cnxn, hotlist_id=hotlist_id, user_id=user_id, commit=False)
+ self.hotlistvisithistory_tbl.InsertRows(
+ cnxn, HOTLISTVISITHISTORY_COLS,
+ [(hotlist_id, user_id, int(time.time()))],
+ commit=commit)
+
+ def ExpungeHotlistsFromHistory(self, cnxn, hotlist_ids, commit=True):
+ self.hotlistvisithistory_tbl.Delete(
+ cnxn, hotlist_id=hotlist_ids, commit=commit)
+
+ def ExpungeUsersHotlistsHistory(self, cnxn, user_ids, commit=True):
+ self.hotlistvisithistory_tbl.Delete(cnxn, user_id=user_ids, commit=commit)
+
+ def TrimUserVisitedHotlists(self, cnxn, commit=True):
+ """For any user who has visited more than 10 hotlists, trim history."""
+ user_id_rows = self.hotlistvisithistory_tbl.Select(
+ cnxn, cols=['user_id'], group_by=['user_id'],
+ having=[('COUNT(*) > %s', [10])], limit=1000)
+
+ for user_id in [row[0] for row in user_id_rows]:
+ viewed_hotlist_rows = self.hotlistvisithistory_tbl.Select(
+ cnxn,
+ cols=['viewed'],
+ user_id=user_id,
+ order_by=[('viewed DESC', [])])
+ if len(viewed_hotlist_rows) > 10:
+ cut_off_date = viewed_hotlist_rows[9][0]
+ self.hotlistvisithistory_tbl.Delete(
+ cnxn,
+ user_id=user_id,
+ where=[('viewed < %s', [cut_off_date])],
+ commit=commit)
+
+ ### Linked account invites
+
+ def GetPendingLinkedInvites(self, cnxn, user_id):
+ """Return lists of accounts that have invited this account."""
+ if not user_id:
+ return [], []
+ invite_rows = self.linkedaccountinvite_tbl.Select(
+ cnxn, cols=LINKEDACCOUNTINVITE_COLS, parent_id=user_id,
+ child_id=user_id, or_where_conds=True)
+ invite_as_parent = [row[1] for row in invite_rows
+ if row[0] == user_id]
+ invite_as_child = [row[0] for row in invite_rows
+ if row[1] == user_id]
+ return invite_as_parent, invite_as_child
+
+ def _AssertNotAlreadyLinked(self, cnxn, parent_id, child_id):
+ """Check constraints on our linked account graph."""
+ # Our linked account graph should be no more than one level deep.
+ parent_is_already_a_child = self.linkedaccount_tbl.Select(
+ cnxn, cols=LINKEDACCOUNT_COLS, child_id=parent_id)
+ if parent_is_already_a_child:
+ raise exceptions.InputException('Parent account is already a child')
+ child_is_already_a_parent = self.linkedaccount_tbl.Select(
+ cnxn, cols=LINKEDACCOUNT_COLS, parent_id=child_id)
+ if child_is_already_a_parent:
+ raise exceptions.InputException('Child account is already a parent')
+
+ # A child account can only be linked to one parent.
+ child_is_already_a_child = self.linkedaccount_tbl.Select(
+ cnxn, cols=LINKEDACCOUNT_COLS, child_id=child_id)
+ if child_is_already_a_child:
+ raise exceptions.InputException('Child account is already linked')
+
+ def InviteLinkedParent(self, cnxn, parent_id, child_id):
+ """Child stores an invite for the proposed parent user to consider."""
+ if not parent_id:
+ raise exceptions.InputException('Parent account is missing')
+ if not child_id:
+ raise exceptions.InputException('Child account is missing')
+ self._AssertNotAlreadyLinked(cnxn, parent_id, child_id)
+ self.linkedaccountinvite_tbl.InsertRow(
+ cnxn, parent_id=parent_id, child_id=child_id)
+
+ def AcceptLinkedChild(self, cnxn, parent_id, child_id):
+ """Parent accepts an invite from a child account."""
+ if not parent_id:
+ raise exceptions.InputException('Parent account is missing')
+ if not child_id:
+ raise exceptions.InputException('Child account is missing')
+ # Check that the child has previously created an invite for this parent.
+ invite_rows = self.linkedaccountinvite_tbl.Select(
+ cnxn, cols=LINKEDACCOUNTINVITE_COLS,
+ parent_id=parent_id, child_id=child_id)
+ if not invite_rows:
+ raise exceptions.InputException('No such invite')
+
+ self._AssertNotAlreadyLinked(cnxn, parent_id, child_id)
+
+ self.linkedaccount_tbl.InsertRow(
+ cnxn, parent_id=parent_id, child_id=child_id)
+ self.linkedaccountinvite_tbl.Delete(
+ cnxn, parent_id=parent_id, child_id=child_id)
+ self.user_2lc.InvalidateKeys(cnxn, [parent_id, child_id])
+
+ def UnlinkAccounts(self, cnxn, parent_id, child_id):
+ """Delete a linked-account relationship."""
+ if not parent_id:
+ raise exceptions.InputException('Parent account is missing')
+ if not child_id:
+ raise exceptions.InputException('Child account is missing')
+ self.linkedaccount_tbl.Delete(
+ cnxn, parent_id=parent_id, child_id=child_id)
+ self.user_2lc.InvalidateKeys(cnxn, [parent_id, child_id])
+
+ ### User settings
+ # Settings are details about a user account that are usually needed
+ # every time that user is displayed to another user.
+
+ # TODO(jrobbins): Move most of these into UserPrefs.
+ def UpdateUserSettings(
+ self, cnxn, user_id, user, notify=None, notify_starred=None,
+ email_compact_subject=None, email_view_widget=None,
+ notify_starred_ping=None, obscure_email=None, after_issue_update=None,
+ is_site_admin=None, is_banned=None, banned_reason=None,
+ keep_people_perms_open=None, preview_on_hover=None,
+ vacation_message=None):
+ """Update the preferences of the specified user.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_id: int user ID of the user whose settings we are updating.
+ user: User PB of user before changes are applied.
+ keyword args: dictionary of setting names mapped to new values.
+
+ Returns:
+ The user's new User PB.
+ """
+ # notifications
+ if notify is not None:
+ user.notify_issue_change = notify
+ if notify_starred is not None:
+ user.notify_starred_issue_change = notify_starred
+ if notify_starred_ping is not None:
+ user.notify_starred_ping = notify_starred_ping
+ if email_compact_subject is not None:
+ user.email_compact_subject = email_compact_subject
+ if email_view_widget is not None:
+ user.email_view_widget = email_view_widget
+
+ # display options
+ if after_issue_update is not None:
+ user.after_issue_update = user_pb2.IssueUpdateNav(after_issue_update)
+ if preview_on_hover is not None:
+ user.preview_on_hover = preview_on_hover
+ if keep_people_perms_open is not None:
+ user.keep_people_perms_open = keep_people_perms_open
+
+ # misc
+ if obscure_email is not None:
+ user.obscure_email = obscure_email
+
+ # admin
+ if is_site_admin is not None:
+ user.is_site_admin = is_site_admin
+ if is_banned is not None:
+ if is_banned:
+ user.banned = banned_reason or 'No reason given'
+ else:
+ user.reset('banned')
+
+ # user availability
+ if vacation_message is not None:
+ user.vacation_message = vacation_message
+
+ # Write the user settings to the database.
+ self.UpdateUser(cnxn, user_id, user)
+
+ ### User preferences
+ # These are separate from settings in the User objects because they are
+ # only needed for the currently signed in user.
+
+ def GetUsersPrefs(self, cnxn, user_ids, use_cache=True):
+ """Return {user_id: userprefs} for the requested user IDs."""
+ prefs_dict, misses = self.userprefs_2lc.GetAll(
+ cnxn, user_ids, use_cache=use_cache)
+ # Make sure that every user is represented in the result.
+ for user_id in misses:
+ prefs_dict[user_id] = user_pb2.UserPrefs(user_id=user_id)
+ return prefs_dict
+
+ def GetUserPrefs(self, cnxn, user_id, use_cache=True):
+ """Return a UserPrefs PB for the requested user ID."""
+ prefs_dict = self.GetUsersPrefs(cnxn, [user_id], use_cache=use_cache)
+ return prefs_dict[user_id]
+
+ def GetUserPrefsByEmail(self, cnxn, email, use_cache=True):
+ """Return a UserPrefs PB for the requested email, or an empty UserPrefs."""
+ try:
+ user_id = self.LookupUserID(cnxn, email)
+ user_prefs = self.GetUserPrefs(cnxn, user_id, use_cache=use_cache)
+ except exceptions.NoSuchUserException:
+ user_prefs = user_pb2.UserPrefs()
+ return user_prefs
+
+ def SetUserPrefs(self, cnxn, user_id, pref_values):
+ """Store the given list of UserPrefValues."""
+ userprefs_rows = [(user_id, upv.name, upv.value) for upv in pref_values]
+ self.userprefs_tbl.InsertRows(
+ cnxn, USERPREFS_COLS, userprefs_rows, replace=True)
+ self.userprefs_2lc.InvalidateKeys(cnxn, [user_id])
+
+ ### Expunge all User Data from DB
+
+ def ExpungeUsers(self, cnxn, user_ids):
+ """Completely wipes user data from User DB tables for given users.
+
+ This method will not commit the operation. This method will not make
+ changes to in-memory data.
+ NOTE: This method ends with an operation that deletes user rows. If
+ appropriate methods that remove references to the User table rows are
+ not called before, the commit will fail. See work_env.ExpungeUsers
+ for more info.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_ids: list of user_ids for users we want to delete.
+ """
+ self.linkedaccount_tbl.Delete(cnxn, parent_id=user_ids, commit=False)
+ self.linkedaccount_tbl.Delete(cnxn, child_id=user_ids, commit=False)
+ self.linkedaccountinvite_tbl.Delete(cnxn, parent_id=user_ids, commit=False)
+ self.linkedaccountinvite_tbl.Delete(cnxn, child_id=user_ids, commit=False)
+ self.userprefs_tbl.Delete(cnxn, user_id=user_ids, commit=False)
+ self.user_tbl.Delete(cnxn, user_id=user_ids, commit=False)
+
+ def TotalUsersCount(self, cnxn):
+ """Returns the total number of rows in the User table.
+
+ The placeholder User reserved for representing deleted users within Monorail
+ will not be counted.
+ """
+ # Subtract one so we don't count the deleted user with
+ # with user_id = framework_constants.DELETED_USER_ID
+ return (self.user_tbl.SelectValue(cnxn, col='COUNT(*)')) - 1
+
+ def GetAllUserEmailsBatch(self, cnxn, limit=1000, offset=0):
+ """Returns a list of user emails.
+
+ This method can be used for listing all user emails in Monorail's DB.
+ The list will contain at most [limit] emails, and be ordered by
+ user_id. The list will start at the given offset value. The email for
+ the placeholder User reserved for representing deleted users within
+ Monorail will never be returned.
+
+ Args:
+ cnxn: connection to SQL database.
+ limit: limit on the number of emails returned, defaults to 1000.
+ offset: starting index of the list, defaults to 0.
+
+ """
+ rows = self.user_tbl.Select(
+ cnxn, cols=['email'],
+ limit=limit,
+ offset=offset,
+ where=[('user_id != %s', [framework_constants.DELETED_USER_ID])],
+ order_by=[('user_id ASC', [])])
+ return [row[0] for row in rows]
diff --git a/services/usergroup_svc.py b/services/usergroup_svc.py
new file mode 100644
index 0000000..72797fc
--- /dev/null
+++ b/services/usergroup_svc.py
@@ -0,0 +1,616 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Persistence class for user groups.
+
+User groups are represented in the database by:
+- A row in the Users table giving an email address and user ID.
+ (A "group ID" is the user_id of the group in the User table.)
+- A row in the UserGroupSettings table giving user group settings.
+
+Membership of a user X in user group Y is represented as:
+- A row in the UserGroup table with user_id=X and group_id=Y.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import collections
+import logging
+import re
+
+from framework import exceptions
+from framework import permissions
+from framework import sql
+from proto import usergroup_pb2
+from services import caches
+
+
+USERGROUP_TABLE_NAME = 'UserGroup'
+USERGROUPSETTINGS_TABLE_NAME = 'UserGroupSettings'
+USERGROUPPROJECTS_TABLE_NAME = 'Group2Project'
+
+USERGROUP_COLS = ['user_id', 'group_id', 'role']
+USERGROUPSETTINGS_COLS = ['group_id', 'who_can_view_members',
+ 'external_group_type', 'last_sync_time',
+ 'notify_members', 'notify_group']
+USERGROUPPROJECTS_COLS = ['group_id', 'project_id']
+
+GROUP_TYPE_ENUM = (
+ 'chrome_infra_auth', 'mdb', 'baggins', 'computed')
+
+
+class MembershipTwoLevelCache(caches.AbstractTwoLevelCache):
+ """Class to manage RAM and memcache for each user's memberships."""
+
+ def __init__(self, cache_manager, usergroup_service, group_dag):
+ super(MembershipTwoLevelCache, self).__init__(
+ cache_manager, 'user', 'memberships:', None)
+ self.usergroup_service = usergroup_service
+ self.group_dag = group_dag
+
+ def _DeserializeMemberships(self, memberships_rows):
+ """Reserialize the DB results into a {user_id: {group_id}}."""
+ result_dict = collections.defaultdict(set)
+ for user_id, group_id in memberships_rows:
+ result_dict[user_id].add(group_id)
+
+ return result_dict
+
+ def FetchItems(self, cnxn, keys):
+ """On RAM and memcache miss, hit the database to get memberships."""
+ direct_memberships_rows = self.usergroup_service.usergroup_tbl.Select(
+ cnxn, cols=['user_id', 'group_id'], distinct=True,
+ user_id=keys)
+ memberships_set = set()
+ self.group_dag.MarkObsolete()
+ logging.info('Rebuild group dag on RAM and memcache miss')
+ for c_id, p_id in direct_memberships_rows:
+ all_parents = self.group_dag.GetAllAncestors(cnxn, p_id, True)
+ all_parents.append(p_id)
+ memberships_set.update([(c_id, g_id) for g_id in all_parents])
+ retrieved_dict = self._DeserializeMemberships(list(memberships_set))
+
+ # Make sure that every requested user is in the result, and gets cached.
+ retrieved_dict.update(
+ (user_id, set()) for user_id in keys
+ if user_id not in retrieved_dict)
+ return retrieved_dict
+
+
+class UserGroupService(object):
+ """The persistence layer for user group data."""
+
+ def __init__(self, cache_manager):
+ """Initialize this service so that it is ready to use.
+
+ Args:
+ cache_manager: local cache with distributed invalidation.
+ """
+ self.usergroup_tbl = sql.SQLTableManager(USERGROUP_TABLE_NAME)
+ self.usergroupsettings_tbl = sql.SQLTableManager(
+ USERGROUPSETTINGS_TABLE_NAME)
+ self.usergroupprojects_tbl = sql.SQLTableManager(
+ USERGROUPPROJECTS_TABLE_NAME)
+
+ self.group_dag = UserGroupDAG(self)
+
+ # Like a dictionary {user_id: {group_id}}
+ self.memberships_2lc = MembershipTwoLevelCache(
+ cache_manager, self, self.group_dag)
+ # Like a dictionary {group_email: [group_id]}
+ self.group_id_cache = caches.ValueCentricRamCache(
+ cache_manager, 'usergroup')
+
+ ### Group creation
+
+ def CreateGroup(self, cnxn, services, group_name, who_can_view_members,
+ ext_group_type=None, friend_projects=None):
+ """Create a new user group.
+
+ Args:
+ cnxn: connection to SQL database.
+ services: connections to backend services.
+ group_name: string email address of the group to create.
+ who_can_view_members: 'owners', 'members', or 'anyone'.
+ ext_group_type: The type of external group to import.
+ friend_projects: The project ids declared as group friends to view its
+ members.
+
+ Returns:
+ int group_id of the new group.
+ """
+ friend_projects = friend_projects or []
+ assert who_can_view_members in ('owners', 'members', 'anyone')
+ if ext_group_type:
+ ext_group_type = str(ext_group_type).lower()
+ assert ext_group_type in GROUP_TYPE_ENUM, ext_group_type
+ assert who_can_view_members == 'owners'
+ group_id = services.user.LookupUserID(
+ cnxn, group_name.lower(), autocreate=True, allowgroups=True)
+ group_settings = usergroup_pb2.MakeSettings(
+ who_can_view_members, ext_group_type, 0, friend_projects)
+ self.UpdateSettings(cnxn, group_id, group_settings)
+ self.group_id_cache.InvalidateAll(cnxn)
+ return group_id
+
+ def DeleteGroups(self, cnxn, group_ids):
+ """Delete groups' members and settings. It will NOT delete user entries.
+
+ Args:
+ cnxn: connection to SQL database.
+ group_ids: list of group ids to delete.
+ """
+ member_ids_dict, owner_ids_dict = self.LookupMembers(cnxn, group_ids)
+ citizens_id_dict = collections.defaultdict(list)
+ for g_id, user_ids in member_ids_dict.items():
+ citizens_id_dict[g_id].extend(user_ids)
+ for g_id, user_ids in owner_ids_dict.items():
+ citizens_id_dict[g_id].extend(user_ids)
+ for g_id, citizen_ids in citizens_id_dict.items():
+ logging.info('Deleting group %d', g_id)
+ # Remove group members, friend projects and settings
+ self.RemoveMembers(cnxn, g_id, citizen_ids)
+ self.usergroupprojects_tbl.Delete(cnxn, group_id=g_id)
+ self.usergroupsettings_tbl.Delete(cnxn, group_id=g_id)
+ self.group_id_cache.InvalidateAll(cnxn)
+
+ def DetermineWhichUserIDsAreGroups(self, cnxn, user_ids):
+ """From a list of user IDs, identify potential user groups.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_ids: list of user IDs to examine.
+
+ Returns:
+ A list with a subset of the given user IDs that are user groups
+ rather than individual users.
+ """
+ # It is a group if there is any entry in the UserGroupSettings table.
+ group_id_rows = self.usergroupsettings_tbl.Select(
+ cnxn, cols=['group_id'], group_id=user_ids)
+ group_ids = [row[0] for row in group_id_rows]
+ return group_ids
+
+ ### User memberships in groups
+
+ def LookupComputedMemberships(self, cnxn, domain, use_cache=True):
+ """Look up the computed group memberships of a list of users.
+
+ Args:
+ cnxn: connection to SQL database.
+ domain: string with domain part of user's email address.
+ use_cache: set to False to ignore cached values.
+
+ Returns:
+ A list [group_id] of computed user groups that match the user.
+ For now, the length of this list will always be zero or one.
+ """
+ group_email = 'everyone@%s' % domain
+ group_id = self.LookupUserGroupID(cnxn, group_email, use_cache=use_cache)
+ if group_id:
+ return [group_id]
+
+ return []
+
+ def LookupUserGroupID(self, cnxn, group_email, use_cache=True):
+ """Lookup the group ID for the given user group email address.
+
+ Args:
+ cnxn: connection to SQL database.
+ group_email: string that identies the user group.
+ use_cache: set to False to ignore cached values.
+
+ Returns:
+ Int group_id if found, otherwise None.
+ """
+ if use_cache and self.group_id_cache.HasItem(group_email):
+ return self.group_id_cache.GetItem(group_email)
+
+ rows = self.usergroupsettings_tbl.Select(
+ cnxn, cols=['email', 'group_id'],
+ left_joins=[('User ON UserGroupSettings.group_id = User.user_id', [])],
+ email=group_email,
+ where=[('group_id IS NOT NULL', [])])
+ retrieved_dict = dict(rows)
+ # Cache a "not found" value for emails that are not user groups.
+ if group_email not in retrieved_dict:
+ retrieved_dict[group_email] = None
+ self.group_id_cache.CacheAll(retrieved_dict)
+
+ return retrieved_dict.get(group_email)
+
+ def LookupAllMemberships(self, cnxn, user_ids, use_cache=True):
+ """Lookup all the group memberships of a list of users.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_ids: list of int user IDs to get memberships for.
+ use_cache: set to False to ignore cached values.
+
+ Returns:
+ A dict {user_id: {group_id}} for the given user_ids.
+ """
+ result_dict, missed_ids = self.memberships_2lc.GetAll(
+ cnxn, user_ids, use_cache=use_cache)
+ assert not missed_ids
+ return result_dict
+
+ def LookupMemberships(self, cnxn, user_id):
+ """Return a set of group_ids that this user is a member of."""
+ membership_dict = self.LookupAllMemberships(cnxn, [user_id])
+ return membership_dict[user_id]
+
+ ### Group member addition, removal, and retrieval
+
+ def RemoveMembers(self, cnxn, group_id, old_member_ids):
+ """Remove the given members/owners from the user group."""
+ self.usergroup_tbl.Delete(
+ cnxn, group_id=group_id, user_id=old_member_ids)
+
+ all_affected = self._GetAllMembersInList(cnxn, old_member_ids)
+
+ self.group_dag.MarkObsolete()
+ self.memberships_2lc.InvalidateAllKeys(cnxn, all_affected)
+
+ def UpdateMembers(self, cnxn, group_id, member_ids, new_role):
+ """Update role for given members/owners to the user group."""
+ # Circle detection
+ for mid in member_ids:
+ if self.group_dag.IsChild(cnxn, group_id, mid):
+ raise exceptions.CircularGroupException(
+ '%s is already an ancestor of group %s.' % (mid, group_id))
+
+ self.usergroup_tbl.Delete(
+ cnxn, group_id=group_id, user_id=member_ids)
+ rows = [(member_id, group_id, new_role) for member_id in member_ids]
+ self.usergroup_tbl.InsertRows(
+ cnxn, ['user_id', 'group_id', 'role'], rows)
+
+ all_affected = self._GetAllMembersInList(cnxn, member_ids)
+
+ self.group_dag.MarkObsolete()
+ self.memberships_2lc.InvalidateAllKeys(cnxn, all_affected)
+
+ def _GetAllMembersInList(self, cnxn, group_ids):
+ """Get all direct/indirect members/owners in a list."""
+ children_member_ids, children_owner_ids = self.LookupAllMembers(
+ cnxn, group_ids)
+ all_members_owners = set()
+ all_members_owners.update(group_ids)
+ for users in children_member_ids.values():
+ all_members_owners.update(users)
+ for users in children_owner_ids.values():
+ all_members_owners.update(users)
+ return list(all_members_owners)
+
+ def LookupAllMembers(self, cnxn, group_ids):
+ """Retrieve user IDs of members/owners of any of the given groups
+ transitively."""
+ member_ids_dict = {}
+ owner_ids_dict = {}
+ if not group_ids:
+ return member_ids_dict, owner_ids_dict
+ direct_member_rows = self.usergroup_tbl.Select(
+ cnxn, cols=['user_id', 'group_id', 'role'], distinct=True,
+ group_id=group_ids)
+ for gid in group_ids:
+ all_descendants = self.group_dag.GetAllDescendants(cnxn, gid, True)
+ indirect_member_rows = []
+ if all_descendants:
+ indirect_member_rows = self.usergroup_tbl.Select(
+ cnxn, cols=['user_id'], distinct=True,
+ group_id=all_descendants)
+
+ # Owners must have direct membership. All indirect users are members.
+ owner_ids_dict[gid] = [m[0] for m in direct_member_rows
+ if m[1] == gid and m[2] == 'owner']
+ member_ids_list = [r[0] for r in indirect_member_rows]
+ member_ids_list.extend([m[0] for m in direct_member_rows
+ if m[1] == gid and m[2] == 'member'])
+ member_ids_dict[gid] = list(set(member_ids_list))
+ return member_ids_dict, owner_ids_dict
+
+ def LookupMembers(self, cnxn, group_ids):
+ """"Retrieve user IDs of direct members/owners of any of the given groups.
+
+ Args:
+ cnxn: connection to SQL database.
+ group_ids: list of int user IDs for all user groups to be examined.
+
+ Returns:
+ A dict of member IDs, and a dict of owner IDs keyed by group id.
+ """
+ member_ids_dict = {}
+ owner_ids_dict = {}
+ if not group_ids:
+ return member_ids_dict, owner_ids_dict
+ member_rows = self.usergroup_tbl.Select(
+ cnxn, cols=['user_id', 'group_id', 'role'], distinct=True,
+ group_id=group_ids)
+ for gid in group_ids:
+ member_ids_dict[gid] = [row[0] for row in member_rows
+ if row[1] == gid and row[2] == 'member']
+ owner_ids_dict[gid] = [row[0] for row in member_rows
+ if row[1] == gid and row[2] == 'owner']
+ return member_ids_dict, owner_ids_dict
+
+ def ExpandAnyGroupEmailRecipients(self, cnxn, user_ids):
+ """Expand the list with members that are part of a group configured
+ to have notifications sent directly to members. Remove any groups
+ not configured to have notifications sent directly to the group.
+
+ Args:
+ cnxn: connection to SQL database.
+ user_ids: list of user IDs to check.
+
+ Returns:
+ A paire (individual user_ids, transitive_ids). individual_user_ids
+ is a list of user IDs that were in the given user_ids list and
+ that identify individual members or a group that has
+ settings.notify_group set to True. transitive_ids is a list of
+ user IDs of members of any user group in user_ids with
+ settings.notify_members set to True.
+ """
+ group_ids = self.DetermineWhichUserIDsAreGroups(cnxn, user_ids)
+ group_settings_dict = self.GetAllGroupSettings(cnxn, group_ids)
+ member_ids_dict, owner_ids_dict = self.LookupAllMembers(cnxn, group_ids)
+ indirect_ids = set()
+ direct_ids = {uid for uid in user_ids if uid not in group_ids}
+ for gid, settings in group_settings_dict.items():
+ if settings.notify_members:
+ indirect_ids.update(member_ids_dict.get(gid, set()))
+ indirect_ids.update(owner_ids_dict.get(gid, set()))
+ if settings.notify_group:
+ direct_ids.add(gid)
+
+ return list(direct_ids), list(indirect_ids)
+
+ def LookupVisibleMembers(
+ self, cnxn, group_id_list, perms, effective_ids, services):
+ """"Retrieve the list of user group direct member/owner IDs that the user
+ may see.
+
+ Args:
+ cnxn: connection to SQL database.
+ group_id_list: list of int user IDs for all user groups to be examined.
+ perms: optional PermissionSet for the user viewing this page.
+ effective_ids: set of int user IDs for that user and all
+ their group memberships.
+ services: backend services.
+
+ Returns:
+ A list of all the member IDs from any group that the user is allowed
+ to view.
+ """
+ settings_dict = self.GetAllGroupSettings(cnxn, group_id_list)
+ group_ids = list(settings_dict.keys())
+ (owned_project_ids, membered_project_ids,
+ contrib_project_ids) = services.project.GetUserRolesInAllProjects(
+ cnxn, effective_ids)
+ project_ids = owned_project_ids.union(
+ membered_project_ids).union(contrib_project_ids)
+ # We need to fetch all members/owners to determine whether the requester
+ # has permission to view.
+ direct_member_ids_dict, direct_owner_ids_dict = self.LookupMembers(
+ cnxn, group_ids)
+ all_member_ids_dict, all_owner_ids_dict = self.LookupAllMembers(
+ cnxn, group_ids)
+ visible_member_ids = {}
+ visible_owner_ids = {}
+ for gid in group_ids:
+ member_ids = all_member_ids_dict[gid]
+ owner_ids = all_owner_ids_dict[gid]
+
+ if permissions.CanViewGroupMembers(
+ perms, effective_ids, settings_dict[gid], member_ids, owner_ids,
+ project_ids):
+ visible_member_ids[gid] = direct_member_ids_dict[gid]
+ visible_owner_ids[gid] = direct_owner_ids_dict[gid]
+
+ return visible_member_ids, visible_owner_ids
+
+ ### Group settings
+
+ def GetAllUserGroupsInfo(self, cnxn):
+ """Fetch (addr, member_count, usergroup_settings) for all user groups."""
+ group_rows = self.usergroupsettings_tbl.Select(
+ cnxn, cols=['email'] + USERGROUPSETTINGS_COLS,
+ left_joins=[('User ON UserGroupSettings.group_id = User.user_id', [])])
+ count_rows = self.usergroup_tbl.Select(
+ cnxn, cols=['group_id', 'COUNT(*)'],
+ group_by=['group_id'])
+ count_dict = dict(count_rows)
+
+ group_ids = [g[1] for g in group_rows]
+ friends_dict = self.GetAllGroupFriendProjects(cnxn, group_ids)
+
+ user_group_info_tuples = [
+ (email, count_dict.get(group_id, 0),
+ usergroup_pb2.MakeSettings(visiblity, group_type, last_sync_time,
+ friends_dict.get(group_id, []),
+ bool(notify_members), bool(notify_group)),
+ group_id)
+ for (email, group_id, visiblity, group_type, last_sync_time,
+ notify_members, notify_group) in group_rows]
+ return user_group_info_tuples
+
+ def GetAllGroupSettings(self, cnxn, group_ids):
+ """Fetch {group_id: group_settings} for the specified groups."""
+ # TODO(jrobbins): add settings to control who can join, etc.
+ rows = self.usergroupsettings_tbl.Select(
+ cnxn, cols=USERGROUPSETTINGS_COLS, group_id=group_ids)
+ friends_dict = self.GetAllGroupFriendProjects(cnxn, group_ids)
+ settings_dict = {
+ group_id: usergroup_pb2.MakeSettings(
+ vis, group_type, last_sync_time, friends_dict.get(group_id, []),
+ notify_members=bool(notify_members),
+ notify_group=bool(notify_group))
+ for (group_id, vis, group_type, last_sync_time,
+ notify_members, notify_group) in rows}
+ return settings_dict
+
+ def GetGroupSettings(self, cnxn, group_id):
+ """Retrieve group settings for the specified user group.
+
+ Args:
+ cnxn: connection to SQL database.
+ group_id: int user ID of the user group.
+
+ Returns:
+ A UserGroupSettings object, or None if no such group exists.
+ """
+ return self.GetAllGroupSettings(cnxn, [group_id]).get(group_id)
+
+ def UpdateSettings(self, cnxn, group_id, group_settings):
+ """Update the visiblity settings of the specified group."""
+ who_can_view_members = str(group_settings.who_can_view_members).lower()
+ ext_group_type = group_settings.ext_group_type
+ assert who_can_view_members in ('owners', 'members', 'anyone')
+ if ext_group_type:
+ ext_group_type = str(group_settings.ext_group_type).lower()
+ assert ext_group_type in GROUP_TYPE_ENUM, ext_group_type
+ assert who_can_view_members == 'owners'
+ self.usergroupsettings_tbl.InsertRow(
+ cnxn, group_id=group_id, who_can_view_members=who_can_view_members,
+ external_group_type=ext_group_type,
+ last_sync_time=group_settings.last_sync_time,
+ notify_members=group_settings.notify_members,
+ notify_group=group_settings.notify_group,
+ replace=True)
+ self.usergroupprojects_tbl.Delete(
+ cnxn, group_id=group_id)
+ if group_settings.friend_projects:
+ rows = [(group_id, p_id) for p_id in group_settings.friend_projects]
+ self.usergroupprojects_tbl.InsertRows(
+ cnxn, ['group_id', 'project_id'], rows)
+
+ def GetAllGroupFriendProjects(self, cnxn, group_ids):
+ """Get {group_id: [project_ids]} for the specified user groups."""
+ rows = self.usergroupprojects_tbl.Select(
+ cnxn, cols=USERGROUPPROJECTS_COLS, group_id=group_ids)
+ friends_dict = {}
+ for group_id, project_id in rows:
+ friends_dict.setdefault(group_id, []).append(project_id)
+ return friends_dict
+
+ def GetGroupFriendProjects(self, cnxn, group_id):
+ """Get a list of friend projects for the specified user group."""
+ return self.GetAllGroupFriendProjects(cnxn, [group_id]).get(group_id)
+
+ def ValidateFriendProjects(self, cnxn, services, friend_projects):
+ """Validate friend projects.
+
+ Returns:
+ A list of project ids if no errors, or an error message.
+ """
+ project_names = list(filter(None, re.split('; |, | |;|,', friend_projects)))
+ id_dict = services.project.LookupProjectIDs(cnxn, project_names)
+ missed_projects = []
+ result = []
+ for p_name in project_names:
+ if p_name in id_dict:
+ result.append(id_dict[p_name])
+ else:
+ missed_projects.append(p_name)
+ error_msg = ''
+ if missed_projects:
+ error_msg = 'Project(s) %s do not exist' % ', '.join(missed_projects)
+ return None, error_msg
+ else:
+ return result, None
+
+ # TODO(jrobbins): re-implement FindUntrustedGroups()
+
+ def ExpungeUsersInGroups(self, cnxn, ids):
+ """Wipes the given user from the groups system.
+ The given user_ids may to members or groups, or groups themselves.
+ The groups and all their members will be deleted. The users will be
+ wiped from the groups they belong to.
+
+ It will NOT delete user entries. This method will not commit the
+ operations. This method will not make any changes to in-memory data.
+ """
+ # Delete any groups
+ self.usergroupprojects_tbl.Delete(cnxn, group_id=ids, commit=False)
+ self.usergroupsettings_tbl.Delete(cnxn, group_id=ids, commit=False)
+ self.usergroup_tbl.Delete(cnxn, group_id=ids, commit=False)
+
+ # Delete any group members
+ self.usergroup_tbl.Delete(cnxn, user_id=ids, commit=False)
+
+
+class UserGroupDAG(object):
+ """A directed-acyclic graph of potentially nested user groups."""
+
+ def __init__(self, usergroup_service):
+ self.usergroup_service = usergroup_service
+ self.user_group_parents = collections.defaultdict(list)
+ self.user_group_children = collections.defaultdict(list)
+ self.initialized = False
+
+ def Build(self, cnxn, circle_detection=False):
+ if not self.initialized:
+ self.user_group_parents.clear()
+ self.user_group_children.clear()
+ group_ids = self.usergroup_service.usergroupsettings_tbl.Select(
+ cnxn, cols=['group_id'])
+ usergroup_rows = self.usergroup_service.usergroup_tbl.Select(
+ cnxn, cols=['user_id', 'group_id'], distinct=True,
+ user_id=[r[0] for r in group_ids])
+ for user_id, group_id in usergroup_rows:
+ self.user_group_parents[user_id].append(group_id)
+ self.user_group_children[group_id].append(user_id)
+ self.initialized = True
+
+ if circle_detection:
+ for child_id, parent_ids in self.user_group_parents.items():
+ for parent_id in parent_ids:
+ if self.IsChild(cnxn, parent_id, child_id):
+ logging.error(
+ 'Circle exists between group %d and %d.', child_id, parent_id)
+
+ def GetAllAncestors(self, cnxn, group_id, circle_detection=False):
+ """Return a list of distinct ancestor group IDs for the given group."""
+ self.Build(cnxn, circle_detection)
+ result = set()
+ child_ids = [group_id]
+ while child_ids:
+ parent_ids = set()
+ for c_id in child_ids:
+ group_ids = self.user_group_parents[c_id]
+ parent_ids.update(g_id for g_id in group_ids if g_id not in result)
+ result.update(parent_ids)
+ child_ids = list(parent_ids)
+ return list(result)
+
+ def GetAllDescendants(self, cnxn, group_id, circle_detection=False):
+ """Return a list of distinct descendant group IDs for the given group."""
+ self.Build(cnxn, circle_detection)
+ result = set()
+ parent_ids = [group_id]
+ while parent_ids:
+ child_ids = set()
+ for p_id in parent_ids:
+ group_ids = self.user_group_children[p_id]
+ child_ids.update(g_id for g_id in group_ids if g_id not in result)
+ result.update(child_ids)
+ parent_ids = list(child_ids)
+ return list(result)
+
+ def IsChild(self, cnxn, child_id, parent_id):
+ """Returns True if child_id is a direct/indirect child of parent_id."""
+ all_descendants = self.GetAllDescendants(cnxn, parent_id)
+ return child_id in all_descendants
+
+ def MarkObsolete(self):
+ """Mark the DAG as uninitialized so it'll be re-built."""
+ self.initialized = False
+
+ def __repr__(self):
+ result = {}
+ result['parents'] = self.user_group_parents
+ result['children'] = self.user_group_children
+ return str(result)