blob: 0def4b6ce56095096fdf9f8949efd8906e6cf36b [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2018 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6from __future__ import print_function
7from __future__ import division
8from __future__ import absolute_import
9
10import StringIO
11import tensorflow as tf
12
13import csv
14import sys
15from googleapiclient import discovery
16from googleapiclient import errors
17from oauth2client.client import GoogleCredentials
18
19import trainer.ml_helpers
20
21
22def fetch_training_data(bucket, prefix, trainer_type):
23
24 credentials = GoogleCredentials.get_application_default()
25 storage = discovery.build('storage', 'v1', credentials=credentials)
26 objects = storage.objects()
27
28 request = objects.list(bucket=bucket, prefix=prefix)
29 response = make_api_request(request)
30 items = response.get('items')
31 csv_filepaths = [blob.get('name') for blob in items]
32
33 if trainer_type == 'spam':
34 return fetch_spam(csv_filepaths, bucket, objects)
35 else:
36 return fetch_component(csv_filepaths, bucket, objects)
37
38
39def fetch_spam(csv_filepaths, bucket, objects):
40
41 training_data = []
42 # Add code
43 csv_filepaths = [
44 'spam-training-data/full-android.csv',
45 'spam-training-data/full-support.csv',
46 ] + csv_filepaths
47
48 for filepath in csv_filepaths:
49 media = fetch_training_csv(filepath, objects, bucket)
50 rows, skipped_rows = trainer.ml_helpers.spam_from_file(
51 StringIO.StringIO(media))
52
53 if len(rows):
54 training_data.extend(rows)
55
56 tf.logging.info('{:<40}{:<20}{:<20}'.format(
57 filepath,
58 'added %d rows' % len(rows),
59 'skipped %d rows' % skipped_rows))
60
61 return training_data
62
63
64def fetch_component(csv_filepaths, bucket, objects):
65
66 training_data = []
67 for filepath in csv_filepaths:
68 media = fetch_training_csv(filepath, objects, bucket)
69 rows = trainer.ml_helpers.component_from_file(
70 StringIO.StringIO(media))
71
72 if len(rows):
73 training_data.extend(rows)
74
75 tf.logging.info('{:<40}{:<20}'.format(
76 filepath,
77 'added %d rows' % len(rows)))
78
79 return training_data
80
81
82def fetch_training_csv(filepath, objects, bucket):
83 request = objects.get_media(bucket=bucket, object=filepath)
84 return make_api_request(request)
85
86
87def make_api_request(request):
88 try:
89 return request.execute()
90 except errors.HttpError, err:
91 tf.logging.error('There was an error with the API. Details:')
92 tf.logging.error(err._get_reason())
93 raise
94
95