blob: 26da2118e712358235591017938b3fd6492bae77 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2018 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6from __future__ import absolute_import
7from __future__ import division
8from __future__ import print_function
9
10import csv
11import os
12import re
13import StringIO
14import sys
15import tensorflow as tf
16import time
17
18from googleapiclient import discovery
19from googleapiclient import errors
20from oauth2client.client import GoogleCredentials
21import google
22from google.cloud.storage import blob, bucket, client
23
24import trainer.ml_helpers
25import trainer.dataset
26
27
28TOP_WORDS = 'topwords.txt'
29STOP_WORDS = 'stopwords.txt'
30
31
32def fetch_stop_words(project_id, objects):
33 request = objects.get_media(bucket=project_id + '-mlengine',
34 object=STOP_WORDS)
35 response = trainer.dataset.make_api_request(request)
36 return response.split()
37
38
39def fetch_training_csv(filepath, objects, b):
40 request = objects.get_media(bucket=b, object=filepath)
41 return trainer.dataset.make_api_request(request)
42
43
44def GenerateTopWords(objects, word_dict, project_id):
45 stop_words = fetch_stop_words(project_id, objects)
46 sorted_words = sorted(word_dict, key=word_dict.get, reverse=True)
47
48 top_words = []
49 index = 0
50
51 while len(top_words) < trainer.ml_helpers.COMPONENT_FEATURES:
52 if sorted_words[index] not in stop_words:
53 top_words.append(sorted_words[index])
54 index += 1
55
56 return top_words
57
58
59def make_top_words_list(job_dir):
60 """Returns the top (most common) words in the entire dataset for component
61 prediction. If a file is already stored in GCS containing these words, the
62 words from the file are simply returned. Otherwise, the most common words are
63 determined and written to GCS, before being returned.
64
65 Returns:
66 A list of the most common words in the dataset (the number of them
67 determined by ml_helpers.COMPONENT_FEATURES).
68 """
69
70 credentials = GoogleCredentials.get_application_default()
71 storage = discovery.build('storage', 'v1', credentials=credentials)
72 objects = storage.objects()
73
74 subpaths = re.match('gs://(monorail-.*)-mlengine/(component_trainer_\d+)',
75 job_dir)
76
77 if subpaths:
78 project_id = subpaths.group(1)
79 trainer_folder = subpaths.group(2)
80 else:
81 project_id = 'monorail-prod'
82
83 storage_bucket = project_id + '.appspot.com'
84 request = objects.list(bucket=storage_bucket,
85 prefix='component_training_data')
86
87 response = trainer.dataset.make_api_request(request)
88
89 items = response.get('items')
90 csv_filepaths = [b.get('name') for b in items]
91
92 final_string = ''
93
94 for word in parse_words(csv_filepaths, objects, storage_bucket, project_id):
95 final_string += word + '\n'
96
97 if subpaths:
98 client_obj = client.Client(project=project_id)
99 bucket_obj = bucket.Bucket(client_obj, project_id + '-mlengine')
100
101 bucket_obj.blob = google.cloud.storage.blob.Blob(trainer_folder
102 + '/'
103 + TOP_WORDS,
104 bucket_obj)
105 bucket_obj.blob.upload_from_string(final_string,
106 content_type='text/plain')
107 return final_string.split()
108
109
110def parse_words(files, objects, b, project_id):
111 word_dict = {}
112
113 csv.field_size_limit(sys.maxsize)
114 for filepath in files:
115 media = fetch_training_csv(filepath, objects, b)
116
117 for row in csv.reader(StringIO.StringIO(media)):
118 _, content = row
119 words = content.split()
120
121 for word in words:
122 if word in word_dict:
123 word_dict[word] += 1
124 else:
125 word_dict[word] = 1
126
127 return GenerateTopWords(objects, word_dict, project_id)