Copybara | 854996b | 2021-09-07 19:36:02 +0000 | [diff] [blame] | 1 | # Copyright 2018 The Chromium Authors. All rights reserved. |
| 2 | # Use of this source code is governed by a BSD-style |
| 3 | # license that can be found in the LICENSE file or at |
| 4 | # https://developers.google.com/open-source/licenses/bsd |
| 5 | |
| 6 | from __future__ import absolute_import |
| 7 | from __future__ import division |
| 8 | from __future__ import print_function |
| 9 | |
| 10 | import argparse |
| 11 | import json |
| 12 | import os |
| 13 | import re |
| 14 | |
| 15 | import numpy as np |
| 16 | import tensorflow as tf |
| 17 | from googleapiclient import discovery |
| 18 | from googleapiclient import errors |
| 19 | from oauth2client.client import GoogleCredentials |
| 20 | from sklearn.model_selection import train_test_split |
| 21 | from tensorflow.contrib.learn.python.learn import learn_runner |
| 22 | from tensorflow.contrib.learn.python.learn.estimators import run_config |
| 23 | from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils |
| 24 | from tensorflow.contrib.training.python.training import hparam |
| 25 | |
| 26 | from google.cloud.storage import blob, bucket, client |
| 27 | |
| 28 | import trainer.dataset |
| 29 | import trainer.model |
| 30 | import trainer.ml_helpers |
| 31 | import trainer.top_words |
| 32 | |
| 33 | def generate_experiment_fn(**experiment_args): |
| 34 | """Create an experiment function. |
| 35 | |
| 36 | Args: |
| 37 | experiment_args: keyword arguments to be passed through to experiment |
| 38 | See `tf.contrib.learn.Experiment` for full args. |
| 39 | Returns: |
| 40 | A function: |
| 41 | (tf.contrib.learn.RunConfig, tf.contrib.training.HParams) -> Experiment |
| 42 | |
| 43 | This function is used by learn_runner to create an Experiment which |
| 44 | executes model code provided in the form of an Estimator and |
| 45 | input functions. |
| 46 | """ |
| 47 | def _experiment_fn(config, hparams): |
| 48 | index_to_component = {} |
| 49 | |
| 50 | if hparams.train_file: |
| 51 | with open(hparams.train_file) as f: |
| 52 | if hparams.trainer_type == 'spam': |
| 53 | training_data = trainer.ml_helpers.spam_from_file(f) |
| 54 | else: |
| 55 | training_data = trainer.ml_helpers.component_from_file(f) |
| 56 | else: |
| 57 | training_data = trainer.dataset.fetch_training_data(hparams.gcs_bucket, |
| 58 | hparams.gcs_prefix, hparams.trainer_type) |
| 59 | |
| 60 | tf.logging.info('Training data received. Len: %d' % len(training_data)) |
| 61 | |
| 62 | if hparams.trainer_type == 'spam': |
| 63 | X, y = trainer.ml_helpers.transform_spam_csv_to_features( |
| 64 | training_data) |
| 65 | else: |
| 66 | top_list = trainer.top_words.make_top_words_list(hparams.job_dir) |
| 67 | X, y, index_to_component = trainer.ml_helpers \ |
| 68 | .transform_component_csv_to_features(training_data, top_list) |
| 69 | |
| 70 | tf.logging.info('Features generated') |
| 71 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, |
| 72 | random_state=42) |
| 73 | |
| 74 | train_input_fn = tf.estimator.inputs.numpy_input_fn( |
| 75 | x=trainer.model.feature_list_to_dict(X_train, hparams.trainer_type), |
| 76 | y=np.array(y_train), |
| 77 | num_epochs=hparams.num_epochs, |
| 78 | batch_size=hparams.train_batch_size, |
| 79 | shuffle=True |
| 80 | ) |
| 81 | eval_input_fn = tf.estimator.inputs.numpy_input_fn( |
| 82 | x=trainer.model.feature_list_to_dict(X_test, hparams.trainer_type), |
| 83 | y=np.array(y_test), |
| 84 | num_epochs=None, |
| 85 | batch_size=hparams.eval_batch_size, |
| 86 | shuffle=False # Don't shuffle evaluation data |
| 87 | ) |
| 88 | |
| 89 | tf.logging.info('Numpy fns created') |
| 90 | if hparams.trainer_type == 'component': |
| 91 | store_component_conversion(hparams.job_dir, index_to_component) |
| 92 | |
| 93 | return tf.contrib.learn.Experiment( |
| 94 | trainer.model.build_estimator(config=config, |
| 95 | trainer_type=hparams.trainer_type, |
| 96 | class_count=len(set(y))), |
| 97 | train_input_fn=train_input_fn, |
| 98 | eval_input_fn=eval_input_fn, |
| 99 | **experiment_args |
| 100 | ) |
| 101 | return _experiment_fn |
| 102 | |
| 103 | |
| 104 | def store_component_conversion(job_dir, data): |
| 105 | |
| 106 | tf.logging.info('job_dir: %s' % job_dir) |
| 107 | job_info = re.search('gs://(monorail-.+)-mlengine/(component_trainer_\d+)', |
| 108 | job_dir) |
| 109 | |
| 110 | # Check if training is being done on GAE or locally. |
| 111 | if job_info: |
| 112 | project = job_info.group(1) |
| 113 | job_name = job_info.group(2) |
| 114 | |
| 115 | client_obj = client.Client(project=project) |
| 116 | bucket_name = '%s-mlengine' % project |
| 117 | bucket_obj = bucket.Bucket(client_obj, bucket_name) |
| 118 | |
| 119 | bucket_obj.blob = blob.Blob(job_name + '/component_index.json', bucket_obj) |
| 120 | |
| 121 | bucket_obj.blob.upload_from_string(json.dumps(data), |
| 122 | content_type='application/json') |
| 123 | |
| 124 | else: |
| 125 | paths = job_dir.split('/') |
| 126 | for y, _ in enumerate(list(range(1, len(paths))), 1): |
| 127 | if not os.path.exists("/".join(paths[:y+1])): |
| 128 | os.makedirs('/'.join(paths[:y+1])) |
| 129 | with open(job_dir + '/component_index.json', 'w') as f: |
| 130 | f.write(json.dumps(data)) |
| 131 | |
| 132 | |
| 133 | def store_eval(job_dir, results): |
| 134 | |
| 135 | tf.logging.info('job_dir: %s' % job_dir) |
| 136 | job_info = re.search('gs://(monorail-.+)-mlengine/(spam_trainer_\d+)', |
| 137 | job_dir) |
| 138 | |
| 139 | # Only upload eval data if this is not being run locally. |
| 140 | if job_info: |
| 141 | project = job_info.group(1) |
| 142 | job_name = job_info.group(2) |
| 143 | |
| 144 | tf.logging.info('project: %s' % project) |
| 145 | tf.logging.info('job_name: %s' % job_name) |
| 146 | |
| 147 | client_obj = client.Client(project=project) |
| 148 | bucket_name = '%s-mlengine' % project |
| 149 | bucket_obj = bucket.Bucket(client_obj, bucket_name) |
| 150 | |
| 151 | bucket_obj.blob = blob.Blob(job_name + '/eval_data.json', bucket_obj) |
| 152 | for key, value in results[0].items(): |
| 153 | if isinstance(value, np.float32): |
| 154 | results[0][key] = value.item() |
| 155 | |
| 156 | bucket_obj.blob.upload_from_string(json.dumps(results[0]), |
| 157 | content_type='application/json') |
| 158 | |
| 159 | else: |
| 160 | tf.logging.error('Could not find bucket "%s" to output evalution to.' |
| 161 | % job_dir) |
| 162 | |
| 163 | |
| 164 | if __name__ == '__main__': |
| 165 | parser = argparse.ArgumentParser() |
| 166 | |
| 167 | # Input Arguments |
| 168 | parser.add_argument( |
| 169 | '--train-file', |
| 170 | help='GCS or local path to training data', |
| 171 | ) |
| 172 | parser.add_argument( |
| 173 | '--gcs-bucket', |
| 174 | help='GCS bucket for training data.', |
| 175 | ) |
| 176 | parser.add_argument( |
| 177 | '--gcs-prefix', |
| 178 | help='Training data path prefix inside GCS bucket.', |
| 179 | ) |
| 180 | parser.add_argument( |
| 181 | '--num-epochs', |
| 182 | help="""\ |
| 183 | Maximum number of training data epochs on which to train. |
| 184 | If both --max-steps and --num-epochs are specified, |
| 185 | the training job will run for --max-steps or --num-epochs, |
| 186 | whichever occurs first. If unspecified will run for --max-steps.\ |
| 187 | """, |
| 188 | type=int, |
| 189 | ) |
| 190 | parser.add_argument( |
| 191 | '--train-batch-size', |
| 192 | help='Batch size for training steps', |
| 193 | type=int, |
| 194 | default=128 |
| 195 | ) |
| 196 | parser.add_argument( |
| 197 | '--eval-batch-size', |
| 198 | help='Batch size for evaluation steps', |
| 199 | type=int, |
| 200 | default=128 |
| 201 | ) |
| 202 | |
| 203 | # Training arguments |
| 204 | parser.add_argument( |
| 205 | '--job-dir', |
| 206 | help='GCS location to write checkpoints and export models', |
| 207 | required=True |
| 208 | ) |
| 209 | |
| 210 | # Logging arguments |
| 211 | parser.add_argument( |
| 212 | '--verbosity', |
| 213 | choices=[ |
| 214 | 'DEBUG', |
| 215 | 'ERROR', |
| 216 | 'FATAL', |
| 217 | 'INFO', |
| 218 | 'WARN' |
| 219 | ], |
| 220 | default='INFO', |
| 221 | ) |
| 222 | |
| 223 | # Experiment arguments |
| 224 | parser.add_argument( |
| 225 | '--eval-delay-secs', |
| 226 | help='How long to wait before running first evaluation', |
| 227 | default=10, |
| 228 | type=int |
| 229 | ) |
| 230 | parser.add_argument( |
| 231 | '--min-eval-frequency', |
| 232 | help='Minimum number of training steps between evaluations', |
| 233 | default=None, # Use TensorFlow's default (currently, 1000) |
| 234 | type=int |
| 235 | ) |
| 236 | parser.add_argument( |
| 237 | '--train-steps', |
| 238 | help="""\ |
| 239 | Steps to run the training job for. If --num-epochs is not specified, |
| 240 | this must be. Otherwise the training job will run indefinitely.\ |
| 241 | """, |
| 242 | type=int |
| 243 | ) |
| 244 | parser.add_argument( |
| 245 | '--eval-steps', |
| 246 | help='Number of steps to run evalution for at each checkpoint', |
| 247 | default=100, |
| 248 | type=int |
| 249 | ) |
| 250 | parser.add_argument( |
| 251 | '--trainer-type', |
| 252 | help='Which trainer to use (spam or component)', |
| 253 | choices=['spam', 'component'], |
| 254 | required=True |
| 255 | ) |
| 256 | |
| 257 | args = parser.parse_args() |
| 258 | |
| 259 | tf.logging.set_verbosity(args.verbosity) |
| 260 | |
| 261 | # Run the training job |
| 262 | # learn_runner pulls configuration information from environment |
| 263 | # variables using tf.learn.RunConfig and uses this configuration |
| 264 | # to conditionally execute Experiment, or param server code. |
| 265 | eval_results = learn_runner.run( |
| 266 | generate_experiment_fn( |
| 267 | min_eval_frequency=args.min_eval_frequency, |
| 268 | eval_delay_secs=args.eval_delay_secs, |
| 269 | train_steps=args.train_steps, |
| 270 | eval_steps=args.eval_steps, |
| 271 | export_strategies=[saved_model_export_utils.make_export_strategy( |
| 272 | trainer.model.SERVING_FUNCTIONS['JSON-' + args.trainer_type], |
| 273 | exports_to_keep=1, |
| 274 | default_output_alternative_key=None, |
| 275 | )], |
| 276 | ), |
| 277 | run_config=run_config.RunConfig(model_dir=args.job_dir), |
| 278 | hparams=hparam.HParams(**args.__dict__) |
| 279 | ) |
| 280 | |
| 281 | # Store a json blob in GCS with the results of training job (AUC of |
| 282 | # precision/recall, etc). |
| 283 | if args.trainer_type == 'spam': |
| 284 | store_eval(args.job_dir, eval_results) |