# Copyright 2016 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""The FrontendSearchPipeline class manages issue search and sorting.

The frontend pipeline checks memcache for cached results in each shard.  It
then calls backend jobs to do any shards that had a cache miss.  On cache hit,
the cached results must be filtered by permissions, so the at-risk cache and
backends are consulted.  Next, the sharded results are combined into an overall
list of IIDs.  Then, that list is paginated and the issues on the current
pagination page can be shown.  Alternatively, this class can determine just the
position the currently shown issue would occupy in the overall sorted list.
"""

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import json

import collections
import logging
import math
import random
import six
import time

from google.appengine.api import apiproxy_stub_map
from google.appengine.api import memcache
from google.appengine.api import modules
from google.appengine.api import urlfetch

import settings
from features import savedqueries_helpers
from framework import framework_bizobj
from framework import framework_constants
from framework import framework_helpers
from framework import paginate
from framework import permissions
from framework import sorting
from framework import urls
from search import ast2ast
from search import query2ast
from search import searchpipeline
from services import fulltext_helpers
from tracker import tracker_bizobj
from tracker import tracker_constants
from tracker import tracker_helpers


# Fail-fast responses usually finish in less than 50ms.  If we see a failure
# in under that amount of time, we don't bother logging it.
FAIL_FAST_LIMIT_SEC = 0.1

DELAY_BETWEEN_RPC_COMPLETION_POLLS = 0.04  # 40 milliseconds

# The choices help balance the cost of choosing samples vs. the cost of
# selecting issues that are in a range bounded by neighboring samples.
# Preferred chunk size parameters were determined by experimentation.
MIN_SAMPLE_CHUNK_SIZE = int(
    math.sqrt(tracker_constants.DEFAULT_RESULTS_PER_PAGE))
MAX_SAMPLE_CHUNK_SIZE = int(math.sqrt(settings.search_limit_per_shard))
PREFERRED_NUM_CHUNKS = 50


# TODO(jojwang): monorail:4127: combine some url parameters info or
# query info into dicts or tuples to make argument manager easier.
class FrontendSearchPipeline(object):
  """Manage the process of issue search, including backends and caching.

  Even though the code is divided into several methods, the public
  methods should be called in sequence, so the execution of the code
  is pretty much in the order of the source code lines here.
  """

  def __init__(
      self,
      cnxn,
      services,
      auth,
      me_user_ids,
      query,
      query_project_names,
      items_per_page,
      paginate_start,
      can,
      group_by_spec,
      sort_spec,
      warnings,
      errors,
      use_cached_searches,
      profiler,
      project=None):
    self.cnxn = cnxn
    self.me_user_ids = me_user_ids
    self.auth = auth
    self.logged_in_user_id = auth.user_id or 0
    self.can = can
    self.items_per_page = items_per_page
    self.paginate_start = paginate_start
    self.group_by_spec = group_by_spec
    self.sort_spec = sort_spec
    self.warnings = warnings
    self.use_cached_searches = use_cached_searches
    self.profiler = profiler

    self.services = services
    self.pagination = None
    self.num_skipped_at_start = 0
    self.total_count = 0
    self.errors = errors

    self.project_name = ''
    if project:
      self.project_name = project.project_name
    self.query_projects = []
    if query_project_names:
      consider_projects = list(services.project.GetProjectsByName(
        self.cnxn, query_project_names).values())
      self.query_projects = [
          p for p in consider_projects
          if permissions.UserCanViewProject(
              self.auth.user_pb, self.auth.effective_ids, p)]
    if project:
      self.query_projects.append(project)
    member_of_all_projects = self.auth.user_pb.is_site_admin or all(
        framework_bizobj.UserIsInProject(p, self.auth.effective_ids)
        for p in self.query_projects)
    self.query_project_ids = sorted([
        p.project_id for p in self.query_projects])
    self.query_project_names = sorted([
        p.project_name for p in self.query_projects])

    config_dict = self.services.config.GetProjectConfigs(
        self.cnxn, self.query_project_ids)
    self.harmonized_config = tracker_bizobj.HarmonizeConfigs(
        list(config_dict.values()))

    # The following fields are filled in as the pipeline progresses.
    # The value None means that we still need to compute that value.
    # A shard_key is a tuple (shard_id, subquery).
    self.users_by_id = {}
    self.nonviewable_iids = {}  # {shard_id: set(iid)}
    self.unfiltered_iids = {}  # {shard_key: [iid, ...]} needing perm checks.
    self.filtered_iids = {}  # {shard_key: [iid, ...]} already perm checked.
    self.search_limit_reached = {}  # {shard_key: [bool, ...]}.
    self.allowed_iids = []  # Matching iids that user is permitted to view.
    self.allowed_results = None  # results that the user is permitted to view.
    self.visible_results = None  # allowed_results on current pagination page.
    self.error_responses = set()

    error_msg = _CheckQuery(
        self.cnxn, self.services, query, self.harmonized_config,
        self.query_project_ids, member_of_all_projects,
        warnings=self.warnings)
    if error_msg:
      self.errors.query = error_msg

    # Split up query into smaller subqueries that would get the same results
    # to improve performance. Smaller queries are more likely to get cache
    # hits and subqueries can be parallelized by querying for them across
    # multiple shards.
    self.subqueries = []
    try:
      self.subqueries = query2ast.QueryToSubqueries(query)
    except query2ast.InvalidQueryError:
      # Ignore errors because they've already been recorded in
      # self.errors.query.
      pass

  def SearchForIIDs(self):
    """Use backends to search each shard and store their results."""
    with self.profiler.Phase('Checking cache and calling Backends'):
      rpc_tuples = _StartBackendSearch(
          self.cnxn, self.query_project_names, self.query_project_ids,
          self.harmonized_config, self.unfiltered_iids,
          self.search_limit_reached, self.nonviewable_iids,
          self.error_responses, self.services, self.me_user_ids,
          self.logged_in_user_id, self.items_per_page + self.paginate_start,
          self.subqueries, self.can, self.group_by_spec, self.sort_spec,
          self.warnings, self.use_cached_searches)

    with self.profiler.Phase('Waiting for Backends'):
      try:
        _FinishBackendSearch(rpc_tuples)
      except Exception as e:
        logging.exception(e)
        raise

    if self.error_responses:
      logging.error('%r error responses. Incomplete search results.',
                    self.error_responses)

    with self.profiler.Phase('Filtering cached results'):
      for shard_key in self.unfiltered_iids:
        shard_id, _subquery = shard_key
        if shard_id not in self.nonviewable_iids:
          logging.error(
            'Not displaying shard %r because of no nonviewable_iids', shard_id)
          self.error_responses.add(shard_id)
          filtered_shard_iids = []
        else:
          unfiltered_shard_iids = self.unfiltered_iids[shard_key]
          nonviewable_shard_iids = self.nonviewable_iids[shard_id]
          # TODO(jrobbins): avoid creating large temporary lists.
          filtered_shard_iids = [iid for iid in unfiltered_shard_iids
                                 if iid not in nonviewable_shard_iids]
        self.filtered_iids[shard_key] = filtered_shard_iids

    seen_iids_by_shard_id = collections.defaultdict(set)
    with self.profiler.Phase('Dedupping result IIDs across shards'):
      for shard_key in self.filtered_iids:
        shard_id, _subquery = shard_key
        deduped = [iid for iid in self.filtered_iids[shard_key]
                   if iid not in seen_iids_by_shard_id[shard_id]]
        self.filtered_iids[shard_key] = deduped
        seen_iids_by_shard_id[shard_id].update(deduped)

    with self.profiler.Phase('Counting all filtered results'):
      for shard_key in self.filtered_iids:
        self.total_count += len(self.filtered_iids[shard_key])

    with self.profiler.Phase('Trimming results beyond pagination page'):
      for shard_key in self.filtered_iids:
        self.filtered_iids[shard_key] = self.filtered_iids[
            shard_key][:self.paginate_start + self.items_per_page]

  def MergeAndSortIssues(self):
    """Merge and sort results from all shards into one combined list."""
    with self.profiler.Phase('selecting issues to merge and sort'):
      self._NarrowFilteredIIDs()
      self.allowed_iids = []
      for filtered_shard_iids in self.filtered_iids.values():
        self.allowed_iids.extend(filtered_shard_iids)

    with self.profiler.Phase('getting allowed results'):
      self.allowed_results = self.services.issue.GetIssues(
          self.cnxn, self.allowed_iids)

    # Note: At this point, we have results that are only sorted within
    # each backend's shard.  We still need to sort the merged result.
    self._LookupNeededUsers(self.allowed_results)
    with self.profiler.Phase('merging and sorting issues'):
      self.allowed_results = _SortIssues(
          self.allowed_results, self.harmonized_config, self.users_by_id,
          self.group_by_spec, self.sort_spec)

  def _NarrowFilteredIIDs(self):
    """Combine filtered shards into a range of IIDs for issues to sort.

    The niave way is to concatenate shard_iids[:start + num] for all
    shards then select [start:start + num].  We do better by sampling
    issues and then determining which of those samples are known to
    come before start or after start+num.  We then trim off all those IIDs
    and sort a smaller range of IIDs that might actuall be displayed.
    See the design doc at go/monorail-sorting.

    This method modifies self.fitered_iids and self.num_skipped_at_start.
    """
    # Sample issues and skip those that are known to come before start.
    # See the "Sorting in Monorail" design doc.

    # If the result set is small, don't bother optimizing it.
    orig_length = _TotalLength(self.filtered_iids)
    if orig_length < self.items_per_page * 4:
      return

    # 1. Get sample issues in each shard and sort them all together.
    last = self.paginate_start + self.items_per_page

    samples_by_shard, sample_iids_to_shard = self._FetchAllSamples(
        self.filtered_iids)
    sample_issues = []
    for issue_dict in samples_by_shard.values():
      sample_issues.extend(list(issue_dict.values()))

    self._LookupNeededUsers(sample_issues)
    sample_issues = _SortIssues(
        sample_issues, self.harmonized_config, self.users_by_id,
        self.group_by_spec, self.sort_spec)
    sample_iid_tuples = [
        (issue.issue_id, sample_iids_to_shard[issue.issue_id])
        for issue in sample_issues]

    # 2. Trim off some IIDs that are sure to be positioned after last.
    num_trimmed_end = _TrimEndShardedIIDs(
        self.filtered_iids, sample_iid_tuples, last)
    logging.info('Trimmed %r issues from the end of shards', num_trimmed_end)

    # 3. Trim off some IIDs that are sure to be posiitoned before start.
    keep = _TotalLength(self.filtered_iids) - self.paginate_start
    # Reverse the sharded lists.
    _ReverseShards(self.filtered_iids)
    sample_iid_tuples.reverse()
    self.num_skipped_at_start = _TrimEndShardedIIDs(
        self.filtered_iids, sample_iid_tuples, keep)
    logging.info('Trimmed %r issues from the start of shards',
                 self.num_skipped_at_start)
    # Reverse sharded lists again to get back into forward order.
    _ReverseShards(self.filtered_iids)

  def DetermineIssuePosition(self, issue):
    """Calculate info needed to show the issue flipper.

    Args:
      issue: The issue currently being viewed.

    Returns:
      A 3-tuple (prev_iid, index, next_iid) were prev_iid is the
      IID of the previous issue in the total ordering (or None),
      index is the index that the current issue has in the total
      ordering, and next_iid is the next issue (or None).  If the current
      issue is not in the list of results at all, returns None, None, None.
    """
    # 1. If the current issue is not in the results at all, then exit.
    if not any(issue.issue_id in filtered_shard_iids
               for filtered_shard_iids in self.filtered_iids.values()):
      return None, None, None

    # 2. Choose and retrieve sample issues in each shard.
    samples_by_shard, _ = self._FetchAllSamples(self.filtered_iids)

    # 3. Build up partial results for each shard.
    preceeding_counts = {}  # dict {shard_key: num_issues_preceeding_current}
    prev_candidates, next_candidates = [], []
    for shard_key in self.filtered_iids:
      prev_candidate, index_in_shard, next_candidate = (
          self._DetermineIssuePositionInShard(
              shard_key, issue, samples_by_shard[shard_key]))
      preceeding_counts[shard_key] = index_in_shard
      if prev_candidate:
        prev_candidates.append(prev_candidate)
      if next_candidate:
        next_candidates.append(next_candidate)

    # 4. Combine the results.
    index = sum(preceeding_counts.values())
    prev_candidates = _SortIssues(
        prev_candidates, self.harmonized_config, self.users_by_id,
        self.group_by_spec, self.sort_spec)
    prev_iid = prev_candidates[-1].issue_id if prev_candidates else None
    next_candidates = _SortIssues(
        next_candidates, self.harmonized_config, self.users_by_id,
        self.group_by_spec, self.sort_spec)
    next_iid = next_candidates[0].issue_id if next_candidates else None

    return prev_iid, index, next_iid

  def _DetermineIssuePositionInShard(self, shard_key, issue, sample_dict):
    """Determine where the given issue would fit into results from a shard."""
    # See the design doc for details.  Basically, it first surveys the results
    # to bound a range where the given issue would belong, then it fetches the
    # issues in that range and sorts them.

    filtered_shard_iids = self.filtered_iids[shard_key]

    # 1. Select a sample of issues, leveraging ones we have in RAM already.
    issues_on_hand = list(sample_dict.values())
    if issue.issue_id not in sample_dict:
      issues_on_hand.append(issue)

    self._LookupNeededUsers(issues_on_hand)
    sorted_on_hand = _SortIssues(
        issues_on_hand, self.harmonized_config, self.users_by_id,
        self.group_by_spec, self.sort_spec)
    sorted_on_hand_iids = [soh.issue_id for soh in sorted_on_hand]
    index_in_on_hand = sorted_on_hand_iids.index(issue.issue_id)

    # 2. Bound the gap around where issue belongs.
    if index_in_on_hand == 0:
      fetch_start = 0
    else:
      prev_on_hand_iid = sorted_on_hand_iids[index_in_on_hand - 1]
      fetch_start = filtered_shard_iids.index(prev_on_hand_iid) + 1

    if index_in_on_hand == len(sorted_on_hand) - 1:
      fetch_end = len(filtered_shard_iids)
    else:
      next_on_hand_iid = sorted_on_hand_iids[index_in_on_hand + 1]
      fetch_end = filtered_shard_iids.index(next_on_hand_iid)

    # 3. Retrieve all the issues in that gap to get an exact answer.
    fetched_issues = self.services.issue.GetIssues(
        self.cnxn, filtered_shard_iids[fetch_start:fetch_end])
    if issue.issue_id not in filtered_shard_iids[fetch_start:fetch_end]:
      fetched_issues.append(issue)
    self._LookupNeededUsers(fetched_issues)
    sorted_fetched = _SortIssues(
        fetched_issues, self.harmonized_config, self.users_by_id,
        self.group_by_spec, self.sort_spec)
    sorted_fetched_iids = [sf.issue_id for sf in sorted_fetched]
    index_in_fetched = sorted_fetched_iids.index(issue.issue_id)

    # 4. Find the issues that come immediately before and after the place where
    # the given issue would belong in this shard.
    if index_in_fetched > 0:
      prev_candidate = sorted_fetched[index_in_fetched - 1]
    elif index_in_on_hand > 0:
      prev_candidate = sorted_on_hand[index_in_on_hand - 1]
    else:
      prev_candidate = None

    if index_in_fetched < len(sorted_fetched) - 1:
      next_candidate = sorted_fetched[index_in_fetched + 1]
    elif index_in_on_hand < len(sorted_on_hand) - 1:
      next_candidate = sorted_on_hand[index_in_on_hand + 1]
    else:
      next_candidate = None

    return prev_candidate, fetch_start + index_in_fetched, next_candidate

  def _FetchAllSamples(self, filtered_iids):
    """Return a dict {shard_key: {iid: sample_issue}}."""
    samples_by_shard = {}  # {shard_key: {iid: sample_issue}}
    sample_iids_to_shard = {}  # {iid: shard_key}
    all_needed_iids = []  # List of iids to retrieve.

    for shard_key in filtered_iids:
      on_hand_issues, shard_needed_iids = self._ChooseSampleIssues(
          filtered_iids[shard_key])
      samples_by_shard[shard_key] = on_hand_issues
      for iid in on_hand_issues:
        sample_iids_to_shard[iid] = shard_key
      for iid in shard_needed_iids:
        sample_iids_to_shard[iid] = shard_key
      all_needed_iids.extend(shard_needed_iids)

    retrieved_samples, _misses = self.services.issue.GetIssuesDict(
        self.cnxn, all_needed_iids)
    for retrieved_iid, retrieved_issue in retrieved_samples.items():
      retr_shard_key = sample_iids_to_shard[retrieved_iid]
      samples_by_shard[retr_shard_key][retrieved_iid] = retrieved_issue

    return samples_by_shard, sample_iids_to_shard

  def _ChooseSampleIssues(self, issue_ids):
    """Select a scattering of issues from the list, leveraging RAM cache.

    Args:
      issue_ids: A list of issue IDs that comprise the results in a shard.

    Returns:
      A pair (on_hand_issues, needed_iids) where on_hand_issues is
      an issue dict {iid: issue} of issues already in RAM, and
      shard_needed_iids is a list of iids of issues that need to be retrieved.
    """
    on_hand_issues = {}  # {iid: issue} of sample issues already in RAM.
    needed_iids = []  # [iid, ...] of sample issues not in RAM yet.
    chunk_size = max(MIN_SAMPLE_CHUNK_SIZE, min(MAX_SAMPLE_CHUNK_SIZE,
        int(len(issue_ids) // PREFERRED_NUM_CHUNKS)))
    for i in range(chunk_size, len(issue_ids), chunk_size):
      issue = self.services.issue.GetAnyOnHandIssue(
          issue_ids, start=i, end=min(i + chunk_size, len(issue_ids)))
      if issue:
        on_hand_issues[issue.issue_id] = issue
      else:
        needed_iids.append(issue_ids[i])

    return on_hand_issues, needed_iids

  def _LookupNeededUsers(self, issues):
    """Look up user info needed to sort issues, if any."""
    with self.profiler.Phase('lookup of owner, reporter, and cc'):
      additional_user_views_by_id = (
          tracker_helpers.MakeViewsForUsersInIssues(
              self.cnxn, issues, self.services.user,
              omit_ids=list(self.users_by_id.keys())))
      self.users_by_id.update(additional_user_views_by_id)

  def Paginate(self):
    """Fetch matching issues and paginate the search results.

    These two actions are intertwined because we try to only
    retrieve the Issues on the current pagination page.
    """
    # We already got the issues, just display a slice of the visible ones.
    limit_reached = False
    for shard_limit_reached in self.search_limit_reached.values():
      limit_reached |= shard_limit_reached
    self.pagination = paginate.ArtifactPagination(
        self.allowed_results,
        self.items_per_page,
        self.paginate_start,
        self.project_name,
        urls.ISSUE_LIST,
        total_count=self.total_count,
        limit_reached=limit_reached,
        skipped=self.num_skipped_at_start)
    self.visible_results = self.pagination.visible_results

    # If we were not forced to look up visible users already, do it now.
    self._LookupNeededUsers(self.visible_results)

  def __repr__(self):
    """Return a string that shows the internal state of this pipeline."""
    if self.allowed_iids:
      shown_allowed_iids = self.allowed_iids[:200]
    else:
      shown_allowed_iids = self.allowed_iids

    if self.allowed_results:
      shown_allowed_results = self.allowed_results[:200]
    else:
      shown_allowed_results = self.allowed_results

    parts = [
        'allowed_iids: %r' % shown_allowed_iids,
        'allowed_results: %r' % shown_allowed_results,
        'len(visible_results): %r' % (
            self.visible_results and len(self.visible_results))]
    return '%s(%s)' % (self.__class__.__name__, '\n'.join(parts))


def _CheckQuery(
    cnxn, services, query, harmonized_config, project_ids,
    member_of_all_projects, warnings=None):
  """Parse the given query and report the first error or None."""
  try:
    query_ast = query2ast.ParseUserQuery(
        query, '', query2ast.BUILTIN_ISSUE_FIELDS, harmonized_config,
        warnings=warnings)
    query_ast = ast2ast.PreprocessAST(
        cnxn, query_ast, project_ids, services, harmonized_config,
        is_member=member_of_all_projects)
  except query2ast.InvalidQueryError as e:
    return str(e)
  except ast2ast.MalformedQuery as e:
    return str(e)

  return None


def _MakeBackendCallback(func, *args):
  # type: (Callable[[*Any], Any], *Any) -> Callable[[*Any], Any]
  """Helper to store a particular function and argument set into a callback.

  Args:
    func: Function to callback.
    *args: The arguments to pass into the function.

  Returns:
    Callback function based on specified arguments.
  """
  return lambda: func(*args)


def _StartBackendSearch(
    cnxn, query_project_names, query_project_ids, harmonized_config,
    unfiltered_iids_dict, search_limit_reached_dict, nonviewable_iids,
    error_responses, services, me_user_ids, logged_in_user_id, new_url_num,
    subqueries, can, group_by_spec, sort_spec, warnings, use_cached_searches):
  # type: (MonorailConnection, Sequence[str], Sequence[int],
  #     mrproto.tracker_pb2.ProjectIssueConfig,
  #     Mapping[Tuple(int, str), Sequence[int]],
  #     Mapping[Tuple(int, str), Sequence[bool]],
  #     Mapping[Tuple(int, str), Collection[int]], Sequence[Tuple(int, str)],
  #     Services, Sequence[int], int, int, Sequence[str], int, str, str,
  #     Sequence[Tuple(str, Sequence[str])], bool) ->
  #     Sequence[Tuple(int, Tuple(int, str),
  #         google.appengine.api.apiproxy_stub_map.UserRPC)]
  """Request that our backends search and return a list of matching issue IDs.

  Args:
    cnxn: monorail connection to the database.
    query_project_names: set of project names to search.
    query_project_ids: list of project IDs to search.
    harmonized_config: combined ProjectIssueConfig for all projects being
        searched.
    unfiltered_iids_dict: dict {shard_key: [iid, ...]} of unfiltered search
        results to accumulate into.  They need to be later filtered by
        permissions and merged into filtered_iids_dict.
    search_limit_reached_dict: dict {shard_key: [bool, ...]} to determine if
        the search limit of any shard was reached.
    nonviewable_iids: dict {shard_id: set(iid)} of restricted issues in the
        projects being searched that the signed in user cannot view.
    error_responses: shard_iids of shards that encountered errors.
    services: connections to backends.
    me_user_ids: Empty list when no user is logged in, or user ID of the logged
        in user when doing an interactive search, or the viewed user ID when
        viewing someone else's dashboard, or the subscribing user's ID when
        evaluating subscriptions.  And, any linked accounts.
    logged_in_user_id: user_id of the logged in user, 0 otherwise
    new_url_num: the number of issues for BackendSearchPipeline to query.
        Computed based on pagination offset + number of items per page.
    subqueries: split up list of query string segments.
    can: "canned query" number to scope the user's search.
    group_by_spec: string that lists the grouping order.
    sort_spec: string that lists the sort order.
    warnings: list to accumulate warning messages.
    use_cached_searches: Bool for whether to use cached searches.

  Returns:
    A list of rpc_tuples that can be passed to _FinishBackendSearch to wait
    on any remaining backend calls.

  SIDE-EFFECTS:
    Any data found in memcache is immediately put into unfiltered_iids_dict.
    As the backends finish their work, _HandleBackendSearchResponse will update
    unfiltered_iids_dict for those shards.

    Any warnings produced throughout this process will be added to the list
    warnings.
  """
  rpc_tuples = []
  needed_shard_keys = set()
  for subquery in subqueries:
    subquery, warnings = searchpipeline.ReplaceKeywordsWithUserIDs(
        me_user_ids, subquery)
    warnings.extend(warnings)
    for shard_id in range(settings.num_logical_shards):
      needed_shard_keys.add((shard_id, subquery))

  # 1. Get whatever we can from memcache.  Cache hits are only kept if they are
  # not already expired.
  project_shard_timestamps = _GetProjectTimestamps(
      query_project_ids, needed_shard_keys)

  if use_cached_searches:
    cached_unfiltered_iids_dict, cached_search_limit_reached_dict = (
        _GetCachedSearchResults(
            cnxn, query_project_ids, needed_shard_keys,
            harmonized_config, project_shard_timestamps, services, me_user_ids,
            can, group_by_spec, sort_spec, warnings))
    unfiltered_iids_dict.update(cached_unfiltered_iids_dict)
    search_limit_reached_dict.update(cached_search_limit_reached_dict)
  for cache_hit_shard_key in unfiltered_iids_dict:
    needed_shard_keys.remove(cache_hit_shard_key)

  # 2. Each kept cache hit will have unfiltered IIDs, so we filter them by
  # removing non-viewable IDs.
  _GetNonviewableIIDs(
    query_project_ids, logged_in_user_id,
    set(range(settings.num_logical_shards)),
    rpc_tuples, nonviewable_iids, project_shard_timestamps,
    services.cache_manager.processed_invalidations_up_to,
    use_cached_searches)

  # 3. Hit backends for any shards that are still needed.  When these results
  # come back, they are also put into unfiltered_iids_dict.
  for shard_key in needed_shard_keys:
    rpc = _StartBackendSearchCall(
        query_project_names,
        shard_key,
        services.cache_manager.processed_invalidations_up_to,
        me_user_ids,
        logged_in_user_id,
        new_url_num,
        can=can,
        sort_spec=sort_spec,
        group_by_spec=group_by_spec)
    rpc_tuple = (time.time(), shard_key, rpc)
    rpc.callback = _MakeBackendCallback(
        _HandleBackendSearchResponse, query_project_names, rpc_tuple,
        rpc_tuples, settings.backend_retries, unfiltered_iids_dict,
        search_limit_reached_dict,
        services.cache_manager.processed_invalidations_up_to, error_responses,
        me_user_ids, logged_in_user_id, new_url_num, can, sort_spec,
        group_by_spec)
    rpc_tuples.append(rpc_tuple)

  return rpc_tuples


def _FinishBackendSearch(rpc_tuples):
  """Wait for all backend calls to complete, including any retries."""
  while rpc_tuples:
    active_rpcs = [rpc for (_time, _shard_key, rpc) in rpc_tuples]
    # Wait for any active RPC to complete.  It's callback function will
    # automatically be called.
    finished_rpc = real_wait_any(active_rpcs)
    # Figure out which rpc_tuple finished and remove it from our list.
    for rpc_tuple in rpc_tuples:
      _time, _shard_key, rpc = rpc_tuple
      if rpc == finished_rpc:
        rpc_tuples.remove(rpc_tuple)
        break
    else:
      raise ValueError('We somehow finished an RPC that is not in rpc_tuples')


def real_wait_any(active_rpcs):
  """Work around the blocking nature of wait_any().

  wait_any() checks for any finished RPCs, and returns one if found.
  If no RPC is finished, it simply blocks on the last RPC in the list.
  This is not the desired behavior because we are not able to detect
  FAST-FAIL RPC results and retry them if wait_any() is blocked on a
  request that is taking a long time to do actual work.

  Instead, we do the same check, without blocking on any individual RPC.
  """
  if six.PY3 or settings.local_mode:
    # The development server has very different code for RPCs than the
    # code used in the hosted environment.
    return apiproxy_stub_map.UserRPC.wait_any(active_rpcs)
  while True:
    finished, _ = apiproxy_stub_map.UserRPC._UserRPC__check_one(active_rpcs)
    if finished:
      return finished
    time.sleep(DELAY_BETWEEN_RPC_COMPLETION_POLLS)

def _GetProjectTimestamps(query_project_ids, needed_shard_keys):
  """Get a dict of modified_ts values for all specified project-shards."""
  project_shard_timestamps = {}
  if query_project_ids:
    keys = []
    for pid in query_project_ids:
      for sid, _subquery in needed_shard_keys:
        keys.append('%d;%d' % (pid, sid))
  else:
    keys = [('all;%d' % sid)
            for sid, _subquery in needed_shard_keys]

  timestamps_for_project = memcache.get_multi(
      keys=keys, namespace=settings.memcache_namespace)
  for key, timestamp in timestamps_for_project.items():
    pid_str, sid_str = key.split(';')
    if pid_str == 'all':
      project_shard_timestamps['all', int(sid_str)] = timestamp
    else:
      project_shard_timestamps[int(pid_str), int(sid_str)] = timestamp

  return project_shard_timestamps


def _GetNonviewableIIDs(
    query_project_ids, logged_in_user_id, needed_shard_ids, rpc_tuples,
    nonviewable_iids, project_shard_timestamps, invalidation_timestep,
    use_cached_searches):
  """Build a set of at-risk IIDs, and accumulate RPCs to get uncached ones."""
  if query_project_ids:
    keys = []
    for pid in query_project_ids:
      for sid in needed_shard_ids:
        keys.append('%d;%d;%d' % (pid, logged_in_user_id, sid))
  else:
    keys = [
        ('all;%d;%d' % (logged_in_user_id, sid)) for sid in needed_shard_ids
    ]

  if use_cached_searches:
    cached_dict = memcache.get_multi(
        keys, key_prefix='nonviewable:', namespace=settings.memcache_namespace)
  else:
    cached_dict = {}

  for sid in needed_shard_ids:
    if query_project_ids:
      for pid in query_project_ids:
        _AccumulateNonviewableIIDs(
            pid, logged_in_user_id, sid, cached_dict, nonviewable_iids,
            project_shard_timestamps, rpc_tuples, invalidation_timestep)
    else:
      _AccumulateNonviewableIIDs(
          None, logged_in_user_id, sid, cached_dict, nonviewable_iids,
          project_shard_timestamps, rpc_tuples, invalidation_timestep)


def _AccumulateNonviewableIIDs(
    pid, logged_in_user_id, sid, cached_dict, nonviewable_iids,
    project_shard_timestamps, rpc_tuples, invalidation_timestep):
  """Use one of the retrieved cache entries or call a backend if needed."""
  if pid is None:
    key = 'all;%d;%d' % (logged_in_user_id, sid)
  else:
    key = '%d;%d;%d' % (pid, logged_in_user_id, sid)

  if key in cached_dict:
    issue_ids, cached_ts = cached_dict.get(key)
    modified_ts = project_shard_timestamps.get((pid, sid))
    if modified_ts is None or modified_ts > cached_ts:
      logging.info('nonviewable too stale on (project %r, shard %r)',
                   pid, sid)
    else:
      logging.info('adding %d nonviewable issue_ids', len(issue_ids))
      nonviewable_iids[sid] = set(issue_ids)

  if sid not in nonviewable_iids:
    logging.info('starting backend call for nonviewable iids %r', key)
    rpc = _StartBackendNonviewableCall(
      pid, logged_in_user_id, sid, invalidation_timestep)
    rpc_tuple = (time.time(), sid, rpc)
    rpc.callback = _MakeBackendCallback(
        _HandleBackendNonviewableResponse, pid, logged_in_user_id, sid,
        rpc_tuple, rpc_tuples, settings.backend_retries, nonviewable_iids,
        invalidation_timestep)
    rpc_tuples.append(rpc_tuple)


def _GetCachedSearchResults(
    cnxn, query_project_ids, needed_shard_keys, harmonized_config,
    project_shard_timestamps, services, me_user_ids, can, group_by_spec,
    sort_spec, warnings):
  """Return a dict of cached search results that are not already stale.

  If it were not for cross-project search, we would simply cache when we do a
  search and then invalidate when an issue is modified.  But, with
  cross-project search we don't know all the memcache entries that would
  need to be invalidated.  So, instead, we write the search result cache
  entries and then an initial modified_ts value for each project if it was
  not already there. And, when we update an issue we write a new
  modified_ts entry, which implicitly invalidate all search result
  cache entries that were written earlier because they are now stale.  When
  reading from the cache, we ignore any query project with modified_ts
  after its search result cache timestamp, because it is stale.

  Args:
    cnxn: monorail connection to the database.
    query_project_ids: list of project ID numbers for all projects being
        searched.
    needed_shard_keys: set of shard keys that need to be checked.
    harmonized_config: ProjectIsueConfig with combined information for all
        projects involved in this search.
    project_shard_timestamps: a dict {(project_id, shard_id): timestamp, ...}
        that tells when each shard was last invalidated.
    services: connections to backends.
    me_user_ids: Empty list when no user is logged in, or user ID of the logged
        in user when doing an interactive search, or the viewed user ID when
        viewing someone else's dashboard, or the subscribing user's ID when
        evaluating subscriptions.  And, any linked accounts.
    can: "canned query" number to scope the user's search.
    group_by_spec: string that lists the grouping order.
    sort_spec: string that lists the sort order.
    warnings: list to accumulate warning messages.


  Returns:
    Tuple consisting of:
      A dictionary {shard_id: [issue_id, ...], ...} of unfiltered search result
      issue IDs. Only shard_ids found in memcache will be in that dictionary.
      The result issue IDs must be permission checked before they can be
      considered to be part of the user's result set.
      A dictionary {shard_id: bool, ...}. The boolean is set to True if
      the search results limit of the shard is hit.
  """
  projects_str = ','.join(str(pid) for pid in sorted(query_project_ids))
  projects_str = projects_str or 'all'
  canned_query = savedqueries_helpers.SavedQueryIDToCond(
      cnxn, services.features, can)
  canned_query, warnings = searchpipeline.ReplaceKeywordsWithUserIDs(
      me_user_ids, canned_query)
  warnings.extend(warnings)

  sd = sorting.ComputeSortDirectives(
      harmonized_config, group_by_spec, sort_spec)
  sd_str = ' '.join(sd)
  memcache_key_prefix = '%s;%s' % (projects_str, canned_query)
  limit_reached_key_prefix = '%s;%s' % (projects_str, canned_query)

  cached_dict = memcache.get_multi(
      ['%s;%s;%s;%d' % (memcache_key_prefix, subquery, sd_str, sid)
       for sid, subquery in needed_shard_keys],
      namespace=settings.memcache_namespace)
  cached_search_limit_reached_dict = memcache.get_multi(
      ['%s;%s;%s;search_limit_reached;%d' % (
          limit_reached_key_prefix, subquery, sd_str, sid)
       for sid, subquery in needed_shard_keys],
      namespace=settings.memcache_namespace)

  unfiltered_dict = {}
  search_limit_reached_dict = {}
  for shard_key in needed_shard_keys:
    shard_id, subquery = shard_key
    memcache_key = '%s;%s;%s;%d' % (
        memcache_key_prefix, subquery, sd_str, shard_id)
    limit_reached_key = '%s;%s;%s;search_limit_reached;%d' % (
        limit_reached_key_prefix, subquery, sd_str, shard_id)
    if memcache_key not in cached_dict:
      logging.info('memcache miss on shard %r', shard_key)
      continue

    cached_iids, cached_ts = cached_dict[memcache_key]
    if cached_search_limit_reached_dict.get(limit_reached_key):
      search_limit_reached, _ = cached_search_limit_reached_dict[
          limit_reached_key]
    else:
      search_limit_reached = False

    stale = False
    if query_project_ids:
      for project_id in query_project_ids:
        modified_ts = project_shard_timestamps.get((project_id, shard_id))
        if modified_ts is None or modified_ts > cached_ts:
          stale = True
          logging.info('memcache too stale on shard %r because of %r',
                       shard_id, project_id)
          break
    else:
      modified_ts = project_shard_timestamps.get(('all', shard_id))
      if modified_ts is None or modified_ts > cached_ts:
        stale = True
        logging.info('memcache too stale on shard %r because of all',
                     shard_id)

    if not stale:
      unfiltered_dict[shard_key] = cached_iids
      search_limit_reached_dict[shard_key] = search_limit_reached

  return unfiltered_dict, search_limit_reached_dict


def _MakeBackendRequestHeaders(failfast):
  headers = {
    # This is needed to allow frontends to talk to backends without going
    # through a login screen on googleplex.com.
    # http://wiki/Main/PrometheusInternal#Internal_Applications_and_APIs
    'X-URLFetch-Service-Id': 'GOOGLEPLEX',
    }
  if failfast:
    headers['X-AppEngine-FailFast'] = 'Yes'
  return headers


def _StartBackendSearchCall(
    query_project_names,
    shard_key,
    invalidation_timestep,
    me_user_ids,
    logged_in_user_id,
    new_url_num,
    can=None,
    sort_spec=None,
    group_by_spec=None,
    deadline=None,
    failfast=True):
  # type: (Sequence[str], Tuple(int, str), int, Sequence[int], int,
  #     int, str, str, int, bool) ->
  #     google.appengine.api.apiproxy_stub_map.UserRPC
  """Ask a backend to query one shard of the database.

  Args:
    query_project_names: List of project names queried.
    shard_key: Tuple specifying which DB shard to query.
    invalidation_timestep: int timestep to use keep cached items fresh.
    me_user_ids: Empty list when no user is logged in, or user ID of the logged
        in user when doing an interactive search, or the viewed user ID when
        viewing someone else's dashboard, or the subscribing user's ID when
        evaluating subscriptions.  And, any linked accounts.
    logged_in_user_id: Id of the logged in user.
    new_url_num: the number of issues for BackendSearchPipeline to query.
        Computed based on pagination offset + number of items per page.
    can: Id of th canned query to use.
    sort_spec: Str specifying how issues should be sorted.
    group_by_spec: Str specifying how issues should be grouped.
    deadline: Max time for the RPC to take before failing.
    failfast: Whether to set the X-AppEngine-FailFast request header.

  Returns:
    UserRPC for the created RPC call.
  """
  shard_id, subquery = shard_key
  protocol = 'https' if not settings.local_mode else 'http'
  backend_host = modules.get_hostname(module='default')
  url = '%s://%s%s' % (
      protocol,
      backend_host,
      framework_helpers.FormatURL(
          [],
          urls.BACKEND_SEARCH,
          projects=','.join(query_project_names),
          q=subquery,
          start=0,
          num=new_url_num,
          can=can,
          sort=sort_spec,
          groupby=group_by_spec,
          logged_in_user_id=logged_in_user_id,
          me_user_ids=','.join(str(uid) for uid in me_user_ids),
          shard_id=shard_id,
          invalidation_timestep=invalidation_timestep))
  logging.info('\n\nCalling backend: %s', url)
  rpc = urlfetch.create_rpc(
      deadline=deadline or settings.backend_deadline)
  headers = _MakeBackendRequestHeaders(failfast)
  # follow_redirects=False is needed to avoid a login screen on googleplex.
  urlfetch.make_fetch_call(rpc, url, follow_redirects=False, headers=headers)
  return rpc


def _StartBackendNonviewableCall(
    project_id, logged_in_user_id, shard_id, invalidation_timestep,
    deadline=None, failfast=True):
  """Ask a backend to query one shard of the database."""
  protocol = 'https' if not settings.local_mode else 'http'
  backend_host = modules.get_hostname(module='default')
  url = '%s://%s%s' % (protocol, backend_host, framework_helpers.FormatURL(
      None, urls.BACKEND_NONVIEWABLE,
      project_id=project_id or '',
      logged_in_user_id=logged_in_user_id or '',
      shard_id=shard_id,
      invalidation_timestep=invalidation_timestep))
  logging.info('Calling backend nonviewable: %s', url)
  rpc = urlfetch.create_rpc(deadline=deadline or settings.backend_deadline)
  headers = _MakeBackendRequestHeaders(failfast)
  # follow_redirects=False is needed to avoid a login screen on googleplex.
  urlfetch.make_fetch_call(rpc, url, follow_redirects=False, headers=headers)
  return rpc


def _HandleBackendSearchResponse(
    query_project_names, rpc_tuple, rpc_tuples, remaining_retries,
    unfiltered_iids, search_limit_reached, invalidation_timestep,
    error_responses, me_user_ids, logged_in_user_id, new_url_num, can,
    sort_spec, group_by_spec):
  # type: (Sequence[str], Tuple(int, Tuple(int, str),
  #         google.appengine.api.apiproxy_stub_map.UserRPC),
  #     Sequence[Tuple(int, Tuple(int, str),
  #         google.appengine.api.apiproxy_stub_map.UserRPC)],
  #     int, Mapping[Tuple(int, str), Sequence[int]],
  #     Mapping[Tuple(int, str), bool], int, Collection[Tuple(int, str)],
  #     Sequence[int], int, int, int, str, str) -> None
  #
  """Process one backend response and retry if there was an error.

  SIDE EFFECTS: This function edits many of the passed in parameters in place.
    For example, search_limit_reached and unfiltered_iids are updated with
    response data from the RPC, keyed by shard_key.

  Args:
    query_project_names: List of projects to query.
    rpc_tuple: Tuple containing an RPC response object, the time it happened,
      and what shard the RPC was queried against.
    rpc_tuples: List of RPC responses to mutate with any retry responses that
      heppened.
    remaining_retries: Number of times left to retry.
    unfiltered_iids: Dict of Issue ids, before they've been filtered by
      permissions.
    search_limit_reached: Dict of whether the search limit for a particular
      shard has been hit.
    invalidation_timestep: int timestep to use keep cached items fresh.
    error_responses:
    me_user_ids: List of relevant user IDs. ie: the currently logged in user
      and linked account IDs if applicable.
    logged_in_user_id: Logged in user's ID.
    new_url_num: the number of issues for BackendSearchPipeline to query.
        Computed based on pagination offset + number of items per page.
    can: Canned query ID to use.
    sort_spec: str specifying how issues should be sorted.
    group_by_spec: str specifying how issues should be grouped.
  """
  start_time, shard_key, rpc = rpc_tuple
  duration_sec = time.time() - start_time

  try:
    response = rpc.get_result()
    logging.info('call to backend took %d sec', duration_sec)
    # Note that response.content has "})]'\n" prepended to it.
    json_content = response.content[5:]
    logging.info('got json text: %r length %r',
                 json_content[:framework_constants.LOGGING_MAX_LENGTH],
                 len(json_content))
    if json_content == b'':
      raise Exception('Fast fail')
    json_data = json.loads(json_content)
    unfiltered_iids[shard_key] = json_data['unfiltered_iids']
    search_limit_reached[shard_key] = json_data['search_limit_reached']
    if json_data.get('error'):
      # Don't raise an exception, just log, because these errors are more like
      # 400s than 500s, and shouldn't be retried.
      logging.error('Backend shard %r returned error "%r"' % (
          shard_key, json_data.get('error')))
      error_responses.add(shard_key)

  except Exception as e:
    if duration_sec > FAIL_FAST_LIMIT_SEC:  # Don't log fail-fast exceptions.
      logging.exception(e)
    if not remaining_retries:
      logging.error('backend search retries exceeded')
      error_responses.add(shard_key)
      return  # Used all retries, so give up.

    if duration_sec >= settings.backend_deadline:
      logging.error('backend search on %r took too long', shard_key)
      error_responses.add(shard_key)
      return  # That backend shard is overloaded, so give up.

    logging.error('backend call for shard %r failed, retrying', shard_key)
    retry_rpc = _StartBackendSearchCall(
        query_project_names,
        shard_key,
        invalidation_timestep,
        me_user_ids,
        logged_in_user_id,
        new_url_num,
        can=can,
        sort_spec=sort_spec,
        group_by_spec=group_by_spec,
        failfast=remaining_retries > 2)
    retry_rpc_tuple = (time.time(), shard_key, retry_rpc)
    retry_rpc.callback = _MakeBackendCallback(
        _HandleBackendSearchResponse, query_project_names, retry_rpc_tuple,
        rpc_tuples, remaining_retries - 1, unfiltered_iids,
        search_limit_reached, invalidation_timestep, error_responses,
        me_user_ids, logged_in_user_id, new_url_num, can, sort_spec,
        group_by_spec)
    rpc_tuples.append(retry_rpc_tuple)


def _HandleBackendNonviewableResponse(
    project_id, logged_in_user_id, shard_id, rpc_tuple, rpc_tuples,
    remaining_retries, nonviewable_iids, invalidation_timestep):
  """Process one backend response and retry if there was an error."""
  start_time, shard_id, rpc = rpc_tuple
  duration_sec = time.time() - start_time

  try:
    response = rpc.get_result()
    logging.info('call to backend nonviewable took %d sec', duration_sec)
    # Note that response.content has "})]'\n" prepended to it.
    json_content = response.content[5:]
    logging.info('got json text: %r length %r',
                 json_content[:framework_constants.LOGGING_MAX_LENGTH],
                 len(json_content))
    if json_content == b'':
      raise Exception('Fast fail')
    json_data = json.loads(json_content)
    nonviewable_iids[shard_id] = set(json_data['nonviewable'])

  except Exception as e:
    if duration_sec > FAIL_FAST_LIMIT_SEC:  # Don't log fail-fast exceptions.
      logging.exception(e)

    if not remaining_retries:
      logging.warning('Used all retries, so give up on shard %r', shard_id)
      return

    if duration_sec >= settings.backend_deadline:
      logging.error('nonviewable call on %r took too long', shard_id)
      return  # That backend shard is overloaded, so give up.

    logging.error(
      'backend nonviewable call for shard %r;%r;%r failed, retrying',
      project_id, logged_in_user_id, shard_id)
    retry_rpc = _StartBackendNonviewableCall(
        project_id, logged_in_user_id, shard_id, invalidation_timestep,
        failfast=remaining_retries > 2)
    retry_rpc_tuple = (time.time(), shard_id, retry_rpc)
    retry_rpc.callback = _MakeBackendCallback(
        _HandleBackendNonviewableResponse, project_id, logged_in_user_id,
        shard_id, retry_rpc_tuple, rpc_tuples, remaining_retries - 1,
        nonviewable_iids, invalidation_timestep)
    rpc_tuples.append(retry_rpc_tuple)


def _TotalLength(sharded_iids):
  """Return the total length of all issue_iids lists."""
  return sum(len(issue_iids) for issue_iids in sharded_iids.values())


def _ReverseShards(sharded_iids):
  """Reverse each issue_iids list in place."""
  for shard_key in sharded_iids:
    sharded_iids[shard_key].reverse()


def _TrimEndShardedIIDs(sharded_iids, sample_iid_tuples, num_needed):
  """Trim the IIDs to keep at least num_needed items.

  Args:
    sharded_iids: dict {shard_key: issue_id_list} for search results.  This is
        modified in place to remove some trailing issue IDs.
    sample_iid_tuples: list of (iid, shard_key) from a sorted list of sample
        issues.
    num_needed: int minimum total number of items to keep.  Some IIDs that are
        known to belong in positions > num_needed will be trimmed off.

  Returns:
    The total number of IIDs removed from the IID lists.
  """
  # 1. Get (sample_iid, position_in_shard) for each sample.
  sample_positions = _CalcSamplePositions(sharded_iids, sample_iid_tuples)

  # 2. Walk through the samples, computing a combined lower bound at each
  # step until we know that we have passed at least num_needed IIDs.
  lower_bound_per_shard = {}
  excess_samples = []
  for i in range(len(sample_positions)):
    _sample_iid, sample_shard_key, pos = sample_positions[i]
    lower_bound_per_shard[sample_shard_key] = pos
    overall_lower_bound = sum(lower_bound_per_shard.values())
    if overall_lower_bound >= num_needed:
      excess_samples = sample_positions[i + 1:]
      break
  else:
    return 0  # We went through all samples and never reached num_needed.

  # 3. Truncate each shard at the first excess sample in that shard.
  already_trimmed = set()
  num_trimmed = 0
  for _sample_iid, sample_shard_key, pos in excess_samples:
    if sample_shard_key not in already_trimmed:
      num_trimmed += len(sharded_iids[sample_shard_key]) - pos
      sharded_iids[sample_shard_key] = sharded_iids[sample_shard_key][:pos]
      already_trimmed.add(sample_shard_key)

  return num_trimmed


# TODO(jrobbins): Convert this to a python generator.
def _CalcSamplePositions(sharded_iids, sample_iids):
  """Return [(iid, shard_key, position_in_shard), ...] for each sample."""
  # We keep track of how far index() has scanned in each shard to avoid
  # starting over at position 0 when looking for the next sample in
  # the same shard.
  scan_positions = collections.defaultdict(lambda: 0)
  sample_positions = []
  for sample_iid, sample_shard_key in sample_iids:
    try:
      pos = sharded_iids.get(sample_shard_key, []).index(
          sample_iid, scan_positions[sample_shard_key])
      scan_positions[sample_shard_key] = pos
      sample_positions.append((sample_iid, sample_shard_key, pos))
    except ValueError:
      pass

  return sample_positions


def _SortIssues(issues, config, users_by_id, group_by_spec, sort_spec):
  """Sort the found issues based on the request and config values.

  Args:
    issues: A list of issues to be sorted.
    config: A ProjectIssueConfig that could impact sort order.
    users_by_id: dictionary {user_id: user_view,...} for all users who
      participate in any issue in the entire list.
    group_by_spec: string that lists the grouping order
    sort_spec: string that lists the sort order


  Returns:
    A sorted list of issues, based on parameters from mr and config.
  """
  issues = sorting.SortArtifacts(
      issues, config, tracker_helpers.SORTABLE_FIELDS,
      tracker_helpers.SORTABLE_FIELDS_POSTPROCESSORS, group_by_spec,
      sort_spec, users_by_id=users_by_id)
  return issues
