blob: 6db23d480919d2a0116af3ddb813dca090a1d181 [file] [log] [blame]
# Copyright 2018 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""
Helper functions for spam and component classification. These are mostly for
feature extraction, so that the serving code and training code both use the same
set of features.
"""
# TODO(crbug.com/monorail/7515): DELETE THIS FILE and all references.
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import csv
import hashlib
import logging
import re
import sys
from six import text_type
from apiclient.discovery import build
from apiclient.errors import Error as ApiClientError
from oauth2client.client import GoogleCredentials
from oauth2client.client import Error as Oauth2ClientError
SPAM_COLUMNS = ['verdict', 'subject', 'content', 'email']
LEGACY_CSV_COLUMNS = ['verdict', 'subject', 'content']
DELIMITERS = [r'\s', r'\,', r'\.', r'\?', '!', r'\:', r'\(', r'\)']
# Must be identical to settings.spam_feature_hashes.
SPAM_FEATURE_HASHES = 500
# Must be identical to settings.component_features.
COMPONENT_FEATURES = 5000
def _ComponentFeatures(content, num_features, top_words):
"""
This uses the most common words in the entire dataset as features.
The count of common words in the issue comments makes up the features.
"""
features = [0] * num_features
for blob in content:
words = blob.split()
for word in words:
if word in top_words:
features[top_words[word]] += 1
return features
def _SpamHashFeatures(content, num_features):
"""
Feature hashing is a fast and compact way to turn a string of text into a
vector of feature values for classification and training.
See also: https://en.wikipedia.org/wiki/Feature_hashing
This is a simple implementation that doesn't try to minimize collisions
or anything else fancy.
"""
features = [0] * num_features
total = 0.0
for blob in content:
words = re.split('|'.join(DELIMITERS), blob)
for word in words:
encoded_word = word
# If we've been passed real unicode strings, convert them to bytestrings.
if isinstance(word, text_type):
encoded_word = word.encode('utf-8')
feature_index = int(
int(hashlib.sha1(encoded_word).hexdigest(), 16) % num_features)
features[feature_index] += 1.0
total += 1.0
if total > 0:
features = [ f / total for f in features ]
return features
def GenerateFeaturesRaw(content, num_features, top_words=None):
"""Generates a vector of features for a given issue or comment.
Args:
content: The content of the issue's description and comments.
num_features: The number of features to generate.
"""
if top_words:
return { 'word_features': _ComponentFeatures(content,
num_features,
top_words)}
return { 'word_hashes': _SpamHashFeatures(content, num_features)}
def transform_spam_csv_to_features(csv_training_data):
X = []
y = []
# Handle if the list is double-wrapped.
if csv_training_data and len(csv_training_data[0]) > 4:
csv_training_data = csv_training_data[0]
for row in csv_training_data:
if len(row) == 4:
verdict, subject, content, _email = row
else:
verdict, subject, content = row
X.append(GenerateFeaturesRaw([str(subject), str(content)],
SPAM_FEATURE_HASHES))
y.append(1 if verdict == 'spam' else 0)
return X, y
def transform_component_csv_to_features(csv_training_data, top_list):
X = []
y = []
top_words = {}
for i in range(len(top_list)):
top_words[top_list[i]] = i
component_to_index = {}
index_to_component = {}
component_index = 0
for row in csv_training_data:
component, content = row
component = str(component).split(",")[0]
if component not in component_to_index:
component_to_index[component] = component_index
index_to_component[component_index] = component
component_index += 1
X.append(GenerateFeaturesRaw([content],
COMPONENT_FEATURES,
top_words))
y.append(component_to_index[component])
return X, y, index_to_component
def spam_from_file(f):
"""Reads a training data file and returns an array."""
rows = []
skipped_rows = 0
for row in csv.reader(f):
if len(row) == len(SPAM_COLUMNS):
# Throw out email field.
rows.append(row[:3])
elif len(row) == len(LEGACY_CSV_COLUMNS):
rows.append(row)
else:
skipped_rows += 1
return rows, skipped_rows
def component_from_file(f):
"""Reads a training data file and returns an array."""
rows = []
csv.field_size_limit(sys.maxsize)
for row in csv.reader(f):
rows.append(row)
return rows
def setup_ml_engine():
"""Sets up an instance of ml engine for ml classes."""
try:
credentials = GoogleCredentials.get_application_default()
ml_engine = build('ml', 'v1', credentials=credentials)
return ml_engine
except (Oauth2ClientError, ApiClientError):
logging.error("Error setting up ML Engine API: %s" % sys.exc_info()[0])