Copybara | 854996b | 2021-09-07 19:36:02 +0000 | [diff] [blame] | 1 | # Copyright 2019 The Chromium Authors. All rights reserved. |
| 2 | # Use of this source code is governed by a BSD-style license that can be |
| 3 | # found in the LICENSE file. |
| 4 | # Or at https://developers.google.com/open-source/licenses/bsd |
| 5 | |
| 6 | from __future__ import absolute_import |
| 7 | |
| 8 | import argparse |
| 9 | import json |
| 10 | import logging |
| 11 | import os |
| 12 | |
| 13 | import tensorflow as tf |
| 14 | from tensorflow.estimator import RunConfig |
| 15 | from sklearn.model_selection import train_test_split |
| 16 | |
| 17 | from trainer2 import dataset |
| 18 | from trainer2 import model |
| 19 | from trainer2 import top_words |
| 20 | from trainer2 import train_ml_helpers |
| 21 | from trainer2.train_ml_helpers import COMPONENT_FEATURES |
| 22 | from trainer2.train_ml_helpers import SPAM_FEATURE_HASHES |
| 23 | |
| 24 | INPUT_TYPE_MAP = { |
| 25 | 'component': {'key': 'word_features', 'shape': (COMPONENT_FEATURES,)}, |
| 26 | 'spam': {'key': 'word_hashes', 'shape': (SPAM_FEATURE_HASHES,)} |
| 27 | } |
| 28 | |
| 29 | |
| 30 | def make_input_fn(trainer_type, features, targets, |
| 31 | num_epochs=None, shuffle=True, batch_size=128): |
| 32 | """Generate input function for training and testing. |
| 33 | |
| 34 | Args: |
| 35 | trainer_type: spam / component |
| 36 | features: an array of features shape like INPUT_TYPE_MAP |
| 37 | targets: an array of labels with the same length of features |
| 38 | num_epochs: training epochs |
| 39 | batch_size: dataset batch size |
| 40 | |
| 41 | Returns: |
| 42 | input function to feed into TrainSpec and EvalSpec. |
| 43 | """ |
| 44 | def _input_fn(): |
| 45 | def gen(): |
| 46 | """Generator function to format feature and target. """ |
| 47 | for feature, target in zip(features, targets): |
| 48 | yield feature[INPUT_TYPE_MAP[trainer_type]['key']], target |
| 49 | |
| 50 | data = tf.data.Dataset.from_generator( |
| 51 | gen, (tf.float64, tf.int32), |
| 52 | output_shapes=(INPUT_TYPE_MAP[trainer_type]['shape'], ())) |
| 53 | data = data.map(lambda x, y: ({INPUT_TYPE_MAP[trainer_type]['key']: x}, y)) |
| 54 | if shuffle: |
| 55 | data = data.shuffle(buffer_size=batch_size * 10) |
| 56 | data = data.repeat(num_epochs).batch(batch_size) |
| 57 | return data |
| 58 | |
| 59 | return _input_fn |
| 60 | |
| 61 | |
| 62 | def generate_json_input_fn(trainer_type): |
| 63 | """Generate ServingInputReceiver function for testing. |
| 64 | |
| 65 | Args: |
| 66 | trainer_type: spam / component |
| 67 | |
| 68 | Returns: |
| 69 | ServingInputReceiver function to feed into exporter. |
| 70 | """ |
| 71 | feature_spec = { |
| 72 | INPUT_TYPE_MAP[trainer_type]['key']: |
| 73 | tf.io.FixedLenFeature(INPUT_TYPE_MAP[trainer_type]['shape'], tf.float32) |
| 74 | } |
| 75 | return tf.estimator.export.build_parsing_serving_input_receiver_fn( |
| 76 | feature_spec) |
| 77 | |
| 78 | |
| 79 | def train_and_evaluate_model(config, hparams): |
| 80 | """Runs the local training job given provided command line arguments. |
| 81 | |
| 82 | Args: |
| 83 | config: RunConfig object |
| 84 | hparams: dictionary passed by command line arguments |
| 85 | |
| 86 | """ |
| 87 | |
| 88 | if hparams['train_file']: |
| 89 | with open(hparams['train_file']) as f: |
| 90 | if hparams['trainer_type'] == 'spam': |
| 91 | contents, labels, _ = train_ml_helpers.spam_from_file(f) |
| 92 | else: |
| 93 | contents, labels = train_ml_helpers.component_from_file(f) |
| 94 | else: |
| 95 | contents, labels = dataset.fetch_training_data( |
| 96 | hparams['gcs_bucket'], hparams['gcs_prefix'], hparams['trainer_type']) |
| 97 | |
| 98 | logger.info('Training data received. Len: %d' % len(contents)) |
| 99 | |
| 100 | # Generate features and targets from extracted contents and labels. |
| 101 | if hparams['trainer_type'] == 'spam': |
| 102 | features, targets = train_ml_helpers \ |
| 103 | .transform_spam_csv_to_features(contents, labels) |
| 104 | else: |
| 105 | #top_list = top_words.make_top_words_list(contents, hparams['job_dir']) |
| 106 | top_list = top_words.parse_words_from_content(contents) |
| 107 | features, targets, index_to_component = train_ml_helpers \ |
| 108 | .transform_component_csv_to_features(contents, labels, top_list) |
| 109 | |
| 110 | # Split training and testing set. |
| 111 | logger.info('Features generated') |
| 112 | features_train, features_test, targets_train, targets_test = train_test_split( |
| 113 | features, targets, test_size=0.2, random_state=42) |
| 114 | |
| 115 | # Generate TrainSpec and EvalSpec for train and evaluate. |
| 116 | estimator = model.build_estimator(config=config, |
| 117 | job_dir=hparams['job_dir'], |
| 118 | trainer_type=hparams['trainer_type'], |
| 119 | class_count=len(set(labels))) |
| 120 | exporter = tf.estimator.LatestExporter(name='saved_model', |
| 121 | serving_input_receiver_fn=generate_json_input_fn(hparams['trainer_type'])) |
| 122 | |
| 123 | train_spec = tf.estimator.TrainSpec( |
| 124 | input_fn=make_input_fn(hparams['trainer_type'], |
| 125 | features_train, targets_train, num_epochs=hparams['num_epochs'], |
| 126 | batch_size=hparams['train_batch_size']), |
| 127 | max_steps=hparams['train_steps']) |
| 128 | eval_spec = tf.estimator.EvalSpec( |
| 129 | input_fn=make_input_fn(hparams['trainer_type'], |
| 130 | features_test, targets_test, shuffle=False, |
| 131 | batch_size=hparams['eval_batch_size']), |
| 132 | exporters=exporter, steps=hparams['eval_steps']) |
| 133 | |
| 134 | if hparams['trainer_type'] == 'component': |
| 135 | store_component_conversion(hparams['job_dir'], index_to_component) |
| 136 | |
| 137 | result = tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) |
| 138 | logging.info(result) |
| 139 | |
| 140 | parsing_spec = tf.feature_column.make_parse_example_spec( |
| 141 | model.INPUT_COLUMNS[hparams['trainer_type']]) |
| 142 | serving_input_fn = ( |
| 143 | tf.estimator.export.build_parsing_serving_input_receiver_fn(parsing_spec)) |
| 144 | estimator.export_saved_model(hparams['job_dir'], serving_input_fn) |
| 145 | |
| 146 | |
| 147 | def store_component_conversion(job_dir, data): |
| 148 | logger.info('job_dir: %s' % job_dir) |
| 149 | |
| 150 | # Store component conversion locally. |
| 151 | paths = job_dir.split('/') |
| 152 | for y, _ in enumerate(list(range(1, len(paths))), 1): |
| 153 | if not os.path.exists("/".join(paths[:y+1])): |
| 154 | os.makedirs('/'.join(paths[:y+1])) |
| 155 | with open(job_dir + '/component_index.json', 'w') as f: |
| 156 | f.write(json.dumps(data)) |
| 157 | |
| 158 | |
| 159 | if __name__ == '__main__': |
| 160 | parser = argparse.ArgumentParser() |
| 161 | |
| 162 | # Input Arguments |
| 163 | parser.add_argument( |
| 164 | '--train-file', |
| 165 | help='GCS or local path to training data', |
| 166 | ) |
| 167 | parser.add_argument( |
| 168 | '--gcs-bucket', |
| 169 | help='GCS bucket for training data.', |
| 170 | ) |
| 171 | parser.add_argument( |
| 172 | '--gcs-prefix', |
| 173 | help='Training data path prefix inside GCS bucket.', |
| 174 | ) |
| 175 | parser.add_argument( |
| 176 | '--num-epochs', |
| 177 | help="""\ |
| 178 | Maximum number of training data epochs on which to train. |
| 179 | If both --train-steps and --num-epochs are specified, |
| 180 | the training job will run for --num-epochs. |
| 181 | If unspecified will run for --train-steps.\ |
| 182 | """, |
| 183 | type=int, |
| 184 | ) |
| 185 | parser.add_argument( |
| 186 | '--train-batch-size', |
| 187 | help='Batch size for training steps', |
| 188 | type=int, |
| 189 | default=128 |
| 190 | ) |
| 191 | parser.add_argument( |
| 192 | '--eval-batch-size', |
| 193 | help='Batch size for evaluation steps', |
| 194 | type=int, |
| 195 | default=128 |
| 196 | ) |
| 197 | |
| 198 | # Training arguments |
| 199 | parser.add_argument( |
| 200 | '--job-dir', |
| 201 | help='GCS location to write checkpoints and export models', |
| 202 | required=True |
| 203 | ) |
| 204 | |
| 205 | # Logging arguments |
| 206 | parser.add_argument( |
| 207 | '--verbosity', |
| 208 | choices=[ |
| 209 | 'DEBUG', |
| 210 | 'ERROR', |
| 211 | 'CRITICAL', |
| 212 | 'INFO', |
| 213 | 'WARNING' |
| 214 | ], |
| 215 | default='INFO', |
| 216 | ) |
| 217 | |
| 218 | # Input function arguments |
| 219 | parser.add_argument( |
| 220 | '--train-steps', |
| 221 | help="""\ |
| 222 | Steps to run the training job for. If --num-epochs is not specified, |
| 223 | this must be. Otherwise the training job will run indefinitely.\ |
| 224 | """, |
| 225 | type=int, |
| 226 | required=True |
| 227 | ) |
| 228 | parser.add_argument( |
| 229 | '--eval-steps', |
| 230 | help='Number of steps to run evalution for at each checkpoint', |
| 231 | default=100, |
| 232 | type=int |
| 233 | ) |
| 234 | parser.add_argument( |
| 235 | '--trainer-type', |
| 236 | help='Which trainer to use (spam or component)', |
| 237 | choices=['spam', 'component'], |
| 238 | required=True |
| 239 | ) |
| 240 | |
| 241 | args = parser.parse_args() |
| 242 | |
| 243 | logger = logging.getLogger() |
| 244 | logger.setLevel(getattr(logging, args.verbosity)) |
| 245 | |
| 246 | if not args.num_epochs: |
| 247 | args.num_epochs = args.train_steps |
| 248 | |
| 249 | # Set C++ Graph Execution level verbosity. |
| 250 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = str( |
| 251 | getattr(logging, args.verbosity) / 10) |
| 252 | |
| 253 | # Run the training job. |
| 254 | train_and_evaluate_model( |
| 255 | config=RunConfig(model_dir=args.job_dir), |
| 256 | hparams=vars(args)) |