Skip to content

Commit 39def9f

Browse files
authored
Merge pull request #312 from dskkato/ex/reg_savedmodel_tf25
Update regression_savedmodel example for tf2.5
2 parents 38c48e2 + 0682aef commit 39def9f

File tree

6 files changed

+65
-48
lines changed

6 files changed

+65
-48
lines changed

examples/regression_savedmodel.rs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,31 @@ fn main() -> Result<(), Box<dyn Error>> {
4646

4747
// Load the saved model exported by regression_savedmodel.py.
4848
let mut graph = Graph::new();
49-
let session = SavedModelBundle::load(
50-
&SessionOptions::new(),
51-
&["train", "serve"],
52-
&mut graph,
53-
export_dir,
54-
)?
55-
.session;
56-
let op_x = graph.operation_by_name_required("x")?;
57-
let op_y = graph.operation_by_name_required("y")?;
58-
let op_train = graph.operation_by_name_required("train")?;
59-
let op_w = graph.operation_by_name_required("w")?;
60-
let op_b = graph.operation_by_name_required("b")?;
49+
let bundle =
50+
SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?;
51+
let session = &bundle.session;
52+
53+
// train
54+
let train_signature = bundle.meta_graph_def().get_signature("train")?;
55+
let x_info = train_signature.get_input("x")?;
56+
let y_info = train_signature.get_input("y")?;
57+
let loss_info = train_signature.get_output("loss")?;
58+
let op_x = graph.operation_by_name_required(&x_info.name().name)?;
59+
let op_y = graph.operation_by_name_required(&y_info.name().name)?;
60+
let op_train = graph.operation_by_name_required(&loss_info.name().name)?;
61+
62+
// internal parameters
63+
let op_b = {
64+
let b_signature = bundle.meta_graph_def().get_signature("b")?;
65+
let b_info = b_signature.get_output("output")?;
66+
graph.operation_by_name_required(&b_info.name().name)?
67+
};
68+
69+
let op_w = {
70+
let w_signature = bundle.meta_graph_def().get_signature("w")?;
71+
let w_info = w_signature.get_output("output")?;
72+
graph.operation_by_name_required(&w_info.name().name)?
73+
};
6174

6275
// Train the model (e.g. for fine tuning).
6376
let mut train_step = SessionRunArgs::new();

examples/regression_savedmodel/assets/.gitkeep

Whitespace-only changes.
Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,47 @@
11
import tensorflow as tf
2-
from tensorflow.python.saved_model.builder import SavedModelBuilder
3-
from tensorflow.python.saved_model.signature_def_utils import build_signature_def
4-
from tensorflow.python.saved_model.signature_constants import REGRESS_METHOD_NAME
5-
from tensorflow.python.saved_model.tag_constants import TRAINING, SERVING
6-
from tensorflow.python.saved_model.utils import build_tensor_info
72

8-
x = tf.placeholder(tf.float32, name='x')
9-
y = tf.placeholder(tf.float32, name='y')
103

11-
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
12-
b = tf.Variable(tf.zeros([1]), name='b')
13-
y_hat = tf.add(w * x, b, name="y_hat")
4+
class LinearRegresstion(tf.Module):
5+
def __init__(self, name=None):
6+
super(LinearRegresstion, self).__init__(name=name)
7+
self.w = tf.Variable(tf.random.uniform([1], -1.0, 1.0), name='w')
8+
self.b = tf.Variable(tf.zeros([1]), name='b')
9+
self.optimizer = tf.keras.optimizers.SGD(0.5)
1410

15-
loss = tf.reduce_mean(tf.square(y_hat - y))
16-
optimizer = tf.train.GradientDescentOptimizer(0.5)
17-
train = optimizer.minimize(loss, name='train')
11+
@tf.function
12+
def __call__(self, x):
13+
y_hat = self.w * x + self.b
14+
return y_hat
1815

19-
init = tf.variables_initializer(tf.global_variables(), name='init')
16+
@tf.function
17+
def get_w(self):
18+
return {'output': self.w}
19+
20+
@tf.function
21+
def get_b(self):
22+
return {'output': self.b}
23+
24+
@tf.function
25+
def train(self, x, y):
26+
with tf.GradientTape() as tape:
27+
y_hat = self(x)
28+
loss = tf.reduce_mean(tf.square(y_hat - y))
29+
grads = tape.gradient(loss, self.trainable_variables)
30+
_ = self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
31+
return {'loss': loss}
32+
33+
34+
model = LinearRegresstion()
35+
36+
# Get concrete functions to generate signatures
37+
x = tf.TensorSpec([None], tf.float32, name='x')
38+
y = tf.TensorSpec([None], tf.float32, name='y')
39+
40+
train = model.train.get_concrete_function(x, y)
41+
w = model.get_w.get_concrete_function()
42+
b = model.get_b.get_concrete_function()
43+
44+
signatures = {'train': train, 'w': w, 'b': b}
2045

2146
directory = 'examples/regression_savedmodel'
22-
builder = SavedModelBuilder(directory)
23-
24-
with tf.Session(graph=tf.get_default_graph()) as sess:
25-
sess.run(init)
26-
27-
signature_inputs = {
28-
"x": build_tensor_info(x),
29-
"y": build_tensor_info(y)
30-
}
31-
signature_outputs = {
32-
"out": build_tensor_info(y_hat)
33-
}
34-
signature_def = build_signature_def(
35-
signature_inputs, signature_outputs,
36-
REGRESS_METHOD_NAME)
37-
builder.add_meta_graph_and_variables(
38-
sess, [TRAINING, SERVING],
39-
signature_def_map={
40-
REGRESS_METHOD_NAME: signature_def
41-
},
42-
assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS))
43-
builder.save(as_text=False)
47+
tf.saved_model.save(model, directory, signatures=signatures)
30.6 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)