diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java new file mode 100644 index 00000000000..2d434ba98fc --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -0,0 +1,291 @@ +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow; + +import java.io.IOException; +import java.util.List; +import java.util.ListIterator; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import org.tensorflow.op.Ops; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; + +/** + * A graph that can be invoked as a single function, with an input and output signature. + * + *

A function can also invoke a + * tf.function + * defined in a {@link SavedModelBundle}. + * + *

{@code
+ * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
+ * Map> outputTensorMap = myFunction.call(inputTensorMap);
+ * }
+ */ +public class ConcreteFunction implements AutoCloseable { + + /** + * Creates a function by building a new graph. + * + *

The {@code functionBuilder} must initialize the function graph from the provided + * {@link Ops} instance and return a valid signature that will be used to feed the input tensors + * and fetch the output tensors on execution. + * + *

The function will be the owner of the new graph and its resulting session. Therefore, + * the function must be enclosed properly with a try-with-resources block to guarantee that + * all native resources will be freed once the function is discarded. For example: + * + *

{@code
+   * public class MyModel {
+   *
+   *   public static Signature addTwo(Ops tf) {
+   *     Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *     Add output = tf.math.add(input, tf.constant(2.0f));
+   *     return Signature.builder("addTwo").input("x", input).output("y", output).build();
+   *   }
+   *
+   *   public static void main(String args[]) {
+   *     try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo);
+   *         Tensor x = TFloat32.scalarOf(2.0f)) {
+   *       assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
+   *     }
+   *   }
+   * }
+   * }
+ * + * @param functionBuilder function builder + * @return the new function + */ + public static ConcreteFunction create(Function functionBuilder) { + Graph graph = new Graph(); + try { + Ops tf = Ops.create(graph); + Signature signature = functionBuilder.apply(tf); + return new ConcreteFunction(signature, graph, new Session(graph), Ownership.GRAPH_AND_SESSION); + } catch (Exception e) { + graph.close(); + throw e; + } + } + + /** + * Create a function from a signature and an existing graph. + * + *

The function will keep the ownership of the session used to run the graph but not + * the graph itself, meaning that the lifetime of the latter can extend beyond the scope + * of the function. For example: + * + *

{@code
+   * try (Graph g = new Graph()) {
+   *   Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *   Add output = tf.math.add(input, tf.constant(2.0f));
+   *   Signature signature = Signature.builder().input("x", input).output("y", output).build();
+   *
+   *   try (ConcreteFunction f = ConcreteFunction.create(signature, g);
+   *       Tensor x = TFloat32.scalarOf(2.0f)) {
+   *     assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
+   *   }
+   *   // Graph g is still valid at this point
+   * }
+   * }
+ * + * @param signature signature of the function to create + * @param graph a valid and initialized graph + * @return a new function + */ + public static ConcreteFunction create(Signature signature, Graph graph) { + return new ConcreteFunction(signature, graph, new Session(graph), Ownership.SESSION_ONLY); + } + + /** + * Create a function from a signature and a valid graph session. + * + *

The function will not own the session nor its graph, meaning that their lifetime + * can extend beyond the scope of the function. Therefore the function does not need to be + * closed after its usage. For example: + * + *

{@code
+   * try (Graph g = new Graph()) {
+   *   Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *   Add output = tf.math.add(input, tf.constant(2.0f));
+   *   Signature signature = Signature.builder().input("x", input).output("y", output).build();
+   *
+   *   try (Session s = new Session(g)) {
+   *     // Auto-closing the function just as an example but this is not required since it has
+   *     // no effect
+   *     try (ConcreteFunction f = ConcreteFunction.create(signature, s);
+   *         Tensor t = TFloat32.scalarOf(2.0f)) {
+   *       assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
+   *     }
+   *     // Session s is still valid at this point
+   *   }
+   *   // Graph g is still valid at this point
+   * }
+   * }
+ * + * @param signature signature of the function to create + * @param graph a valid session to an initialized graph + * @return a new function + */ + public static ConcreteFunction create(Signature signature, Session session) { + return new ConcreteFunction(signature, session.graph(), session, Ownership.NONE); + } + + /** + * Returns the signature of this function + */ + public Signature signature() { + return signature; + } + + /** + * Invokes a function. + * + *

Caller is responsible for closing all Tensors. + * + * @param tensor input tensor + * @return output tensor + */ + public Map> call(Map> arguments) + throws IllegalArgumentException { + + final SignatureDef signatureDef = signature.asSignatureDef(); + final Session.Runner runner = session.runner(); + + signatureDef.getInputsMap().forEach((argName, t) -> { + Tensor tensor = arguments.get(argName); + if (tensor == null) { + throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); + } + runner.feed(t.getName(), tensor); + }); + + Map outputToNode = signatureDef.getOutputsMap(); + outputToNode.values().forEach(t -> runner.fetch(t.getName())); + + List> resultTensors = runner.run(); + try { + ListIterator> resultTensorIter = resultTensors.listIterator(); + Map> returnMap = new HashMap>(); + + // Use the output names as present in the signature definition + for (String nodeName: outputToNode.keySet()) { + returnMap.put(nodeName, resultTensorIter.next()); + } + return returnMap; + + } catch (Exception e) { + // Release tensors before throwing exception + for (Tensor t : resultTensors) { + t.close(); + } + throw e; + } + } + + /** + * Invokes a function with a single input and output. + * + *

Caller is responsible for closing all Tensors. + * + * @param tensor input tensor + * @return output tensor + * @throws IllegalArgumentException if there are multiple input or output parameters defined + * in the function + */ + public Tensor call(Tensor tensor) throws IllegalArgumentException { + final SignatureDef signatureDef = signature.asSignatureDef(); + + if (signatureDef.getInputsCount() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); + } + String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName(); + + if (signatureDef.getOutputsCount() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); + } + String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName(); + + return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); + } + + /** + * Export this function as a saved model. + * + *

This method is convenient shortcut equivalent to + * {@code SavedModel.exporter(exportDir).withFunction(this).export()} + * + * @throws IOException if saved model or variable state cannot be written on disk + */ + public void save(String exportDir) throws IOException { + SavedModelBundle.exporter(exportDir).withFunction(this).export(); + } + + /** + * Returns the session used to execute the graph when calling this function + * + *

In general, a user does not need to handle directly the session of a function and rely + * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to + * the session might be necessary, as it allows more running options. + * + * @return the function session + */ + public Session session() { + return session; + } + + /** + * Returns the graph of this function + */ + public Graph graph() { + return graph; + } + + @Override + public void close() { + if (ownership != Ownership.NONE) { + session.close(); + if (ownership == Ownership.GRAPH_AND_SESSION) { + graph.close(); + } + } + } + + @Override + public String toString() { + return signature.toString(); + } + + private enum Ownership { + GRAPH_AND_SESSION, SESSION_ONLY, NONE; + } + + private final Graph graph; + private final Session session; + private final Signature signature; + private final Ownership ownership; + + ConcreteFunction(Signature signature, Graph graph, Session session, Ownership ownership) { + this.graph = graph; + this.session = session; + this.signature = signature; + this.ownership = ownership; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 221268191fd..f365956dff1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -43,8 +43,17 @@ import org.tensorflow.internal.c_api.TF_Output; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_WhileParams; +import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.NoOp; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.train.Restore; +import org.tensorflow.op.train.Save; import org.tensorflow.proto.framework.GraphDef; +import org.tensorflow.proto.util.SaverDef; +import org.tensorflow.types.TString; /** @@ -67,6 +76,11 @@ public Graph() { this.nativeHandle = nativeHandle; } + Graph(TF_Graph nativeHandle, SaverDef saverDef) { + this(nativeHandle); + this.saverDef = saverDef; + } + /** * Release resources associated with the Graph. * @@ -402,9 +416,27 @@ public Output[] whileLoop( } } + /** + * Return the {@link SaverDef} instance used to save the state of all variables present in + * this graph. + * + *

On the first call of this method, all nodes necessary to save and restore the state of the + * variables are added to the graph. Consequently, any variables that are added to the graph after + * this call could not be saved nor restored using this {@link SaverDef}. + * + * @return a {@link SaverDef} instance + */ + synchronized SaverDef saverDef() { + if (saverDef == null) { + saverDef = addVariableSaver(this); + } + return saverDef; + } + private final Object nativeHandleLock = new Object(); private TF_Graph nativeHandle; private int refcount = 0; + private SaverDef saverDef; private final List initializers = new ArrayList<>(); @@ -726,6 +758,53 @@ private static Object[] whileLoop( } } + private static SaverDef addVariableSaver(Graph graph) { + Ops tf = Ops.create(graph).withSubScope("save"); + + List varNames = new ArrayList<>(); + List> varOutputs = new ArrayList<>(); + List> varTypes = new ArrayList<>(); + + for (Iterator iter = graph.operations(); iter.hasNext();) { + Operation op = iter.next(); + if (op.type().equals("VariableV2")) { + varNames.add(op.name()); + varOutputs.add(op.output(0)); + varTypes.add(op.output(0).dataType()); + } + } + + // FIXME Need an easier way to initialize an NdArray from a list + String[] tmp = new String[varNames.size()]; + Constant varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp))); + Operand varSlices = tf.zerosLike(varNamesTensor); + + Placeholder saveFilename = tf.placeholder(TString.DTYPE); + Save saveVariables = tf.train.save( + saveFilename, + varNamesTensor, + varSlices, + varOutputs + ); + Restore restoreVariables = tf.train.restore( + saveFilename, + varNamesTensor, + varSlices, + varTypes + ); + List restoreOps = new ArrayList<>(varOutputs.size()); + for (int i = 0; i < varOutputs.size(); ++i) { + restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i))); + } + NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp(); + + return SaverDef.newBuilder() + .setFilenameTensorName(saveFilename.op().name()) + .setSaveTensorName(saveVariables.op().name()) + .setRestoreOpName(restoreAll.op().name()) + .build(); + } + static { TensorFlow.init(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 8f683a59d89..3ee294c8fdc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -20,6 +20,17 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; @@ -31,7 +42,10 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.MetaGraphDef; +import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef; import org.tensorflow.proto.framework.RunOptions; +import org.tensorflow.proto.framework.SavedModel; +import org.tensorflow.proto.util.SaverDef; /** * SavedModelBundle represents a model loaded from storage. @@ -43,8 +57,12 @@ * protocol buffer. */ public class SavedModelBundle implements AutoCloseable { + + public static final String DEFAULT_TAG = "serve"; + /** Options for loading a SavedModel. */ public static final class Loader { + /** Load a SavedModelBundle with the configured options. */ public SavedModelBundle load() { return SavedModelBundle.load(exportDir, tags, configProto, runOptions); @@ -56,6 +74,7 @@ public SavedModelBundle load() { * @param options A RunOptions * protocol buffer. + * @return this object */ public Loader withRunOptions(RunOptions options) { this.runOptions = options; @@ -68,6 +87,7 @@ public Loader withRunOptions(RunOptions options) { * @param configProto A ConfigProto * protocol buffer. + * @return this object */ public Loader withConfigProto(ConfigProto configProto) { this.configProto = configProto; @@ -77,9 +97,14 @@ public Loader withConfigProto(ConfigProto configProto) { /** * Sets the set of tags that identify the specific graph in the saved model to load. * + *

Has no effect if {@code tags} is null or empty + * * @param tags the tags identifying the specific MetaGraphDef to load. + * @return this object + * @throws IllegalArgumentException if tags are invalid */ public Loader withTags(String... tags) { + validateTags(tags); this.tags = tags; return this; } @@ -89,11 +114,111 @@ private Loader(String exportDir) { } private String exportDir = null; - private String[] tags = null; + private String[] tags = { DEFAULT_TAG }; private ConfigProto configProto = null; private RunOptions runOptions = null; } + /** Options for exporting a SavedModel. */ + public static final class Exporter { + + /** + * Sets the set of tags that identify the specific graph in the saved model to save. + * + *

Note that only one graph per model can be saved right now using this API. + * + * @param tags the tags identifying the specific MetaGraphDef to save. + * @return this object + * @throws IllegalArgumentException if tags are invalid + */ + public Exporter withTags(String... tags) { + validateTags(tags); + this.tags = tags; + return this; + } + + /** + * Save a concrete function of this model. + * + *

The concrete function carries a signature (i.e. a list of user-friendly input and outputs + * names to a graph) and a valid session to a graph to be saved in the model. + * + *

Note:Eventually, TensorFlow for Java will support the export of functions objects like + * the Python API does but right now, only session-centric models are supported (i.e. models that + * has a single main graph and one or more signatures). These models are compatible with those + * exported by TensorFlow 1.x or by TensorFlow 2.x estimators. + * + *

Therefore, all functions exported in a model should share the same session at the moment + * or an exception will be thrown. + * + * @param function a function carrying a signature and a valid session to the graph to be saved + * @return this object + * @throws IllegalArgumentException if a function with the same name has already been added to the model + * @throws UnsupportedOperationException if this function does not share the same session with the other + * functions added to this model + */ + public Exporter withFunction(ConcreteFunction function) { + Signature signature = function.signature(); + if (functions.containsKey(signature.key())) { + throw new IllegalArgumentException("Function \"" + signature.key() + "\" was already added to the model"); + } + functions.put(signature.key(), function); + if (session == null) { + session = function.session(); + } else if (session != function.session()) { + throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); + } + metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef()); + return this; + } + + /** + * Save the model into the export directory. + * + * @throws IOException if saved model or variable state cannot be written on disk + */ + public void export() throws IOException { + if (functions.isEmpty() || session == null) { + throw new IllegalStateException("Model should contain at least one valid function"); + } + Graph graph = session.graph(); + + // It is imperative to retrieve the graphDef after the saverDef, as the former might add + // new ops to the graph for saving and restoring the variables. + SaverDef saverDef = graph.saverDef(); + + MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder + .setSaverDef(saverDef) + .setGraphDef(graph.toGraphDef()) + .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(Arrays.asList(tags))); + functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); + + // Make sure saved model directories exist + Path variableDir = Paths.get(exportDir, "variables"); + variableDir.toFile().mkdirs(); + + // Save the variables state + session.save(variableDir.resolve("variables").toString()); + + // Save the graph + SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build(); + try (OutputStream file = + new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) { + savedModelDef.writeTo(file); + } + } + + Exporter(String exportDir) { + this.exportDir = exportDir; + } + + private final String exportDir; + private String[] tags = { DEFAULT_TAG }; + private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); + private final Map functions = new LinkedHashMap<>(); + private Session session; + } + /** * Load a saved model from an export directory. The model that is being loaded should be created * using the Saved Model @@ -110,7 +235,11 @@ private Loader(String exportDir) { * @return a bundle containing the graph and associated session. */ public static SavedModelBundle load(String exportDir, String... tags) { - return loader(exportDir).withTags(tags).load(); + Loader loader = loader(exportDir); + if (tags != null && tags.length > 0) { + loader.withTags(tags); + } + return loader.load(); } /** @@ -125,6 +254,18 @@ public static Loader loader(String exportDir) { return new Loader(exportDir); } + /** + * Export a saved model. + * + *

Returns a Exporter object for setting configuration options before actually + * saving the model. + * + * @param exportDir the directory path containing a saved model. + */ + public static Exporter exporter(String exportDir) { + return new Exporter(exportDir); + } + /** * Returns the MetaGraphDef @@ -148,6 +289,64 @@ public Session session() { return session; } + /** + * Return the signature of all functions available in this saved model. + */ + public List signatures() { + return functions.values().stream().map(f -> f.signature()).collect(Collectors.toList()); + } + + /** + * Return a {@link ConcreteFunction} corresponding to the function signature. + * + *

{@code
+   * ConcreteFunction myFunction = savedModelBundle.function("mySignatureKey");
+   * Map> outputTensorMap = myFunction.call(session, inputTensorMap);
+   * }
+ * + * @param signatureKey name of the {@code SignatureDef} in the saved model. + * @return object that can be used to make calls to a function + * @throws IllegalArgumentException if {@code signatureKey} is not found in this + * saved model. + */ + public ConcreteFunction function(String signatureKey) { + ConcreteFunction function = functions.get(signatureKey); + if (function == null) { + throw new IllegalArgumentException( + String.format("Function with signature [%s] not found", signatureKey)); + } + return function; + } + + /** + * Invokes the default function directly from this model. + * + *

The default function selection is done based on the first of the following conditions that + * is true: + *

+ * + *

Caller is responsible for closing all returned Tensors. + * + * @param arguments list of input tensors, mapped by their signature name + * @return list of output tensors, mapped by the signature name + * @throws IllegalArgumentException if no function can be selected by default + */ + public Map> call(Map> arguments) { + ConcreteFunction function = null; + if (functions.size() == 1) { + function = functions.values().iterator().next(); + } else { + function = functions.get(Signature.DEFAULT_KEY); + } + if (function == null) { + throw new IllegalArgumentException("Cannot elect a default function for this model"); + } + return function.call(arguments); + } + /** * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model * bundle. @@ -161,11 +360,13 @@ public void close() { private final Graph graph; private final Session session; private final MetaGraphDef metaGraphDef; + private final Map functions; - private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef) { + private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) { this.graph = graph; this.session = session; this.metaGraphDef = metaGraphDef; + this.functions = functions; } /** @@ -175,10 +376,21 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef *

Invoked from the native load method. Takes ownership of the handles. */ private static SavedModelBundle fromHandle( - TF_Graph graphHandle, TF_Session sessionHandle, MetaGraphDef metaGraphDef) { - Graph graph = new Graph(graphHandle); - Session session = new Session(graph, sessionHandle); - return new SavedModelBundle(graph, session, metaGraphDef); + final TF_Graph graphHandle, final TF_Session sessionHandle, MetaGraphDef metaGraphDef) { + + final Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef()); + final Session session = new Session(graph, sessionHandle); + + // Create a separate function for each signature of the main graph. + // Note that the saved model will remain the owner of the graph and the session, meaning + // that the functions do not need to be closed by the user and if it does, it should have + // no effect. + final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); + metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { + Signature signature = new Signature(signatureName, signatureDef); + functions.put(signatureName, ConcreteFunction.create(signature, session)); + }); + return new SavedModelBundle(graph, session, metaGraphDef, functions); } private static SavedModelBundle load( @@ -216,6 +428,12 @@ opts, runOpts, new BytePointer(exportDir), new PointerPointer(tags), return bundle; } + private static void validateTags(String[] tags) { + if (tags == null || tags.length == 0 || Arrays.stream(tags).anyMatch(t -> t == null || t.isEmpty())) { + throw new IllegalArgumentException("Invalid tags: " + Arrays.toString(tags)); + } + } + static { TensorFlow.init(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 0676ce8ec4e..cffad0db976 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -36,6 +36,8 @@ import java.util.ArrayList; import java.util.List; +import org.tensorflow.proto.util.SaverDef; +import org.tensorflow.types.TString; import static org.tensorflow.Graph.resolveOutputs; import static org.tensorflow.internal.c_api.global.tensorflow.*; @@ -444,6 +446,27 @@ public void run(Op op) { runner().addTarget(op.op()).run(); } + /** + * Saves the actual state of the variables of this session's graph. + * + *

{@code prefix} is a path where the files containing the variables state will be saved, + * followed by a prefix for naming these files. For example, if {@code prefix} is set to + * mymodel/myvariables/variables, then the generated files will be located under + * mymodel/myvariables and named variables.data-*-of-* + * + *

Note that this method might alter the underlying graph if it is the first time that one + * of its session is saved, see {@link Graph#saverDef()} for more details. + * + * @param prefix prefix to the variable files to save + */ + public void save(String prefix) { + SaverDef saverDef = graph.saverDef(); + runner() + .addTarget(saverDef.getSaveTensorName()) + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) + .run(); + } + /** * Output tensors and metadata obtained when executing a session. * @@ -463,6 +486,10 @@ public static final class Run { public RunMetadata metadata; } + Graph graph() { + return graph; + } + private final Graph graph; private final Graph.Reference graphRef; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java new file mode 100644 index 00000000000..cd138f44c3c --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -0,0 +1,206 @@ +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow; + +import java.util.Map; +import java.util.Set; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.proto.framework.TensorShapeProto; +import org.tensorflow.proto.framework.TensorShapeProto.Dim; + +/** + * Describe the inputs and outputs of an executable entity, such as a {@link ConcreteFunction}, among + * other useful metadata. + */ +public class Signature { + + /** The default signature key, when not provided */ + public static final String DEFAULT_KEY = "serving_default"; + + /** + * Builds a new function signature. + */ + public static class Builder { + + /** + * Sets the unique key of this signature. + * + *

When not set explicitly, the default value is {@link #DEFAULT_KEY}. + * + * @param key signature key + * @return this builder + * @throws IllegalArgumentException if the key is invalid + */ + public Builder key(String key) { + if (key == null || key.isEmpty()) { + throw new IllegalArgumentException("Invalid key: " + key); + } + this.key = key; + return this; + } + + /** + * Register a tensor as an input of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor + * @return this builder + * @throws IllegalArgumentException if {@code inputName} is already mapped to another input + */ + public Builder input(String inputName, Operand input) { + if (signatureBuilder.containsInputs(inputName)) { + throw new IllegalArgumentException("\"" + inputName + "\" is already being mapped to another input"); + } + signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput())); + return this; + } + + /** + * Register a tensor as an output of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor + * @return this builder + * @throws IllegalArgumentException if {@code outputName} is already mapped to another output + */ + public Builder output(String outputName, Operand output) { + if (signatureBuilder.containsOutputs(outputName)) { + throw new IllegalArgumentException("\"" + outputName + "\" is already being mapped to another output"); + } + signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput())); + return this; + } + + /** + * Provide extensible name information enabling third-party users to mark a signature as + * supporting a particular method + * + * @param methodName method name or null for none (default) + * @return this builder + */ + public Builder methodName(String methodName) { + signatureBuilder.setMethodName(methodName == null ? "" : methodName); + return this; + } + + /** + * Returns a signature from the provided data. + */ + public Signature build() { + return new Signature(key, signatureBuilder.build()); + } + + private static TensorInfo toTensorInfo(Output operand) { + Shape shape = operand.shape(); + TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); + for (int i = 0; i < shape.numDimensions(); ++i) { + tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i))); + } + return TensorInfo.newBuilder() + .setDtype(DataType.forNumber(operand.dataType().nativeCode())) + .setTensorShape(tensorShapeBuilder) + .setName(operand.op().name() + ":" + operand.index()) + .build(); + } + + private String key = DEFAULT_KEY; + private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); + } + + /** + * Returns a new builder for creating a signature + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Return the key of this signature + */ + public String key() { + return key; + } + + /** + * Returns the method name of this signature (e.g. as exposed by TF serving) or null if none + */ + public String methodName() { + return signatureDef.getMethodName().isEmpty() ? null : signatureDef.getMethodName(); + } + + /** + * Returns the names of the inputs in this signature + */ + public Set inputNames() { + return signatureDef.getInputsMap().keySet(); + } + + /** + * Returns the names of the outputs in this signature + */ + public Set outputNames() { + return signatureDef.getOutputsMap().keySet(); + } + + @Override + public String toString() { + StringBuilder strBuilder = new StringBuilder("Signature for \"" + key +"\":\n"); + if (!methodName().isEmpty()) { + strBuilder.append("\tMethod: \"").append(methodName()).append("\"\n"); + } + if (signatureDef.getInputsCount() > 0) { + strBuilder.append("\tInputs:\n"); + printTensorInfo(signatureDef.getInputsMap(), strBuilder); + } + if (signatureDef.getOutputsCount() > 0) { + strBuilder.append("\tOutputs:\n"); + printTensorInfo(signatureDef.getOutputsMap(), strBuilder); + } + return strBuilder.toString(); + } + + Signature(String key, SignatureDef signatureDef) { + this.key = key; + this.signatureDef = signatureDef; + } + + SignatureDef asSignatureDef() { + return signatureDef; + } + + private final String key; + private final SignatureDef signatureDef; + + private static void printTensorInfo(Map tensorMap, StringBuilder strBuilder) { + tensorMap.forEach((key, tensorInfo) -> { + strBuilder.append("\t\t\"") + .append(key) + .append("\": dtype=") + .append(tensorInfo.getDtype().name()) + .append(", shape=("); + for (int i = 0; i < tensorInfo.getTensorShape().getDimCount(); ++i) { + strBuilder.append(tensorInfo.getTensorShape().getDim(i).getSize()); + if (i < tensorInfo.getTensorShape().getDimCount() - 1) { + strBuilder.append(", "); + } + } + strBuilder.append(")\n"); + }); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java new file mode 100644 index 00000000000..3ea20fcbb46 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TFloat32; + +public class ConcreteFunctionTest { + + private static Signature plusFive(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.DTYPE); + Add output = tf.math.add(input, tf.constant(5.0f)); + Init init = tf.init(); // for native resource management tests + return Signature.builder().key("plusFive").input("x", input).output("y", output).build(); + } + + private static Signature minusTwo(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.DTYPE); + Sub output = tf.math.sub(input, tf.constant(2.0f)); + return Signature.builder().key("minusTwo").input("x", input).output("y", output).build(); + } + + @Test + public void createFunction() { + try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + + @Test + public void createFunctionFromGraph() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + try (ConcreteFunction f = ConcreteFunction.create(signature, g); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + } + + @Test + public void createFunctionFromSession() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + try (Session s = new Session(g)) { + try (ConcreteFunction f = ConcreteFunction.create(signature, s); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + } + } + + @Test + public void chainFunctions() { + try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive); + ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(6.0f, f2.call(f1.call(x)).expect(TFloat32.DTYPE).data().getFloat()); + } + } + + @Test + public void closingFunctionReleaseAllResourcesItOwns() { + Graph g; + Session s; + try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive)) { + g = f.graph(); + s = f.session(); + } + assertThrows(IllegalStateException.class, () -> s.run("Add")); + assertThrows(IllegalStateException.class, () -> g.toGraphDef()); + } + + @Test + public void closingFunctionCreatedFromGraphOnlyReleaseResourcesItOwns() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + Session s; + try (ConcreteFunction f = ConcreteFunction.create(signature, g)) { + s = f.session(); + } + assertThrows(IllegalStateException.class, () -> s.run(Init.DEFAULT_NAME)); + g.toGraphDef(); // check that graph is still valid + } + } + + @Test + public void closingFunctionCreatedFromSessionDoesNotReleaseResources() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + try (Session s = new Session(g)) { + try (ConcreteFunction f = ConcreteFunction.create(signature, s)) { + } + s.run(Init.DEFAULT_NAME); // check that session is still valid + } + g.toGraphDef(); // check that graph is still valid + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index f144b11b840..314c3063422 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -1,3 +1,17 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ package org.tensorflow; import org.junit.jupiter.api.Test; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 91c07e3f4b6..eabb86f732f 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -15,24 +15,48 @@ package org.tensorflow; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.io.IOException; import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Collections; +import java.util.Map; +import java.util.HashMap; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Identity; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunOptions; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.types.TFloat32; /** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ public class SavedModelBundleTest { + private static final float EPSILON = 1e-7f; private static final String SAVED_MODEL_PATH; + private static final String SAVED_MODEL_PY_PATH; + static { try { SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString(); + SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()).toString(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -72,13 +96,242 @@ public void loader() { } } + @Test + public void exportFunctionWithVariables() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + float reducedSum; + FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); + Shape xyShape = Shape.of(2, 3L); + try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { + // Init variable state by running the Init operation directly + f.session().run(Init.DEFAULT_NAME); + + // Call the graph and remember the result of computation for later + try (Tensor xTensor = TFloat32.tensorOf(xValue); + Tensor zTensor = f.call(xTensor).expect(TFloat32.DTYPE)) { + reducedSum = zTensor.data().getFloat(); + } + // Save/export the model (which is a single function in this case) + f.save(testFolder.toString()); + } + assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); + assertTrue(Files + .exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001")))); + assertTrue(Files.exists(testFolder.resolve("saved_model.pb"))); + + // Reload the model just saved and validate its data + try (SavedModelBundle savedModel = + SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) { + assertNotNull(savedModel.metaGraphDef()); + assertNotNull(savedModel.metaGraphDef().getSaverDef()); + assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); + assertEquals(Signature.DEFAULT_KEY, + savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); + + ConcreteFunction function = savedModel.function(Signature.DEFAULT_KEY); + assertNotNull(function); + + Signature signature = function.signature(); + assertNotNull(signature); + assertEquals(1, signature.inputNames().size()); + assertEquals("input", signature.inputNames().iterator().next()); + assertEquals(1, signature.outputNames().size()); + assertEquals("reducedSum", signature.outputNames().iterator().next()); + + SignatureDef signatureDef = signature.asSignatureDef(); + assertEquals(1, signatureDef.getInputsCount()); + assertEquals(1, signatureDef.getOutputsCount()); + + TensorInfo inputInfo = signatureDef.getInputsMap().get("input"); + assertNotNull(inputInfo); + assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount()); + for (int i = 0; i < xyShape.numDimensions(); ++i) { + assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize()); + } + + TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum"); + assertNotNull(outputInfo); + assertEquals(0, outputInfo.getTensorShape().getDimCount()); + + try (Tensor xTensor = TFloat32.tensorOf(xValue)) { + // Call the saved model function and make sure it returns the same result as before + try (Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + } + // Now call the same function directly from the model + try (Tensor zTensor = + savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum").expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + } + } + } + } + + @Test + public void exportMultipleFunctions() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + float reducedSum; + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, "identity"); + try (Session s = new Session(g); + ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { + f1.session().run(Init.DEFAULT_NAME); + try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { + reducedSum = t.data().getFloat(); + } + SavedModelBundle.exporter(testFolder.toString()) + .withFunction(f1) + .withFunction(f2) + .export(); + } + } + try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { + assertEquals(2, model.signatures().size()); + ConcreteFunction f1 = model.function(Signature.DEFAULT_KEY); + assertNotNull(f1); + try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, t.data().getFloat(), EPSILON); + } + ConcreteFunction f2 = model.function("identity"); + assertNotNull(f2); + try (Tensor x = TFloat32.scalarOf(10.0f); + Tensor t = f2.call(x).expect(TFloat32.DTYPE)) { + assertEquals(10.0f, t.data().getFloat(), 0.0f); + } + try { + model.function("NoSuchFunction"); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + } + + @Test + public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, "identity"); + try (ConcreteFunction f1 = ConcreteFunction.create(f1Signature, g); + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, g)) { + f1.session().run(Init.DEFAULT_NAME); + try { + SavedModelBundle.exporter(testFolder.toString()) + .withFunction(f1) + .withFunction(f2) + .export(); + fail(); + } catch (UnsupportedOperationException e) { + // as expected + } + } + } + } + + @Test + public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, Signature.DEFAULT_KEY); + try (Session s = new Session(g); + ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { + f1.session().run(Init.DEFAULT_NAME); + try { + SavedModelBundle.exporter(testFolder.toString()) + .withFunction(f1) + .withFunction(f2) + .export(); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + } + } + + @Test + public void cannotExportOrImportInvalidTags() { + assertThrows(IllegalArgumentException.class, () -> + SavedModelBundle.loader("/").withTags() + ); + assertThrows(IllegalArgumentException.class, () -> + SavedModelBundle.loader("/").withTags(new String[]{}) + ); + assertThrows(IllegalArgumentException.class, () -> + SavedModelBundle.loader("/").withTags(new String[]{"tag", null}) + ); + assertThrows(IllegalArgumentException.class, () -> + SavedModelBundle.loader("/").withTags(new String[]{"tag", ""}) + ); + assertThrows(IllegalArgumentException.class, () -> + SavedModelBundle.exporter("/").withTags() + ); + assertThrows(IllegalArgumentException.class, () -> + SavedModelBundle.exporter("/").withTags(new String[]{}) + ); + assertThrows(IllegalArgumentException.class, () -> + SavedModelBundle.exporter("/").withTags(new String[]{"tag", null}) + ); + assertThrows(IllegalArgumentException.class, () -> + SavedModelBundle.exporter("/").withTags(new String[]{"tag", ""}) + ); + } + + @Test + public void pythonTfFunction() { + // ConcreteFunctions on models saved using python + try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PY_PATH, "serve")) { + /* + * Test model was created in python + * Signature name used for saving 'add', argument names 'a' and 'b' + */ + ConcreteFunction add = bundle.function("add"); + Map> args = new HashMap(); + try (Tensor a = TFloat32.scalarOf(10.0f); + Tensor b = TFloat32.scalarOf(15.5f)) { + args.put("a", a); + args.put("b", b); + Map> result = add.call(args); + assertEquals(result.size(), 1); + try (Tensor c = result.values().iterator().next().expect(TFloat32.DTYPE)) { + assertEquals(25.5f, c.data().getFloat()); + } + } + } + } + + private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { + Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xShape)); + Variable y = tf + .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.DTYPE)); + ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); + Init init = tf.init(); + return Signature.builder().input("input", x).output("reducedSum", z).build(); + } + + private static Signature buildIdentityGraph(Ops tf, String signatureKey) { + Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + Identity xprime = tf.identity(x); + return Signature.builder().key(signatureKey).input("x", x).output("x", xprime).build(); + } + private static RunOptions sillyRunOptions() { return RunOptions.newBuilder() .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) .build(); } - public static ConfigProto sillyConfigProto() { + private static ConfigProto sillyConfigProto() { return ConfigProto.newBuilder() .setInterOpParallelismThreads(1) .setIntraOpParallelismThreads(1) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 7faf8c6fbdb..fa41af32a29 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Split; import org.tensorflow.op.core.Variable; import org.tensorflow.op.linalg.MatMul; @@ -31,6 +36,7 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; /** Unit tests for {@link org.tensorflow.Session}. */ @@ -205,6 +211,24 @@ public void runInitByName() { } } + @Test + public void save() throws IOException { + Path testFolder = Files.createTempDirectory("tf-session-save-test"); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Variable x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); + Variable y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); + Init init = tf.init(); + + try (Session s = new Session(g)) { + s.run(init); + s.save(testFolder.resolve("checkpoint").toString()); + } + } + assertTrue(Files.exists(testFolder.resolve("checkpoint.index"))); + assertTrue(Files.exists(testFolder.resolve("checkpoint.data-00000-of-00001"))); + } + private static RunOptions fullTraceRunOptions() { return RunOptions.newBuilder() .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java new file mode 100644 index 00000000000..8931ecbbde1 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TFloat32; + +public class SignatureTest { + + @Test + public void cannotUseEmptyKey() { + Signature.Builder builder = Signature.builder(); + assertThrows(IllegalArgumentException.class, () -> builder.key(null)); + assertThrows(IllegalArgumentException.class, () -> builder.key("")); + builder.key("valid key"); + } + + @Test + public void cannotDuplicateInputOutputNames() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature.Builder builder = Signature.builder() + .input("x", tf.constant(10.0f)) + .output("x", tf.constant(10.0f)) // can add an output with the same name as an input + .output("y", tf.constant(20.0f)); + assertThrows(IllegalArgumentException.class, () -> builder.input("x", tf.constant(10))); + assertThrows(IllegalArgumentException.class, () -> builder.output("y", tf.constant(20.0f))); + } + } + + @Test + public void emptyMethodNameConvertedToNull() { + Signature signature = Signature.builder().key("f").build(); + assertNull(signature.methodName()); + signature = Signature.builder().key("f").methodName("").build(); + assertNull(signature.methodName()); + signature = Signature.builder().key("f").methodName(null).build(); + assertNull(signature.methodName()); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb new file mode 100644 index 00000000000..169e0095a3e Binary files /dev/null and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb differ diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..ac369237d31 Binary files /dev/null and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index new file mode 100644 index 00000000000..8be702a36ed Binary files /dev/null and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index differ diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py new file mode 100644 index 00000000000..f5160401515 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py @@ -0,0 +1,63 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Saved model created using Python @tf.function, tensorflow version 2.3.0. +# +# Python code used to create saved model below +# +# WARNING: This script is just attached to the test code base for reference and is not used nor +# executed by this project to generate the saved model, which has been added manually. + +import tensorflow as tf + +class MyModel(tf.keras.Model): + def __init__(self): + super(MyModel, self).__init__() + self.const_scalar = tf.constant(0.0) + self.const_vector = tf.constant([0.0, 0.0, 0.0]) + self.const_matrix = tf.constant([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + + @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='request')]) + def serve(self, x): + return self.const_scalar + x + + @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')]) + def get_scalar(self, x): + return self.const_scalar + x + + @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')]) + def get_vector(self, x): + return self.const_vector + x + + @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')]) + def get_matrix(self, x): + return self.const_matrix + x + + @tf.function(input_signature=[ + tf.TensorSpec(shape=None, dtype=tf.float32, name='a'), + tf.TensorSpec(shape=None, dtype=tf.float32, name='b')]) + def add(self, a, b): + return a + b + +model = MyModel() + +signatures = { + "get_const_scalar": model.get_scalar, + "get_const_vector": model.get_vector, + "get_const_matrix": model.get_matrix, + "add": model.add +} + +tf.saved_model.save(obj=model, export_dir='model', signatures=signatures)