diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java new file mode 100644 index 00000000000..c76b62d5486 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -0,0 +1,286 @@ +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow; + +import java.io.IOException; +import java.util.List; +import java.util.ListIterator; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import org.tensorflow.op.Ops; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; + +/** + * A graph that can be invoked as a single function, with an input and output signature. + * + *
A function can also invoke a + * tf.function + * defined in a {@link SavedModelBundle}. + * + *
{@code + * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName"); + * Map+ */ +public class ConcreteFunction implements AutoCloseable { + + /** + * Creates a function by building a new graph. + * + * The {@code functionBuilder} must initialize the function graph from the provided + * {@link Ops} instance and return a valid signature that will be used to feed the input tensors + * and fetch the output tensors on execution. + * + * The function will be the owner of the new graph and its resulting session. Therefore, + * the function must be enclosed properly with a try-with-resources block to guarantee that + * all native resources will be freed once the function is discarded. For example: + * + *> outputTensorMap = myFunction.call(inputTensorMap); + * }
{@code + * public class MyModel { + * + * public static Signature addTwo(Ops tf) { + * Placeholder+ * + * @param functionBuilder function builder + * @return the new function + */ + public static ConcreteFunction create(Functioninput = tf.placeholder(TFloat32.DTYPE); + * Add output = tf.math.add(input, tf.constant(2.0f)); + * return Signature.builder("addTwo").input("x", input).output("y", output).build(); + * } + * + * public static void main(String args[]) { + * try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo); + * Tensor x = TFloat32.scalarOf(2.0f)) { + * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); + * } + * } + * } + * }
{@code + * try (Graph g = new Graph()) { + * Placeholder+ * + * @param signature signature of the function to create + * @param graph a valid and initialized graph + * @return a new function + */ + public static ConcreteFunction create(Signature signature, Graph graph) { + return new ConcreteFunction(signature, graph, new Session(graph), Ownership.SESSION_ONLY); + } + + /** + * Create a function from a signature and a valid graph session. + * + * The function will not own the session nor its graph, meaning that their lifetime + * can extend beyond the scope of the function. Therefore the function does not need to be + * closed after its usage. For example: + * + *input = tf.placeholder(TFloat32.DTYPE); + * Add output = tf.math.add(input, tf.constant(2.0f)); + * Signature signature = Signature.builder().input("x", input).output("y", output).build(); + * + * try (ConcreteFunction f = ConcreteFunction.create(signature, g); + * Tensor x = TFloat32.scalarOf(2.0f)) { + * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); + * } + * // Graph g is still valid at this point + * } + * }
{@code + * try (Graph g = new Graph()) { + * Placeholder+ * + * @param signature signature of the function to create + * @param graph a valid session to an initialized graph + * @return a new function + */ + public static ConcreteFunction create(Signature signature, Session session) { + return new ConcreteFunction(signature, session.graph(), session, Ownership.NONE); + } + + /** + * Returns the signature of this function + */ + public Signature signature() { + return signature; + } + + /** + * Invokes a function. + * + *input = tf.placeholder(TFloat32.DTYPE); + * Add output = tf.math.add(input, tf.constant(2.0f)); + * Signature signature = Signature.builder().input("x", input).output("y", output).build(); + * + * try (Session s = new Session(g)) { + * // Auto-closing the function just as an example but this is not required since it has + * // no effect + * try (ConcreteFunction f = ConcreteFunction.create(signature, s); + * Tensor t = TFloat32.scalarOf(2.0f)) { + * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); + * } + * // Session s is still valid at this point + * } + * // Graph g is still valid at this point + * } + * }
Caller is responsible for closing all Tensors.
+ *
+ * @param tensor input tensor
+ * @return output tensor
+ */
+ public Map Caller is responsible for closing all Tensors.
+ *
+ * @param tensor input tensor
+ * @return output tensor
+ * @throws IllegalArgumentException if there are multiple input or output parameters defined
+ * in the function
+ */
+ public Tensor> call(Tensor> tensor) throws IllegalArgumentException {
+ final SignatureDef signatureDef = signature.asSignatureDef();
+
+ if (signatureDef.getInputsCount() != 1) {
+ throw new IllegalArgumentException(
+ String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName()));
+ }
+ String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName();
+
+ if (signatureDef.getOutputsCount() != 1) {
+ throw new IllegalArgumentException(
+ String.format("Function [%s] has multiple outputs", signatureDef.getMethodName()));
+ }
+ String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName();
+
+ return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0);
+ }
+
+ /**
+ * Export this function as a saved model.
+ *
+ * This method is convenient shortcut equivalent to
+ * {@code SavedModel.exporter(exportDir).withFunction(this).export()}
+ */
+ public void save(String exportDir) throws IOException {
+ SavedModelBundle.exporter(exportDir)
+ .withFunction(this)
+ .export();
+ }
+
+ /**
+ * Returns the session used to execute the graph when calling this function
+ *
+ * In general, a user does not need to handle directly the session of a function and rely
+ * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to
+ * the session might be necessary, as it allows more running options.
+ *
+ * @return the function session
+ */
+ public Session session() {
+ return session;
+ }
+
+ /**
+ * Returns the graph of this function
+ */
+ public Graph graph() {
+ return graph;
+ }
+
+ @Override
+ public void close() {
+ if (ownership != Ownership.NONE) {
+ session.close();
+ if (ownership == Ownership.GRAPH_AND_SESSION) {
+ graph.close();
+ }
+ }
+ }
+
+ private enum Ownership {
+ GRAPH_AND_SESSION, SESSION_ONLY, NONE;
+ }
+
+ private final Graph graph;
+ private final Session session;
+ private final Signature signature;
+ private final Ownership ownership;
+
+ ConcreteFunction(Signature signature, Graph graph, Session session, Ownership ownership) {
+ this.graph = graph;
+ this.session = session;
+ this.signature = signature;
+ this.ownership = ownership;
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
index 221268191fd..f365956dff1 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
@@ -43,8 +43,17 @@
import org.tensorflow.internal.c_api.TF_Output;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_WhileParams;
+import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.op.Op;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Constant;
+import org.tensorflow.op.core.NoOp;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.train.Restore;
+import org.tensorflow.op.train.Save;
import org.tensorflow.proto.framework.GraphDef;
+import org.tensorflow.proto.util.SaverDef;
+import org.tensorflow.types.TString;
/**
@@ -67,6 +76,11 @@ public Graph() {
this.nativeHandle = nativeHandle;
}
+ Graph(TF_Graph nativeHandle, SaverDef saverDef) {
+ this(nativeHandle);
+ this.saverDef = saverDef;
+ }
+
/**
* Release resources associated with the Graph.
*
@@ -402,9 +416,27 @@ public Output>[] whileLoop(
}
}
+ /**
+ * Return the {@link SaverDef} instance used to save the state of all variables present in
+ * this graph.
+ *
+ * Has no effect if {@code tags} is null or empty
+ *
* @param tags the tags identifying the specific MetaGraphDef to load.
*/
public Loader withTags(String... tags) {
- this.tags = tags;
+ if (tags != null && tags.length > 0) {
+ this.tags = tags;
+ }
return this;
}
@@ -93,104 +109,105 @@ private Loader(String exportDir) {
}
private String exportDir = null;
- private String[] tags = null;
+ private String[] tags = {DEFAULT_TAG};
private ConfigProto configProto = null;
private RunOptions runOptions = null;
}
- /**
- * SignatureToNodeName finds the node names in the {@link Graph} corresponding to the
- * input / output parameters of a tf.function
- */
- public static final class SignatureToNodeName {
-
- public SignatureToNodeName(SavedModelBundle savedModelBundle) {
- loadSignatures(savedModelBundle);
- }
+ /** Options for exporting a SavedModel. */
+ public static final class Exporter {
/**
- * Given a tf.function signature name, find the node names corresponding
- * to the input arguments
+ * Sets the set of tags that identify the specific graph in the saved model to save.
*
- * @param functionSignatureName tf.function signature name
- * @return a map from input arguments to node names in the {@link Graph}
+ * Therefore, all functions exported in a model should share the same session at the moment
+ * or an exception will be thrown.
+ *
+ * @param function a function carrying a signature and a valid session to the graph to be saved
+ * @return this object
+ * @throws IllegalArgumentException if a function with the same name has already been added to the model
*/
- public Map Caller is responsible for closing all returned Tensors.
+ *
+ * @param arguments list of input tensors, mapped by their signature name
+ * @return list of output tensors, mapped by the signature name
+ * @throws IllegalArgumentException if no function can be selected by default
*/
- public TfFunction function(String functionSignatureName) {
- return new TfFunction(functionSignatureName, this.getSignatureToNodeName(), this.session);
+ public Map Invoked from the native load method. Takes ownership of the handles.
*/
private static SavedModelBundle fromHandle(
- TF_Graph graphHandle, TF_Session sessionHandle, MetaGraphDef metaGraphDef) {
- Graph graph = new Graph(graphHandle);
- Session session = new Session(graph, sessionHandle);
- return new SavedModelBundle(graph, session, metaGraphDef);
+ final TF_Graph graphHandle, final TF_Session sessionHandle, MetaGraphDef metaGraphDef) {
+
+ final Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef());
+ final Session session = new Session(graph, sessionHandle);
+
+ // Create a separate function for each signature of the main graph.
+ // Note that the saved model will remain the owner of the graph and the session, meaning
+ // that the functions do not need to be closed by the user and if it does, it should have
+ // no effect.
+ final MapSavedModelBundle
with the configured options. */
public SavedModelBundle load() {
return SavedModelBundle.load(exportDir, tags, configProto, runOptions);
@@ -81,10 +93,14 @@ public Loader withConfigProto(ConfigProto configProto) {
/**
* Sets the set of tags that identify the specific graph in the saved model to load.
*
+ * Exporter
object for setting configuration options before actually
+ * saving the model.
+ *
+ * @param exportDir the directory path containing a saved model.
+ */
+ public static Exporter exporter(String exportDir) {
+ return new Exporter(exportDir);
+ }
+
/**
* Returns the MetaGraphDef
@@ -248,31 +277,54 @@ public Session session() {
}
/**
- * Returns the {@link SignatureToNodeName} translator for the model.
+ * Return a {@link ConcreteFunction} corresponding to the function signature.
+ *
+ * {@code
+ * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
+ * Map
*
- * @return SignatureToNodeName translator
+ * @param functionSignatureName name of the {@code SignatureDef} in the saved model.
+ * @return TfFunction object that can be used to make calls to the tf.function
+ * @throws IllegalArgumentException if {@code functionSignatureName} is not found in this
+ * saved model.
*/
- public SignatureToNodeName getSignatureToNodeName() {
- if (this.sigToNodeName == null) {
- // no need to lock, ok to create multiple instances
- this.sigToNodeName = new SignatureToNodeName(this);
+ public ConcreteFunction function(String functionSignatureName) {
+ ConcreteFunction function = functions.get(functionSignatureName);
+ if (function == null) {
+ throw new IllegalArgumentException(
+ String.format("Function with signature [%s] not found", functionSignatureName));
}
- return this.sigToNodeName;
+ return function;
}
/**
- * Return a {@link TfFunction} corresponding to the function signature.
+ * Invokes the default function directly from this model.
*
- * {@code
- * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
- * Map
+ * The default function selection is done based on the first of the following conditions that
+ * is true:
+ *
+ *
*
- * @param functionSignatureName name of the {@code SignatureDef} in the saved model.
- * @return TfFunction object that can be used to make calls to the tf.function
+ * {@code
- * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
- * Map
- *
- */
-public class TfFunction {
-
- public TfFunction(
- String functionSignatureName,
- SavedModelBundle.SignatureToNodeName nameToNode, Session session) {
- this.nameToNode = nameToNode;
- this.session = session;
- this.functionSignatureName = functionSignatureName;
- }
-
- /**
- * Invokes a tf.function.
- * Caller is responsible for closing all Tensors.
- *
- * @param arguments map of input tensors
- * @return map of output tensors
- */
- public Map