Project import generated by Copybara.

GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/tools/ml/trainer2/model.py b/tools/ml/trainer2/model.py
new file mode 100644
index 0000000..823d0d1
--- /dev/null
+++ b/tools/ml/trainer2/model.py
@@ -0,0 +1,45 @@
+# Copyright 2019 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+# Or at https://developers.google.com/open-source/licenses/bsd
+
+from __future__ import absolute_import
+
+import tensorflow as tf
+
+from trainer2.train_ml_helpers import COMPONENT_FEATURES
+from trainer2.train_ml_helpers import SPAM_FEATURE_HASHES
+
+# Important: we assume this list mirrors the output of GenerateFeaturesRaw.
+INPUT_COLUMNS = {'component': [
+                     tf.feature_column.numeric_column(
+                         key='word_features',
+                         shape=(COMPONENT_FEATURES,)),
+                 ],
+                 'spam': [
+                     tf.feature_column.numeric_column(
+                         key='word_hashes',
+                         shape=(SPAM_FEATURE_HASHES,)),
+                 ]}
+
+def build_estimator(config, job_dir, trainer_type, class_count):
+  """Returns a tf.Estimator.
+
+  Args:
+    config: tf.contrib.learn.RunConfig defining the runtime environment for the
+      estimator (including model_dir).
+  Returns:
+    A LinearClassifier
+  """
+  return tf.estimator.DNNClassifier(
+    config=config,
+    model_dir=job_dir,
+    feature_columns=(INPUT_COLUMNS[trainer_type]),
+    hidden_units=[1024, 512, 256],
+    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001,
+      beta_1=0.9,
+      beta_2=0.999,
+      epsilon=1e-08,
+      name='Adam'),
+    n_classes=class_count
+  )