blob: 2fa85807a01265f894ff18a134aa1ed1d7e741c0 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2019 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4# Or at https://developers.google.com/open-source/licenses/bsd
5
6from __future__ import absolute_import
7
8import argparse
9import json
10import logging
11import os
12
13import tensorflow as tf
14from tensorflow.estimator import RunConfig
15from sklearn.model_selection import train_test_split
16
17from trainer2 import dataset
18from trainer2 import model
19from trainer2 import top_words
20from trainer2 import train_ml_helpers
21from trainer2.train_ml_helpers import COMPONENT_FEATURES
22from trainer2.train_ml_helpers import SPAM_FEATURE_HASHES
23
24INPUT_TYPE_MAP = {
25 'component': {'key': 'word_features', 'shape': (COMPONENT_FEATURES,)},
26 'spam': {'key': 'word_hashes', 'shape': (SPAM_FEATURE_HASHES,)}
27}
28
29
30def make_input_fn(trainer_type, features, targets,
31 num_epochs=None, shuffle=True, batch_size=128):
32 """Generate input function for training and testing.
33
34 Args:
35 trainer_type: spam / component
36 features: an array of features shape like INPUT_TYPE_MAP
37 targets: an array of labels with the same length of features
38 num_epochs: training epochs
39 batch_size: dataset batch size
40
41 Returns:
42 input function to feed into TrainSpec and EvalSpec.
43 """
44 def _input_fn():
45 def gen():
46 """Generator function to format feature and target. """
47 for feature, target in zip(features, targets):
48 yield feature[INPUT_TYPE_MAP[trainer_type]['key']], target
49
50 data = tf.data.Dataset.from_generator(
51 gen, (tf.float64, tf.int32),
52 output_shapes=(INPUT_TYPE_MAP[trainer_type]['shape'], ()))
53 data = data.map(lambda x, y: ({INPUT_TYPE_MAP[trainer_type]['key']: x}, y))
54 if shuffle:
55 data = data.shuffle(buffer_size=batch_size * 10)
56 data = data.repeat(num_epochs).batch(batch_size)
57 return data
58
59 return _input_fn
60
61
62def generate_json_input_fn(trainer_type):
63 """Generate ServingInputReceiver function for testing.
64
65 Args:
66 trainer_type: spam / component
67
68 Returns:
69 ServingInputReceiver function to feed into exporter.
70 """
71 feature_spec = {
72 INPUT_TYPE_MAP[trainer_type]['key']:
73 tf.io.FixedLenFeature(INPUT_TYPE_MAP[trainer_type]['shape'], tf.float32)
74 }
75 return tf.estimator.export.build_parsing_serving_input_receiver_fn(
76 feature_spec)
77
78
79def train_and_evaluate_model(config, hparams):
80 """Runs the local training job given provided command line arguments.
81
82 Args:
83 config: RunConfig object
84 hparams: dictionary passed by command line arguments
85
86 """
87
88 if hparams['train_file']:
89 with open(hparams['train_file']) as f:
90 if hparams['trainer_type'] == 'spam':
91 contents, labels, _ = train_ml_helpers.spam_from_file(f)
92 else:
93 contents, labels = train_ml_helpers.component_from_file(f)
94 else:
95 contents, labels = dataset.fetch_training_data(
96 hparams['gcs_bucket'], hparams['gcs_prefix'], hparams['trainer_type'])
97
98 logger.info('Training data received. Len: %d' % len(contents))
99
100 # Generate features and targets from extracted contents and labels.
101 if hparams['trainer_type'] == 'spam':
102 features, targets = train_ml_helpers \
103 .transform_spam_csv_to_features(contents, labels)
104 else:
105 #top_list = top_words.make_top_words_list(contents, hparams['job_dir'])
106 top_list = top_words.parse_words_from_content(contents)
107 features, targets, index_to_component = train_ml_helpers \
108 .transform_component_csv_to_features(contents, labels, top_list)
109
110 # Split training and testing set.
111 logger.info('Features generated')
112 features_train, features_test, targets_train, targets_test = train_test_split(
113 features, targets, test_size=0.2, random_state=42)
114
115 # Generate TrainSpec and EvalSpec for train and evaluate.
116 estimator = model.build_estimator(config=config,
117 job_dir=hparams['job_dir'],
118 trainer_type=hparams['trainer_type'],
119 class_count=len(set(labels)))
120 exporter = tf.estimator.LatestExporter(name='saved_model',
121 serving_input_receiver_fn=generate_json_input_fn(hparams['trainer_type']))
122
123 train_spec = tf.estimator.TrainSpec(
124 input_fn=make_input_fn(hparams['trainer_type'],
125 features_train, targets_train, num_epochs=hparams['num_epochs'],
126 batch_size=hparams['train_batch_size']),
127 max_steps=hparams['train_steps'])
128 eval_spec = tf.estimator.EvalSpec(
129 input_fn=make_input_fn(hparams['trainer_type'],
130 features_test, targets_test, shuffle=False,
131 batch_size=hparams['eval_batch_size']),
132 exporters=exporter, steps=hparams['eval_steps'])
133
134 if hparams['trainer_type'] == 'component':
135 store_component_conversion(hparams['job_dir'], index_to_component)
136
137 result = tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
138 logging.info(result)
139
140 parsing_spec = tf.feature_column.make_parse_example_spec(
141 model.INPUT_COLUMNS[hparams['trainer_type']])
142 serving_input_fn = (
143 tf.estimator.export.build_parsing_serving_input_receiver_fn(parsing_spec))
144 estimator.export_saved_model(hparams['job_dir'], serving_input_fn)
145
146
147def store_component_conversion(job_dir, data):
148 logger.info('job_dir: %s' % job_dir)
149
150 # Store component conversion locally.
151 paths = job_dir.split('/')
152 for y, _ in enumerate(list(range(1, len(paths))), 1):
153 if not os.path.exists("/".join(paths[:y+1])):
154 os.makedirs('/'.join(paths[:y+1]))
155 with open(job_dir + '/component_index.json', 'w') as f:
156 f.write(json.dumps(data))
157
158
159if __name__ == '__main__':
160 parser = argparse.ArgumentParser()
161
162 # Input Arguments
163 parser.add_argument(
164 '--train-file',
165 help='GCS or local path to training data',
166 )
167 parser.add_argument(
168 '--gcs-bucket',
169 help='GCS bucket for training data.',
170 )
171 parser.add_argument(
172 '--gcs-prefix',
173 help='Training data path prefix inside GCS bucket.',
174 )
175 parser.add_argument(
176 '--num-epochs',
177 help="""\
178 Maximum number of training data epochs on which to train.
179 If both --train-steps and --num-epochs are specified,
180 the training job will run for --num-epochs.
181 If unspecified will run for --train-steps.\
182 """,
183 type=int,
184 )
185 parser.add_argument(
186 '--train-batch-size',
187 help='Batch size for training steps',
188 type=int,
189 default=128
190 )
191 parser.add_argument(
192 '--eval-batch-size',
193 help='Batch size for evaluation steps',
194 type=int,
195 default=128
196 )
197
198 # Training arguments
199 parser.add_argument(
200 '--job-dir',
201 help='GCS location to write checkpoints and export models',
202 required=True
203 )
204
205 # Logging arguments
206 parser.add_argument(
207 '--verbosity',
208 choices=[
209 'DEBUG',
210 'ERROR',
211 'CRITICAL',
212 'INFO',
213 'WARNING'
214 ],
215 default='INFO',
216 )
217
218 # Input function arguments
219 parser.add_argument(
220 '--train-steps',
221 help="""\
222 Steps to run the training job for. If --num-epochs is not specified,
223 this must be. Otherwise the training job will run indefinitely.\
224 """,
225 type=int,
226 required=True
227 )
228 parser.add_argument(
229 '--eval-steps',
230 help='Number of steps to run evalution for at each checkpoint',
231 default=100,
232 type=int
233 )
234 parser.add_argument(
235 '--trainer-type',
236 help='Which trainer to use (spam or component)',
237 choices=['spam', 'component'],
238 required=True
239 )
240
241 args = parser.parse_args()
242
243 logger = logging.getLogger()
244 logger.setLevel(getattr(logging, args.verbosity))
245
246 if not args.num_epochs:
247 args.num_epochs = args.train_steps
248
249 # Set C++ Graph Execution level verbosity.
250 os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(
251 getattr(logging, args.verbosity) / 10)
252
253 # Run the training job.
254 train_and_evaluate_model(
255 config=RunConfig(model_dir=args.job_dir),
256 hparams=vars(args))