blob: 7416c6818272f895609b5480e048ca1a676700c8 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2018 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6from __future__ import absolute_import
7from __future__ import division
8from __future__ import print_function
9
10import argparse
11import json
12import os
13import re
14
15import numpy as np
16import tensorflow as tf
17from googleapiclient import discovery
18from googleapiclient import errors
19from oauth2client.client import GoogleCredentials
20from sklearn.model_selection import train_test_split
21from tensorflow.contrib.learn.python.learn import learn_runner
22from tensorflow.contrib.learn.python.learn.estimators import run_config
23from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
24from tensorflow.contrib.training.python.training import hparam
25
26from google.cloud.storage import blob, bucket, client
27
28import trainer.dataset
29import trainer.model
30import trainer.ml_helpers
31import trainer.top_words
32
33def generate_experiment_fn(**experiment_args):
34 """Create an experiment function.
35
36 Args:
37 experiment_args: keyword arguments to be passed through to experiment
38 See `tf.contrib.learn.Experiment` for full args.
39 Returns:
40 A function:
41 (tf.contrib.learn.RunConfig, tf.contrib.training.HParams) -> Experiment
42
43 This function is used by learn_runner to create an Experiment which
44 executes model code provided in the form of an Estimator and
45 input functions.
46 """
47 def _experiment_fn(config, hparams):
48 index_to_component = {}
49
50 if hparams.train_file:
51 with open(hparams.train_file) as f:
52 if hparams.trainer_type == 'spam':
53 training_data = trainer.ml_helpers.spam_from_file(f)
54 else:
55 training_data = trainer.ml_helpers.component_from_file(f)
56 else:
57 training_data = trainer.dataset.fetch_training_data(hparams.gcs_bucket,
58 hparams.gcs_prefix, hparams.trainer_type)
59
60 tf.logging.info('Training data received. Len: %d' % len(training_data))
61
62 if hparams.trainer_type == 'spam':
63 X, y = trainer.ml_helpers.transform_spam_csv_to_features(
64 training_data)
65 else:
66 top_list = trainer.top_words.make_top_words_list(hparams.job_dir)
67 X, y, index_to_component = trainer.ml_helpers \
68 .transform_component_csv_to_features(training_data, top_list)
69
70 tf.logging.info('Features generated')
71 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
72 random_state=42)
73
74 train_input_fn = tf.estimator.inputs.numpy_input_fn(
75 x=trainer.model.feature_list_to_dict(X_train, hparams.trainer_type),
76 y=np.array(y_train),
77 num_epochs=hparams.num_epochs,
78 batch_size=hparams.train_batch_size,
79 shuffle=True
80 )
81 eval_input_fn = tf.estimator.inputs.numpy_input_fn(
82 x=trainer.model.feature_list_to_dict(X_test, hparams.trainer_type),
83 y=np.array(y_test),
84 num_epochs=None,
85 batch_size=hparams.eval_batch_size,
86 shuffle=False # Don't shuffle evaluation data
87 )
88
89 tf.logging.info('Numpy fns created')
90 if hparams.trainer_type == 'component':
91 store_component_conversion(hparams.job_dir, index_to_component)
92
93 return tf.contrib.learn.Experiment(
94 trainer.model.build_estimator(config=config,
95 trainer_type=hparams.trainer_type,
96 class_count=len(set(y))),
97 train_input_fn=train_input_fn,
98 eval_input_fn=eval_input_fn,
99 **experiment_args
100 )
101 return _experiment_fn
102
103
104def store_component_conversion(job_dir, data):
105
106 tf.logging.info('job_dir: %s' % job_dir)
107 job_info = re.search('gs://(monorail-.+)-mlengine/(component_trainer_\d+)',
108 job_dir)
109
110 # Check if training is being done on GAE or locally.
111 if job_info:
112 project = job_info.group(1)
113 job_name = job_info.group(2)
114
115 client_obj = client.Client(project=project)
116 bucket_name = '%s-mlengine' % project
117 bucket_obj = bucket.Bucket(client_obj, bucket_name)
118
119 bucket_obj.blob = blob.Blob(job_name + '/component_index.json', bucket_obj)
120
121 bucket_obj.blob.upload_from_string(json.dumps(data),
122 content_type='application/json')
123
124 else:
125 paths = job_dir.split('/')
126 for y, _ in enumerate(list(range(1, len(paths))), 1):
127 if not os.path.exists("/".join(paths[:y+1])):
128 os.makedirs('/'.join(paths[:y+1]))
129 with open(job_dir + '/component_index.json', 'w') as f:
130 f.write(json.dumps(data))
131
132
133def store_eval(job_dir, results):
134
135 tf.logging.info('job_dir: %s' % job_dir)
136 job_info = re.search('gs://(monorail-.+)-mlengine/(spam_trainer_\d+)',
137 job_dir)
138
139 # Only upload eval data if this is not being run locally.
140 if job_info:
141 project = job_info.group(1)
142 job_name = job_info.group(2)
143
144 tf.logging.info('project: %s' % project)
145 tf.logging.info('job_name: %s' % job_name)
146
147 client_obj = client.Client(project=project)
148 bucket_name = '%s-mlengine' % project
149 bucket_obj = bucket.Bucket(client_obj, bucket_name)
150
151 bucket_obj.blob = blob.Blob(job_name + '/eval_data.json', bucket_obj)
152 for key, value in results[0].items():
153 if isinstance(value, np.float32):
154 results[0][key] = value.item()
155
156 bucket_obj.blob.upload_from_string(json.dumps(results[0]),
157 content_type='application/json')
158
159 else:
160 tf.logging.error('Could not find bucket "%s" to output evalution to.'
161 % job_dir)
162
163
164if __name__ == '__main__':
165 parser = argparse.ArgumentParser()
166
167 # Input Arguments
168 parser.add_argument(
169 '--train-file',
170 help='GCS or local path to training data',
171 )
172 parser.add_argument(
173 '--gcs-bucket',
174 help='GCS bucket for training data.',
175 )
176 parser.add_argument(
177 '--gcs-prefix',
178 help='Training data path prefix inside GCS bucket.',
179 )
180 parser.add_argument(
181 '--num-epochs',
182 help="""\
183 Maximum number of training data epochs on which to train.
184 If both --max-steps and --num-epochs are specified,
185 the training job will run for --max-steps or --num-epochs,
186 whichever occurs first. If unspecified will run for --max-steps.\
187 """,
188 type=int,
189 )
190 parser.add_argument(
191 '--train-batch-size',
192 help='Batch size for training steps',
193 type=int,
194 default=128
195 )
196 parser.add_argument(
197 '--eval-batch-size',
198 help='Batch size for evaluation steps',
199 type=int,
200 default=128
201 )
202
203 # Training arguments
204 parser.add_argument(
205 '--job-dir',
206 help='GCS location to write checkpoints and export models',
207 required=True
208 )
209
210 # Logging arguments
211 parser.add_argument(
212 '--verbosity',
213 choices=[
214 'DEBUG',
215 'ERROR',
216 'FATAL',
217 'INFO',
218 'WARN'
219 ],
220 default='INFO',
221 )
222
223 # Experiment arguments
224 parser.add_argument(
225 '--eval-delay-secs',
226 help='How long to wait before running first evaluation',
227 default=10,
228 type=int
229 )
230 parser.add_argument(
231 '--min-eval-frequency',
232 help='Minimum number of training steps between evaluations',
233 default=None, # Use TensorFlow's default (currently, 1000)
234 type=int
235 )
236 parser.add_argument(
237 '--train-steps',
238 help="""\
239 Steps to run the training job for. If --num-epochs is not specified,
240 this must be. Otherwise the training job will run indefinitely.\
241 """,
242 type=int
243 )
244 parser.add_argument(
245 '--eval-steps',
246 help='Number of steps to run evalution for at each checkpoint',
247 default=100,
248 type=int
249 )
250 parser.add_argument(
251 '--trainer-type',
252 help='Which trainer to use (spam or component)',
253 choices=['spam', 'component'],
254 required=True
255 )
256
257 args = parser.parse_args()
258
259 tf.logging.set_verbosity(args.verbosity)
260
261 # Run the training job
262 # learn_runner pulls configuration information from environment
263 # variables using tf.learn.RunConfig and uses this configuration
264 # to conditionally execute Experiment, or param server code.
265 eval_results = learn_runner.run(
266 generate_experiment_fn(
267 min_eval_frequency=args.min_eval_frequency,
268 eval_delay_secs=args.eval_delay_secs,
269 train_steps=args.train_steps,
270 eval_steps=args.eval_steps,
271 export_strategies=[saved_model_export_utils.make_export_strategy(
272 trainer.model.SERVING_FUNCTIONS['JSON-' + args.trainer_type],
273 exports_to_keep=1,
274 default_output_alternative_key=None,
275 )],
276 ),
277 run_config=run_config.RunConfig(model_dir=args.job_dir),
278 hparams=hparam.HParams(**args.__dict__)
279 )
280
281 # Store a json blob in GCS with the results of training job (AUC of
282 # precision/recall, etc).
283 if args.trainer_type == 'spam':
284 store_eval(args.job_dir, eval_results)