blob: 7416c6818272f895609b5480e048ca1a676700c8 [file] [log] [blame]
# Copyright 2018 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import re
import numpy as np
import tensorflow as tf
from googleapiclient import discovery
from googleapiclient import errors
from oauth2client.client import GoogleCredentials
from sklearn.model_selection import train_test_split
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.contrib.training.python.training import hparam
from google.cloud.storage import blob, bucket, client
import trainer.dataset
import trainer.model
import trainer.ml_helpers
import trainer.top_words
def generate_experiment_fn(**experiment_args):
"""Create an experiment function.
Args:
experiment_args: keyword arguments to be passed through to experiment
See `tf.contrib.learn.Experiment` for full args.
Returns:
A function:
(tf.contrib.learn.RunConfig, tf.contrib.training.HParams) -> Experiment
This function is used by learn_runner to create an Experiment which
executes model code provided in the form of an Estimator and
input functions.
"""
def _experiment_fn(config, hparams):
index_to_component = {}
if hparams.train_file:
with open(hparams.train_file) as f:
if hparams.trainer_type == 'spam':
training_data = trainer.ml_helpers.spam_from_file(f)
else:
training_data = trainer.ml_helpers.component_from_file(f)
else:
training_data = trainer.dataset.fetch_training_data(hparams.gcs_bucket,
hparams.gcs_prefix, hparams.trainer_type)
tf.logging.info('Training data received. Len: %d' % len(training_data))
if hparams.trainer_type == 'spam':
X, y = trainer.ml_helpers.transform_spam_csv_to_features(
training_data)
else:
top_list = trainer.top_words.make_top_words_list(hparams.job_dir)
X, y, index_to_component = trainer.ml_helpers \
.transform_component_csv_to_features(training_data, top_list)
tf.logging.info('Features generated')
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x=trainer.model.feature_list_to_dict(X_train, hparams.trainer_type),
y=np.array(y_train),
num_epochs=hparams.num_epochs,
batch_size=hparams.train_batch_size,
shuffle=True
)
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x=trainer.model.feature_list_to_dict(X_test, hparams.trainer_type),
y=np.array(y_test),
num_epochs=None,
batch_size=hparams.eval_batch_size,
shuffle=False # Don't shuffle evaluation data
)
tf.logging.info('Numpy fns created')
if hparams.trainer_type == 'component':
store_component_conversion(hparams.job_dir, index_to_component)
return tf.contrib.learn.Experiment(
trainer.model.build_estimator(config=config,
trainer_type=hparams.trainer_type,
class_count=len(set(y))),
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
**experiment_args
)
return _experiment_fn
def store_component_conversion(job_dir, data):
tf.logging.info('job_dir: %s' % job_dir)
job_info = re.search('gs://(monorail-.+)-mlengine/(component_trainer_\d+)',
job_dir)
# Check if training is being done on GAE or locally.
if job_info:
project = job_info.group(1)
job_name = job_info.group(2)
client_obj = client.Client(project=project)
bucket_name = '%s-mlengine' % project
bucket_obj = bucket.Bucket(client_obj, bucket_name)
bucket_obj.blob = blob.Blob(job_name + '/component_index.json', bucket_obj)
bucket_obj.blob.upload_from_string(json.dumps(data),
content_type='application/json')
else:
paths = job_dir.split('/')
for y, _ in enumerate(list(range(1, len(paths))), 1):
if not os.path.exists("/".join(paths[:y+1])):
os.makedirs('/'.join(paths[:y+1]))
with open(job_dir + '/component_index.json', 'w') as f:
f.write(json.dumps(data))
def store_eval(job_dir, results):
tf.logging.info('job_dir: %s' % job_dir)
job_info = re.search('gs://(monorail-.+)-mlengine/(spam_trainer_\d+)',
job_dir)
# Only upload eval data if this is not being run locally.
if job_info:
project = job_info.group(1)
job_name = job_info.group(2)
tf.logging.info('project: %s' % project)
tf.logging.info('job_name: %s' % job_name)
client_obj = client.Client(project=project)
bucket_name = '%s-mlengine' % project
bucket_obj = bucket.Bucket(client_obj, bucket_name)
bucket_obj.blob = blob.Blob(job_name + '/eval_data.json', bucket_obj)
for key, value in results[0].items():
if isinstance(value, np.float32):
results[0][key] = value.item()
bucket_obj.blob.upload_from_string(json.dumps(results[0]),
content_type='application/json')
else:
tf.logging.error('Could not find bucket "%s" to output evalution to.'
% job_dir)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Input Arguments
parser.add_argument(
'--train-file',
help='GCS or local path to training data',
)
parser.add_argument(
'--gcs-bucket',
help='GCS bucket for training data.',
)
parser.add_argument(
'--gcs-prefix',
help='Training data path prefix inside GCS bucket.',
)
parser.add_argument(
'--num-epochs',
help="""\
Maximum number of training data epochs on which to train.
If both --max-steps and --num-epochs are specified,
the training job will run for --max-steps or --num-epochs,
whichever occurs first. If unspecified will run for --max-steps.\
""",
type=int,
)
parser.add_argument(
'--train-batch-size',
help='Batch size for training steps',
type=int,
default=128
)
parser.add_argument(
'--eval-batch-size',
help='Batch size for evaluation steps',
type=int,
default=128
)
# Training arguments
parser.add_argument(
'--job-dir',
help='GCS location to write checkpoints and export models',
required=True
)
# Logging arguments
parser.add_argument(
'--verbosity',
choices=[
'DEBUG',
'ERROR',
'FATAL',
'INFO',
'WARN'
],
default='INFO',
)
# Experiment arguments
parser.add_argument(
'--eval-delay-secs',
help='How long to wait before running first evaluation',
default=10,
type=int
)
parser.add_argument(
'--min-eval-frequency',
help='Minimum number of training steps between evaluations',
default=None, # Use TensorFlow's default (currently, 1000)
type=int
)
parser.add_argument(
'--train-steps',
help="""\
Steps to run the training job for. If --num-epochs is not specified,
this must be. Otherwise the training job will run indefinitely.\
""",
type=int
)
parser.add_argument(
'--eval-steps',
help='Number of steps to run evalution for at each checkpoint',
default=100,
type=int
)
parser.add_argument(
'--trainer-type',
help='Which trainer to use (spam or component)',
choices=['spam', 'component'],
required=True
)
args = parser.parse_args()
tf.logging.set_verbosity(args.verbosity)
# Run the training job
# learn_runner pulls configuration information from environment
# variables using tf.learn.RunConfig and uses this configuration
# to conditionally execute Experiment, or param server code.
eval_results = learn_runner.run(
generate_experiment_fn(
min_eval_frequency=args.min_eval_frequency,
eval_delay_secs=args.eval_delay_secs,
train_steps=args.train_steps,
eval_steps=args.eval_steps,
export_strategies=[saved_model_export_utils.make_export_strategy(
trainer.model.SERVING_FUNCTIONS['JSON-' + args.trainer_type],
exports_to_keep=1,
default_output_alternative_key=None,
)],
),
run_config=run_config.RunConfig(model_dir=args.job_dir),
hparams=hparam.HParams(**args.__dict__)
)
# Store a json blob in GCS with the results of training job (AUC of
# precision/recall, etc).
if args.trainer_type == 'spam':
store_eval(args.job_dir, eval_results)