blob: 3b627a91d65b499692c9ddd6feea06cc69e643d8 [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001# Copyright 2018 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file or at
4# https://developers.google.com/open-source/licenses/bsd
5
6from __future__ import absolute_import
7from __future__ import division
8from __future__ import print_function
9
10import numpy as np
11import tensorflow as tf
12
13from trainer.ml_helpers import COMPONENT_FEATURES
14from trainer.ml_helpers import SPAM_FEATURE_HASHES
15
16# Important: we assume this list mirrors the output of GenerateFeaturesRaw.
17INPUT_COLUMNS = {'component': [
18 tf.feature_column.numeric_column(
19 key='word_features',
20 shape=(COMPONENT_FEATURES,)),
21 ],
22 'spam': [
23 tf.feature_column.numeric_column(
24 key='word_hashes',
25 shape=(SPAM_FEATURE_HASHES,)),
26 ]}
27
28
29def build_estimator(config, trainer_type, class_count):
30 """Returns a tf.Estimator.
31
32 Args:
33 config: tf.contrib.learn.RunConfig defining the runtime environment for the
34 estimator (including model_dir).
35 Returns:
36 A LinearClassifier
37 """
38 return tf.contrib.learn.DNNClassifier(
39 config=config,
40 feature_columns=(INPUT_COLUMNS[trainer_type]),
41 hidden_units=[1024, 512, 256],
42 optimizer=tf.train.AdamOptimizer(learning_rate=0.001,
43 beta1=0.9,
44 beta2=0.999,
45 epsilon=1e-08,
46 use_locking=False,
47 name='Adam'),
48 n_classes=class_count
49 )
50
51
52def feature_list_to_dict(X, trainer_type):
53 """Converts an array of feature dicts into to one dict of
54 {feature_name: [feature_values]}.
55
56 Important: this assumes the ordering of X and INPUT_COLUMNS is the same.
57
58 Args:
59 X: an array of feature dicts
60 Returns:
61 A dictionary where each key is a feature name its value is a numpy array of
62 shape (len(X),).
63 """
64 feature_dict = {}
65
66 for feature_column in INPUT_COLUMNS[trainer_type]:
67 feature_dict[feature_column.name] = []
68
69 for instance in X:
70 for key in instance.keys():
71 feature_dict[key].append(instance[key])
72
73 for key in [f.name for f in INPUT_COLUMNS[trainer_type]]:
74 feature_dict[key] = np.array(feature_dict[key])
75
76 return feature_dict
77
78
79def generate_json_serving_input_fn(trainer_type):
80 def json_serving_input_fn():
81 """Build the serving inputs.
82
83 Returns:
84 An InputFnOps containing features with placeholders.
85 """
86 features_placeholders = {}
87 for column in INPUT_COLUMNS[trainer_type]:
88 name = '%s_placeholder' % column.name
89
90 # Special case non-scalar features.
91 if column.shape[0] > 1:
92 shape = [None, column.shape[0]]
93 else:
94 shape = [None]
95
96 placeholder = tf.placeholder(tf.float32, shape, name=name)
97 features_placeholders[column.name] = placeholder
98
99 labels = None # Unknown at serving time
100 return tf.contrib.learn.InputFnOps(features_placeholders, labels,
101 features_placeholders)
102
103 return json_serving_input_fn
104
105
106SERVING_FUNCTIONS = {
107 'JSON-component': generate_json_serving_input_fn('component'),
108 'JSON-spam': generate_json_serving_input_fn('spam')
109}