Project import generated by Copybara.
GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/tools/ml/trainer/task.py b/tools/ml/trainer/task.py
new file mode 100644
index 0000000..7416c68
--- /dev/null
+++ b/tools/ml/trainer/task.py
@@ -0,0 +1,284 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import json
+import os
+import re
+
+import numpy as np
+import tensorflow as tf
+from googleapiclient import discovery
+from googleapiclient import errors
+from oauth2client.client import GoogleCredentials
+from sklearn.model_selection import train_test_split
+from tensorflow.contrib.learn.python.learn import learn_runner
+from tensorflow.contrib.learn.python.learn.estimators import run_config
+from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
+from tensorflow.contrib.training.python.training import hparam
+
+from google.cloud.storage import blob, bucket, client
+
+import trainer.dataset
+import trainer.model
+import trainer.ml_helpers
+import trainer.top_words
+
+def generate_experiment_fn(**experiment_args):
+ """Create an experiment function.
+
+ Args:
+ experiment_args: keyword arguments to be passed through to experiment
+ See `tf.contrib.learn.Experiment` for full args.
+ Returns:
+ A function:
+ (tf.contrib.learn.RunConfig, tf.contrib.training.HParams) -> Experiment
+
+ This function is used by learn_runner to create an Experiment which
+ executes model code provided in the form of an Estimator and
+ input functions.
+ """
+ def _experiment_fn(config, hparams):
+ index_to_component = {}
+
+ if hparams.train_file:
+ with open(hparams.train_file) as f:
+ if hparams.trainer_type == 'spam':
+ training_data = trainer.ml_helpers.spam_from_file(f)
+ else:
+ training_data = trainer.ml_helpers.component_from_file(f)
+ else:
+ training_data = trainer.dataset.fetch_training_data(hparams.gcs_bucket,
+ hparams.gcs_prefix, hparams.trainer_type)
+
+ tf.logging.info('Training data received. Len: %d' % len(training_data))
+
+ if hparams.trainer_type == 'spam':
+ X, y = trainer.ml_helpers.transform_spam_csv_to_features(
+ training_data)
+ else:
+ top_list = trainer.top_words.make_top_words_list(hparams.job_dir)
+ X, y, index_to_component = trainer.ml_helpers \
+ .transform_component_csv_to_features(training_data, top_list)
+
+ tf.logging.info('Features generated')
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
+ random_state=42)
+
+ train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x=trainer.model.feature_list_to_dict(X_train, hparams.trainer_type),
+ y=np.array(y_train),
+ num_epochs=hparams.num_epochs,
+ batch_size=hparams.train_batch_size,
+ shuffle=True
+ )
+ eval_input_fn = tf.estimator.inputs.numpy_input_fn(
+ x=trainer.model.feature_list_to_dict(X_test, hparams.trainer_type),
+ y=np.array(y_test),
+ num_epochs=None,
+ batch_size=hparams.eval_batch_size,
+ shuffle=False # Don't shuffle evaluation data
+ )
+
+ tf.logging.info('Numpy fns created')
+ if hparams.trainer_type == 'component':
+ store_component_conversion(hparams.job_dir, index_to_component)
+
+ return tf.contrib.learn.Experiment(
+ trainer.model.build_estimator(config=config,
+ trainer_type=hparams.trainer_type,
+ class_count=len(set(y))),
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ **experiment_args
+ )
+ return _experiment_fn
+
+
+def store_component_conversion(job_dir, data):
+
+ tf.logging.info('job_dir: %s' % job_dir)
+ job_info = re.search('gs://(monorail-.+)-mlengine/(component_trainer_\d+)',
+ job_dir)
+
+ # Check if training is being done on GAE or locally.
+ if job_info:
+ project = job_info.group(1)
+ job_name = job_info.group(2)
+
+ client_obj = client.Client(project=project)
+ bucket_name = '%s-mlengine' % project
+ bucket_obj = bucket.Bucket(client_obj, bucket_name)
+
+ bucket_obj.blob = blob.Blob(job_name + '/component_index.json', bucket_obj)
+
+ bucket_obj.blob.upload_from_string(json.dumps(data),
+ content_type='application/json')
+
+ else:
+ paths = job_dir.split('/')
+ for y, _ in enumerate(list(range(1, len(paths))), 1):
+ if not os.path.exists("/".join(paths[:y+1])):
+ os.makedirs('/'.join(paths[:y+1]))
+ with open(job_dir + '/component_index.json', 'w') as f:
+ f.write(json.dumps(data))
+
+
+def store_eval(job_dir, results):
+
+ tf.logging.info('job_dir: %s' % job_dir)
+ job_info = re.search('gs://(monorail-.+)-mlengine/(spam_trainer_\d+)',
+ job_dir)
+
+ # Only upload eval data if this is not being run locally.
+ if job_info:
+ project = job_info.group(1)
+ job_name = job_info.group(2)
+
+ tf.logging.info('project: %s' % project)
+ tf.logging.info('job_name: %s' % job_name)
+
+ client_obj = client.Client(project=project)
+ bucket_name = '%s-mlengine' % project
+ bucket_obj = bucket.Bucket(client_obj, bucket_name)
+
+ bucket_obj.blob = blob.Blob(job_name + '/eval_data.json', bucket_obj)
+ for key, value in results[0].items():
+ if isinstance(value, np.float32):
+ results[0][key] = value.item()
+
+ bucket_obj.blob.upload_from_string(json.dumps(results[0]),
+ content_type='application/json')
+
+ else:
+ tf.logging.error('Could not find bucket "%s" to output evalution to.'
+ % job_dir)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ # Input Arguments
+ parser.add_argument(
+ '--train-file',
+ help='GCS or local path to training data',
+ )
+ parser.add_argument(
+ '--gcs-bucket',
+ help='GCS bucket for training data.',
+ )
+ parser.add_argument(
+ '--gcs-prefix',
+ help='Training data path prefix inside GCS bucket.',
+ )
+ parser.add_argument(
+ '--num-epochs',
+ help="""\
+ Maximum number of training data epochs on which to train.
+ If both --max-steps and --num-epochs are specified,
+ the training job will run for --max-steps or --num-epochs,
+ whichever occurs first. If unspecified will run for --max-steps.\
+ """,
+ type=int,
+ )
+ parser.add_argument(
+ '--train-batch-size',
+ help='Batch size for training steps',
+ type=int,
+ default=128
+ )
+ parser.add_argument(
+ '--eval-batch-size',
+ help='Batch size for evaluation steps',
+ type=int,
+ default=128
+ )
+
+ # Training arguments
+ parser.add_argument(
+ '--job-dir',
+ help='GCS location to write checkpoints and export models',
+ required=True
+ )
+
+ # Logging arguments
+ parser.add_argument(
+ '--verbosity',
+ choices=[
+ 'DEBUG',
+ 'ERROR',
+ 'FATAL',
+ 'INFO',
+ 'WARN'
+ ],
+ default='INFO',
+ )
+
+ # Experiment arguments
+ parser.add_argument(
+ '--eval-delay-secs',
+ help='How long to wait before running first evaluation',
+ default=10,
+ type=int
+ )
+ parser.add_argument(
+ '--min-eval-frequency',
+ help='Minimum number of training steps between evaluations',
+ default=None, # Use TensorFlow's default (currently, 1000)
+ type=int
+ )
+ parser.add_argument(
+ '--train-steps',
+ help="""\
+ Steps to run the training job for. If --num-epochs is not specified,
+ this must be. Otherwise the training job will run indefinitely.\
+ """,
+ type=int
+ )
+ parser.add_argument(
+ '--eval-steps',
+ help='Number of steps to run evalution for at each checkpoint',
+ default=100,
+ type=int
+ )
+ parser.add_argument(
+ '--trainer-type',
+ help='Which trainer to use (spam or component)',
+ choices=['spam', 'component'],
+ required=True
+ )
+
+ args = parser.parse_args()
+
+ tf.logging.set_verbosity(args.verbosity)
+
+ # Run the training job
+ # learn_runner pulls configuration information from environment
+ # variables using tf.learn.RunConfig and uses this configuration
+ # to conditionally execute Experiment, or param server code.
+ eval_results = learn_runner.run(
+ generate_experiment_fn(
+ min_eval_frequency=args.min_eval_frequency,
+ eval_delay_secs=args.eval_delay_secs,
+ train_steps=args.train_steps,
+ eval_steps=args.eval_steps,
+ export_strategies=[saved_model_export_utils.make_export_strategy(
+ trainer.model.SERVING_FUNCTIONS['JSON-' + args.trainer_type],
+ exports_to_keep=1,
+ default_output_alternative_key=None,
+ )],
+ ),
+ run_config=run_config.RunConfig(model_dir=args.job_dir),
+ hparams=hparam.HParams(**args.__dict__)
+ )
+
+ # Store a json blob in GCS with the results of training job (AUC of
+ # precision/recall, etc).
+ if args.trainer_type == 'spam':
+ store_eval(args.job_dir, eval_results)