blob: d05a5822974e0b1f145b44c5d1cbec3c7a9996e3 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2018 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""
7Helper functions for spam and component classification. These are mostly for
8feature extraction, so that the serving code and training code both use the same
9set of features.
10"""
Adrià Vilanova Martínezde942802022-07-15 14:06:55 +020011# TODO(crbug.com/monorail/7515): DELETE THIS FILE and all references.
Copybara854996b2021-09-07 19:36:02 +000012
13from __future__ import division
14from __future__ import print_function
15from __future__ import absolute_import
16
17import csv
18import hashlib
19import httplib2
20import logging
21import re
22import sys
23
24from six import text_type
25
26from apiclient.discovery import build
27from apiclient.errors import Error as ApiClientError
28from oauth2client.client import GoogleCredentials
29from oauth2client.client import Error as Oauth2ClientError
30
31
32SPAM_COLUMNS = ['verdict', 'subject', 'content', 'email']
33LEGACY_CSV_COLUMNS = ['verdict', 'subject', 'content']
34DELIMITERS = ['\s', '\,', '\.', '\?', '!', '\:', '\(', '\)']
35
36# Must be identical to settings.spam_feature_hashes.
37SPAM_FEATURE_HASHES = 500
38# Must be identical to settings.component_features.
39COMPONENT_FEATURES = 5000
40
41
42def _ComponentFeatures(content, num_features, top_words):
43 """
44 This uses the most common words in the entire dataset as features.
45 The count of common words in the issue comments makes up the features.
46 """
47
48 features = [0] * num_features
49 for blob in content:
50 words = blob.split()
51 for word in words:
52 if word in top_words:
53 features[top_words[word]] += 1
54
55 return features
56
57
58def _SpamHashFeatures(content, num_features):
59 """
60 Feature hashing is a fast and compact way to turn a string of text into a
61 vector of feature values for classification and training.
62 See also: https://en.wikipedia.org/wiki/Feature_hashing
63 This is a simple implementation that doesn't try to minimize collisions
64 or anything else fancy.
65 """
66 features = [0] * num_features
67 total = 0.0
68 for blob in content:
69 words = re.split('|'.join(DELIMITERS), blob)
70 for word in words:
71 encoded_word = word
72 # If we've been passed real unicode strings, convert them to bytestrings.
73 if isinstance(word, text_type):
74 encoded_word = word.encode('utf-8')
75 feature_index = int(
76 int(hashlib.sha1(encoded_word).hexdigest(), 16) % num_features)
77 features[feature_index] += 1.0
78 total += 1.0
79
80 if total > 0:
81 features = [ f / total for f in features ]
82
83 return features
84
85
86def GenerateFeaturesRaw(content, num_features, top_words=None):
87 """Generates a vector of features for a given issue or comment.
88
89 Args:
90 content: The content of the issue's description and comments.
91 num_features: The number of features to generate.
92 """
93 if top_words:
94 return { 'word_features': _ComponentFeatures(content,
95 num_features,
96 top_words)}
97
98 return { 'word_hashes': _SpamHashFeatures(content, num_features)}
99
100
101def transform_spam_csv_to_features(csv_training_data):
102 X = []
103 y = []
104
105 # Handle if the list is double-wrapped.
106 if csv_training_data and len(csv_training_data[0]) > 4:
107 csv_training_data = csv_training_data[0]
108
109 for row in csv_training_data:
110 if len(row) == 4:
111 verdict, subject, content, _email = row
112 else:
113 verdict, subject, content = row
114 X.append(GenerateFeaturesRaw([str(subject), str(content)],
115 SPAM_FEATURE_HASHES))
116 y.append(1 if verdict == 'spam' else 0)
117 return X, y
118
119
120def transform_component_csv_to_features(csv_training_data, top_list):
121 X = []
122 y = []
123 top_words = {}
124
125 for i in range(len(top_list)):
126 top_words[top_list[i]] = i
127
128 component_to_index = {}
129 index_to_component = {}
130 component_index = 0
131
132 for row in csv_training_data:
133 component, content = row
134 component = str(component).split(",")[0]
135
136 if component not in component_to_index:
137 component_to_index[component] = component_index
138 index_to_component[component_index] = component
139 component_index += 1
140
141 X.append(GenerateFeaturesRaw([content],
142 COMPONENT_FEATURES,
143 top_words))
144 y.append(component_to_index[component])
145
146 return X, y, index_to_component
147
148
149def spam_from_file(f):
150 """Reads a training data file and returns an array."""
151 rows = []
152 skipped_rows = 0
153 for row in csv.reader(f):
154 if len(row) == len(SPAM_COLUMNS):
155 # Throw out email field.
156 rows.append(row[:3])
157 elif len(row) == len(LEGACY_CSV_COLUMNS):
158 rows.append(row)
159 else:
160 skipped_rows += 1
161 return rows, skipped_rows
162
163
164def component_from_file(f):
165 """Reads a training data file and returns an array."""
166 rows = []
167 csv.field_size_limit(sys.maxsize)
168 for row in csv.reader(f):
169 rows.append(row)
170
171 return rows
172
173
174def setup_ml_engine():
175 """Sets up an instance of ml engine for ml classes."""
176 try:
177 credentials = GoogleCredentials.get_application_default()
178 ml_engine = build('ml', 'v1', http=httplib2.Http(), credentials=credentials)
179 return ml_engine
180
181 except (Oauth2ClientError, ApiClientError):
182 logging.error("Error setting up ML Engine API: %s" % sys.exc_info()[0])