blob: 823d0d120c76d1c978dd42fca8542fd910a72cbe [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2019 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4# Or at https://developers.google.com/open-source/licenses/bsd
5
6from __future__ import absolute_import
7
8import tensorflow as tf
9
10from trainer2.train_ml_helpers import COMPONENT_FEATURES
11from trainer2.train_ml_helpers import SPAM_FEATURE_HASHES
12
13# Important: we assume this list mirrors the output of GenerateFeaturesRaw.
14INPUT_COLUMNS = {'component': [
15 tf.feature_column.numeric_column(
16 key='word_features',
17 shape=(COMPONENT_FEATURES,)),
18 ],
19 'spam': [
20 tf.feature_column.numeric_column(
21 key='word_hashes',
22 shape=(SPAM_FEATURE_HASHES,)),
23 ]}
24
25def build_estimator(config, job_dir, trainer_type, class_count):
26 """Returns a tf.Estimator.
27
28 Args:
29 config: tf.contrib.learn.RunConfig defining the runtime environment for the
30 estimator (including model_dir).
31 Returns:
32 A LinearClassifier
33 """
34 return tf.estimator.DNNClassifier(
35 config=config,
36 model_dir=job_dir,
37 feature_columns=(INPUT_COLUMNS[trainer_type]),
38 hidden_units=[1024, 512, 256],
39 optimizer=tf.keras.optimizers.Adam(learning_rate=0.001,
40 beta_1=0.9,
41 beta_2=0.999,
42 epsilon=1e-08,
43 name='Adam'),
44 n_classes=class_count
45 )