signatures() {
+ return functions.values().stream().map(f -> f.signature()).collect(Collectors.toList());
+ }
+
+ /**
+ * Return a {@link ConcreteFunction} corresponding to the function signature.
+ *
+ * {@code
+ * ConcreteFunction myFunction = savedModelBundle.function("mySignatureKey");
+ * Map> outputTensorMap = myFunction.call(session, inputTensorMap);
+ * }
+ *
+ * @param signatureKey name of the {@code SignatureDef} in the saved model.
+ * @return object that can be used to make calls to a function
+ * @throws IllegalArgumentException if {@code signatureKey} is not found in this
+ * saved model.
+ */
+ public ConcreteFunction function(String signatureKey) {
+ ConcreteFunction function = functions.get(signatureKey);
+ if (function == null) {
+ throw new IllegalArgumentException(
+ String.format("Function with signature [%s] not found", signatureKey));
+ }
+ return function;
+ }
+
+ /**
+ * Invokes the default function directly from this model.
+ *
+ * The default function selection is done based on the first of the following conditions that
+ * is true:
+ *
+ * - The function is the only signature available attached to the main graph of this saved model
+ * - The function is mapped to the default signature name, which is "serving_default"
+ *
+ *
+ * Caller is responsible for closing all returned Tensors.
+ *
+ * @param arguments list of input tensors, mapped by their signature name
+ * @return list of output tensors, mapped by the signature name
+ * @throws IllegalArgumentException if no function can be selected by default
+ */
+ public Map> call(Map> arguments) {
+ ConcreteFunction function = null;
+ if (functions.size() == 1) {
+ function = functions.values().iterator().next();
+ } else {
+ function = functions.get(Signature.DEFAULT_KEY);
+ }
+ if (function == null) {
+ throw new IllegalArgumentException("Cannot elect a default function for this model");
+ }
+ return function.call(arguments);
+ }
+
/**
* Releases resources (the {@link Graph} and {@link Session}) associated with the saved model
* bundle.
@@ -161,11 +360,13 @@ public void close() {
private final Graph graph;
private final Session session;
private final MetaGraphDef metaGraphDef;
+ private final Map functions;
- private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef) {
+ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) {
this.graph = graph;
this.session = session;
this.metaGraphDef = metaGraphDef;
+ this.functions = functions;
}
/**
@@ -175,10 +376,21 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef
* Invoked from the native load method. Takes ownership of the handles.
*/
private static SavedModelBundle fromHandle(
- TF_Graph graphHandle, TF_Session sessionHandle, MetaGraphDef metaGraphDef) {
- Graph graph = new Graph(graphHandle);
- Session session = new Session(graph, sessionHandle);
- return new SavedModelBundle(graph, session, metaGraphDef);
+ final TF_Graph graphHandle, final TF_Session sessionHandle, MetaGraphDef metaGraphDef) {
+
+ final Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef());
+ final Session session = new Session(graph, sessionHandle);
+
+ // Create a separate function for each signature of the main graph.
+ // Note that the saved model will remain the owner of the graph and the session, meaning
+ // that the functions do not need to be closed by the user and if it does, it should have
+ // no effect.
+ final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount());
+ metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> {
+ Signature signature = new Signature(signatureName, signatureDef);
+ functions.put(signatureName, ConcreteFunction.create(signature, session));
+ });
+ return new SavedModelBundle(graph, session, metaGraphDef, functions);
}
private static SavedModelBundle load(
@@ -216,6 +428,12 @@ opts, runOpts, new BytePointer(exportDir), new PointerPointer(tags),
return bundle;
}
+ private static void validateTags(String[] tags) {
+ if (tags == null || tags.length == 0 || Arrays.stream(tags).anyMatch(t -> t == null || t.isEmpty())) {
+ throw new IllegalArgumentException("Invalid tags: " + Arrays.toString(tags));
+ }
+ }
+
static {
TensorFlow.init();
}
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
index 0676ce8ec4e..cffad0db976 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
@@ -36,6 +36,8 @@
import java.util.ArrayList;
import java.util.List;
+import org.tensorflow.proto.util.SaverDef;
+import org.tensorflow.types.TString;
import static org.tensorflow.Graph.resolveOutputs;
import static org.tensorflow.internal.c_api.global.tensorflow.*;
@@ -444,6 +446,27 @@ public void run(Op op) {
runner().addTarget(op.op()).run();
}
+ /**
+ * Saves the actual state of the variables of this session's graph.
+ *
+ * {@code prefix} is a path where the files containing the variables state will be saved,
+ * followed by a prefix for naming these files. For example, if {@code prefix} is set to
+ * mymodel/myvariables/variables, then the generated files will be located under
+ * mymodel/myvariables and named variables.data-*-of-*
+ *
+ * Note that this method might alter the underlying graph if it is the first time that one
+ * of its session is saved, see {@link Graph#saverDef()} for more details.
+ *
+ * @param prefix prefix to the variable files to save
+ */
+ public void save(String prefix) {
+ SaverDef saverDef = graph.saverDef();
+ runner()
+ .addTarget(saverDef.getSaveTensorName())
+ .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
+ .run();
+ }
+
/**
* Output tensors and metadata obtained when executing a session.
*
@@ -463,6 +486,10 @@ public static final class Run {
public RunMetadata metadata;
}
+ Graph graph() {
+ return graph;
+ }
+
private final Graph graph;
private final Graph.Reference graphRef;
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java
new file mode 100644
index 00000000000..cd138f44c3c
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java
@@ -0,0 +1,206 @@
+/*
+ * Copyright 2020 The TensorFlow Authors. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow;
+
+import java.util.Map;
+import java.util.Set;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.proto.framework.DataType;
+import org.tensorflow.proto.framework.SignatureDef;
+import org.tensorflow.proto.framework.TensorInfo;
+import org.tensorflow.proto.framework.TensorShapeProto;
+import org.tensorflow.proto.framework.TensorShapeProto.Dim;
+
+/**
+ * Describe the inputs and outputs of an executable entity, such as a {@link ConcreteFunction}, among
+ * other useful metadata.
+ */
+public class Signature {
+
+ /** The default signature key, when not provided */
+ public static final String DEFAULT_KEY = "serving_default";
+
+ /**
+ * Builds a new function signature.
+ */
+ public static class Builder {
+
+ /**
+ * Sets the unique key of this signature.
+ *
+ * When not set explicitly, the default value is {@link #DEFAULT_KEY}.
+ *
+ * @param key signature key
+ * @return this builder
+ * @throws IllegalArgumentException if the key is invalid
+ */
+ public Builder key(String key) {
+ if (key == null || key.isEmpty()) {
+ throw new IllegalArgumentException("Invalid key: " + key);
+ }
+ this.key = key;
+ return this;
+ }
+
+ /**
+ * Register a tensor as an input of the function.
+ *
+ * @param inputName user-friendly name for this input tensor
+ * @param input input tensor
+ * @return this builder
+ * @throws IllegalArgumentException if {@code inputName} is already mapped to another input
+ */
+ public Builder input(String inputName, Operand> input) {
+ if (signatureBuilder.containsInputs(inputName)) {
+ throw new IllegalArgumentException("\"" + inputName + "\" is already being mapped to another input");
+ }
+ signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput()));
+ return this;
+ }
+
+ /**
+ * Register a tensor as an output of the function.
+ *
+ * @param inputName user-friendly name for this input tensor
+ * @param input input tensor
+ * @return this builder
+ * @throws IllegalArgumentException if {@code outputName} is already mapped to another output
+ */
+ public Builder output(String outputName, Operand> output) {
+ if (signatureBuilder.containsOutputs(outputName)) {
+ throw new IllegalArgumentException("\"" + outputName + "\" is already being mapped to another output");
+ }
+ signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput()));
+ return this;
+ }
+
+ /**
+ * Provide extensible name information enabling third-party users to mark a signature as
+ * supporting a particular method
+ *
+ * @param methodName method name or null for none (default)
+ * @return this builder
+ */
+ public Builder methodName(String methodName) {
+ signatureBuilder.setMethodName(methodName == null ? "" : methodName);
+ return this;
+ }
+
+ /**
+ * Returns a signature from the provided data.
+ */
+ public Signature build() {
+ return new Signature(key, signatureBuilder.build());
+ }
+
+ private static TensorInfo toTensorInfo(Output> operand) {
+ Shape shape = operand.shape();
+ TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
+ for (int i = 0; i < shape.numDimensions(); ++i) {
+ tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
+ }
+ return TensorInfo.newBuilder()
+ .setDtype(DataType.forNumber(operand.dataType().nativeCode()))
+ .setTensorShape(tensorShapeBuilder)
+ .setName(operand.op().name() + ":" + operand.index())
+ .build();
+ }
+
+ private String key = DEFAULT_KEY;
+ private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder();
+ }
+
+ /**
+ * Returns a new builder for creating a signature
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /**
+ * Return the key of this signature
+ */
+ public String key() {
+ return key;
+ }
+
+ /**
+ * Returns the method name of this signature (e.g. as exposed by TF serving) or null if none
+ */
+ public String methodName() {
+ return signatureDef.getMethodName().isEmpty() ? null : signatureDef.getMethodName();
+ }
+
+ /**
+ * Returns the names of the inputs in this signature
+ */
+ public Set inputNames() {
+ return signatureDef.getInputsMap().keySet();
+ }
+
+ /**
+ * Returns the names of the outputs in this signature
+ */
+ public Set outputNames() {
+ return signatureDef.getOutputsMap().keySet();
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder strBuilder = new StringBuilder("Signature for \"" + key +"\":\n");
+ if (!methodName().isEmpty()) {
+ strBuilder.append("\tMethod: \"").append(methodName()).append("\"\n");
+ }
+ if (signatureDef.getInputsCount() > 0) {
+ strBuilder.append("\tInputs:\n");
+ printTensorInfo(signatureDef.getInputsMap(), strBuilder);
+ }
+ if (signatureDef.getOutputsCount() > 0) {
+ strBuilder.append("\tOutputs:\n");
+ printTensorInfo(signatureDef.getOutputsMap(), strBuilder);
+ }
+ return strBuilder.toString();
+ }
+
+ Signature(String key, SignatureDef signatureDef) {
+ this.key = key;
+ this.signatureDef = signatureDef;
+ }
+
+ SignatureDef asSignatureDef() {
+ return signatureDef;
+ }
+
+ private final String key;
+ private final SignatureDef signatureDef;
+
+ private static void printTensorInfo(Map tensorMap, StringBuilder strBuilder) {
+ tensorMap.forEach((key, tensorInfo) -> {
+ strBuilder.append("\t\t\"")
+ .append(key)
+ .append("\": dtype=")
+ .append(tensorInfo.getDtype().name())
+ .append(", shape=(");
+ for (int i = 0; i < tensorInfo.getTensorShape().getDimCount(); ++i) {
+ strBuilder.append(tensorInfo.getTensorShape().getDim(i).getSize());
+ if (i < tensorInfo.getTensorShape().getDimCount() - 1) {
+ strBuilder.append(", ");
+ }
+ }
+ strBuilder.append(")\n");
+ });
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java
new file mode 100644
index 00000000000..3ea20fcbb46
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java
@@ -0,0 +1,122 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.fail;
+
+import org.junit.jupiter.api.Test;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Init;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.math.Add;
+import org.tensorflow.op.math.Sub;
+import org.tensorflow.types.TFloat32;
+
+public class ConcreteFunctionTest {
+
+ private static Signature plusFive(Ops tf) {
+ Placeholder input = tf.placeholder(TFloat32.DTYPE);
+ Add output = tf.math.add(input, tf.constant(5.0f));
+ Init init = tf.init(); // for native resource management tests
+ return Signature.builder().key("plusFive").input("x", input).output("y", output).build();
+ }
+
+ private static Signature minusTwo(Ops tf) {
+ Placeholder input = tf.placeholder(TFloat32.DTYPE);
+ Sub output = tf.math.sub(input, tf.constant(2.0f));
+ return Signature.builder().key("minusTwo").input("x", input).output("y", output).build();
+ }
+
+ @Test
+ public void createFunction() {
+ try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive);
+ Tensor x = TFloat32.scalarOf(3.0f)) {
+ assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat());
+ }
+ }
+
+ @Test
+ public void createFunctionFromGraph() {
+ try (Graph g = new Graph()) {
+ Signature signature = plusFive(Ops.create(g));
+ try (ConcreteFunction f = ConcreteFunction.create(signature, g);
+ Tensor x = TFloat32.scalarOf(3.0f)) {
+ assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat());
+ }
+ }
+ }
+
+ @Test
+ public void createFunctionFromSession() {
+ try (Graph g = new Graph()) {
+ Signature signature = plusFive(Ops.create(g));
+ try (Session s = new Session(g)) {
+ try (ConcreteFunction f = ConcreteFunction.create(signature, s);
+ Tensor x = TFloat32.scalarOf(3.0f)) {
+ assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat());
+ }
+ }
+ }
+ }
+
+ @Test
+ public void chainFunctions() {
+ try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive);
+ ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo);
+ Tensor x = TFloat32.scalarOf(3.0f)) {
+ assertEquals(6.0f, f2.call(f1.call(x)).expect(TFloat32.DTYPE).data().getFloat());
+ }
+ }
+
+ @Test
+ public void closingFunctionReleaseAllResourcesItOwns() {
+ Graph g;
+ Session s;
+ try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive)) {
+ g = f.graph();
+ s = f.session();
+ }
+ assertThrows(IllegalStateException.class, () -> s.run("Add"));
+ assertThrows(IllegalStateException.class, () -> g.toGraphDef());
+ }
+
+ @Test
+ public void closingFunctionCreatedFromGraphOnlyReleaseResourcesItOwns() {
+ try (Graph g = new Graph()) {
+ Signature signature = plusFive(Ops.create(g));
+ Session s;
+ try (ConcreteFunction f = ConcreteFunction.create(signature, g)) {
+ s = f.session();
+ }
+ assertThrows(IllegalStateException.class, () -> s.run(Init.DEFAULT_NAME));
+ g.toGraphDef(); // check that graph is still valid
+ }
+ }
+
+ @Test
+ public void closingFunctionCreatedFromSessionDoesNotReleaseResources() {
+ try (Graph g = new Graph()) {
+ Signature signature = plusFive(Ops.create(g));
+ try (Session s = new Session(g)) {
+ try (ConcreteFunction f = ConcreteFunction.create(signature, s)) {
+ }
+ s.run(Init.DEFAULT_NAME); // check that session is still valid
+ }
+ g.toGraphDef(); // check that graph is still valid
+ }
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java
index f144b11b840..314c3063422 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java
@@ -1,3 +1,17 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
package org.tensorflow;
import org.junit.jupiter.api.Test;
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
index 91c07e3f4b6..eabb86f732f 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
@@ -15,24 +15,48 @@
package org.tensorflow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
+import java.io.IOException;
import java.net.URISyntaxException;
+import java.nio.file.Files;
+import java.nio.file.Path;
import java.nio.file.Paths;
+import java.util.Collections;
+import java.util.Map;
+import java.util.HashMap;
import org.junit.jupiter.api.Test;
import org.tensorflow.exceptions.TensorFlowException;
+import org.tensorflow.ndarray.FloatNdArray;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.StdArrays;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Identity;
+import org.tensorflow.op.core.Init;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.core.ReduceSum;
+import org.tensorflow.op.core.Variable;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.RunOptions;
+import org.tensorflow.proto.framework.SignatureDef;
+import org.tensorflow.proto.framework.TensorInfo;
+import org.tensorflow.types.TFloat32;
/** Unit tests for {@link org.tensorflow.SavedModelBundle}. */
public class SavedModelBundleTest {
+ private static final float EPSILON = 1e-7f;
private static final String SAVED_MODEL_PATH;
+ private static final String SAVED_MODEL_PY_PATH;
+
static {
try {
SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString();
+ SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()).toString();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
@@ -72,13 +96,242 @@ public void loader() {
}
}
+ @Test
+ public void exportFunctionWithVariables() throws IOException {
+ Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");
+ float reducedSum;
+ FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}});
+ Shape xyShape = Shape.of(2, 3L);
+ try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) {
+ // Init variable state by running the Init operation directly
+ f.session().run(Init.DEFAULT_NAME);
+
+ // Call the graph and remember the result of computation for later
+ try (Tensor xTensor = TFloat32.tensorOf(xValue);
+ Tensor zTensor = f.call(xTensor).expect(TFloat32.DTYPE)) {
+ reducedSum = zTensor.data().getFloat();
+ }
+ // Save/export the model (which is a single function in this case)
+ f.save(testFolder.toString());
+ }
+ assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index"))));
+ assertTrue(Files
+ .exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001"))));
+ assertTrue(Files.exists(testFolder.resolve("saved_model.pb")));
+
+ // Reload the model just saved and validate its data
+ try (SavedModelBundle savedModel =
+ SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) {
+ assertNotNull(savedModel.metaGraphDef());
+ assertNotNull(savedModel.metaGraphDef().getSaverDef());
+ assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());
+ assertEquals(Signature.DEFAULT_KEY,
+ savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next());
+
+ ConcreteFunction function = savedModel.function(Signature.DEFAULT_KEY);
+ assertNotNull(function);
+
+ Signature signature = function.signature();
+ assertNotNull(signature);
+ assertEquals(1, signature.inputNames().size());
+ assertEquals("input", signature.inputNames().iterator().next());
+ assertEquals(1, signature.outputNames().size());
+ assertEquals("reducedSum", signature.outputNames().iterator().next());
+
+ SignatureDef signatureDef = signature.asSignatureDef();
+ assertEquals(1, signatureDef.getInputsCount());
+ assertEquals(1, signatureDef.getOutputsCount());
+
+ TensorInfo inputInfo = signatureDef.getInputsMap().get("input");
+ assertNotNull(inputInfo);
+ assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount());
+ for (int i = 0; i < xyShape.numDimensions(); ++i) {
+ assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize());
+ }
+
+ TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum");
+ assertNotNull(outputInfo);
+ assertEquals(0, outputInfo.getTensorShape().getDimCount());
+
+ try (Tensor xTensor = TFloat32.tensorOf(xValue)) {
+ // Call the saved model function and make sure it returns the same result as before
+ try (Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) {
+ assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON);
+ }
+ // Now call the same function directly from the model
+ try (Tensor zTensor =
+ savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum").expect(TFloat32.DTYPE)) {
+ assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON);
+ }
+ }
+ }
+ }
+
+ @Test
+ public void exportMultipleFunctions() throws IOException {
+ Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");
+ float reducedSum;
+ try (Graph g = new Graph()) {
+ Ops tf = Ops.create(g);
+ Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1));
+ Signature f2Signature = buildIdentityGraph(tf, "identity");
+ try (Session s = new Session(g);
+ ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s);
+ ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) {
+ f1.session().run(Init.DEFAULT_NAME);
+ try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2}));
+ Tensor t = f1.call(x).expect(TFloat32.DTYPE)) {
+ reducedSum = t.data().getFloat();
+ }
+ SavedModelBundle.exporter(testFolder.toString())
+ .withFunction(f1)
+ .withFunction(f2)
+ .export();
+ }
+ }
+ try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) {
+ assertEquals(2, model.signatures().size());
+ ConcreteFunction f1 = model.function(Signature.DEFAULT_KEY);
+ assertNotNull(f1);
+ try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2}));
+ Tensor t = f1.call(x).expect(TFloat32.DTYPE)) {
+ assertEquals(reducedSum, t.data().getFloat(), EPSILON);
+ }
+ ConcreteFunction f2 = model.function("identity");
+ assertNotNull(f2);
+ try (Tensor x = TFloat32.scalarOf(10.0f);
+ Tensor t = f2.call(x).expect(TFloat32.DTYPE)) {
+ assertEquals(10.0f, t.data().getFloat(), 0.0f);
+ }
+ try {
+ model.function("NoSuchFunction");
+ fail();
+ } catch (IllegalArgumentException e) {
+ // as expected
+ }
+ }
+ }
+
+ @Test
+ public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOException {
+ Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");
+ try (Graph g = new Graph()) {
+ Ops tf = Ops.create(g);
+ Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1));
+ Signature f2Signature = buildIdentityGraph(tf, "identity");
+ try (ConcreteFunction f1 = ConcreteFunction.create(f1Signature, g);
+ ConcreteFunction f2 = ConcreteFunction.create(f2Signature, g)) {
+ f1.session().run(Init.DEFAULT_NAME);
+ try {
+ SavedModelBundle.exporter(testFolder.toString())
+ .withFunction(f1)
+ .withFunction(f2)
+ .export();
+ fail();
+ } catch (UnsupportedOperationException e) {
+ // as expected
+ }
+ }
+ }
+ }
+
+ @Test
+ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOException {
+ Path testFolder = Files.createTempDirectory("tf-saved-model-export-test");
+ try (Graph g = new Graph()) {
+ Ops tf = Ops.create(g);
+ Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1));
+ Signature f2Signature = buildIdentityGraph(tf, Signature.DEFAULT_KEY);
+ try (Session s = new Session(g);
+ ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s);
+ ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) {
+ f1.session().run(Init.DEFAULT_NAME);
+ try {
+ SavedModelBundle.exporter(testFolder.toString())
+ .withFunction(f1)
+ .withFunction(f2)
+ .export();
+ fail();
+ } catch (IllegalArgumentException e) {
+ // as expected
+ }
+ }
+ }
+ }
+
+ @Test
+ public void cannotExportOrImportInvalidTags() {
+ assertThrows(IllegalArgumentException.class, () ->
+ SavedModelBundle.loader("/").withTags()
+ );
+ assertThrows(IllegalArgumentException.class, () ->
+ SavedModelBundle.loader("/").withTags(new String[]{})
+ );
+ assertThrows(IllegalArgumentException.class, () ->
+ SavedModelBundle.loader("/").withTags(new String[]{"tag", null})
+ );
+ assertThrows(IllegalArgumentException.class, () ->
+ SavedModelBundle.loader("/").withTags(new String[]{"tag", ""})
+ );
+ assertThrows(IllegalArgumentException.class, () ->
+ SavedModelBundle.exporter("/").withTags()
+ );
+ assertThrows(IllegalArgumentException.class, () ->
+ SavedModelBundle.exporter("/").withTags(new String[]{})
+ );
+ assertThrows(IllegalArgumentException.class, () ->
+ SavedModelBundle.exporter("/").withTags(new String[]{"tag", null})
+ );
+ assertThrows(IllegalArgumentException.class, () ->
+ SavedModelBundle.exporter("/").withTags(new String[]{"tag", ""})
+ );
+ }
+
+ @Test
+ public void pythonTfFunction() {
+ // ConcreteFunctions on models saved using python
+ try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PY_PATH, "serve")) {
+ /*
+ * Test model was created in python
+ * Signature name used for saving 'add', argument names 'a' and 'b'
+ */
+ ConcreteFunction add = bundle.function("add");
+ Map> args = new HashMap();
+ try (Tensor a = TFloat32.scalarOf(10.0f);
+ Tensor b = TFloat32.scalarOf(15.5f)) {
+ args.put("a", a);
+ args.put("b", b);
+ Map> result = add.call(args);
+ assertEquals(result.size(), 1);
+ try (Tensor c = result.values().iterator().next().expect(TFloat32.DTYPE)) {
+ assertEquals(25.5f, c.data().getFloat());
+ }
+ }
+ }
+ }
+
+ private static Signature buildGraphWithVariables(Ops tf, Shape xShape) {
+ Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xShape));
+ Variable y = tf
+ .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.DTYPE));
+ ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1));
+ Init init = tf.init();
+ return Signature.builder().input("input", x).output("reducedSum", z).build();
+ }
+
+ private static Signature buildIdentityGraph(Ops tf, String signatureKey) {
+ Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar()));
+ Identity xprime = tf.identity(x);
+ return Signature.builder().key(signatureKey).input("x", x).output("x", xprime).build();
+ }
+
private static RunOptions sillyRunOptions() {
return RunOptions.newBuilder()
.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
.build();
}
- public static ConfigProto sillyConfigProto() {
+ private static ConfigProto sillyConfigProto() {
return ConfigProto.newBuilder()
.setInterOpParallelismThreads(1)
.setIntraOpParallelismThreads(1)
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java
index 7faf8c6fbdb..fa41af32a29 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java
@@ -20,8 +20,13 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
import org.junit.jupiter.api.Test;
import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Init;
import org.tensorflow.op.core.Split;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
@@ -31,6 +36,7 @@
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.StdArrays;
+import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
/** Unit tests for {@link org.tensorflow.Session}. */
@@ -205,6 +211,24 @@ public void runInitByName() {
}
}
+ @Test
+ public void save() throws IOException {
+ Path testFolder = Files.createTempDirectory("tf-session-save-test");
+ try (Graph g = new Graph()) {
+ Ops tf = Ops.create(g);
+ Variable x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE));
+ Variable y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE));
+ Init init = tf.init();
+
+ try (Session s = new Session(g)) {
+ s.run(init);
+ s.save(testFolder.resolve("checkpoint").toString());
+ }
+ }
+ assertTrue(Files.exists(testFolder.resolve("checkpoint.index")));
+ assertTrue(Files.exists(testFolder.resolve("checkpoint.data-00000-of-00001")));
+ }
+
private static RunOptions fullTraceRunOptions() {
return RunOptions.newBuilder()
.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java
new file mode 100644
index 00000000000..8931ecbbde1
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java
@@ -0,0 +1,61 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import org.junit.jupiter.api.Test;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Init;
+import org.tensorflow.op.core.Placeholder;
+import org.tensorflow.op.math.Add;
+import org.tensorflow.op.math.Sub;
+import org.tensorflow.types.TFloat32;
+
+public class SignatureTest {
+
+ @Test
+ public void cannotUseEmptyKey() {
+ Signature.Builder builder = Signature.builder();
+ assertThrows(IllegalArgumentException.class, () -> builder.key(null));
+ assertThrows(IllegalArgumentException.class, () -> builder.key(""));
+ builder.key("valid key");
+ }
+
+ @Test
+ public void cannotDuplicateInputOutputNames() {
+ try (Graph g = new Graph()) {
+ Ops tf = Ops.create(g);
+ Signature.Builder builder = Signature.builder()
+ .input("x", tf.constant(10.0f))
+ .output("x", tf.constant(10.0f)) // can add an output with the same name as an input
+ .output("y", tf.constant(20.0f));
+ assertThrows(IllegalArgumentException.class, () -> builder.input("x", tf.constant(10)));
+ assertThrows(IllegalArgumentException.class, () -> builder.output("y", tf.constant(20.0f)));
+ }
+ }
+
+ @Test
+ public void emptyMethodNameConvertedToNull() {
+ Signature signature = Signature.builder().key("f").build();
+ assertNull(signature.methodName());
+ signature = Signature.builder().key("f").methodName("").build();
+ assertNull(signature.methodName());
+ signature = Signature.builder().key("f").methodName(null).build();
+ assertNull(signature.methodName());
+ }
+}
diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb
new file mode 100644
index 00000000000..169e0095a3e
Binary files /dev/null and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb differ
diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001
new file mode 100644
index 00000000000..ac369237d31
Binary files /dev/null and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 differ
diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index
new file mode 100644
index 00000000000..8be702a36ed
Binary files /dev/null and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index differ
diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py
new file mode 100644
index 00000000000..f5160401515
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py
@@ -0,0 +1,63 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Saved model created using Python @tf.function, tensorflow version 2.3.0.
+#
+# Python code used to create saved model below
+#
+# WARNING: This script is just attached to the test code base for reference and is not used nor
+# executed by this project to generate the saved model, which has been added manually.
+
+import tensorflow as tf
+
+class MyModel(tf.keras.Model):
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.const_scalar = tf.constant(0.0)
+ self.const_vector = tf.constant([0.0, 0.0, 0.0])
+ self.const_matrix = tf.constant([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
+
+ @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='request')])
+ def serve(self, x):
+ return self.const_scalar + x
+
+ @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
+ def get_scalar(self, x):
+ return self.const_scalar + x
+
+ @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
+ def get_vector(self, x):
+ return self.const_vector + x
+
+ @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
+ def get_matrix(self, x):
+ return self.const_matrix + x
+
+ @tf.function(input_signature=[
+ tf.TensorSpec(shape=None, dtype=tf.float32, name='a'),
+ tf.TensorSpec(shape=None, dtype=tf.float32, name='b')])
+ def add(self, a, b):
+ return a + b
+
+model = MyModel()
+
+signatures = {
+ "get_const_scalar": model.get_scalar,
+ "get_const_vector": model.get_vector,
+ "get_const_matrix": model.get_matrix,
+ "add": model.add
+}
+
+tf.saved_model.save(obj=model, export_dir='model', signatures=signatures)