Project import generated by Copybara.

GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/search/backendsearchpipeline.py b/search/backendsearchpipeline.py
new file mode 100644
index 0000000..69fdc6b
--- /dev/null
+++ b/search/backendsearchpipeline.py
@@ -0,0 +1,325 @@
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Backend issue issue search and sorting.
+
+Each of several "besearch" backend jobs manages one shard of the overall set
+of issues in the system. The backend search pipeline retrieves the issues
+that match the user query, puts them into memcache, and returns them to
+the frontend search pipeline.
+"""
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+
+import logging
+import re
+import time
+
+from google.appengine.api import memcache
+
+import settings
+from features import savedqueries_helpers
+from framework import authdata
+from framework import framework_constants
+from framework import framework_helpers
+from framework import sorting
+from framework import sql
+from proto import ast_pb2
+from proto import tracker_pb2
+from search import ast2ast
+from search import ast2select
+from search import ast2sort
+from search import query2ast
+from search import searchpipeline
+from services import tracker_fulltext
+from services import fulltext_helpers
+from tracker import tracker_bizobj
+
+
+# Used in constructing the at-risk query.
+AT_RISK_LABEL_RE = re.compile(r'^(restrict-view-.+)$', re.IGNORECASE)
+
+# Limit on the number of list items to show in debug log statements
+MAX_LOG = 200
+
+
+class BackendSearchPipeline(object):
+  """Manage the process of issue search, including Promises and caching.
+
+  Even though the code is divided into several methods, the public
+  methods should be called in sequence, so the execution of the code
+  is pretty much in the order of the source code lines here.
+  """
+
+  def __init__(
+      self, mr, services, default_results_per_page,
+      query_project_names, logged_in_user_id, me_user_ids):
+
+    self.mr = mr
+    self.services = services
+    self.default_results_per_page = default_results_per_page
+
+    self.query_project_list = list(services.project.GetProjectsByName(
+        mr.cnxn, query_project_names).values())
+    self.query_project_ids = [
+        p.project_id for p in self.query_project_list]
+
+    self.me_user_ids = me_user_ids
+    self.mr.auth = authdata.AuthData.FromUserID(
+        mr.cnxn, logged_in_user_id, services)
+
+    # The following fields are filled in as the pipeline progresses.
+    # The value None means that we still need to compute that value.
+    self.result_iids = None  # Sorted issue IDs that match the query
+    self.search_limit_reached = False  # True if search results limit is hit.
+    self.error = None
+
+    self._MakePromises()
+
+  def _MakePromises(self):
+    config_dict = self.services.config.GetProjectConfigs(
+        self.mr.cnxn, self.query_project_ids)
+    self.harmonized_config = tracker_bizobj.HarmonizeConfigs(
+        list(config_dict.values()))
+
+    self.canned_query = savedqueries_helpers.SavedQueryIDToCond(
+        self.mr.cnxn, self.services.features, self.mr.can)
+
+    self.canned_query, warnings = searchpipeline.ReplaceKeywordsWithUserIDs(
+        self.me_user_ids, self.canned_query)
+    self.mr.warnings.extend(warnings)
+    self.user_query, warnings = searchpipeline.ReplaceKeywordsWithUserIDs(
+        self.me_user_ids, self.mr.query)
+    self.mr.warnings.extend(warnings)
+    logging.debug('Searching query: %s %s', self.canned_query, self.user_query)
+
+    slice_term = ('Issue.shard = %s', [self.mr.shard_id])
+
+    sd = sorting.ComputeSortDirectives(
+        self.harmonized_config, self.mr.group_by_spec, self.mr.sort_spec)
+
+    self.result_iids_promise = framework_helpers.Promise(
+        _GetQueryResultIIDs, self.mr.cnxn,
+        self.services, self.canned_query, self.user_query,
+        self.query_project_ids, self.harmonized_config, sd,
+        slice_term, self.mr.shard_id, self.mr.invalidation_timestep)
+
+  def SearchForIIDs(self):
+    """Wait for the search Promises and store their results."""
+    with self.mr.profiler.Phase('WaitOnPromises'):
+      self.result_iids, self.search_limit_reached, self.error = (
+          self.result_iids_promise.WaitAndGetValue())
+
+
+def SearchProjectCan(
+    cnxn, services, project_ids, query_ast, shard_id, harmonized_config,
+    left_joins=None, where=None, sort_directives=None, query_desc=''):
+  """Return a list of issue global IDs in the projects that satisfy the query.
+
+  Args:
+    cnxn: Regular database connection to the primary DB.
+    services: interface to issue storage backends.
+    project_ids: list of int IDs of the project to search
+    query_ast: A QueryAST PB with conjunctions and conditions.
+    shard_id: limit search to the specified shard ID int.
+    harmonized_config: harmonized config for all projects being searched.
+    left_joins: SQL LEFT JOIN clauses that are needed in addition to
+        anything generated from the query_ast.
+    where: SQL WHERE clauses that are needed in addition to
+        anything generated from the query_ast.
+    sort_directives: list of strings specifying the columns to sort on.
+    query_desc: descriptive string for debugging.
+
+  Returns:
+    (issue_ids, capped, error) where issue_ids is a list of issue issue_ids
+    that satisfy the query, capped is True if the number of results were
+    capped due to an implementation limit, and error is any well-known error
+    (probably a query parsing error) encountered during search.
+  """
+  logging.info('searching projects %r for AST %r', project_ids, query_ast)
+  start_time = time.time()
+  left_joins = left_joins or []
+  where = where or []
+  if project_ids:
+    cond_str = 'Issue.project_id IN (%s)' % sql.PlaceHolders(project_ids)
+    where.append((cond_str, project_ids))
+
+  try:
+    query_ast = ast2ast.PreprocessAST(
+        cnxn, query_ast, project_ids, services, harmonized_config)
+    logging.info('simplified AST is %r', query_ast)
+    query_left_joins, query_where, _ = ast2select.BuildSQLQuery(query_ast)
+    left_joins.extend(query_left_joins)
+    where.extend(query_where)
+  except ast2ast.MalformedQuery as e:
+    # TODO(jrobbins): inform the user that their query had invalid tokens.
+    logging.info('Invalid query tokens %s.\n %r\n\n', e.message, query_ast)
+    return [], False, e
+  except ast2select.NoPossibleResults as e:
+    # TODO(jrobbins): inform the user that their query was impossible.
+    logging.info('Impossible query %s.\n %r\n\n', e.message, query_ast)
+    return [], False, e
+  logging.info('translated to left_joins %r', left_joins)
+  logging.info('translated to where %r', where)
+
+  fts_capped = False
+  if query_ast.conjunctions:
+    # TODO(jrobbins): Handle "OR" in queries.  For now, we just process the
+    # first conjunction.
+    assert len(query_ast.conjunctions) == 1
+    conj = query_ast.conjunctions[0]
+    full_text_iids, fts_capped = tracker_fulltext.SearchIssueFullText(
+        project_ids, conj, shard_id)
+    if full_text_iids is not None:
+      if not full_text_iids:
+        return [], False, None  # No match on fulltext, so don't bother DB.
+      cond_str = 'Issue.id IN (%s)' % sql.PlaceHolders(full_text_iids)
+      where.append((cond_str, full_text_iids))
+
+  label_def_rows = []
+  status_def_rows = []
+  if sort_directives:
+    if project_ids:
+      for pid in project_ids:
+        label_def_rows.extend(services.config.GetLabelDefRows(cnxn, pid))
+        status_def_rows.extend(services.config.GetStatusDefRows(cnxn, pid))
+    else:
+      label_def_rows = services.config.GetLabelDefRowsAnyProject(cnxn)
+      status_def_rows = services.config.GetStatusDefRowsAnyProject(cnxn)
+
+  harmonized_labels = tracker_bizobj.HarmonizeLabelOrStatusRows(
+      label_def_rows)
+  harmonized_statuses = tracker_bizobj.HarmonizeLabelOrStatusRows(
+      status_def_rows)
+  harmonized_fields = harmonized_config.field_defs
+  sort_left_joins, order_by = ast2sort.BuildSortClauses(
+      sort_directives, harmonized_labels, harmonized_statuses,
+      harmonized_fields)
+  logging.info('translated to sort left_joins %r', sort_left_joins)
+  logging.info('translated to order_by %r', order_by)
+
+  issue_ids, db_capped = services.issue.RunIssueQuery(
+      cnxn, left_joins + sort_left_joins, where, order_by, shard_id=shard_id)
+  logging.warn('executed "%s" query %r for %d issues in %dms',
+               query_desc, query_ast, len(issue_ids),
+               int((time.time() - start_time) * 1000))
+  capped = fts_capped or db_capped
+  return issue_ids, capped, None
+
+def _FilterSpam(query_ast):
+  uses_spam = False
+  # TODO(jrobbins): Handle "OR" in queries.  For now, we just modify the
+  # first conjunction.
+  conjunction = query_ast.conjunctions[0]
+  for condition in conjunction.conds:
+    for field in condition.field_defs:
+      if field.field_name == 'spam':
+        uses_spam = True
+
+  if not uses_spam:
+    query_ast.conjunctions[0].conds.append(
+        ast_pb2.MakeCond(
+            ast_pb2.QueryOp.NE,
+            [tracker_pb2.FieldDef(
+                field_name='spam',
+                field_type=tracker_pb2.FieldTypes.BOOL_TYPE)
+             ],
+        [], []))
+
+  return query_ast
+
+def _GetQueryResultIIDs(
+    cnxn, services, canned_query, user_query,
+    query_project_ids, harmonized_config, sd, slice_term,
+    shard_id, invalidation_timestep):
+  """Do a search and return a list of matching issue IDs.
+
+  Args:
+    cnxn: connection to the database.
+    services: interface to issue storage backends.
+    canned_query: string part of the query from the drop-down menu.
+    user_query: string part of the query that the user typed in.
+    query_project_ids: list of project IDs to search.
+    harmonized_config: combined configs for all the queried projects.
+    sd: list of sort directives.
+    slice_term: additional query term to narrow results to a logical shard
+        within a physical shard.
+    shard_id: int number of the database shard to search.
+    invalidation_timestep: int timestep to use keep memcached items fresh.
+
+  Returns:
+    Tuple consisting of:
+      A list of issue issue_ids that match the user's query.  An empty list, [],
+      is returned if no issues match the query.
+      Boolean that is set to True if the search results limit of this shard is
+      hit.
+      An error (subclass of Exception) encountered during query processing. None
+      means that no error was encountered.
+  """
+  query_ast = _FilterSpam(query2ast.ParseUserQuery(
+      user_query, canned_query, query2ast.BUILTIN_ISSUE_FIELDS,
+      harmonized_config))
+
+  logging.info('query_project_ids is %r', query_project_ids)
+
+  is_fulltext_query = bool(
+    query_ast.conjunctions and
+    fulltext_helpers.BuildFTSQuery(
+      query_ast.conjunctions[0], tracker_fulltext.ISSUE_FULLTEXT_FIELDS))
+  expiration = framework_constants.CACHE_EXPIRATION
+  if is_fulltext_query:
+    expiration = framework_constants.FULLTEXT_MEMCACHE_EXPIRATION
+
+  # Might raise ast2ast.MalformedQuery or ast2select.NoPossibleResults.
+  result_iids, search_limit_reached, error = SearchProjectCan(
+      cnxn, services, query_project_ids, query_ast, shard_id,
+      harmonized_config, sort_directives=sd, where=[slice_term],
+      query_desc='getting query issue IDs')
+  logging.info('Found %d result_iids', len(result_iids))
+  if error:
+    logging.warn('Got error %r', error)
+
+  projects_str = ','.join(str(pid) for pid in sorted(query_project_ids))
+  projects_str = projects_str or 'all'
+  memcache_key = ';'.join([
+      projects_str, canned_query, user_query, ' '.join(sd), str(shard_id)])
+  memcache.set(memcache_key, (result_iids, invalidation_timestep),
+               time=expiration, namespace=settings.memcache_namespace)
+  logging.info('set memcache key %r', memcache_key)
+
+  search_limit_memcache_key = ';'.join([
+      projects_str, canned_query, user_query, ' '.join(sd),
+      'search_limit_reached', str(shard_id)])
+  memcache.set(search_limit_memcache_key,
+               (search_limit_reached, invalidation_timestep),
+               time=expiration, namespace=settings.memcache_namespace)
+  logging.info('set search limit memcache key %r',
+               search_limit_memcache_key)
+
+  timestamps_for_projects = memcache.get_multi(
+      keys=(['%d;%d' % (pid, shard_id) for pid in query_project_ids] +
+            ['all:%d' % shard_id]),
+      namespace=settings.memcache_namespace)
+
+  if query_project_ids:
+    for pid in query_project_ids:
+      key = '%d;%d' % (pid, shard_id)
+      if key not in timestamps_for_projects:
+        memcache.set(
+            key,
+            invalidation_timestep,
+            time=framework_constants.CACHE_EXPIRATION,
+            namespace=settings.memcache_namespace)
+  else:
+    key = 'all;%d' % shard_id
+    if key not in timestamps_for_projects:
+      memcache.set(
+          key,
+          invalidation_timestep,
+          time=framework_constants.CACHE_EXPIRATION,
+          namespace=settings.memcache_namespace)
+
+  return result_iids, search_limit_reached, error