-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtf_model.py
More file actions
99 lines (75 loc) · 3.01 KB
/
tf_model.py
File metadata and controls
99 lines (75 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 14 14:24:39 2019
@author: zjy
"""
import tensorflow as tf
import numpy as np
import os
tf.app.flags.DEFINE_string("model_dir", "model_dir",
help="dir where model checkpoints are saved")
tf.app.flags.DEFINE_string('export_dir', "export_dir",
help="Use half floats instead of full floats if True.")
tf.app.flags.DEFINE_string('model_version', "1",
help="the export model version")
# TODO: model version trail
FLAGS = tf.app.flags.FLAGS
class Model:
""" Model class """
def __init__(self, mode, params):
""" Initialization"""
self.mode = mode
self.params = params
self.addition_layer = AdditionLayer()
def __call__(self, features):
return self.addition_layer(features["x"], features["y"])
class AdditionLayer(tf.layers.Layer):
""" Model Layer class"""
def __init__(self):
super(AdditionLayer, self).__init__()
def call(self, x, y, *args, **kwargs):
return x + y
def model_fn(features, labels, mode, params):
# Create model and get output logits.
model = Model(features, mode)
results = model(features)
# prediction
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(tf.estimator.ModeKeys.PREDICT, predictions={"sum": results})
else:
global_step = tf.train.get_global_step()
return tf.estimator.EstimatorSpec(mode=mode, loss=tf.constant(0), train_op=tf.assign_add(global_step, 1))
def dummy_input_fn():
x_samples, y_samples = np.random.rand(10), np.random.rand(10)
dataset = tf.data.Dataset.from_tensor_slices((x_samples, y_samples))
def parse_fn(*tensors):
tensor_names = ['x', 'y']
dict_element = {k: v for k, v in zip(tensor_names, tensors)}
return dict_element, tensors[-1]
return dataset.map(parse_fn)
def export_input_fn():
def preprocess(float_plh):
# this part is only for emphasizing the difference between
# the raw placeholders and the maybe-preprocessed features
return float_plh
x_plh = tf.placeholder(dtype=tf.float32, shape=(None,), name='x')
y_plh = tf.placeholder(dtype=tf.float32, shape=(None,), name='y')
receiver_tensors = {"x": x_plh, "y": y_plh}
features = {"x": preprocess(x_plh), "y": preprocess(y_plh)}
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
def export(*args, **kwargs):
# create model using estimator api
model = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=FLAGS.model_dir,
params=FLAGS.flag_values_dict(),
config=None)
# run a dummy example and save model
model.train(input_fn=dummy_input_fn, steps=1)
# export model
export_dir = os.path.join(FLAGS.export_dir, FLAGS.model_version)
model.export_savedmodel(export_dir, export_input_fn)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run(main=export)