diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java new file mode 100644 index 00000000000..c76b62d5486 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -0,0 +1,286 @@ +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow; + +import java.io.IOException; +import java.util.List; +import java.util.ListIterator; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import org.tensorflow.op.Ops; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; + +/** + * A graph that can be invoked as a single function, with an input and output signature. + * + *

A function can also invoke a + * tf.function + * defined in a {@link SavedModelBundle}. + * + *

{@code
+ * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
+ * Map> outputTensorMap = myFunction.call(inputTensorMap);
+ * }
+ */ +public class ConcreteFunction implements AutoCloseable { + + /** + * Creates a function by building a new graph. + * + *

The {@code functionBuilder} must initialize the function graph from the provided + * {@link Ops} instance and return a valid signature that will be used to feed the input tensors + * and fetch the output tensors on execution. + * + *

The function will be the owner of the new graph and its resulting session. Therefore, + * the function must be enclosed properly with a try-with-resources block to guarantee that + * all native resources will be freed once the function is discarded. For example: + * + *

{@code
+   * public class MyModel {
+   *
+   *   public static Signature addTwo(Ops tf) {
+   *     Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *     Add output = tf.math.add(input, tf.constant(2.0f));
+   *     return Signature.builder("addTwo").input("x", input).output("y", output).build();
+   *   }
+   *
+   *   public static void main(String args[]) {
+   *     try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo);
+   *         Tensor x = TFloat32.scalarOf(2.0f)) {
+   *       assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
+   *     }
+   *   }
+   * }
+   * }
+ * + * @param functionBuilder function builder + * @return the new function + */ + public static ConcreteFunction create(Function functionBuilder) { + Graph graph = new Graph(); + try { + Ops tf = Ops.create(graph); + Signature signature = functionBuilder.apply(tf); + return new ConcreteFunction(signature, graph, new Session(graph), Ownership.GRAPH_AND_SESSION); + } catch (Exception e) { + graph.close(); + throw e; + } + } + + /** + * Create a function from a signature and an existing graph. + * + *

The function will keep the ownership of the session used to run the graph but not + * the graph itself, meaning that the lifetime of the latter can extend beyond the scope + * of the function. For example: + * + *

{@code
+   * try (Graph g = new Graph()) {
+   *   Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *   Add output = tf.math.add(input, tf.constant(2.0f));
+   *   Signature signature = Signature.builder().input("x", input).output("y", output).build();
+   *
+   *   try (ConcreteFunction f = ConcreteFunction.create(signature, g);
+   *       Tensor x = TFloat32.scalarOf(2.0f)) {
+   *     assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
+   *   }
+   *   // Graph g is still valid at this point
+   * }
+   * }
+ * + * @param signature signature of the function to create + * @param graph a valid and initialized graph + * @return a new function + */ + public static ConcreteFunction create(Signature signature, Graph graph) { + return new ConcreteFunction(signature, graph, new Session(graph), Ownership.SESSION_ONLY); + } + + /** + * Create a function from a signature and a valid graph session. + * + *

The function will not own the session nor its graph, meaning that their lifetime + * can extend beyond the scope of the function. Therefore the function does not need to be + * closed after its usage. For example: + * + *

{@code
+   * try (Graph g = new Graph()) {
+   *   Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *   Add output = tf.math.add(input, tf.constant(2.0f));
+   *   Signature signature = Signature.builder().input("x", input).output("y", output).build();
+   *
+   *   try (Session s = new Session(g)) {
+   *     // Auto-closing the function just as an example but this is not required since it has
+   *     // no effect
+   *     try (ConcreteFunction f = ConcreteFunction.create(signature, s);
+   *         Tensor t = TFloat32.scalarOf(2.0f)) {
+   *       assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
+   *     }
+   *     // Session s is still valid at this point
+   *   }
+   *   // Graph g is still valid at this point
+   * }
+   * }
+ * + * @param signature signature of the function to create + * @param graph a valid session to an initialized graph + * @return a new function + */ + public static ConcreteFunction create(Signature signature, Session session) { + return new ConcreteFunction(signature, session.graph(), session, Ownership.NONE); + } + + /** + * Returns the signature of this function + */ + public Signature signature() { + return signature; + } + + /** + * Invokes a function. + * + *

Caller is responsible for closing all Tensors. + * + * @param tensor input tensor + * @return output tensor + */ + public Map> call(Map> arguments) + throws IllegalArgumentException { + + final SignatureDef signatureDef = signature.asSignatureDef(); + final Session.Runner runner = session.runner(); + + signatureDef.getInputsMap().forEach((argName, t) -> { + Tensor tensor = arguments.get(argName); + if (tensor == null) { + throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); + } + runner.feed(t.getName(), tensor); + }); + + Map outputToNode = signatureDef.getOutputsMap(); + outputToNode.values().forEach(t -> runner.fetch(t.getName())); + + List> resultTensors = runner.run(); + try { + ListIterator> resultTensorIter = resultTensors.listIterator(); + Map> returnMap = new HashMap>(); + + // Use the output names as present in the signature definition + for (String nodeName: outputToNode.keySet()) { + returnMap.put(nodeName, resultTensorIter.next()); + } + return returnMap; + + } catch (Exception e) { + // Release tensors before throwing exception + for (Tensor t : resultTensors) { + t.close(); + } + throw e; + } + } + + /** + * Invokes a function with a single input and output. + * + *

Caller is responsible for closing all Tensors. + * + * @param tensor input tensor + * @return output tensor + * @throws IllegalArgumentException if there are multiple input or output parameters defined + * in the function + */ + public Tensor call(Tensor tensor) throws IllegalArgumentException { + final SignatureDef signatureDef = signature.asSignatureDef(); + + if (signatureDef.getInputsCount() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); + } + String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName(); + + if (signatureDef.getOutputsCount() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); + } + String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName(); + + return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); + } + + /** + * Export this function as a saved model. + * + *

This method is convenient shortcut equivalent to + * {@code SavedModel.exporter(exportDir).withFunction(this).export()} + */ + public void save(String exportDir) throws IOException { + SavedModelBundle.exporter(exportDir) + .withFunction(this) + .export(); + } + + /** + * Returns the session used to execute the graph when calling this function + * + *

In general, a user does not need to handle directly the session of a function and rely + * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to + * the session might be necessary, as it allows more running options. + * + * @return the function session + */ + public Session session() { + return session; + } + + /** + * Returns the graph of this function + */ + public Graph graph() { + return graph; + } + + @Override + public void close() { + if (ownership != Ownership.NONE) { + session.close(); + if (ownership == Ownership.GRAPH_AND_SESSION) { + graph.close(); + } + } + } + + private enum Ownership { + GRAPH_AND_SESSION, SESSION_ONLY, NONE; + } + + private final Graph graph; + private final Session session; + private final Signature signature; + private final Ownership ownership; + + ConcreteFunction(Signature signature, Graph graph, Session session, Ownership ownership) { + this.graph = graph; + this.session = session; + this.signature = signature; + this.ownership = ownership; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 221268191fd..f365956dff1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -43,8 +43,17 @@ import org.tensorflow.internal.c_api.TF_Output; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_WhileParams; +import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.NoOp; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.train.Restore; +import org.tensorflow.op.train.Save; import org.tensorflow.proto.framework.GraphDef; +import org.tensorflow.proto.util.SaverDef; +import org.tensorflow.types.TString; /** @@ -67,6 +76,11 @@ public Graph() { this.nativeHandle = nativeHandle; } + Graph(TF_Graph nativeHandle, SaverDef saverDef) { + this(nativeHandle); + this.saverDef = saverDef; + } + /** * Release resources associated with the Graph. * @@ -402,9 +416,27 @@ public Output[] whileLoop( } } + /** + * Return the {@link SaverDef} instance used to save the state of all variables present in + * this graph. + * + *

On the first call of this method, all nodes necessary to save and restore the state of the + * variables are added to the graph. Consequently, any variables that are added to the graph after + * this call could not be saved nor restored using this {@link SaverDef}. + * + * @return a {@link SaverDef} instance + */ + synchronized SaverDef saverDef() { + if (saverDef == null) { + saverDef = addVariableSaver(this); + } + return saverDef; + } + private final Object nativeHandleLock = new Object(); private TF_Graph nativeHandle; private int refcount = 0; + private SaverDef saverDef; private final List initializers = new ArrayList<>(); @@ -726,6 +758,53 @@ private static Object[] whileLoop( } } + private static SaverDef addVariableSaver(Graph graph) { + Ops tf = Ops.create(graph).withSubScope("save"); + + List varNames = new ArrayList<>(); + List> varOutputs = new ArrayList<>(); + List> varTypes = new ArrayList<>(); + + for (Iterator iter = graph.operations(); iter.hasNext();) { + Operation op = iter.next(); + if (op.type().equals("VariableV2")) { + varNames.add(op.name()); + varOutputs.add(op.output(0)); + varTypes.add(op.output(0).dataType()); + } + } + + // FIXME Need an easier way to initialize an NdArray from a list + String[] tmp = new String[varNames.size()]; + Constant varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp))); + Operand varSlices = tf.zerosLike(varNamesTensor); + + Placeholder saveFilename = tf.placeholder(TString.DTYPE); + Save saveVariables = tf.train.save( + saveFilename, + varNamesTensor, + varSlices, + varOutputs + ); + Restore restoreVariables = tf.train.restore( + saveFilename, + varNamesTensor, + varSlices, + varTypes + ); + List restoreOps = new ArrayList<>(varOutputs.size()); + for (int i = 0; i < varOutputs.size(); ++i) { + restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i))); + } + NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp(); + + return SaverDef.newBuilder() + .setFilenameTensorName(saveFilename.op().name()) + .setSaveTensorName(saveVariables.op().name()) + .setRestoreOpName(restoreAll.op().name()) + .build(); + } + static { TensorFlow.init(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index b9fbbd9dfd9..dc4c58ed478 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -20,9 +20,16 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; @@ -34,8 +41,9 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.MetaGraphDef; +import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef; import org.tensorflow.proto.framework.RunOptions; -import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.SavedModel; /** * SavedModelBundle represents a model loaded from storage. @@ -47,8 +55,12 @@ * protocol buffer. */ public class SavedModelBundle implements AutoCloseable { + + public static final String DEFAULT_TAG = "serve"; + /** Options for loading a SavedModel. */ public static final class Loader { + /** Load a SavedModelBundle with the configured options. */ public SavedModelBundle load() { return SavedModelBundle.load(exportDir, tags, configProto, runOptions); @@ -81,10 +93,14 @@ public Loader withConfigProto(ConfigProto configProto) { /** * Sets the set of tags that identify the specific graph in the saved model to load. * + *

Has no effect if {@code tags} is null or empty + * * @param tags the tags identifying the specific MetaGraphDef to load. */ public Loader withTags(String... tags) { - this.tags = tags; + if (tags != null && tags.length > 0) { + this.tags = tags; + } return this; } @@ -93,104 +109,105 @@ private Loader(String exportDir) { } private String exportDir = null; - private String[] tags = null; + private String[] tags = {DEFAULT_TAG}; private ConfigProto configProto = null; private RunOptions runOptions = null; } - /** - * SignatureToNodeName finds the node names in the {@link Graph} corresponding to the - * input / output parameters of a tf.function - */ - public static final class SignatureToNodeName { - - public SignatureToNodeName(SavedModelBundle savedModelBundle) { - loadSignatures(savedModelBundle); - } + /** Options for exporting a SavedModel. */ + public static final class Exporter { /** - * Given a tf.function signature name, find the node names corresponding - * to the input arguments + * Sets the set of tags that identify the specific graph in the saved model to save. * - * @param functionSignatureName tf.function signature name - * @return a map from input arguments to node names in the {@link Graph} + *

Note that only one graph per model can be saved right now using this API. + * + * @param tags the tags identifying the specific MetaGraphDef to save. + * @return this object */ - public Map inputNameToNode(String functionSignatureName) { - NameContainer nc = this.functionMap.get(functionSignatureName); - return (nc == null) ? null : nc.inputNameToNode(); + public Exporter withTags(String... tags) { + this.tags.addAll(Arrays.asList(tags)); + return this; } /** - * Given a tf.function signature name, find the node names corresponding - * to the output arguments + * Save a concrete function of this model. * - * @param functionSignatureName tf.function signature name - * @return a map from output arguments to node names in the {@link Graph} + *

The concrete function carries a signature (i.e. a list of user-friendly input and outputs + * names to a graph) and a valid session to a graph to be saved in the model. + * + *

Note:Eventually, TensorFlow for Java will support the export of functions objects like + * the Python API does but right now, only session-centric models are supported (i.e. models that + * has a single main graph and one or more signatures). These models are compatible with those + * exported by TensorFlow 1.x or by TensorFlow 2.x estimators. + * + *

Therefore, all functions exported in a model should share the same session at the moment + * or an exception will be thrown. + * + * @param function a function carrying a signature and a valid session to the graph to be saved + * @return this object + * @throws IllegalArgumentException if a function with the same name has already been added to the model */ - public Map outputNameToNode(String functionSignatureName) { - NameContainer nc = this.functionMap.get(functionSignatureName); - return (nc == null) ? null : nc.outputNameToNode(); + public Exporter withFunction(ConcreteFunction function) { + Signature signature = function.signature(); + if (functions.containsKey(signature.name())) { + throw new IllegalArgumentException("Function \"" + signature.name() + "\" was already added to the model"); + } + functions.put(signature.name(), function); + if (session == null) { + session = function.session(); + } else if (session != function.session()) { + throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); + } + metaGraphDefBuilder.putSignatureDef(signature.name(), signature.asSignatureDef()); + return this; } /** - * Given a tf.function signature name, find the method name + * Save the model into the export directory. + * + * @throws IOException if saved model or variable state can be written on disk */ - public String methodName(String functionSignatureName) { - NameContainer nc = this.functionMap.get(functionSignatureName); - return (nc == null) ? null : nc.methodName(); - } - - private void loadSignatures(SavedModelBundle savedModelBundle) { - MetaGraphDef metaGraph = savedModelBundle.metaGraphDef(); - Map signatureMap = metaGraph.getSignatureDefMap(); - - // A saved model can contain multiple SignatureDef - for (Map.Entry entry : signatureMap.entrySet()) { - NameContainer nc = new NameContainer(entry.getValue()); - this.functionMap.put(entry.getKey(), nc); + public void export() throws IOException { + if (functions.isEmpty() || session == null) { + throw new IllegalStateException("Model should contain at least one valid function"); + } + if (tags.isEmpty()) { + tags.add(DEFAULT_TAG); + } + // It is imperative to retrieve the graphDef after the saverDef, as the former might add + // new ops to the graph for saving and restoring the variables. + Graph graph = session.graph(); + MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder + .setSaverDef(graph.saverDef()) + .setGraphDef(graph.toGraphDef()) + .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(tags)); + functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); + + // Make sure saved model directories exist + Path variableDir = Paths.get(exportDir, "variables"); + variableDir.toFile().mkdirs(); + + // Save the variables state + session.save(variableDir.resolve("variables").toString()); + + // Save the graph + SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build(); + try (OutputStream file = + new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) { + savedModelDef.writeTo(file); } } - private Map functionMap = new HashMap<>(); - - private static final class NameContainer { - NameContainer(SignatureDef sd) { - this.inputNameToNodeName = sd.getInputsMap() - .entrySet() - .stream() - .collect(Collectors.toMap( - e -> e.getKey(), - e -> e.getValue().getName() - )); - - this.outputNameToNodeName = sd.getOutputsMap() - .entrySet() - .stream() - .collect(Collectors.toMap( - e -> e.getKey(), - e -> e.getValue().getName() - )); - - this.method = sd.getMethodName(); - } - - public Map inputNameToNode() { - return this.inputNameToNodeName; - } - - public Map outputNameToNode() { - return this.outputNameToNodeName; - } - - public String methodName() { - return this.method; - } - - private Map inputNameToNodeName; - private Map outputNameToNodeName; - private String method; + Exporter(String exportDir) { + this.exportDir = exportDir; } + + private final String exportDir; + private final List tags = new ArrayList<>(); + private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); + private final Map functions = new HashMap<>(); + private Session session; } /** @@ -224,6 +241,18 @@ public static Loader loader(String exportDir) { return new Loader(exportDir); } + /** + * Export a saved model. + * + *

Returns a Exporter object for setting configuration options before actually + * saving the model. + * + * @param exportDir the directory path containing a saved model. + */ + public static Exporter exporter(String exportDir) { + return new Exporter(exportDir); + } + /** * Returns the MetaGraphDef @@ -248,31 +277,54 @@ public Session session() { } /** - * Returns the {@link SignatureToNodeName} translator for the model. + * Return a {@link ConcreteFunction} corresponding to the function signature. + * + *

{@code
+   * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
+   * Map> outputTensorMap = myFunction.call(session, inputTensorMap);
+   * }
* - * @return SignatureToNodeName translator + * @param functionSignatureName name of the {@code SignatureDef} in the saved model. + * @return TfFunction object that can be used to make calls to the tf.function + * @throws IllegalArgumentException if {@code functionSignatureName} is not found in this + * saved model. */ - public SignatureToNodeName getSignatureToNodeName() { - if (this.sigToNodeName == null) { - // no need to lock, ok to create multiple instances - this.sigToNodeName = new SignatureToNodeName(this); + public ConcreteFunction function(String functionSignatureName) { + ConcreteFunction function = functions.get(functionSignatureName); + if (function == null) { + throw new IllegalArgumentException( + String.format("Function with signature [%s] not found", functionSignatureName)); } - return this.sigToNodeName; + return function; } /** - * Return a {@link TfFunction} corresponding to the function signature. + * Invokes the default function directly from this model. * - *
{@code
-   * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
-   * Map> outputTensorMap = myFunction.call(inputTensorMap);
-   * }
+ *

The default function selection is done based on the first of the following conditions that + * is true: + *

* - * @param functionSignatureName name of the {@code SignatureDef} in the saved model. - * @return TfFunction object that can be used to make calls to the tf.function + *

Caller is responsible for closing all returned Tensors. + * + * @param arguments list of input tensors, mapped by their signature name + * @return list of output tensors, mapped by the signature name + * @throws IllegalArgumentException if no function can be selected by default */ - public TfFunction function(String functionSignatureName) { - return new TfFunction(functionSignatureName, this.getSignatureToNodeName(), this.session); + public Map> call(Map> arguments) { + ConcreteFunction function = null; + if (functions.size() == 1) { + function = functions.values().iterator().next(); + } else { + function = functions.get(Signature.DEFAULT_NAME); + } + if (function == null) { + throw new IllegalArgumentException("Cannot elect a default function for this model"); + } + return function.call(arguments); } /** @@ -288,12 +340,13 @@ public void close() { private final Graph graph; private final Session session; private final MetaGraphDef metaGraphDef; - private SignatureToNodeName sigToNodeName; + private final Map functions; - private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef) { + private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) { this.graph = graph; this.session = session; this.metaGraphDef = metaGraphDef; + this.functions = functions; } /** @@ -303,10 +356,21 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef *

Invoked from the native load method. Takes ownership of the handles. */ private static SavedModelBundle fromHandle( - TF_Graph graphHandle, TF_Session sessionHandle, MetaGraphDef metaGraphDef) { - Graph graph = new Graph(graphHandle); - Session session = new Session(graph, sessionHandle); - return new SavedModelBundle(graph, session, metaGraphDef); + final TF_Graph graphHandle, final TF_Session sessionHandle, MetaGraphDef metaGraphDef) { + + final Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef()); + final Session session = new Session(graph, sessionHandle); + + // Create a separate function for each signature of the main graph. + // Note that the saved model will remain the owner of the graph and the session, meaning + // that the functions do not need to be closed by the user and if it does, it should have + // no effect. + final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); + metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { + Signature signature = new Signature(signatureName, signatureDef); + functions.put(signatureName, ConcreteFunction.create(signature, session)); + }); + return new SavedModelBundle(graph, session, metaGraphDef, functions); } private static SavedModelBundle load( diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 0676ce8ec4e..6f6ee4e136f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -36,6 +36,8 @@ import java.util.ArrayList; import java.util.List; +import org.tensorflow.proto.util.SaverDef; +import org.tensorflow.types.TString; import static org.tensorflow.Graph.resolveOutputs; import static org.tensorflow.internal.c_api.global.tensorflow.*; @@ -444,6 +446,24 @@ public void run(Op op) { runner().addTarget(op.op()).run(); } + /** + * Saves the actual state of the variables of this session's graph. + * + *

{@code prefix} is a path where the files containing the variables state will be saved, + * followed by a prefix for naming these files. For example, if {@code prefix} is set to + * mymodel/myvariables/variables, then the generated files will be located under + * mymodel/myvariables and named variables.data-*-of-* + * + * @param prefix + */ + public void save(String prefix) { + SaverDef saverDef = graph.saverDef(); + runner() + .addTarget(saverDef.getSaveTensorName()) + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) + .run(); + } + /** * Output tensors and metadata obtained when executing a session. * @@ -463,6 +483,10 @@ public static final class Run { public RunMetadata metadata; } + Graph graph() { + return graph; + } + private final Graph graph; private final Graph.Reference graphRef; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java new file mode 100644 index 00000000000..479a7c168ac --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -0,0 +1,159 @@ +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow; + +import java.util.Set; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.proto.framework.TensorShapeProto; +import org.tensorflow.proto.framework.TensorShapeProto.Dim; + +/** + * Describe the inputs and outputs of an executable entity, such as a {@link ConcreteFunction}, among + * other useful metadata. + */ +public class Signature { + + /** The default signature name, when not provided */ + public static final String DEFAULT_NAME = "serving_default"; + + /** + * Builds a new function signature. + */ + public static class Builder { + + /** + * Sets the name of this signature. + * + *

When not set explicitly, the default value is {@link #DEFAULT_NAME}. + * + * @param name signature name + * @return this builder + */ + public Builder name(String name) { + this.name = name; + return this; + } + + /** + * Register a tensor as an input of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor + * @return this builder + */ + public Builder input(String inputName, Operand input) { + signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput())); + return this; + } + + /** + * Register a tensor as an output of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor + * @return this builder + */ + public Builder output(String outputName, Operand output) { + signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput())); + return this; + } + + /** + * Provide extensible name information enabling third-party users to mark a signature as + * supporting a particular method + * + * @param methodName method name + * @return this builder + */ + public Builder methodName(String methodName) { + signatureBuilder.setMethodName(methodName); + return this; + } + + /** + * Returns a signature from the provided data. + */ + public Signature build() { + return new Signature(name, signatureBuilder.build()); + } + + private static TensorInfo toTensorInfo(Output operand) { + Shape shape = operand.shape(); + TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); + for (int i = 0; i < shape.numDimensions(); ++i) { + tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i))); + } + return TensorInfo.newBuilder() + .setDtype(DataType.forNumber(operand.dataType().nativeCode())) + .setTensorShape(tensorShapeBuilder) + .setName(operand.op().name() + ":" + operand.index()) + .build(); + } + + private String name = DEFAULT_NAME; + private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); + } + + /** + * Returns a new builder for creating a signature + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Return the name of this signature + */ + public String name() { + return name; + } + + /** + * Returns the method name of this signature (e.g. as exposed by TF serving) + */ + public String methodName() { + return signatureDef.getMethodName(); + } + + /** + * Returns the names of the inputs in this signature + */ + public Set inputNames() { + return signatureDef.getInputsMap().keySet(); + } + + /** + * Returns the names of the outputs in this signature + */ + public Set outputNames() { + return signatureDef.getOutputsMap().keySet(); + } + + SignatureDef asSignatureDef() { + return signatureDef; + } + + private final String name; + private final SignatureDef signatureDef; + + Signature(String name, SignatureDef signatureDef) { + this.name = name; + this.signatureDef = signatureDef; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java deleted file mode 100644 index 5dc5a128898..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright 2020 The TensorFlow Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow; - -import com.google.protobuf.InvalidProtocolBufferException; - -import java.util.List; -import java.util.ListIterator; -import java.util.HashMap; -import java.util.Map; - -/** - * Invoke tf.function - * defined in a {@link SavedModelBundle}. - * - *

{@code
- * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
- * Map> outputTensorMap = myFunction.call(inputTensorMap);
- * }
- * - */ -public class TfFunction { - - public TfFunction( - String functionSignatureName, - SavedModelBundle.SignatureToNodeName nameToNode, Session session) { - this.nameToNode = nameToNode; - this.session = session; - this.functionSignatureName = functionSignatureName; - } - - /** - * Invokes a tf.function. - * Caller is responsible for closing all Tensors. - * - * @param arguments map of input tensors - * @return map of output tensors - */ - public Map> call( - Map> arguments) throws IllegalArgumentException { - - Session.Runner runner = this.session.runner(); - - Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); - - if (inputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%s] is missing input", this.functionSignatureName)); - } - - // Join arguments.key, inputToNodeName.key - for (Map.Entry entry: inputToNode.entrySet()) { - String argName = entry.getKey(); - Tensor tensor = arguments.get(argName); - - if (tensor == null) { - throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); - } - - // Node name in the tensorflow graph, corresponding to the tf.function argument - runner = runner.feed(entry.getValue(), tensor); - } - - Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); - if (outputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%] is missing output", this.functionSignatureName)); - } - - for (String nodeName: outputToNode.values()) { - // Node names corresponding to the return value - runner = runner.fetch(nodeName); - } - - List> resultTensors = runner.run(); - ListIterator> resultTensorIter = resultTensors.listIterator(); - - Map> returnMap = new HashMap>(); - - // Use the output names as present in the signature definition - for (String nodeName: outputToNode.keySet()) { - returnMap.put(nodeName, resultTensorIter.next()); - } - - return returnMap; - } - - /** - * Invokes a tf.function. - * Caller is responsible for closing all Tensors. - * - * Throws IllegalArgumentException if there are multiple input or output parameters defined - * in the tf.function - * - * @param tensor input tensor - * @return output tensor - */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { - Session.Runner runner = this.session.runner(); - - Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); - - if (inputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%s] is missing input", this.functionSignatureName)); - } - - if (inputToNode.size() != 1) { - throw new IllegalArgumentException( - String.format("Function [%s] requires multiple inputs", this.functionSignatureName)); - } - - // Feed the single argument - for (Map.Entry entry: inputToNode.entrySet()) { - // Node name in the tensorflow graph, corresponding to the tf.function argument - runner = runner.feed(entry.getValue(), tensor); - } - - Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); - if (outputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%] is missing output", this.functionSignatureName)); - } - - if (outputToNode.size() != 1) { - throw new IllegalArgumentException( - String.format("Function [%s] has multiple outputs", this.functionSignatureName)); - } - - // Fetch the single return tensor - for (String nodeName: outputToNode.values()) { - // Node names corresponding to the return value - runner = runner.fetch(nodeName); - } - - List> resultTensors = runner.run(); - - return resultTensors.get(0); - } - - private final Session session; - private final SavedModelBundle.SignatureToNodeName nameToNode; - private final String functionSignatureName; -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java new file mode 100644 index 00000000000..ec32c69c7f4 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -0,0 +1,122 @@ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TFloat32; + +public class ConcreteFunctionTest { + + private static Signature plusFive(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.DTYPE); + Add output = tf.math.add(input, tf.constant(5.0f)); + Init init = tf.init(); // for native resource management tests + return Signature.builder().name("plusFive").input("x", input).output("y", output).build(); + } + + private static Signature minusTwo(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.DTYPE); + Sub output = tf.math.sub(input, tf.constant(2.0f)); + return Signature.builder().name("minusTwo").input("x", input).output("y", output).build(); + } + + @Test + public void createFunction() { + try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + + @Test + public void createFunctionFromGraph() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + try (ConcreteFunction f = ConcreteFunction.create(signature, g); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + } + + @Test + public void createFunctionFromSession() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + try (Session s = new Session(g)) { + try (ConcreteFunction f = ConcreteFunction.create(signature, s); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + } + } + + @Test + public void chainFunctions() { + try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive); + ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(6.0f, f2.call(f1.call(x)).expect(TFloat32.DTYPE).data().getFloat()); + } + } + + @Test + public void closingFunctionReleaseAllResourcesItOwns() { + Graph g; + Session s; + try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive)) { + g = f.graph(); + s = f.session(); + } + try { + s.run("Add"); + fail(); + } catch (IllegalStateException e) { + // as expected + } + try { + g.toGraphDef(); + fail(); + } catch (IllegalStateException e) { + // as expected + } + } + + @Test + public void closingFunctionCreatedFromGraphOnlyReleaseResourcesItOwns() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + Session s; + try (ConcreteFunction f = ConcreteFunction.create(signature, g)) { + s = f.session(); + } + try { + s.run(Init.DEFAULT_NAME); + fail(); + } catch (IllegalStateException e) { + // as expected + } + g.toGraphDef(); // check that graph is still valid + } + } + + @Test + public void closingFunctionCreatedFromSessionDoesNotReleaseResources() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + try (Session s = new Session(g)) { + try (ConcreteFunction f = ConcreteFunction.create(signature, s)) { + } + s.run(Init.DEFAULT_NAME); // check that session is still valid + } + g.toGraphDef(); // check that graph is still valid + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 91c07e3f4b6..43e9888176f 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -15,20 +15,40 @@ package org.tensorflow; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.io.IOException; import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Collections; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Identity; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Sign; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunOptions; +import org.tensorflow.proto.framework.SavedModel; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.types.TFloat32; /** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ public class SavedModelBundleTest { + private static final float EPSILON = 1e-7f; private static final String SAVED_MODEL_PATH; static { try { @@ -72,13 +92,190 @@ public void loader() { } } + @Test + public void exportFunctionWithVariables() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + float reducedSum; + FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); + Shape xyShape = Shape.of(2, 3L); + try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { + // Init variable state by running the Init operation directly + f.session().run(Init.DEFAULT_NAME); + + // Call the graph and remember the result of computation for later + try (Tensor xTensor = TFloat32.tensorOf(xValue); + Tensor zTensor = f.call(xTensor).expect(TFloat32.DTYPE)) { + reducedSum = zTensor.data().getFloat(); + } + // Save/export the model (which is a single function in this case) + f.save(testFolder.toString()); + } + assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); + assertTrue(Files + .exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001")))); + assertTrue(Files.exists(testFolder.resolve("saved_model.pb"))); + + // Reload the model just saved and validate its data + try (SavedModelBundle savedModel = + SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) { + assertNotNull(savedModel.metaGraphDef()); + assertNotNull(savedModel.metaGraphDef().getSaverDef()); + assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); + assertEquals(Signature.DEFAULT_NAME, + savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); + + ConcreteFunction function = savedModel.function(Signature.DEFAULT_NAME); + assertNotNull(function); + + Signature signature = function.signature(); + assertNotNull(signature); + assertEquals(1, signature.inputNames().size()); + assertEquals("input", signature.inputNames().iterator().next()); + assertEquals(1, signature.outputNames().size()); + assertEquals("reducedSum", signature.outputNames().iterator().next()); + + SignatureDef signatureDef = signature.asSignatureDef(); + assertEquals(1, signatureDef.getInputsCount()); + assertEquals(1, signatureDef.getOutputsCount()); + + TensorInfo inputInfo = signatureDef.getInputsMap().get("input"); + assertNotNull(inputInfo); + assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount()); + for (int i = 0; i < xyShape.numDimensions(); ++i) { + assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize()); + } + + TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum"); + assertNotNull(outputInfo); + assertEquals(0, outputInfo.getTensorShape().getDimCount()); + + try (Tensor xTensor = TFloat32.tensorOf(xValue)) { + // Call the saved model function and make sure it returns the same result as before + try (Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + } + // Now call the same function directly from the model + try (Tensor zTensor = + savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum").expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + } + } + } + } + + @Test + public void exportMultipleFunctions() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + float reducedSum; + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, "identity"); + try (Session s = new Session(g); + ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { + f1.session().run(Init.DEFAULT_NAME); + try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { + reducedSum = t.data().getFloat(); + } + SavedModelBundle.exporter(testFolder.toString()) + .withFunction(f1) + .withFunction(f2) + .export(); + } + } + try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { + ConcreteFunction f1 = model.function(Signature.DEFAULT_NAME); + assertNotNull(f1); + try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, t.data().getFloat(), EPSILON); + } + ConcreteFunction f2 = model.function("identity"); + assertNotNull(f2); + try (Tensor x = TFloat32.scalarOf(10.0f); + Tensor t = f2.call(x).expect(TFloat32.DTYPE)) { + assertEquals(10.0f, t.data().getFloat(), 0.0f); + } + try { + model.function("NoSuchFunction"); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + } + + @Test + public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, "identity"); + try (ConcreteFunction f1 = ConcreteFunction.create(f1Signature, g); + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, g)) { + f1.session().run(Init.DEFAULT_NAME); + try { + SavedModelBundle.exporter(testFolder.toString()) + .withFunction(f1) + .withFunction(f2) + .export(); + fail(); + } catch (UnsupportedOperationException e) { + // as expected + } + } + } + } + + @Test + public void cannotExportMultipleFunctionsWithSameSignatureName() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, Signature.DEFAULT_NAME); + try (Session s = new Session(g); + ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { + f1.session().run(Init.DEFAULT_NAME); + try { + SavedModelBundle.exporter(testFolder.toString()) + .withFunction(f1) + .withFunction(f2) + .export(); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + } + } + + private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { + Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xShape)); + Variable y = tf + .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.DTYPE)); + ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); + Init init = tf.init(); + return Signature.builder().input("input", x).output("reducedSum", z).build(); + } + + private static Signature buildIdentityGraph(Ops tf, String signatureName) { + Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + Identity xprime = tf.identity(x); + return Signature.builder().name(signatureName).input("x", x).output("x", xprime).build(); + } + private static RunOptions sillyRunOptions() { return RunOptions.newBuilder() .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) .build(); } - public static ConfigProto sillyConfigProto() { + private static ConfigProto sillyConfigProto() { return ConfigProto.newBuilder() .setInterOpParallelismThreads(1) .setIntraOpParallelismThreads(1) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 7faf8c6fbdb..fa41af32a29 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Split; import org.tensorflow.op.core.Variable; import org.tensorflow.op.linalg.MatMul; @@ -31,6 +36,7 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; /** Unit tests for {@link org.tensorflow.Session}. */ @@ -205,6 +211,24 @@ public void runInitByName() { } } + @Test + public void save() throws IOException { + Path testFolder = Files.createTempDirectory("tf-session-save-test"); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Variable x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); + Variable y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); + Init init = tf.init(); + + try (Session s = new Session(g)) { + s.run(init); + s.save(testFolder.resolve("checkpoint").toString()); + } + } + assertTrue(Files.exists(testFolder.resolve("checkpoint.index"))); + assertTrue(Files.exists(testFolder.resolve("checkpoint.data-00000-of-00001"))); + } + private static RunOptions fullTraceRunOptions() { return RunOptions.newBuilder() .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)