Skip to content

Commit d00975b

Browse files
Shajanshajandasan
authored andcommitted
Draft: Java API to use tf.function available on SavedModel. (#89)
Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]>
1 parent 86840e7 commit d00975b

File tree

2 files changed

+285
-0
lines changed

2 files changed

+285
-0
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
2121

2222
import com.google.protobuf.InvalidProtocolBufferException;
23+
import java.util.HashMap;
24+
import java.util.Map;
25+
import java.util.stream.Collectors;
2326
import org.bytedeco.javacpp.BytePointer;
2427
import org.bytedeco.javacpp.PointerPointer;
2528
import org.bytedeco.javacpp.PointerScope;
@@ -32,6 +35,7 @@
3235
import org.tensorflow.proto.framework.ConfigProto;
3336
import org.tensorflow.proto.framework.MetaGraphDef;
3437
import org.tensorflow.proto.framework.RunOptions;
38+
import org.tensorflow.proto.framework.SignatureDef;
3539

3640
/**
3741
* SavedModelBundle represents a model loaded from storage.
@@ -94,6 +98,101 @@ private Loader(String exportDir) {
9498
private RunOptions runOptions = null;
9599
}
96100

101+
/**
102+
* SignatureToNodeName finds the node names in the {@link Graph} corresponding to the
103+
* input / output parameters of a <a
104+
* href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a>
105+
*/
106+
public static final class SignatureToNodeName {
107+
108+
public SignatureToNodeName(SavedModelBundle savedModelBundle) {
109+
loadSignatures(savedModelBundle);
110+
}
111+
112+
/**
113+
* Given a tf.function signature name, find the node names corresponding
114+
* to the input arguments
115+
*
116+
* @param functionSignatureName tf.function signature name
117+
* @return a map from input arguments to node names in the {@link Graph}
118+
*/
119+
public Map<String, String> inputNameToNode(String functionSignatureName) {
120+
NameContainer nc = this.functionMap.get(functionSignatureName);
121+
return (nc == null) ? null : nc.inputNameToNode();
122+
}
123+
124+
/**
125+
* Given a tf.function signature name, find the node names corresponding
126+
* to the output arguments
127+
*
128+
* @param functionSignatureName tf.function signature name
129+
* @return a map from output arguments to node names in the {@link Graph}
130+
*/
131+
public Map<String, String> outputNameToNode(String functionSignatureName) {
132+
NameContainer nc = this.functionMap.get(functionSignatureName);
133+
return (nc == null) ? null : nc.outputNameToNode();
134+
}
135+
136+
/**
137+
* Given a tf.function signature name, find the method name
138+
*/
139+
public String methodName(String functionSignatureName) {
140+
NameContainer nc = this.functionMap.get(functionSignatureName);
141+
return (nc == null) ? null : nc.methodName();
142+
}
143+
144+
private void loadSignatures(SavedModelBundle savedModelBundle) {
145+
MetaGraphDef metaGraph = savedModelBundle.metaGraphDef();
146+
Map<String, SignatureDef> signatureMap = metaGraph.getSignatureDefMap();
147+
148+
// A saved model can contain multiple SignatureDef
149+
for (Map.Entry<String, SignatureDef> entry : signatureMap.entrySet()) {
150+
NameContainer nc = new NameContainer(entry.getValue());
151+
this.functionMap.put(entry.getKey(), nc);
152+
}
153+
}
154+
155+
private Map<String, NameContainer> functionMap = new HashMap<>();
156+
157+
private static final class NameContainer {
158+
NameContainer(SignatureDef sd) {
159+
this.inputNameToNodeName = sd.getInputsMap()
160+
.entrySet()
161+
.stream()
162+
.collect(Collectors.toMap(
163+
e -> e.getKey(),
164+
e -> e.getValue().getName()
165+
));
166+
167+
this.outputNameToNodeName = sd.getOutputsMap()
168+
.entrySet()
169+
.stream()
170+
.collect(Collectors.toMap(
171+
e -> e.getKey(),
172+
e -> e.getValue().getName()
173+
));
174+
175+
this.method = sd.getMethodName();
176+
}
177+
178+
public Map<String, String> inputNameToNode() {
179+
return this.inputNameToNodeName;
180+
}
181+
182+
public Map<String, String> outputNameToNode() {
183+
return this.outputNameToNodeName;
184+
}
185+
186+
public String methodName() {
187+
return this.method;
188+
}
189+
190+
private Map<String, String> inputNameToNodeName;
191+
private Map<String, String> outputNameToNodeName;
192+
private String method;
193+
}
194+
}
195+
97196
/**
98197
* Load a saved model from an export directory. The model that is being loaded should be created
99198
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
@@ -148,6 +247,34 @@ public Session session() {
148247
return session;
149248
}
150249

250+
/**
251+
* Returns the {@link SignatureToNodeName} translator for the model.
252+
*
253+
* @return SignatureToNodeName translator
254+
*/
255+
public SignatureToNodeName getSignatureToNodeName() {
256+
if (this.sigToNodeName == null) {
257+
// no need to lock, ok to create multiple instances
258+
this.sigToNodeName = new SignatureToNodeName(this);
259+
}
260+
return this.sigToNodeName;
261+
}
262+
263+
/**
264+
* Return a {@link TfFunction} corresponding to the function signature.
265+
*
266+
* <pre>{@code
267+
* TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
268+
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
269+
* }</pre>
270+
*
271+
* @param functionSignatureName name of the {@code SignatureDef} in the saved model.
272+
* @return TfFunction object that can be used to make calls to the tf.function
273+
*/
274+
public TfFunction function(String functionSignatureName) {
275+
return new TfFunction(functionSignatureName, this.getSignatureToNodeName(), this.session);
276+
}
277+
151278
/**
152279
* Releases resources (the {@link Graph} and {@link Session}) associated with the saved model
153280
* bundle.
@@ -161,6 +288,7 @@ public void close() {
161288
private final Graph graph;
162289
private final Session session;
163290
private final MetaGraphDef metaGraphDef;
291+
private SignatureToNodeName sigToNodeName;
164292

165293
private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef) {
166294
this.graph = graph;
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
* Copyright 2020 The TensorFlow Authors. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.tensorflow;
17+
18+
import com.google.protobuf.InvalidProtocolBufferException;
19+
20+
import java.util.List;
21+
import java.util.ListIterator;
22+
import java.util.HashMap;
23+
import java.util.Map;
24+
25+
/**
26+
* Invoke <a href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a>
27+
* defined in a {@link SavedModelBundle}.
28+
*
29+
* <pre>{@code
30+
* TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
31+
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
32+
* }</pre>
33+
*
34+
*/
35+
public class TfFunction {
36+
37+
public TfFunction(
38+
String functionSignatureName,
39+
SavedModelBundle.SignatureToNodeName nameToNode, Session session) {
40+
this.nameToNode = nameToNode;
41+
this.session = session;
42+
this.functionSignatureName = functionSignatureName;
43+
}
44+
45+
/**
46+
* Invokes a tf.function.
47+
* Caller is responsible for closing all Tensors.
48+
*
49+
* @param arguments map of input tensors
50+
* @return map of output tensors
51+
*/
52+
public Map<String, Tensor<?>> call(
53+
Map<String, Tensor<?>> arguments) throws IllegalArgumentException {
54+
55+
Session.Runner runner = this.session.runner();
56+
57+
Map<String, String> inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName);
58+
59+
if (inputToNode == null) {
60+
throw new IllegalArgumentException(
61+
String.format("Function [%s] is missing input", this.functionSignatureName));
62+
}
63+
64+
// Join arguments.key, inputToNodeName.key
65+
for (Map.Entry<String, String> entry: inputToNode.entrySet()) {
66+
String argName = entry.getKey();
67+
Tensor<?> tensor = arguments.get(argName);
68+
69+
if (tensor == null) {
70+
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName));
71+
}
72+
73+
// Node name in the tensorflow graph, corresponding to the tf.function argument
74+
runner = runner.feed(entry.getValue(), tensor);
75+
}
76+
77+
Map<String, String> outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName);
78+
if (outputToNode == null) {
79+
throw new IllegalArgumentException(
80+
String.format("Function [%] is missing output", this.functionSignatureName));
81+
}
82+
83+
for (String nodeName: outputToNode.values()) {
84+
// Node names corresponding to the return value
85+
runner = runner.fetch(nodeName);
86+
}
87+
88+
List<Tensor<?>> resultTensors = runner.run();
89+
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator();
90+
91+
Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>();
92+
93+
// Use the output names as present in the signature definition
94+
for (String nodeName: outputToNode.keySet()) {
95+
returnMap.put(nodeName, resultTensorIter.next());
96+
}
97+
98+
return returnMap;
99+
}
100+
101+
/**
102+
* Invokes a tf.function.
103+
* Caller is responsible for closing all Tensors.
104+
*
105+
* Throws IllegalArgumentException if there are multiple input or output parameters defined
106+
* in the tf.function
107+
*
108+
* @param tensor input tensor
109+
* @return output tensor
110+
*/
111+
public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException {
112+
Session.Runner runner = this.session.runner();
113+
114+
Map<String, String> inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName);
115+
116+
if (inputToNode == null) {
117+
throw new IllegalArgumentException(
118+
String.format("Function [%s] is missing input", this.functionSignatureName));
119+
}
120+
121+
if (inputToNode.size() != 1) {
122+
throw new IllegalArgumentException(
123+
String.format("Function [%s] requires multiple inputs", this.functionSignatureName));
124+
}
125+
126+
// Feed the single argument
127+
for (Map.Entry<String, String> entry: inputToNode.entrySet()) {
128+
// Node name in the tensorflow graph, corresponding to the tf.function argument
129+
runner = runner.feed(entry.getValue(), tensor);
130+
}
131+
132+
Map<String, String> outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName);
133+
if (outputToNode == null) {
134+
throw new IllegalArgumentException(
135+
String.format("Function [%] is missing output", this.functionSignatureName));
136+
}
137+
138+
if (outputToNode.size() != 1) {
139+
throw new IllegalArgumentException(
140+
String.format("Function [%s] has multiple outputs", this.functionSignatureName));
141+
}
142+
143+
// Fetch the single return tensor
144+
for (String nodeName: outputToNode.values()) {
145+
// Node names corresponding to the return value
146+
runner = runner.fetch(nodeName);
147+
}
148+
149+
List<Tensor<?>> resultTensors = runner.run();
150+
151+
return resultTensors.get(0);
152+
}
153+
154+
private final Session session;
155+
private final SavedModelBundle.SignatureToNodeName nameToNode;
156+
private final String functionSignatureName;
157+
}

0 commit comments

Comments
 (0)