blob: c4650b4009d8fedc1eaa63b7783d5b9858602b07 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2018 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6"""
7Helper functions for spam and component classification. These are mostly for
8feature extraction, so that the serving code and training code both use the same
9set of features.
10"""
11
12from __future__ import division
13from __future__ import print_function
14from __future__ import absolute_import
15
16import csv
17import hashlib
18import httplib2
19import logging
20import re
21import sys
22
23from six import text_type
24
25from apiclient.discovery import build
26from apiclient.errors import Error as ApiClientError
27from oauth2client.client import GoogleCredentials
28from oauth2client.client import Error as Oauth2ClientError
29
30
31SPAM_COLUMNS = ['verdict', 'subject', 'content', 'email']
32LEGACY_CSV_COLUMNS = ['verdict', 'subject', 'content']
33DELIMITERS = ['\s', '\,', '\.', '\?', '!', '\:', '\(', '\)']
34
35# Must be identical to settings.spam_feature_hashes.
36SPAM_FEATURE_HASHES = 500
37# Must be identical to settings.component_features.
38COMPONENT_FEATURES = 5000
39
40
41def _ComponentFeatures(content, num_features, top_words):
42 """
43 This uses the most common words in the entire dataset as features.
44 The count of common words in the issue comments makes up the features.
45 """
46
47 features = [0] * num_features
48 for blob in content:
49 words = blob.split()
50 for word in words:
51 if word in top_words:
52 features[top_words[word]] += 1
53
54 return features
55
56
57def _SpamHashFeatures(content, num_features):
58 """
59 Feature hashing is a fast and compact way to turn a string of text into a
60 vector of feature values for classification and training.
61 See also: https://en.wikipedia.org/wiki/Feature_hashing
62 This is a simple implementation that doesn't try to minimize collisions
63 or anything else fancy.
64 """
65 features = [0] * num_features
66 total = 0.0
67 for blob in content:
68 words = re.split('|'.join(DELIMITERS), blob)
69 for word in words:
70 encoded_word = word
71 # If we've been passed real unicode strings, convert them to bytestrings.
72 if isinstance(word, text_type):
73 encoded_word = word.encode('utf-8')
74 feature_index = int(
75 int(hashlib.sha1(encoded_word).hexdigest(), 16) % num_features)
76 features[feature_index] += 1.0
77 total += 1.0
78
79 if total > 0:
80 features = [ f / total for f in features ]
81
82 return features
83
84
85def GenerateFeaturesRaw(content, num_features, top_words=None):
86 """Generates a vector of features for a given issue or comment.
87
88 Args:
89 content: The content of the issue's description and comments.
90 num_features: The number of features to generate.
91 """
92 if top_words:
93 return { 'word_features': _ComponentFeatures(content,
94 num_features,
95 top_words)}
96
97 return { 'word_hashes': _SpamHashFeatures(content, num_features)}
98
99
100def transform_spam_csv_to_features(csv_training_data):
101 X = []
102 y = []
103
104 # Handle if the list is double-wrapped.
105 if csv_training_data and len(csv_training_data[0]) > 4:
106 csv_training_data = csv_training_data[0]
107
108 for row in csv_training_data:
109 if len(row) == 4:
110 verdict, subject, content, _email = row
111 else:
112 verdict, subject, content = row
113 X.append(GenerateFeaturesRaw([str(subject), str(content)],
114 SPAM_FEATURE_HASHES))
115 y.append(1 if verdict == 'spam' else 0)
116 return X, y
117
118
119def transform_component_csv_to_features(csv_training_data, top_list):
120 X = []
121 y = []
122 top_words = {}
123
124 for i in range(len(top_list)):
125 top_words[top_list[i]] = i
126
127 component_to_index = {}
128 index_to_component = {}
129 component_index = 0
130
131 for row in csv_training_data:
132 component, content = row
133 component = str(component).split(",")[0]
134
135 if component not in component_to_index:
136 component_to_index[component] = component_index
137 index_to_component[component_index] = component
138 component_index += 1
139
140 X.append(GenerateFeaturesRaw([content],
141 COMPONENT_FEATURES,
142 top_words))
143 y.append(component_to_index[component])
144
145 return X, y, index_to_component
146
147
148def spam_from_file(f):
149 """Reads a training data file and returns an array."""
150 rows = []
151 skipped_rows = 0
152 for row in csv.reader(f):
153 if len(row) == len(SPAM_COLUMNS):
154 # Throw out email field.
155 rows.append(row[:3])
156 elif len(row) == len(LEGACY_CSV_COLUMNS):
157 rows.append(row)
158 else:
159 skipped_rows += 1
160 return rows, skipped_rows
161
162
163def component_from_file(f):
164 """Reads a training data file and returns an array."""
165 rows = []
166 csv.field_size_limit(sys.maxsize)
167 for row in csv.reader(f):
168 rows.append(row)
169
170 return rows
171
172
173def setup_ml_engine():
174 """Sets up an instance of ml engine for ml classes."""
175 try:
176 credentials = GoogleCredentials.get_application_default()
177 ml_engine = build('ml', 'v1', http=httplib2.Http(), credentials=credentials)
178 return ml_engine
179
180 except (Oauth2ClientError, ApiClientError):
181 logging.error("Error setting up ML Engine API: %s" % sys.exc_info()[0])