blob: 45a29cc31ed43a019adfc9b3941fd299f7aa28fa [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# coding=utf-8
2from __future__ import division
3from __future__ import print_function
4from __future__ import absolute_import
5
6import io
7import unittest
8
9from services import ml_helpers
10
11
12NUM_WORD_HASHES = 5
13
14TOP_WORDS = {'cat': 0, 'dog': 1, 'bunny': 2, 'chinchilla': 3, 'hamster': 4}
15NUM_COMPONENT_FEATURES = len(TOP_WORDS)
16
17
18class MLHelpersTest(unittest.TestCase):
19
20 def testSpamHashFeatures(self):
21 hashes = ml_helpers._SpamHashFeatures(tuple(), NUM_WORD_HASHES)
22 self.assertEqual([0, 0, 0, 0, 0], hashes)
23
24 hashes = ml_helpers._SpamHashFeatures(('', ''), NUM_WORD_HASHES)
25 self.assertEqual([1.0, 0, 0, 0, 0], hashes)
26
27 hashes = ml_helpers._SpamHashFeatures(('abc', 'abc def'), NUM_WORD_HASHES)
28 self.assertEqual([0, 0, 2 / 3, 0, 1 / 3], hashes)
29
30 def testComponentFeatures(self):
31
32 features = ml_helpers._ComponentFeatures(['cat dog is not bunny'
33 ' chinchilla hamster'],
34 NUM_COMPONENT_FEATURES,
35 TOP_WORDS)
36 self.assertEqual([1, 1, 1, 1, 1], features)
37
38 features = ml_helpers._ComponentFeatures(['none of these are features'],
39 NUM_COMPONENT_FEATURES,
40 TOP_WORDS)
41 self.assertEqual([0, 0, 0, 0, 0], features)
42
43 features = ml_helpers._ComponentFeatures(['do hamsters look like a'
44 ' chinchilla'],
45 NUM_COMPONENT_FEATURES,
46 TOP_WORDS)
47 self.assertEqual([0, 0, 0, 1, 0], features)
48
49 features = ml_helpers._ComponentFeatures([''],
50 NUM_COMPONENT_FEATURES,
51 TOP_WORDS)
52 self.assertEqual([0, 0, 0, 0, 0], features)
53
54 def testGenerateFeaturesRaw(self):
55
56 features = ml_helpers.GenerateFeaturesRaw(
57 ['abc', 'abc def http://www.google.com http://www.google.com'],
58 NUM_WORD_HASHES)
59 self.assertEqual(
60 [1 / 2.75, 0.0, 1 / 5.5, 0.0, 1 / 2.2], features['word_hashes'])
61
62 features = ml_helpers.GenerateFeaturesRaw(['abc', 'abc def'],
63 NUM_WORD_HASHES)
64 self.assertEqual([0.0, 0.0, 2 / 3, 0.0, 1 / 3], features['word_hashes'])
65
66 features = ml_helpers.GenerateFeaturesRaw(['do hamsters look like a'
67 ' chinchilla'],
68 NUM_COMPONENT_FEATURES,
69 TOP_WORDS)
70 self.assertEqual([0, 0, 0, 1, 0], features['word_features'])
71
72 # BMP Unicode
73 features = ml_helpers.GenerateFeaturesRaw(
74 [u'abc’', u'abc ’ def'], NUM_WORD_HASHES)
75 self.assertEqual([0.0, 0.0, 0.25, 0.25, 0.5], features['word_hashes'])
76
77 # Non-BMP Unicode
78 features = ml_helpers.GenerateFeaturesRaw([u'abc國', u'abc 國 def'],
79 NUM_WORD_HASHES)
80 self.assertEqual([0.0, 0.0, 0.25, 0.25, 0.5], features['word_hashes'])
81
82 # A non-unicode bytestring containing unicode characters
83 features = ml_helpers.GenerateFeaturesRaw(['abc…', 'abc … def'],
84 NUM_WORD_HASHES)
85 self.assertEqual([0.25, 0.0, 0.25, 0.25, 0.25], features['word_hashes'])
86
87 # Empty input
88 features = ml_helpers.GenerateFeaturesRaw(['', ''], NUM_WORD_HASHES)
89 self.assertEqual([1.0, 0.0, 0.0, 0.0, 0.0], features['word_hashes'])
90
91 def test_from_file(self):
92 csv_file = io.StringIO(
93 u'''
94 "spam","the subject 1","the contents 1","spammer@gmail.com"
95 "ham","the subject 2"
96 "spam","the subject 3","the contents 2","spammer2@gmail.com"
97 '''.strip())
98 samples, skipped = ml_helpers.spam_from_file(csv_file)
99 self.assertEqual(len(samples), 2)
100 self.assertEqual(skipped, 1)
101 self.assertEqual(len(samples[1]), 3, 'Strips email')
102 self.assertEqual(samples[1][2], 'the contents 2')
103
104 def test_transform_csv_to_features(self):
105 training_data = [
106 ['spam', 'subject 1', 'contents 1'],
107 ['ham', 'subject 2', 'contents 2'],
108 ['spam', 'subject 3', 'contents 3'],
109 ]
110 X, y = ml_helpers.transform_spam_csv_to_features(training_data)
111
112 self.assertIsInstance(X, list)
113 self.assertIsInstance(X[0], dict)
114 self.assertIsInstance(y, list)
115
116 self.assertEqual(len(X), 3)
117 self.assertEqual(len(y), 3)
118
119 self.assertEqual(len(X[0]['word_hashes']), 500)
120 self.assertEqual(y, [1, 0, 1])