-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathregression_savedmodel.py
More file actions
52 lines (41 loc) · 1.59 KB
/
regression_savedmodel.py
File metadata and controls
52 lines (41 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
class LinearRegresstion(tf.Module):
def __init__(self, name=None):
super(LinearRegresstion, self).__init__(name=name)
self.w = tf.Variable(tf.random.uniform([1], -1.0, 1.0), name="w")
self.b = tf.Variable(tf.zeros([1]), name="b")
self.optimizer = tf.optimizers.SGD(0.5)
@tf.function
def __call__(self, x):
y_hat = self.w * x + self.b
return y_hat
@tf.function
def get_w(self):
return {"output": self.w}
@tf.function
def get_b(self):
return {"output": self.b}
@tf.function
def train(self, x, y):
with tf.GradientTape() as tape:
y_hat = self(x)
loss = tf.reduce_mean(tf.square(y_hat - y))
grads = tape.gradient(loss, self.trainable_variables)
_ = self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
return {"train": loss}
model = LinearRegresstion()
x = tf.TensorSpec([None], tf.float32, name="x")
y = tf.TensorSpec([None], tf.float32, name="y")
train = model.train.get_concrete_function(x, y)
w = model.get_w.get_concrete_function()
b = model.get_b.get_concrete_function()
directory = "examples/regression_savedmodel"
signatures = {"train": train, "w": w, "b": b}
tf.saved_model.save(model, directory, signatures=signatures)
# export graph info to TensorBoard
logdir = "logs/regression_savedmodel"
writer = tf.summary.create_file_writer(logdir)
with writer.as_default():
tf.summary.graph(train.graph)
tf.summary.graph(w.graph)
tf.summary.graph(b.graph)