Project import generated by Copybara.
GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/tools/ml/trainer/model.py b/tools/ml/trainer/model.py
new file mode 100644
index 0000000..3b627a9
--- /dev/null
+++ b/tools/ml/trainer/model.py
@@ -0,0 +1,109 @@
+# Copyright 2018 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from trainer.ml_helpers import COMPONENT_FEATURES
+from trainer.ml_helpers import SPAM_FEATURE_HASHES
+
+# Important: we assume this list mirrors the output of GenerateFeaturesRaw.
+INPUT_COLUMNS = {'component': [
+ tf.feature_column.numeric_column(
+ key='word_features',
+ shape=(COMPONENT_FEATURES,)),
+ ],
+ 'spam': [
+ tf.feature_column.numeric_column(
+ key='word_hashes',
+ shape=(SPAM_FEATURE_HASHES,)),
+ ]}
+
+
+def build_estimator(config, trainer_type, class_count):
+ """Returns a tf.Estimator.
+
+ Args:
+ config: tf.contrib.learn.RunConfig defining the runtime environment for the
+ estimator (including model_dir).
+ Returns:
+ A LinearClassifier
+ """
+ return tf.contrib.learn.DNNClassifier(
+ config=config,
+ feature_columns=(INPUT_COLUMNS[trainer_type]),
+ hidden_units=[1024, 512, 256],
+ optimizer=tf.train.AdamOptimizer(learning_rate=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-08,
+ use_locking=False,
+ name='Adam'),
+ n_classes=class_count
+ )
+
+
+def feature_list_to_dict(X, trainer_type):
+ """Converts an array of feature dicts into to one dict of
+ {feature_name: [feature_values]}.
+
+ Important: this assumes the ordering of X and INPUT_COLUMNS is the same.
+
+ Args:
+ X: an array of feature dicts
+ Returns:
+ A dictionary where each key is a feature name its value is a numpy array of
+ shape (len(X),).
+ """
+ feature_dict = {}
+
+ for feature_column in INPUT_COLUMNS[trainer_type]:
+ feature_dict[feature_column.name] = []
+
+ for instance in X:
+ for key in instance.keys():
+ feature_dict[key].append(instance[key])
+
+ for key in [f.name for f in INPUT_COLUMNS[trainer_type]]:
+ feature_dict[key] = np.array(feature_dict[key])
+
+ return feature_dict
+
+
+def generate_json_serving_input_fn(trainer_type):
+ def json_serving_input_fn():
+ """Build the serving inputs.
+
+ Returns:
+ An InputFnOps containing features with placeholders.
+ """
+ features_placeholders = {}
+ for column in INPUT_COLUMNS[trainer_type]:
+ name = '%s_placeholder' % column.name
+
+ # Special case non-scalar features.
+ if column.shape[0] > 1:
+ shape = [None, column.shape[0]]
+ else:
+ shape = [None]
+
+ placeholder = tf.placeholder(tf.float32, shape, name=name)
+ features_placeholders[column.name] = placeholder
+
+ labels = None # Unknown at serving time
+ return tf.contrib.learn.InputFnOps(features_placeholders, labels,
+ features_placeholders)
+
+ return json_serving_input_fn
+
+
+SERVING_FUNCTIONS = {
+ 'JSON-component': generate_json_serving_input_fn('component'),
+ 'JSON-spam': generate_json_serving_input_fn('spam')
+}