From bcd533f6ff5f48027d5975e75b479a80c92ff0eb Mon Sep 17 00:00:00 2001 From: Shajan Dasan Date: Mon, 27 Jul 2020 19:41:38 -0700 Subject: [PATCH 1/6] Draft: Java API to use tf.function available on SavedModel. (#89) Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan --- .../java/org/tensorflow/SavedModelBundle.java | 128 ++++++++++++++ .../main/java/org/tensorflow/TfFunction.java | 157 ++++++++++++++++++ 2 files changed, 285 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 8f683a59d89..b9fbbd9dfd9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -20,6 +20,9 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; @@ -32,6 +35,7 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.MetaGraphDef; import org.tensorflow.proto.framework.RunOptions; +import org.tensorflow.proto.framework.SignatureDef; /** * SavedModelBundle represents a model loaded from storage. @@ -94,6 +98,101 @@ private Loader(String exportDir) { private RunOptions runOptions = null; } + /** + * SignatureToNodeName finds the node names in the {@link Graph} corresponding to the + * input / output parameters of a tf.function + */ + public static final class SignatureToNodeName { + + public SignatureToNodeName(SavedModelBundle savedModelBundle) { + loadSignatures(savedModelBundle); + } + + /** + * Given a tf.function signature name, find the node names corresponding + * to the input arguments + * + * @param functionSignatureName tf.function signature name + * @return a map from input arguments to node names in the {@link Graph} + */ + public Map inputNameToNode(String functionSignatureName) { + NameContainer nc = this.functionMap.get(functionSignatureName); + return (nc == null) ? null : nc.inputNameToNode(); + } + + /** + * Given a tf.function signature name, find the node names corresponding + * to the output arguments + * + * @param functionSignatureName tf.function signature name + * @return a map from output arguments to node names in the {@link Graph} + */ + public Map outputNameToNode(String functionSignatureName) { + NameContainer nc = this.functionMap.get(functionSignatureName); + return (nc == null) ? null : nc.outputNameToNode(); + } + + /** + * Given a tf.function signature name, find the method name + */ + public String methodName(String functionSignatureName) { + NameContainer nc = this.functionMap.get(functionSignatureName); + return (nc == null) ? null : nc.methodName(); + } + + private void loadSignatures(SavedModelBundle savedModelBundle) { + MetaGraphDef metaGraph = savedModelBundle.metaGraphDef(); + Map signatureMap = metaGraph.getSignatureDefMap(); + + // A saved model can contain multiple SignatureDef + for (Map.Entry entry : signatureMap.entrySet()) { + NameContainer nc = new NameContainer(entry.getValue()); + this.functionMap.put(entry.getKey(), nc); + } + } + + private Map functionMap = new HashMap<>(); + + private static final class NameContainer { + NameContainer(SignatureDef sd) { + this.inputNameToNodeName = sd.getInputsMap() + .entrySet() + .stream() + .collect(Collectors.toMap( + e -> e.getKey(), + e -> e.getValue().getName() + )); + + this.outputNameToNodeName = sd.getOutputsMap() + .entrySet() + .stream() + .collect(Collectors.toMap( + e -> e.getKey(), + e -> e.getValue().getName() + )); + + this.method = sd.getMethodName(); + } + + public Map inputNameToNode() { + return this.inputNameToNodeName; + } + + public Map outputNameToNode() { + return this.outputNameToNodeName; + } + + public String methodName() { + return this.method; + } + + private Map inputNameToNodeName; + private Map outputNameToNodeName; + private String method; + } + } + /** * Load a saved model from an export directory. The model that is being loaded should be created * using the Saved Model @@ -148,6 +247,34 @@ public Session session() { return session; } + /** + * Returns the {@link SignatureToNodeName} translator for the model. + * + * @return SignatureToNodeName translator + */ + public SignatureToNodeName getSignatureToNodeName() { + if (this.sigToNodeName == null) { + // no need to lock, ok to create multiple instances + this.sigToNodeName = new SignatureToNodeName(this); + } + return this.sigToNodeName; + } + + /** + * Return a {@link TfFunction} corresponding to the function signature. + * + *
{@code
+   * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
+   * Map> outputTensorMap = myFunction.call(inputTensorMap);
+   * }
+ * + * @param functionSignatureName name of the {@code SignatureDef} in the saved model. + * @return TfFunction object that can be used to make calls to the tf.function + */ + public TfFunction function(String functionSignatureName) { + return new TfFunction(functionSignatureName, this.getSignatureToNodeName(), this.session); + } + /** * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model * bundle. @@ -161,6 +288,7 @@ public void close() { private final Graph graph; private final Session session; private final MetaGraphDef metaGraphDef; + private SignatureToNodeName sigToNodeName; private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef) { this.graph = graph; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java new file mode 100644 index 00000000000..5dc5a128898 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java @@ -0,0 +1,157 @@ +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow; + +import com.google.protobuf.InvalidProtocolBufferException; + +import java.util.List; +import java.util.ListIterator; +import java.util.HashMap; +import java.util.Map; + +/** + * Invoke
tf.function + * defined in a {@link SavedModelBundle}. + * + *
{@code
+ * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
+ * Map> outputTensorMap = myFunction.call(inputTensorMap);
+ * }
+ * + */ +public class TfFunction { + + public TfFunction( + String functionSignatureName, + SavedModelBundle.SignatureToNodeName nameToNode, Session session) { + this.nameToNode = nameToNode; + this.session = session; + this.functionSignatureName = functionSignatureName; + } + + /** + * Invokes a tf.function. + * Caller is responsible for closing all Tensors. + * + * @param arguments map of input tensors + * @return map of output tensors + */ + public Map> call( + Map> arguments) throws IllegalArgumentException { + + Session.Runner runner = this.session.runner(); + + Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); + + if (inputToNode == null) { + throw new IllegalArgumentException( + String.format("Function [%s] is missing input", this.functionSignatureName)); + } + + // Join arguments.key, inputToNodeName.key + for (Map.Entry entry: inputToNode.entrySet()) { + String argName = entry.getKey(); + Tensor tensor = arguments.get(argName); + + if (tensor == null) { + throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); + } + + // Node name in the tensorflow graph, corresponding to the tf.function argument + runner = runner.feed(entry.getValue(), tensor); + } + + Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); + if (outputToNode == null) { + throw new IllegalArgumentException( + String.format("Function [%] is missing output", this.functionSignatureName)); + } + + for (String nodeName: outputToNode.values()) { + // Node names corresponding to the return value + runner = runner.fetch(nodeName); + } + + List> resultTensors = runner.run(); + ListIterator> resultTensorIter = resultTensors.listIterator(); + + Map> returnMap = new HashMap>(); + + // Use the output names as present in the signature definition + for (String nodeName: outputToNode.keySet()) { + returnMap.put(nodeName, resultTensorIter.next()); + } + + return returnMap; + } + + /** + * Invokes a tf.function. + * Caller is responsible for closing all Tensors. + * + * Throws IllegalArgumentException if there are multiple input or output parameters defined + * in the tf.function + * + * @param tensor input tensor + * @return output tensor + */ + public Tensor call(Tensor tensor) throws IllegalArgumentException { + Session.Runner runner = this.session.runner(); + + Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); + + if (inputToNode == null) { + throw new IllegalArgumentException( + String.format("Function [%s] is missing input", this.functionSignatureName)); + } + + if (inputToNode.size() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] requires multiple inputs", this.functionSignatureName)); + } + + // Feed the single argument + for (Map.Entry entry: inputToNode.entrySet()) { + // Node name in the tensorflow graph, corresponding to the tf.function argument + runner = runner.feed(entry.getValue(), tensor); + } + + Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); + if (outputToNode == null) { + throw new IllegalArgumentException( + String.format("Function [%] is missing output", this.functionSignatureName)); + } + + if (outputToNode.size() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] has multiple outputs", this.functionSignatureName)); + } + + // Fetch the single return tensor + for (String nodeName: outputToNode.values()) { + // Node names corresponding to the return value + runner = runner.fetch(nodeName); + } + + List> resultTensors = runner.run(); + + return resultTensors.get(0); + } + + private final Session session; + private final SavedModelBundle.SignatureToNodeName nameToNode; + private final String functionSignatureName; +} From 0dbdd3ea83216e58fdf7ecf8dcabe8d7f9e60370 Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Wed, 5 Aug 2020 23:38:04 -0400 Subject: [PATCH 2/6] tmp --- .../annotations/org/tensorflow/op/Ops.java | 12 +- .../main/java/org/tensorflow/DataTypes.java | 2 + .../java/org/tensorflow/FunctionGraph.java | 213 +++++++++++++++++ .../src/main/java/org/tensorflow/Graph.java | 73 ++++++ .../java/org/tensorflow/SavedModelBundle.java | 214 +++++++++--------- .../src/main/java/org/tensorflow/Session.java | 14 ++ .../main/java/org/tensorflow/TfFunction.java | 157 ------------- .../org/tensorflow/SavedModelBundleTest.java | 97 ++++++++ .../test/java/org/tensorflow/SessionTest.java | 24 ++ 9 files changed, 535 insertions(+), 271 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java delete mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index ed7f136a1ac..61b383162e9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -323,10 +323,10 @@ public final class Ops { public final ImageOps image; - public final DataOps data; - public final ShapeOps shape; + public final DataOps data; + public final IoOps io; public final DtypesOps dtypes; @@ -349,10 +349,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -360,8 +360,8 @@ private Ops(Scope scope) { nn = new NnOps(scope); summary = new SummaryOps(scope); image = new ImageOps(scope); - data = new DataOps(scope); shape = new ShapeOps(scope); + data = new DataOps(scope); io = new IoOps(scope); dtypes = new DtypesOps(scope); xla = new XlaOps(scope); @@ -373,8 +373,8 @@ private Ops(Scope scope) { math = new MathOps(scope); audio = new AudioOps(scope); signal = new SignalOps(scope); - train = new TrainOps(scope); quantization = new QuantizationOps(scope); + train = new TrainOps(scope); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java index 77c0de0c83f..2dfa15619e4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java @@ -19,6 +19,7 @@ import java.util.HashMap; import java.util.Map; +import javax.annotation.Nullable; import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat16; @@ -68,6 +69,7 @@ static DataType fromNativeCode(int nativeCode) { // TODO (karllessard): Right now this method is private but we might want to expose it // to allow user to register custom data types? + @Nullable private static void register(DataType dataType) { DATA_TYPE_REGISTRY.put(dataType.nativeCode(), dataType); DATA_TYPE_REGISTRY.put(dataType.nativeCode() + 100, dataType); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java new file mode 100644 index 00000000000..e0ec42d249d --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java @@ -0,0 +1,213 @@ +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow; + +import java.util.List; +import java.util.ListIterator; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.proto.framework.TensorShapeProto; +import org.tensorflow.proto.framework.TensorShapeProto.Dim; + +/** + * Invoke tf.function + * defined in a {@link SavedModelBundle}. + * + *
{@code
+ * FunctionGraph myFunction = savedModelBundle.function("myFunctionSignatureName");
+ * Map> outputTensorMap = myFunction.call(session, inputTensorMap);
+ * }
+ * + */ +public class FunctionGraph implements AutoCloseable { + + public static class SignatureBuilder { + + public SignatureBuilder addInput(String inputName, Operand input) { + signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput())); + return this; + } + + public SignatureBuilder addOutput(String outputName, Operand output) { + signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput())); + return this; + } + + public SignatureBuilder methodName(String methodName) { + signatureBuilder.setMethodName(methodName); + return this; + } + + private SignatureDef build() { + return signatureBuilder.build(); + } + + private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); + + private static TensorInfo toTensorInfo(Output operand) { + Shape shape = operand.shape(); + TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); + for (int i = 0; i < shape.numDimensions(); ++i) { + tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i))); + } + return TensorInfo.newBuilder() + .setDtype(DataType.forNumber(operand.dataType().nativeCode())) + .setTensorShape(tensorShapeBuilder) + .setName(operand.op().name() + ":" + operand.index()) + .build(); + } + } + + public static FunctionGraph create(BiConsumer, Graph> function) { + Graph graph = new Graph(); + Map signatures = new HashMap<>(); + function.accept(signatures, graph); + return new FunctionGraph(signatures, graph); + } + + public static FunctionGraph create(SignatureDef signature, Graph graph) { + return new FunctionGraph(signature, graph); + } + + /** + * Returns the method name of this function + */ + public String methodName() { + return signature.getMethodName(); + } + + /** + * Returns the names of the inputs of this function. + */ + public Set inputNames() { + return signature.getInputsMap().keySet(); + } + + /** + * Returns the names of the outputs of this function. + */ + public Set outputNames() { + return signature.getOutputsMap().keySet(); + } + + /** + * Invokes a function. + * + *

Caller is responsible for closing all Tensors. + * + * @param tensor input tensor + * @return output tensor + */ + public Map> call(Map> arguments) + throws IllegalArgumentException { + + final Session.Runner runner = session.runner(); + + signature.getInputsMap().forEach((argName, t) -> { + Tensor tensor = arguments.get(argName); + if (tensor == null) { + throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); + } + runner.feed(t.getName(), tensor); + }); + + Map outputToNode = signature.getOutputsMap(); + outputToNode.values().forEach(t -> runner.fetch(t.getName())); + + List> resultTensors = runner.run(); + try { + ListIterator> resultTensorIter = resultTensors.listIterator(); + Map> returnMap = new HashMap>(); + + // Use the output names as present in the signature definition + for (String nodeName: outputToNode.keySet()) { + returnMap.put(nodeName, resultTensorIter.next()); + } + return returnMap; + + } catch (Exception e) { + // Release tensors before throwing exception + for (Tensor t : resultTensors) { + t.close(); + } + throw e; + } + } + + /** + * Invokes a function with a single input and output. + * + *

Caller is responsible for closing all Tensors. + * + * @param tensor input tensor + * @return output tensor + * @throws IllegalArgumentException if there are multiple input or output parameters defined + * in the function + */ + public Tensor call(Tensor tensor) throws IllegalArgumentException { + if (signature.getInputsCount() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] requires multiple inputs", signature.getMethodName())); + } + String inputNodeName = signature.getInputsMap().values().iterator().next().getName(); + + if (signature.getOutputsCount() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] has multiple outputs", signature.getMethodName())); + } + String outputNodeName = signature.getOutputsMap().values().iterator().next().getName(); + + return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); + } + + /** + * Returns the signature of this function + */ + public SignatureDef signature() { + return signature; + } + + @Override + public void close() { + session.close(); + graph.close(); + } + + FunctionGraph(SignatureDef signature, Graph graph) { + this.graph = graph; + this.session = new Session(graph); + this.signature = signature; + } + + FunctionGraph(Session session, SignatureDef signature) { + this.graph = session.graph(); + this.session = session; + this.signature = signature; + } + + private final Map signatures; + private final Graph graph; + private final Session session; + private final SignatureDef signature; +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 221268191fd..071a28bb6c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -43,8 +43,17 @@ import org.tensorflow.internal.c_api.TF_Output; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_WhileParams; +import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.NoOp; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.train.Restore; +import org.tensorflow.op.train.Save; import org.tensorflow.proto.framework.GraphDef; +import org.tensorflow.proto.util.SaverDef; +import org.tensorflow.types.TString; /** @@ -67,6 +76,11 @@ public Graph() { this.nativeHandle = nativeHandle; } + Graph(TF_Graph nativeHandle, SaverDef saverDef) { + this(nativeHandle); + this.saverDef = saverDef; + } + /** * Release resources associated with the Graph. * @@ -287,6 +301,17 @@ public Output[] addGradients(Output y, Output[] x) { return addGradients(null, new Output[] {y}, x, null); } + public SaverDef saverDef() { + if (saverDef == null) { + synchronized (this) { + if (saverDef == null) { + saverDef = addVariableSaver(this); + } + } + } + return saverDef; + } + /** * Used to instantiate an abstract class which overrides the buildSubgraph method to build a * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to @@ -405,6 +430,7 @@ public Output[] whileLoop( private final Object nativeHandleLock = new Object(); private TF_Graph nativeHandle; private int refcount = 0; + private SaverDef saverDef; private final List initializers = new ArrayList<>(); @@ -726,6 +752,53 @@ private static Object[] whileLoop( } } + private static SaverDef addVariableSaver(Graph graph) { + Ops tf = Ops.create(graph).withSubScope("save"); + + List varNames = new ArrayList<>(); + List> varOutputs = new ArrayList<>(); + List> varTypes = new ArrayList<>(); + + for (Iterator iter = graph.operations(); iter.hasNext();) { + Operation op = iter.next(); + if (op.type().equals("VariableV2")) { + varNames.add(op.name()); + varOutputs.add(op.output(0)); + varTypes.add(op.output(0).dataType()); + } + } + + // FIXME Need an easier way to initialize an NdArray from a list + String[] tmp = new String[varNames.size()]; + Constant varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp))); + Operand varSlices = tf.zerosLike(varNamesTensor); + + Placeholder saveFilename = tf.placeholder(TString.DTYPE); + Save saveVariables = tf.train.save( + saveFilename, + varNamesTensor, + varSlices, + varOutputs + ); + Restore restoreVariables = tf.train.restore( + saveFilename, + varNamesTensor, + varSlices, + varTypes + ); + List restoreOps = new ArrayList<>(varOutputs.size()); + for (int i = 0; i < varOutputs.size(); ++i) { + restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i))); + } + NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp(); + + return SaverDef.newBuilder() + .setFilenameTensorName(saveFilename.op().name()) + .setSaveTensorName(saveVariables.op().name()) + .setRestoreOpName(restoreAll.op().name()) + .build(); + } + static { TensorFlow.init(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index b9fbbd9dfd9..2efaf5e883e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -20,9 +20,17 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; @@ -32,10 +40,17 @@ import org.tensorflow.internal.c_api.TF_Session; import org.tensorflow.internal.c_api.TF_SessionOptions; import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.ndarray.Shape; import org.tensorflow.proto.framework.ConfigProto; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.MetaGraphDef; +import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef; import org.tensorflow.proto.framework.RunOptions; +import org.tensorflow.proto.framework.SavedModel; import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.proto.framework.TensorShapeProto; +import org.tensorflow.proto.framework.TensorShapeProto.Dim; /** * SavedModelBundle represents a model loaded from storage. @@ -47,8 +62,12 @@ * protocol buffer. */ public class SavedModelBundle implements AutoCloseable { + + public static final String DEFAULT_SIGNATURE_NAME = "serving_default"; + /** Options for loading a SavedModel. */ public static final class Loader { + /** Load a SavedModelBundle with the configured options. */ public SavedModelBundle load() { return SavedModelBundle.load(exportDir, tags, configProto, runOptions); @@ -98,99 +117,61 @@ private Loader(String exportDir) { private RunOptions runOptions = null; } - /** - * SignatureToNodeName finds the node names in the {@link Graph} corresponding to the - * input / output parameters of a tf.function - */ - public static final class SignatureToNodeName { + /** Options for exporting a SavedModel. */ + public static final class Exporter { - public SignatureToNodeName(SavedModelBundle savedModelBundle) { - loadSignatures(savedModelBundle); + public Exporter withTags(String... tags) { + this.tags.addAll(Arrays.asList(tags)); + return this; } - /** - * Given a tf.function signature name, find the node names corresponding - * to the input arguments - * - * @param functionSignatureName tf.function signature name - * @return a map from input arguments to node names in the {@link Graph} - */ - public Map inputNameToNode(String functionSignatureName) { - NameContainer nc = this.functionMap.get(functionSignatureName); - return (nc == null) ? null : nc.inputNameToNode(); + public Exporter withFunction(FunctionGraph functionGraph) { } - /** - * Given a tf.function signature name, find the node names corresponding - * to the output arguments - * - * @param functionSignatureName tf.function signature name - * @return a map from output arguments to node names in the {@link Graph} - */ - public Map outputNameToNode(String functionSignatureName) { - NameContainer nc = this.functionMap.get(functionSignatureName); - return (nc == null) ? null : nc.outputNameToNode(); + public Exporter withSignature(SignatureDef signatureDef) { + return withSignature(DEFAULT_SIGNATURE_NAME, signatureDef); } - /** - * Given a tf.function signature name, find the method name - */ - public String methodName(String functionSignatureName) { - NameContainer nc = this.functionMap.get(functionSignatureName); - return (nc == null) ? null : nc.methodName(); + public Exporter withSignature(String signatureName, SignatureDef signature) { + metaGraphDefBuilder.putSignatureDef(signatureName, signature); + return this; } - private void loadSignatures(SavedModelBundle savedModelBundle) { - MetaGraphDef metaGraph = savedModelBundle.metaGraphDef(); - Map signatureMap = metaGraph.getSignatureDefMap(); - - // A saved model can contain multiple SignatureDef - for (Map.Entry entry : signatureMap.entrySet()) { - NameContainer nc = new NameContainer(entry.getValue()); - this.functionMap.put(entry.getKey(), nc); + public void export(Session session) throws IOException { + Graph graph = session.graph(); + if (tags.isEmpty()) { + tags.add("serve"); + } + // It is imperative to retrieve the graphDef after the saverDef, as the former might add + // new ops to the graph. + MetaGraphDef metaGraphDef = metaGraphDefBuilder + .setSaverDef(graph.saverDef()) + .setGraphDef(graph.toGraphDef()) + .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(tags)) + .build(); + + // Make sure saved model directories exist + Path variableDir = Paths.get(exportDir, "variables"); + variableDir.toFile().mkdirs(); + + // Save variables state, using the "variables-*" prefix + session.save(variableDir.resolve("variables").toString()); + + // Save graph + SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build(); + try (OutputStream file = + new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) { + savedModelDef.writeTo(file); } } - private Map functionMap = new HashMap<>(); - - private static final class NameContainer { - NameContainer(SignatureDef sd) { - this.inputNameToNodeName = sd.getInputsMap() - .entrySet() - .stream() - .collect(Collectors.toMap( - e -> e.getKey(), - e -> e.getValue().getName() - )); - - this.outputNameToNodeName = sd.getOutputsMap() - .entrySet() - .stream() - .collect(Collectors.toMap( - e -> e.getKey(), - e -> e.getValue().getName() - )); - - this.method = sd.getMethodName(); - } - - public Map inputNameToNode() { - return this.inputNameToNodeName; - } - - public Map outputNameToNode() { - return this.outputNameToNodeName; - } - - public String methodName() { - return this.method; - } - - private Map inputNameToNodeName; - private Map outputNameToNodeName; - private String method; + Exporter(String exportDir) { + this.exportDir = exportDir; } + + private final String exportDir; + private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); + private final List tags = new ArrayList<>(); } /** @@ -224,6 +205,10 @@ public static Loader loader(String exportDir) { return new Loader(exportDir); } + public static Exporter exporter(String exportDir) { + return new Exporter(exportDir); + } + /** * Returns the MetaGraphDef @@ -248,31 +233,37 @@ public Session session() { } /** - * Returns the {@link SignatureToNodeName} translator for the model. + * Return a {@link FunctionGraph} corresponding to the function signature. + * + *

{@code
+   * FunctionGraph myFunction = savedModelBundle.function("myFunctionSignatureName");
+   * Map> outputTensorMap = myFunction.call(session, inputTensorMap);
+   * }
* - * @return SignatureToNodeName translator + * @param functionSignatureName name of the {@code SignatureDef} in the saved model. + * @return TfFunction object that can be used to make calls to the tf.function + * @throws IllegalArgumentException if {@code functionSignatureName} is not found in this + * saved model. */ - public SignatureToNodeName getSignatureToNodeName() { - if (this.sigToNodeName == null) { - // no need to lock, ok to create multiple instances - this.sigToNodeName = new SignatureToNodeName(this); + public FunctionGraph function(String functionSignatureName) { + SignatureDef signature = metaGraphDef.getSignatureDefMap().get(functionSignatureName); + if (signature == null) { + throw new IllegalArgumentException( + String.format("Function with signature [%s] not found", functionSignatureName)); } - return this.sigToNodeName; + return new FunctionGraph(session, signature); } /** - * Return a {@link TfFunction} corresponding to the function signature. - * - *
{@code
-   * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
-   * Map> outputTensorMap = myFunction.call(inputTensorMap);
-   * }
+ * Return the {@link FunctionGraph} corresponding to the default function signature of this model. * * @param functionSignatureName name of the {@code SignatureDef} in the saved model. * @return TfFunction object that can be used to make calls to the tf.function + * @throws IllegalArgumentException if no function with the default signature name can be found in + * this saved model. */ - public TfFunction function(String functionSignatureName) { - return new TfFunction(functionSignatureName, this.getSignatureToNodeName(), this.session); + public FunctionGraph function() { + return function(DEFAULT_SIGNATURE_NAME); } /** @@ -281,19 +272,15 @@ public TfFunction function(String functionSignatureName) { */ @Override public void close() { - session.close(); - graph.close(); + functions.forEach((s, f) -> f.close()); } - private final Graph graph; - private final Session session; private final MetaGraphDef metaGraphDef; - private SignatureToNodeName sigToNodeName; + private final Map functions; - private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef) { - this.graph = graph; - this.session = session; + private SavedModelBundle(MetaGraphDef metaGraphDef, Map functions) { this.metaGraphDef = metaGraphDef; + this.functions = functions; } /** @@ -303,10 +290,21 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef *

Invoked from the native load method. Takes ownership of the handles. */ private static SavedModelBundle fromHandle( - TF_Graph graphHandle, TF_Session sessionHandle, MetaGraphDef metaGraphDef) { - Graph graph = new Graph(graphHandle); - Session session = new Session(graph, sessionHandle); - return new SavedModelBundle(graph, session, metaGraphDef); + final TF_Graph graphHandle, final TF_Session sessionHandle, MetaGraphDef metaGraphDef) { + + final Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef()); + final Session session = new Session(graph, sessionHandle); + + // For each signature, we will create a separate function. To support cases where multiple + // signatures are attached to the same graph, each function instance will retain a reference + // to the underlying resource, so that they are freed only when the last function is released. + final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); + metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { + graphHandle.retainReference(); + sessionHandle.retainReference(); + functions.put(signatureName, new FunctionGraph(session, signatureDef)); + }); + return new SavedModelBundle(metaGraphDef, functions); } private static SavedModelBundle load( diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 0676ce8ec4e..21d338ca765 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -36,6 +36,8 @@ import java.util.ArrayList; import java.util.List; +import org.tensorflow.proto.util.SaverDef; +import org.tensorflow.types.TString; import static org.tensorflow.Graph.resolveOutputs; import static org.tensorflow.internal.c_api.global.tensorflow.*; @@ -444,6 +446,14 @@ public void run(Op op) { runner().addTarget(op.op()).run(); } + public void save(String prefix) { + SaverDef saverDef = graph.saverDef(); + runner() + .addTarget(saverDef.getSaveTensorName()) + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) + .run(); + } + /** * Output tensors and metadata obtained when executing a session. * @@ -463,6 +473,10 @@ public static final class Run { public RunMetadata metadata; } + Graph graph() { + return graph; + } + private final Graph graph; private final Graph.Reference graphRef; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java deleted file mode 100644 index 5dc5a128898..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright 2020 The TensorFlow Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow; - -import com.google.protobuf.InvalidProtocolBufferException; - -import java.util.List; -import java.util.ListIterator; -import java.util.HashMap; -import java.util.Map; - -/** - * Invoke tf.function - * defined in a {@link SavedModelBundle}. - * - *

{@code
- * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
- * Map> outputTensorMap = myFunction.call(inputTensorMap);
- * }
- * - */ -public class TfFunction { - - public TfFunction( - String functionSignatureName, - SavedModelBundle.SignatureToNodeName nameToNode, Session session) { - this.nameToNode = nameToNode; - this.session = session; - this.functionSignatureName = functionSignatureName; - } - - /** - * Invokes a tf.function. - * Caller is responsible for closing all Tensors. - * - * @param arguments map of input tensors - * @return map of output tensors - */ - public Map> call( - Map> arguments) throws IllegalArgumentException { - - Session.Runner runner = this.session.runner(); - - Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); - - if (inputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%s] is missing input", this.functionSignatureName)); - } - - // Join arguments.key, inputToNodeName.key - for (Map.Entry entry: inputToNode.entrySet()) { - String argName = entry.getKey(); - Tensor tensor = arguments.get(argName); - - if (tensor == null) { - throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); - } - - // Node name in the tensorflow graph, corresponding to the tf.function argument - runner = runner.feed(entry.getValue(), tensor); - } - - Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); - if (outputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%] is missing output", this.functionSignatureName)); - } - - for (String nodeName: outputToNode.values()) { - // Node names corresponding to the return value - runner = runner.fetch(nodeName); - } - - List> resultTensors = runner.run(); - ListIterator> resultTensorIter = resultTensors.listIterator(); - - Map> returnMap = new HashMap>(); - - // Use the output names as present in the signature definition - for (String nodeName: outputToNode.keySet()) { - returnMap.put(nodeName, resultTensorIter.next()); - } - - return returnMap; - } - - /** - * Invokes a tf.function. - * Caller is responsible for closing all Tensors. - * - * Throws IllegalArgumentException if there are multiple input or output parameters defined - * in the tf.function - * - * @param tensor input tensor - * @return output tensor - */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { - Session.Runner runner = this.session.runner(); - - Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); - - if (inputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%s] is missing input", this.functionSignatureName)); - } - - if (inputToNode.size() != 1) { - throw new IllegalArgumentException( - String.format("Function [%s] requires multiple inputs", this.functionSignatureName)); - } - - // Feed the single argument - for (Map.Entry entry: inputToNode.entrySet()) { - // Node name in the tensorflow graph, corresponding to the tf.function argument - runner = runner.feed(entry.getValue(), tensor); - } - - Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); - if (outputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%] is missing output", this.functionSignatureName)); - } - - if (outputToNode.size() != 1) { - throw new IllegalArgumentException( - String.format("Function [%s] has multiple outputs", this.functionSignatureName)); - } - - // Fetch the single return tensor - for (String nodeName: outputToNode.values()) { - // Node names corresponding to the return value - runner = runner.fetch(nodeName); - } - - List> resultTensors = runner.run(); - - return resultTensors.get(0); - } - - private final Session session; - private final SavedModelBundle.SignatureToNodeName nameToNode; - private final String functionSignatureName; -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 91c07e3f4b6..3c9b40939f2 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -15,20 +15,37 @@ package org.tensorflow; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.io.File; +import java.io.IOException; import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.core.Variable; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunOptions; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.types.TFloat32; /** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ public class SavedModelBundleTest { + private static final float EPSILON = 1e-7f; private static final String SAVED_MODEL_PATH; static { try { @@ -72,6 +89,86 @@ public void loader() { } } + @Test + public void export() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + float reducedSum; + FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); + Shape xyShape = Shape.of(2, 3L); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xyShape)); + Variable y = tf + .variable(tf.random.randomUniform(tf.constant(xyShape), TFloat32.DTYPE)); + ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); + Init init = tf.init(); + + try (Session s = new Session(g)) { + s.run(init); + + FunctionGraph function = FunctionGraph.builder() + .addInput("input", x) + .addOutput("reducedSum", z) + .build(s); + + // Call the graph and remember the result of computation for later + try (Tensor xTensor = TFloat32.tensorOf(xValue); + Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { + reducedSum = zTensor.data().getFloat(); + } + // Export the model + SavedModelBundle.exporter(testFolder.toString()) + .withTags("test") + .withFunction(function) + .export(s); + } + } + assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); + assertTrue(Files + .exists(testFolder.resolve(Paths.get("variables", "variables.data-00000-of-00001")))); + assertTrue(Files.exists(testFolder.resolve("saved_model.pb"))); + + // Reload the model just saved and validate its data + try (SavedModelBundle savedModel = SavedModelBundle.load(testFolder.toString(), "test")) { + assertNotNull(savedModel.metaGraphDef()); + assertNotNull(savedModel.metaGraphDef().getSaverDef()); + assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); + assertEquals(SavedModelBundle.DEFAULT_SIGNATURE_NAME, + savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); + + SignatureDef signature = savedModel.metaGraphDef().getSignatureDefMap() + .get(SavedModelBundle.DEFAULT_SIGNATURE_NAME); + assertNotNull(signature); + assertEquals(1, signature.getInputsCount()); + assertEquals(1, signature.getOutputsCount()); + + TensorInfo inputInfo = signature.getInputsMap().get("input"); + assertNotNull(inputInfo); + assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount()); + for (int i = 0; i < xyShape.numDimensions(); ++i) { + assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize()); + } + + TensorInfo outputInfo = signature.getOutputsMap().get("reducedSum"); + assertNotNull(outputInfo); + assertEquals(0, outputInfo.getTensorShape().getDimCount()); + + FunctionGraph function = savedModel.function(); + assertNotNull(function); + assertEquals(1, function.inputNames().size()); + assertEquals("input", function.inputNames().iterator().next()); + assertEquals(1, function.outputNames().size()); + assertEquals("reducedSum", function.outputNames().iterator().next()); + assertEquals(FunctionGraph.DEFAULT_METHOD_NAME, function.methodName()); + + // Call the saved model function and make sure it returns the same result as before + try (Tensor xTensor = TFloat32.tensorOf(xValue); + Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + } + } + } + private static RunOptions sillyRunOptions() { return RunOptions.newBuilder() .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 7faf8c6fbdb..fa41af32a29 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Split; import org.tensorflow.op.core.Variable; import org.tensorflow.op.linalg.MatMul; @@ -31,6 +36,7 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; /** Unit tests for {@link org.tensorflow.Session}. */ @@ -205,6 +211,24 @@ public void runInitByName() { } } + @Test + public void save() throws IOException { + Path testFolder = Files.createTempDirectory("tf-session-save-test"); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Variable x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); + Variable y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.DTYPE)); + Init init = tf.init(); + + try (Session s = new Session(g)) { + s.run(init); + s.save(testFolder.resolve("checkpoint").toString()); + } + } + assertTrue(Files.exists(testFolder.resolve("checkpoint.index"))); + assertTrue(Files.exists(testFolder.resolve("checkpoint.data-00000-of-00001"))); + } + private static RunOptions fullTraceRunOptions() { return RunOptions.newBuilder() .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) From b311e27b71ca70c11c0e4eb8b75b3d4a462b2258 Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Mon, 24 Aug 2020 00:38:37 -0400 Subject: [PATCH 3/6] Create function objects as callable graphs --- .../main/java/org/tensorflow/DataTypes.java | 2 - .../java/org/tensorflow/FunctionGraph.java | 177 ++++++++++++------ .../java/org/tensorflow/SavedModelBundle.java | 144 +++++++++----- .../main/java/org/tensorflow/TfFunction.java | 157 ---------------- .../org/tensorflow/FunctionGraphTest.java | 58 ++++++ .../org/tensorflow/SavedModelBundleTest.java | 17 +- 6 files changed, 289 insertions(+), 266 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java index 2dfa15619e4..77c0de0c83f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java @@ -19,7 +19,6 @@ import java.util.HashMap; import java.util.Map; -import javax.annotation.Nullable; import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat16; @@ -69,7 +68,6 @@ static DataType fromNativeCode(int nativeCode) { // TODO (karllessard): Right now this method is private but we might want to expose it // to allow user to register custom data types? - @Nullable private static void register(DataType dataType) { DATA_TYPE_REGISTRY.put(dataType.nativeCode(), dataType); DATA_TYPE_REGISTRY.put(dataType.nativeCode() + 100, dataType); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java index e0ec42d249d..60a6ce3e797 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java @@ -20,9 +20,6 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; -import java.util.function.BiConsumer; -import java.util.function.Consumer; -import java.util.function.Function; import org.tensorflow.ndarray.Shape; import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.SignatureDef; @@ -31,40 +28,99 @@ import org.tensorflow.proto.framework.TensorShapeProto.Dim; /** - * Invoke tf.function + * A graph that can be invoked as a single function, with an input and output signature. + * + *

Note that the lifetime of a function is coupled with the lifetime of its graph or session, i.e. + * the function will failed to be invoked after the graph or session is released, which ever comes + * first. e.g. + * + *

{@code
+ * FunctionGraph function;
+ * try (Graph g = new Graph()) {
+ *   Ops tf = Ops.create(g);
+ *   Placeholder x = tf.placeholder(TFloat32.DTYPE);
+ *   Add y = tf.math.add(x, tf.constant(2.0f));
+ *   try (Session s = new Session(s)) {
+ *     function = FunctionGraph.builder("myFunction").input("x", x).output("y", y).build(s);
+ *     try (Tensor xValue = TFloat32.scalarOf(10.0f);
+ *          Tensor yValue = function.call(xValue).expect(TFloat32.DTYPE)) {
+ *       assertEquals(12.0f, yValue.data().getFloat());
+ *     }
+ *   }
+ * }
+ * try (Tensor xValue = TFloat32.scalarOf(10.0f)) {
+ *   function.call(xValue); // fails, graph has been closed
+ * }
+ * }
+ * + *

A function can also invoke a + * tf.function * defined in a {@link SavedModelBundle}. * *

{@code
  * FunctionGraph myFunction = savedModelBundle.function("myFunctionSignatureName");
- * Map> outputTensorMap = myFunction.call(session, inputTensorMap);
+ * Map> outputTensorMap = myFunction.call(inputTensorMap);
  * }
- * */ -public class FunctionGraph implements AutoCloseable { +public class FunctionGraph { - public static class SignatureBuilder { + /** The default signature name, when not provided */ + public static final String DEFAULT_NAME = "serving_default"; + + /** + * Builds a new function signature. + */ + public static class Builder { - public SignatureBuilder addInput(String inputName, Operand input) { + /** + * Register a tensor as an input of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor + * @return this builder + */ + public Builder input(String inputName, Operand input) { signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput())); return this; } - public SignatureBuilder addOutput(String outputName, Operand output) { + /** + * Register a tensor as an output of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor + * @return this builder + */ + public Builder output(String outputName, Operand output) { signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput())); return this; } - public SignatureBuilder methodName(String methodName) { + /** + * Provide extensible name information enabling third-party users to mark a signature as + * supporting a particular method + * + * @param methodName method name + * @return this builder + */ + public Builder methodName(String methodName) { signatureBuilder.setMethodName(methodName); return this; } - private SignatureDef build() { - return signatureBuilder.build(); + /** + * Creates a function from a graph session. + * + *

The provided session will be used for running or saving this function. + * + * @param signature signature of the function + * @param session a graph session + * @return a function + */ + public FunctionGraph build(Session session) { + return new FunctionGraph(name, signatureBuilder.build(), session); } - private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); - private static TensorInfo toTensorInfo(Output operand) { Shape shape = operand.shape(); TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); @@ -77,38 +133,59 @@ private static TensorInfo toTensorInfo(Output operand) { .setName(operand.op().name() + ":" + operand.index()) .build(); } + + private final String name; + private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); + + private Builder(String name) { + this.name = name; + } } - public static FunctionGraph create(BiConsumer, Graph> function) { - Graph graph = new Graph(); - Map signatures = new HashMap<>(); - function.accept(signatures, graph); - return new FunctionGraph(signatures, graph); + /** + * Returns a new builder for creating a function + * + *

"serving_default" will be used as the default function signature name. + */ + public static Builder builder() { + return new Builder(DEFAULT_NAME); } - public static FunctionGraph create(SignatureDef signature, Graph graph) { - return new FunctionGraph(signature, graph); + /** + * Returns a new builder for creating a function. + * + * @param name function signature name + */ + public static Builder builder(String name) { + return new Builder(name); } /** - * Returns the method name of this function + * Return the name of this function + */ + public String name() { + return name; + } + + /** + * Returns the method name of this function (e.g. as exposed by a server) */ public String methodName() { - return signature.getMethodName(); + return signatureDef.getMethodName(); } /** * Returns the names of the inputs of this function. */ public Set inputNames() { - return signature.getInputsMap().keySet(); + return signatureDef.getInputsMap().keySet(); } /** * Returns the names of the outputs of this function. */ public Set outputNames() { - return signature.getOutputsMap().keySet(); + return signatureDef.getOutputsMap().keySet(); } /** @@ -124,7 +201,7 @@ public Map> call(Map> arguments) final Session.Runner runner = session.runner(); - signature.getInputsMap().forEach((argName, t) -> { + signatureDef.getInputsMap().forEach((argName, t) -> { Tensor tensor = arguments.get(argName); if (tensor == null) { throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); @@ -132,7 +209,7 @@ public Map> call(Map> arguments) runner.feed(t.getName(), tensor); }); - Map outputToNode = signature.getOutputsMap(); + Map outputToNode = signatureDef.getOutputsMap(); outputToNode.values().forEach(t -> runner.fetch(t.getName())); List> resultTensors = runner.run(); @@ -166,48 +243,36 @@ public Map> call(Map> arguments) * in the function */ public Tensor call(Tensor tensor) throws IllegalArgumentException { - if (signature.getInputsCount() != 1) { + if (signatureDef.getInputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] requires multiple inputs", signature.getMethodName())); + String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); } - String inputNodeName = signature.getInputsMap().values().iterator().next().getName(); + String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName(); - if (signature.getOutputsCount() != 1) { + if (signatureDef.getOutputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] has multiple outputs", signature.getMethodName())); + String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); } - String outputNodeName = signature.getOutputsMap().values().iterator().next().getName(); + String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName(); return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); } - /** - * Returns the signature of this function - */ - public SignatureDef signature() { - return signature; + Session session() { + return session; } - @Override - public void close() { - session.close(); - graph.close(); + SignatureDef signatureDef() { + return signatureDef; } - FunctionGraph(SignatureDef signature, Graph graph) { - this.graph = graph; - this.session = new Session(graph); - this.signature = signature; - } + private final String name; + private final Session session; + private final SignatureDef signatureDef; - FunctionGraph(Session session, SignatureDef signature) { - this.graph = session.graph(); + FunctionGraph(String name, SignatureDef signatureDef, Session session) { + this.name = name; this.session = session; - this.signature = signature; + this.signatureDef = signatureDef; } - - private final Map signatures; - private final Graph graph; - private final Session session; - private final SignatureDef signature; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 2efaf5e883e..10e80c14e1f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -27,10 +27,10 @@ import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; @@ -40,17 +40,13 @@ import org.tensorflow.internal.c_api.TF_Session; import org.tensorflow.internal.c_api.TF_SessionOptions; import org.tensorflow.internal.c_api.TF_Status; -import org.tensorflow.ndarray.Shape; import org.tensorflow.proto.framework.ConfigProto; -import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.MetaGraphDef; import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef; +import org.tensorflow.proto.framework.MetaGraphDefOrBuilder; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.framework.SavedModel; import org.tensorflow.proto.framework.SignatureDef; -import org.tensorflow.proto.framework.TensorInfo; -import org.tensorflow.proto.framework.TensorShapeProto; -import org.tensorflow.proto.framework.TensorShapeProto.Dim; /** * SavedModelBundle represents a model loaded from storage. @@ -63,7 +59,7 @@ */ public class SavedModelBundle implements AutoCloseable { - public static final String DEFAULT_SIGNATURE_NAME = "serving_default"; + public static final String DEFAULT_TAG = "serve"; /** Options for loading a SavedModel. */ public static final class Loader { @@ -120,35 +116,71 @@ private Loader(String exportDir) { /** Options for exporting a SavedModel. */ public static final class Exporter { + /** + * Sets the set of tags that identify the specific graph in the saved model to save. + * + *

Note that only one graph per model can be saved right now using this API. + * + * @param tags the tags identifying the specific MetaGraphDef to save. + * @return this object + */ public Exporter withTags(String... tags) { this.tags.addAll(Arrays.asList(tags)); return this; } - public Exporter withFunction(FunctionGraph functionGraph) { - } - - public Exporter withSignature(SignatureDef signatureDef) { - return withSignature(DEFAULT_SIGNATURE_NAME, signatureDef); - } - - public Exporter withSignature(String signatureName, SignatureDef signature) { - metaGraphDefBuilder.putSignatureDef(signatureName, signature); + /** + * Save a function with this model. + * + *

The function carries a signature (i.e. a list of user-friendly input and outputs names to + * a graph) and a valid session to a graph to be saved in the model. + * + *

Note:Eventually, TensorFlow for Java will support the export of functions objects like the + * Python API does. But right now, only session-centric models are supported (i.e. models that + * has a single main graph and one or more signatures), like TensorFlow 1.x and estimators do. + * + *

Still, the actual Java API is "function-ready", meaning that the signatures exposed by the + * main graph are provided as `FunctionGraph` objects. Only functions based on the same graph + * can then be saved within a single model, or an exception will be thrown. + * + * @param function a function carrying a signature and a valid session to graph to be saved + * @return this object + * @throws IllegalArgumentException if a function with the same name has already been added to the model + */ + public Exporter function(FunctionGraph function) { + if (functions.containsKey(function.name())) { + throw new IllegalArgumentException("Function \"" + function.name() + "\" was already added to the model"); + } + functions.put(function.name(), function); + if (session == null) { + session = function.session(); + } else if (session != function.session()) { + throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); + } + metaGraphDefBuilder.putSignatureDef(function.name(), function.signatureDef()); return this; } - public void export(Session session) throws IOException { - Graph graph = session.graph(); + /** + * Save the model into the export directory. + * + * @throws IOException if saved model or variable state can be written on disk + */ + public void export() throws IOException { + if (functions.isEmpty() || session == null) { + throw new IllegalStateException("Model should contain at least one valid function"); + } if (tags.isEmpty()) { - tags.add("serve"); + tags.add(DEFAULT_TAG); } // It is imperative to retrieve the graphDef after the saverDef, as the former might add // new ops to the graph. - MetaGraphDef metaGraphDef = metaGraphDefBuilder + Graph graph = session.graph(); + MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder .setSaverDef(graph.saverDef()) .setGraphDef(graph.toGraphDef()) - .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(tags)) - .build(); + .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(tags)); + functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signatureDef())); // Make sure saved model directories exist Path variableDir = Paths.get(exportDir, "variables"); @@ -170,8 +202,10 @@ public void export(Session session) throws IOException { } private final String exportDir; - private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); private final List tags = new ArrayList<>(); + private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); + private final Map functions = new HashMap<>(); + private Session session; } /** @@ -205,6 +239,14 @@ public static Loader loader(String exportDir) { return new Loader(exportDir); } + /** + * Export a saved model. + * + *

Returns a Exporter object for setting configuration options before actually + * saving the model. + * + * @param exportDir the directory path containing a saved model. + */ public static Exporter exporter(String exportDir) { return new Exporter(exportDir); } @@ -246,24 +288,41 @@ public Session session() { * saved model. */ public FunctionGraph function(String functionSignatureName) { - SignatureDef signature = metaGraphDef.getSignatureDefMap().get(functionSignatureName); - if (signature == null) { + FunctionGraph function = functions.get(functionSignatureName); + if (function == null) { throw new IllegalArgumentException( String.format("Function with signature [%s] not found", functionSignatureName)); } - return new FunctionGraph(session, signature); + return function; } /** - * Return the {@link FunctionGraph} corresponding to the default function signature of this model. + * Invokes the default function directly from this model. * - * @param functionSignatureName name of the {@code SignatureDef} in the saved model. - * @return TfFunction object that can be used to make calls to the tf.function - * @throws IllegalArgumentException if no function with the default signature name can be found in - * this saved model. + *

The default function selection is done based on the first of the following conditions that + * is true: + *

    + *
  • The function is the only signature available attached to the main graph of this saved model
  • + *
  • The function is mapped to the default signature name, which is "serving_default"
  • + *
+ * + *

Caller is responsible for closing all returned Tensors. + * + * @param arguments list of input tensors, mapped by their signature name + * @return list of output tensors, mapped by the signature name + * @throws IllegalArgumentException if no function can be selected by default */ - public FunctionGraph function() { - return function(DEFAULT_SIGNATURE_NAME); + public Map> call(Map> arguments) { + FunctionGraph function = null; + if (functions.size() == 1) { + function = functions.values().iterator().next(); + } else { + function = functions.get(FunctionGraph.DEFAULT_NAME); + } + if (function == null) { + throw new IllegalArgumentException("Cannot elect a default function for this model"); + } + return function.call(arguments); } /** @@ -272,13 +331,18 @@ public FunctionGraph function() { */ @Override public void close() { - functions.forEach((s, f) -> f.close()); + session.close(); + graph.close(); } + private final Graph graph; + private final Session session; private final MetaGraphDef metaGraphDef; private final Map functions; - private SavedModelBundle(MetaGraphDef metaGraphDef, Map functions) { + private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) { + this.graph = graph; + this.session = session; this.metaGraphDef = metaGraphDef; this.functions = functions; } @@ -295,16 +359,12 @@ private static SavedModelBundle fromHandle( final Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef()); final Session session = new Session(graph, sessionHandle); - // For each signature, we will create a separate function. To support cases where multiple - // signatures are attached to the same graph, each function instance will retain a reference - // to the underlying resource, so that they are freed only when the last function is released. + // For each signature definition, create a distinct function based on the main graph/session final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { - graphHandle.retainReference(); - sessionHandle.retainReference(); - functions.put(signatureName, new FunctionGraph(session, signatureDef)); + functions.put(signatureName, new FunctionGraph(signatureName, signatureDef, session)); }); - return new SavedModelBundle(metaGraphDef, functions); + return new SavedModelBundle(graph, session, metaGraphDef, functions); } private static SavedModelBundle load( diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java deleted file mode 100644 index 5dc5a128898..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright 2020 The TensorFlow Authors. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow; - -import com.google.protobuf.InvalidProtocolBufferException; - -import java.util.List; -import java.util.ListIterator; -import java.util.HashMap; -import java.util.Map; - -/** - * Invoke tf.function - * defined in a {@link SavedModelBundle}. - * - *

{@code
- * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
- * Map> outputTensorMap = myFunction.call(inputTensorMap);
- * }
- * - */ -public class TfFunction { - - public TfFunction( - String functionSignatureName, - SavedModelBundle.SignatureToNodeName nameToNode, Session session) { - this.nameToNode = nameToNode; - this.session = session; - this.functionSignatureName = functionSignatureName; - } - - /** - * Invokes a tf.function. - * Caller is responsible for closing all Tensors. - * - * @param arguments map of input tensors - * @return map of output tensors - */ - public Map> call( - Map> arguments) throws IllegalArgumentException { - - Session.Runner runner = this.session.runner(); - - Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); - - if (inputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%s] is missing input", this.functionSignatureName)); - } - - // Join arguments.key, inputToNodeName.key - for (Map.Entry entry: inputToNode.entrySet()) { - String argName = entry.getKey(); - Tensor tensor = arguments.get(argName); - - if (tensor == null) { - throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); - } - - // Node name in the tensorflow graph, corresponding to the tf.function argument - runner = runner.feed(entry.getValue(), tensor); - } - - Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); - if (outputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%] is missing output", this.functionSignatureName)); - } - - for (String nodeName: outputToNode.values()) { - // Node names corresponding to the return value - runner = runner.fetch(nodeName); - } - - List> resultTensors = runner.run(); - ListIterator> resultTensorIter = resultTensors.listIterator(); - - Map> returnMap = new HashMap>(); - - // Use the output names as present in the signature definition - for (String nodeName: outputToNode.keySet()) { - returnMap.put(nodeName, resultTensorIter.next()); - } - - return returnMap; - } - - /** - * Invokes a tf.function. - * Caller is responsible for closing all Tensors. - * - * Throws IllegalArgumentException if there are multiple input or output parameters defined - * in the tf.function - * - * @param tensor input tensor - * @return output tensor - */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { - Session.Runner runner = this.session.runner(); - - Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); - - if (inputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%s] is missing input", this.functionSignatureName)); - } - - if (inputToNode.size() != 1) { - throw new IllegalArgumentException( - String.format("Function [%s] requires multiple inputs", this.functionSignatureName)); - } - - // Feed the single argument - for (Map.Entry entry: inputToNode.entrySet()) { - // Node name in the tensorflow graph, corresponding to the tf.function argument - runner = runner.feed(entry.getValue(), tensor); - } - - Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); - if (outputToNode == null) { - throw new IllegalArgumentException( - String.format("Function [%] is missing output", this.functionSignatureName)); - } - - if (outputToNode.size() != 1) { - throw new IllegalArgumentException( - String.format("Function [%s] has multiple outputs", this.functionSignatureName)); - } - - // Fetch the single return tensor - for (String nodeName: outputToNode.values()) { - // Node names corresponding to the return value - runner = runner.fetch(nodeName); - } - - List> resultTensors = runner.run(); - - return resultTensors.get(0); - } - - private final Session session; - private final SavedModelBundle.SignatureToNodeName nameToNode; - private final String functionSignatureName; -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java new file mode 100644 index 00000000000..a798266dfc4 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java @@ -0,0 +1,58 @@ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.Collections; +import jdk.nashorn.internal.codegen.FunctionSignature; +import org.junit.jupiter.api.Test; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.math.Add; +import org.tensorflow.types.TFloat32; + +public class FunctionGraphTest { + + @Test + public void createFunctionFromGraph() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Placeholder x = tf.placeholder(TFloat32.DTYPE); + Add y = tf.math.add(x, tf.constant(2.0f)); + try (Session s = new Session(g)) { + FunctionGraph function = FunctionGraph.builder().input("x", x).output("y", y).build(s); + try (Tensor xValue = TFloat32.scalarOf(10.0f)) { + // Call with explicit input/output names + try (Tensor yValue = function.call(Collections.singletonMap("x", xValue)) + .get("y") + .expect(TFloat32.DTYPE)) { + assertEquals(12.0f, yValue.data().getFloat()); + } + // Call with implicit single input/output names + try (Tensor yValue = function.call(xValue).expect(TFloat32.DTYPE)) { + assertEquals(12.0f, yValue.data().getFloat()); + } + } + } + } + } + + @Test + public void cannotCallFunctionAfterSessionIsClosed() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Placeholder x = tf.placeholder(TFloat32.DTYPE); + Add y = tf.math.add(x, tf.constant(2.0f)); + FunctionGraph function; + try (Session s = new Session(g)) { + function = FunctionGraph.builder().input("x", x).output("y", y).build(s); + } + try (Tensor xValue = TFloat32.scalarOf(10.0f)) { + function.call(xValue); + fail(); + } catch (IllegalStateException e) { + // as expected + } + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 3c9b40939f2..f8e63e0da59 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -20,12 +20,12 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -import java.io.File; import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import jdk.nashorn.internal.codegen.FunctionSignature; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.ndarray.FloatNdArray; @@ -107,8 +107,8 @@ public void export() throws IOException { s.run(init); FunctionGraph function = FunctionGraph.builder() - .addInput("input", x) - .addOutput("reducedSum", z) + .input("input", x) + .output("reducedSum", z) .build(s); // Call the graph and remember the result of computation for later @@ -119,8 +119,8 @@ public void export() throws IOException { // Export the model SavedModelBundle.exporter(testFolder.toString()) .withTags("test") - .withFunction(function) - .export(s); + .function(function) + .export(); } } assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); @@ -133,11 +133,11 @@ public void export() throws IOException { assertNotNull(savedModel.metaGraphDef()); assertNotNull(savedModel.metaGraphDef().getSaverDef()); assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); - assertEquals(SavedModelBundle.DEFAULT_SIGNATURE_NAME, + assertEquals(FunctionGraph.DEFAULT_NAME, savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); SignatureDef signature = savedModel.metaGraphDef().getSignatureDefMap() - .get(SavedModelBundle.DEFAULT_SIGNATURE_NAME); + .get(FunctionGraph.DEFAULT_NAME); assertNotNull(signature); assertEquals(1, signature.getInputsCount()); assertEquals(1, signature.getOutputsCount()); @@ -153,13 +153,12 @@ public void export() throws IOException { assertNotNull(outputInfo); assertEquals(0, outputInfo.getTensorShape().getDimCount()); - FunctionGraph function = savedModel.function(); + FunctionGraph function = savedModel.function(FunctionGraph.DEFAULT_NAME); assertNotNull(function); assertEquals(1, function.inputNames().size()); assertEquals("input", function.inputNames().iterator().next()); assertEquals(1, function.outputNames().size()); assertEquals("reducedSum", function.outputNames().iterator().next()); - assertEquals(FunctionGraph.DEFAULT_METHOD_NAME, function.methodName()); // Call the saved model function and make sure it returns the same result as before try (Tensor xTensor = TFloat32.tensorOf(xValue); From 175d9e6676d360efc4832d51d96a75b8d7c09b7f Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Mon, 24 Aug 2020 11:40:15 -0400 Subject: [PATCH 4/6] Make FunctionGraph auto-closeable --- .../java/org/tensorflow/FunctionGraph.java | 280 +++++++++--------- .../java/org/tensorflow/SavedModelBundle.java | 21 +- .../main/java/org/tensorflow/Signature.java | 165 +++++++++++ .../org/tensorflow/FunctionGraphTest.java | 54 ++-- .../org/tensorflow/SavedModelBundleTest.java | 44 ++- 5 files changed, 361 insertions(+), 203 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java index 60a6ce3e797..af16afbd2f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java @@ -19,40 +19,15 @@ import java.util.ListIterator; import java.util.HashMap; import java.util.Map; -import java.util.Set; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.proto.framework.DataType; +import java.util.function.Function; +import org.tensorflow.op.Ops; +import org.tensorflow.op.math.Sign; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; -import org.tensorflow.proto.framework.TensorShapeProto; -import org.tensorflow.proto.framework.TensorShapeProto.Dim; /** * A graph that can be invoked as a single function, with an input and output signature. * - *

Note that the lifetime of a function is coupled with the lifetime of its graph or session, i.e. - * the function will failed to be invoked after the graph or session is released, which ever comes - * first. e.g. - * - *

{@code
- * FunctionGraph function;
- * try (Graph g = new Graph()) {
- *   Ops tf = Ops.create(g);
- *   Placeholder x = tf.placeholder(TFloat32.DTYPE);
- *   Add y = tf.math.add(x, tf.constant(2.0f));
- *   try (Session s = new Session(s)) {
- *     function = FunctionGraph.builder("myFunction").input("x", x).output("y", y).build(s);
- *     try (Tensor xValue = TFloat32.scalarOf(10.0f);
- *          Tensor yValue = function.call(xValue).expect(TFloat32.DTYPE)) {
- *       assertEquals(12.0f, yValue.data().getFloat());
- *     }
- *   }
- * }
- * try (Tensor xValue = TFloat32.scalarOf(10.0f)) {
- *   function.call(xValue); // fails, graph has been closed
- * }
- * }
- * *

A function can also invoke a * tf.function * defined in a {@link SavedModelBundle}. @@ -62,130 +37,120 @@ * Map> outputTensorMap = myFunction.call(inputTensorMap); * } */ -public class FunctionGraph { - - /** The default signature name, when not provided */ - public static final String DEFAULT_NAME = "serving_default"; +public class FunctionGraph implements AutoCloseable { /** - * Builds a new function signature. + * Creates a function by building a new graph. + * + *

The {@code functionBuilder} must initialize the function graph from the provided + * {@link Ops} instance and return a valid signature that will be used to feed the input tensors + * and fetch the output tensors on execution. + * + *

The function will be the owner of the new graph and its resulting session. Therefore, + * the function must be enclosed properly with a try-with-resources block to guarantee that + * all native resources will be freed once the function is discarded. For example: + * + *

{@code
+   * public class MyModel {
+   *
+   *   public static Signature addTwo(Ops tf) {
+   *     Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *     Add output = tf.math.add(input, tf.constant(2.0f));
+   *     return Signature.builder("addTwo").input("x", input).output("y", output).build();
+   *   }
+   *
+   *   public static void main(String args[]) {
+   *     try (FunctionGraph function = FunctionGraph.create(MyModel::addTwo);
+   *         Tensor x = TFloat32.scalarOf(2.0f)) {
+   *       assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
+   *     }
+   *   }
+   * }
+   * }
+ * + * @param functionBuilder function builder + * @return the new function */ - public static class Builder { - - /** - * Register a tensor as an input of the function. - * - * @param inputName user-friendly name for this input tensor - * @param input input tensor - * @return this builder - */ - public Builder input(String inputName, Operand input) { - signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput())); - return this; - } - - /** - * Register a tensor as an output of the function. - * - * @param inputName user-friendly name for this input tensor - * @param input input tensor - * @return this builder - */ - public Builder output(String outputName, Operand output) { - signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput())); - return this; - } - - /** - * Provide extensible name information enabling third-party users to mark a signature as - * supporting a particular method - * - * @param methodName method name - * @return this builder - */ - public Builder methodName(String methodName) { - signatureBuilder.setMethodName(methodName); - return this; - } - - /** - * Creates a function from a graph session. - * - *

The provided session will be used for running or saving this function. - * - * @param signature signature of the function - * @param session a graph session - * @return a function - */ - public FunctionGraph build(Session session) { - return new FunctionGraph(name, signatureBuilder.build(), session); - } - - private static TensorInfo toTensorInfo(Output operand) { - Shape shape = operand.shape(); - TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); - for (int i = 0; i < shape.numDimensions(); ++i) { - tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i))); - } - return TensorInfo.newBuilder() - .setDtype(DataType.forNumber(operand.dataType().nativeCode())) - .setTensorShape(tensorShapeBuilder) - .setName(operand.op().name() + ":" + operand.index()) - .build(); - } - - private final String name; - private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); - - private Builder(String name) { - this.name = name; + public static FunctionGraph create(Function functionBuilder) { + Graph graph = new Graph(); + try { + Ops tf = Ops.create(graph); + Signature signature = functionBuilder.apply(tf); + return new FunctionGraph(signature, graph, new Session(graph), Ownership.GRAPH); + } catch (Exception e) { + graph.close(); + throw e; } } /** - * Returns a new builder for creating a function + * Create a function from a signature and an existing graph. * - *

"serving_default" will be used as the default function signature name. - */ - public static Builder builder() { - return new Builder(DEFAULT_NAME); - } - - /** - * Returns a new builder for creating a function. + *

The function will keep the ownership of the session used to run the graph but not + * the graph itself, meaning that the lifetime of the latter can extend beyond the scope + * of the function. For example: * - * @param name function signature name - */ - public static Builder builder(String name) { - return new Builder(name); - } - - /** - * Return the name of this function - */ - public String name() { - return name; - } - - /** - * Returns the method name of this function (e.g. as exposed by a server) + *

{@code
+   * try (Graph g = new Graph()) {
+   *   Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *   Add output = tf.math.add(input, tf.constant(2.0f));
+   *   Signature signature = Signature.builder().input("x", input).output("y", output).build();
+   *
+   *   try (FunctionGraph f = FunctionGraph.create(signature, g);
+   *       Tensor x = TFloat32.scalarOf(2.0f)) {
+   *     assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
+   *   }
+   *   // Graph g is still valid at this point
+   * }
+   * }
+ * + * @param signature signature of the function to create + * @param graph a valid and initialized graph + * @return a new function */ - public String methodName() { - return signatureDef.getMethodName(); + public static FunctionGraph create(Signature signature, Graph graph) { + return new FunctionGraph(signature, graph, new Session(graph), Ownership.SESSION); } /** - * Returns the names of the inputs of this function. + * Create a function from a signature and a valid graph session. + * + *

The function will not own the session nor its graph, meaning that their lifetime + * can extend beyond the scope of the function. Therefore the function does not need to be + * closed after its usage. For example: + * + *

{@code
+   * try (Graph g = new Graph()) {
+   *   Placeholder input = tf.placeholder(TFloat32.DTYPE);
+   *   Add output = tf.math.add(input, tf.constant(2.0f));
+   *   Signature signature = Signature.builder().input("x", input).output("y", output).build();
+   *
+   *   try (Session s = new Session(g)) {
+   *     // Auto-closing the function just as an example but this is not required since it has
+   *     // no effect
+   *     try (FunctionGraph f = FunctionGraph.create(signature, s);
+   *         Tensor t = TFloat32.scalarOf(2.0f)) {
+   *       assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
+   *     }
+   *     // Session s is still valid at this point
+   *   }
+   *   // Graph g is still valid at this point
+   * }
+   * }
+ * + * @param signature signature of the function to create + * @param graph a valid session to an initialized graph + * @return a new function */ - public Set inputNames() { - return signatureDef.getInputsMap().keySet(); + public static FunctionGraph create(Signature signature, Session session) { + return new FunctionGraph(signature, session.graph(), session, Ownership.NONE); } /** - * Returns the names of the outputs of this function. + * Returns the signature of this function */ - public Set outputNames() { - return signatureDef.getOutputsMap().keySet(); + public Signature signature() { + return signature; } /** @@ -199,6 +164,7 @@ public Set outputNames() { public Map> call(Map> arguments) throws IllegalArgumentException { + final SignatureDef signatureDef = signature.asSignatureDef(); final Session.Runner runner = session.runner(); signatureDef.getInputsMap().forEach((argName, t) -> { @@ -243,6 +209,8 @@ public Map> call(Map> arguments) * in the function */ public Tensor call(Tensor tensor) throws IllegalArgumentException { + final SignatureDef signatureDef = signature.asSignatureDef(); + if (signatureDef.getInputsCount() != 1) { throw new IllegalArgumentException( String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); @@ -258,21 +226,49 @@ public Tensor call(Tensor tensor) throws IllegalArgumentException { return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); } - Session session() { + /** + * Returns the session used to execute the graph when calling this function + * + *

In general, a user does not need to handle directly the session of a function and rely + * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to + * the session might be necessary, as it allows more running options. + * + * @return the function session + */ + public Session session() { return session; } - SignatureDef signatureDef() { - return signatureDef; + /** + * Returns the graph of this function + */ + public Graph graph() { + return graph; + } + + @Override + public void close() { + if (ownership != Ownership.NONE) { + session.close(); + if (ownership == Ownership.GRAPH) { + graph.close(); + } + } + } + + private enum Ownership { + GRAPH, SESSION, NONE; } - private final String name; + private final Graph graph; private final Session session; - private final SignatureDef signatureDef; + private final Signature signature; + private final Ownership ownership; - FunctionGraph(String name, SignatureDef signatureDef, Session session) { - this.name = name; + FunctionGraph(Signature signature, Graph graph, Session session, Ownership ownership) { + this.graph = graph; this.session = session; - this.signatureDef = signatureDef; + this.signature = signature; + this.ownership = ownership; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 10e80c14e1f..e1b81006104 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -148,16 +148,17 @@ public Exporter withTags(String... tags) { * @throws IllegalArgumentException if a function with the same name has already been added to the model */ public Exporter function(FunctionGraph function) { - if (functions.containsKey(function.name())) { - throw new IllegalArgumentException("Function \"" + function.name() + "\" was already added to the model"); + Signature signature = function.signature(); + if (functions.containsKey(signature.name())) { + throw new IllegalArgumentException("Function \"" + signature.name() + "\" was already added to the model"); } - functions.put(function.name(), function); + functions.put(signature.name(), function); if (session == null) { session = function.session(); } else if (session != function.session()) { throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); } - metaGraphDefBuilder.putSignatureDef(function.name(), function.signatureDef()); + metaGraphDefBuilder.putSignatureDef(signature.name(), signature.asSignatureDef()); return this; } @@ -180,7 +181,7 @@ public void export() throws IOException { .setSaverDef(graph.saverDef()) .setGraphDef(graph.toGraphDef()) .setMetaInfoDef(MetaInfoDef.newBuilder().addAllTags(tags)); - functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signatureDef())); + functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef())); // Make sure saved model directories exist Path variableDir = Paths.get(exportDir, "variables"); @@ -317,7 +318,7 @@ public Map> call(Map> arguments) { if (functions.size() == 1) { function = functions.values().iterator().next(); } else { - function = functions.get(FunctionGraph.DEFAULT_NAME); + function = functions.get(Signature.DEFAULT_NAME); } if (function == null) { throw new IllegalArgumentException("Cannot elect a default function for this model"); @@ -359,10 +360,14 @@ private static SavedModelBundle fromHandle( final Graph graph = new Graph(graphHandle, metaGraphDef.getSaverDef()); final Session session = new Session(graph, sessionHandle); - // For each signature definition, create a distinct function based on the main graph/session + // Create a separate function for each signature of the main graph. + // Note that the saved model will remain the owner of the graph and the session, meaning + // that the functions do not need to be closed by the user and if it does, it should have + // no effect. final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { - functions.put(signatureName, new FunctionGraph(signatureName, signatureDef, session)); + Signature signature = new Signature(signatureName, signatureDef); + functions.put(signatureName, FunctionGraph.create(signature, session)); }); return new SavedModelBundle(graph, session, metaGraphDef, functions); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java new file mode 100644 index 00000000000..e7d0ef9d319 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -0,0 +1,165 @@ +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow; + +import java.util.HashMap; +import java.util.List; +import java.util.ListIterator; +import java.util.Map; +import java.util.Set; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.proto.framework.DataType; +import org.tensorflow.proto.framework.SignatureDef; +import org.tensorflow.proto.framework.TensorInfo; +import org.tensorflow.proto.framework.TensorShapeProto; +import org.tensorflow.proto.framework.TensorShapeProto.Dim; + +/** + * Describe the inputs and outputs of an executable entity, such as a {@link FunctionGraph}, among + * other useful metadata. + */ +public class Signature { + + /** The default signature name, when not provided */ + public static final String DEFAULT_NAME = "serving_default"; + + /** + * Builds a new function signature. + */ + public static class Builder { + + /** + * Register a tensor as an input of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor + * @return this builder + */ + public Builder input(String inputName, Operand input) { + signatureBuilder.putInputs(inputName, toTensorInfo(input.asOutput())); + return this; + } + + /** + * Register a tensor as an output of the function. + * + * @param inputName user-friendly name for this input tensor + * @param input input tensor + * @return this builder + */ + public Builder output(String outputName, Operand output) { + signatureBuilder.putOutputs(outputName, toTensorInfo(output.asOutput())); + return this; + } + + /** + * Provide extensible name information enabling third-party users to mark a signature as + * supporting a particular method + * + * @param methodName method name + * @return this builder + */ + public Builder methodName(String methodName) { + signatureBuilder.setMethodName(methodName); + return this; + } + + /** + * Returns a signature from the provided data. + */ + public Signature build() { + return new Signature(name, signatureBuilder.build()); + } + + private static TensorInfo toTensorInfo(Output operand) { + Shape shape = operand.shape(); + TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); + for (int i = 0; i < shape.numDimensions(); ++i) { + tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i))); + } + return TensorInfo.newBuilder() + .setDtype(DataType.forNumber(operand.dataType().nativeCode())) + .setTensorShape(tensorShapeBuilder) + .setName(operand.op().name() + ":" + operand.index()) + .build(); + } + + private final String name; + private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); + + private Builder(String name) { + this.name = name; + } + } + + /** + * Returns a new builder for creating a signature + * + *

"serving_default" will be used as the default signature name. + */ + public static Builder builder() { + return new Builder(DEFAULT_NAME); + } + + /** + * Returns a new builder for creating a signature. + * + * @param name signature name + */ + public static Builder builder(String name) { + return new Builder(name); + } + + /** + * Return the name of this signature + */ + public String name() { + return name; + } + + /** + * Returns the method name of this signature (e.g. as exposed by TF serving) + */ + public String methodName() { + return signatureDef.getMethodName(); + } + + /** + * Returns the names of the inputs in this signature + */ + public Set inputNames() { + return signatureDef.getInputsMap().keySet(); + } + + /** + * Returns the names of the outputs in this signature + */ + public Set outputNames() { + return signatureDef.getOutputsMap().keySet(); + } + + SignatureDef asSignatureDef() { + return signatureDef; + } + + private final String name; + private final SignatureDef signatureDef; + + Signature(String name, SignatureDef signatureDef) { + this.name = name; + this.signatureDef = signatureDef; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java index a798266dfc4..eb928b843d4 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java @@ -4,7 +4,6 @@ import static org.junit.jupiter.api.Assertions.fail; import java.util.Collections; -import jdk.nashorn.internal.codegen.FunctionSignature; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Placeholder; @@ -13,45 +12,40 @@ public class FunctionGraphTest { + private static Signature addTwo(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.DTYPE); + Add output = tf.math.add(input, tf.constant(2.0f)); + return Signature.builder("addTwo").input("x", input).output("y", output).build(); + } + + @Test + public void createFunctionFromScratch() { + try (FunctionGraph f = FunctionGraph.create(FunctionGraphTest::addTwo); + Tensor x = TFloat32.scalarOf(2.0f)) { + assertEquals(4.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + @Test public void createFunctionFromGraph() { try (Graph g = new Graph()) { - Ops tf = Ops.create(g); - Placeholder x = tf.placeholder(TFloat32.DTYPE); - Add y = tf.math.add(x, tf.constant(2.0f)); - try (Session s = new Session(g)) { - FunctionGraph function = FunctionGraph.builder().input("x", x).output("y", y).build(s); - try (Tensor xValue = TFloat32.scalarOf(10.0f)) { - // Call with explicit input/output names - try (Tensor yValue = function.call(Collections.singletonMap("x", xValue)) - .get("y") - .expect(TFloat32.DTYPE)) { - assertEquals(12.0f, yValue.data().getFloat()); - } - // Call with implicit single input/output names - try (Tensor yValue = function.call(xValue).expect(TFloat32.DTYPE)) { - assertEquals(12.0f, yValue.data().getFloat()); - } - } + Signature signature = addTwo(Ops.create(g)); + try (FunctionGraph f = FunctionGraph.create(signature, g); + Tensor x = TFloat32.scalarOf(2.0f)) { + assertEquals(4.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); } } } @Test - public void cannotCallFunctionAfterSessionIsClosed() { + public void createFunctionFromSession() { try (Graph g = new Graph()) { - Ops tf = Ops.create(g); - Placeholder x = tf.placeholder(TFloat32.DTYPE); - Add y = tf.math.add(x, tf.constant(2.0f)); - FunctionGraph function; + Signature signature = addTwo(Ops.create(g)); try (Session s = new Session(g)) { - function = FunctionGraph.builder().input("x", x).output("y", y).build(s); - } - try (Tensor xValue = TFloat32.scalarOf(10.0f)) { - function.call(xValue); - fail(); - } catch (IllegalStateException e) { - // as expected + try (FunctionGraph f = FunctionGraph.create(signature, s); + Tensor x = TFloat32.scalarOf(2.0f)) { + assertEquals(4.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index f8e63e0da59..e7949f14d4f 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -36,6 +36,7 @@ import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.ReduceSum; import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Sign; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.framework.SignatureDef; @@ -102,24 +103,20 @@ public void export() throws IOException { .variable(tf.random.randomUniform(tf.constant(xyShape), TFloat32.DTYPE)); ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); Init init = tf.init(); + Signature signature = Signature.builder().input("input", x).output("reducedSum", z).build(); - try (Session s = new Session(g)) { - s.run(init); - - FunctionGraph function = FunctionGraph.builder() - .input("input", x) - .output("reducedSum", z) - .build(s); + try (FunctionGraph f = FunctionGraph.create(signature, g)) { + f.session().run(init); // Call the graph and remember the result of computation for later try (Tensor xTensor = TFloat32.tensorOf(xValue); - Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { + Tensor zTensor = f.call(xTensor).expect(TFloat32.DTYPE)) { reducedSum = zTensor.data().getFloat(); } // Export the model SavedModelBundle.exporter(testFolder.toString()) .withTags("test") - .function(function) + .function(f) .export(); } } @@ -133,33 +130,34 @@ public void export() throws IOException { assertNotNull(savedModel.metaGraphDef()); assertNotNull(savedModel.metaGraphDef().getSaverDef()); assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); - assertEquals(FunctionGraph.DEFAULT_NAME, + assertEquals(Signature.DEFAULT_NAME, savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); - SignatureDef signature = savedModel.metaGraphDef().getSignatureDefMap() - .get(FunctionGraph.DEFAULT_NAME); + FunctionGraph function = savedModel.function(Signature.DEFAULT_NAME); + assertNotNull(function); + + Signature signature = function.signature(); assertNotNull(signature); - assertEquals(1, signature.getInputsCount()); - assertEquals(1, signature.getOutputsCount()); + assertEquals(1, signature.inputNames().size()); + assertEquals("input", signature.inputNames().iterator().next()); + assertEquals(1, signature.outputNames().size()); + assertEquals("reducedSum", signature.outputNames().iterator().next()); - TensorInfo inputInfo = signature.getInputsMap().get("input"); + SignatureDef signatureDef = signature.asSignatureDef(); + assertEquals(1, signatureDef.getInputsCount()); + assertEquals(1, signatureDef.getOutputsCount()); + + TensorInfo inputInfo = signatureDef.getInputsMap().get("input"); assertNotNull(inputInfo); assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount()); for (int i = 0; i < xyShape.numDimensions(); ++i) { assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize()); } - TensorInfo outputInfo = signature.getOutputsMap().get("reducedSum"); + TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum"); assertNotNull(outputInfo); assertEquals(0, outputInfo.getTensorShape().getDimCount()); - FunctionGraph function = savedModel.function(FunctionGraph.DEFAULT_NAME); - assertNotNull(function); - assertEquals(1, function.inputNames().size()); - assertEquals("input", function.inputNames().iterator().next()); - assertEquals(1, function.outputNames().size()); - assertEquals("reducedSum", function.outputNames().iterator().next()); - // Call the saved model function and make sure it returns the same result as before try (Tensor xTensor = TFloat32.tensorOf(xValue); Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { From 99a74503f7a97e4615f754b916b9d7fe95add1dc Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Thu, 27 Aug 2020 12:07:21 -0400 Subject: [PATCH 5/6] Rename `FunctionGraph` to `ConcreteFunction` --- ...nctionGraph.java => ConcreteFunction.java} | 42 +++--- .../src/main/java/org/tensorflow/Graph.java | 28 ++-- .../java/org/tensorflow/SavedModelBundle.java | 51 ++++---- .../src/main/java/org/tensorflow/Session.java | 10 ++ .../main/java/org/tensorflow/Signature.java | 6 +- .../org/tensorflow/ConcreteFunctionTest.java | 122 ++++++++++++++++++ .../org/tensorflow/FunctionGraphTest.java | 52 -------- .../org/tensorflow/SavedModelBundleTest.java | 8 +- 8 files changed, 204 insertions(+), 115 deletions(-) rename tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/{FunctionGraph.java => ConcreteFunction.java} (85%) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java delete mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java similarity index 85% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java rename to tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index af16afbd2f7..c76b62d5486 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/FunctionGraph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -15,13 +15,13 @@ */ package org.tensorflow; +import java.io.IOException; import java.util.List; import java.util.ListIterator; import java.util.HashMap; import java.util.Map; import java.util.function.Function; import org.tensorflow.op.Ops; -import org.tensorflow.op.math.Sign; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; @@ -33,11 +33,11 @@ * defined in a {@link SavedModelBundle}. * *

{@code
- * FunctionGraph myFunction = savedModelBundle.function("myFunctionSignatureName");
+ * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
  * Map> outputTensorMap = myFunction.call(inputTensorMap);
  * }
*/ -public class FunctionGraph implements AutoCloseable { +public class ConcreteFunction implements AutoCloseable { /** * Creates a function by building a new graph. @@ -60,7 +60,7 @@ public class FunctionGraph implements AutoCloseable { * } * * public static void main(String args[]) { - * try (FunctionGraph function = FunctionGraph.create(MyModel::addTwo); + * try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo); * Tensor x = TFloat32.scalarOf(2.0f)) { * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); * } @@ -71,12 +71,12 @@ public class FunctionGraph implements AutoCloseable { * @param functionBuilder function builder * @return the new function */ - public static FunctionGraph create(Function functionBuilder) { + public static ConcreteFunction create(Function functionBuilder) { Graph graph = new Graph(); try { Ops tf = Ops.create(graph); Signature signature = functionBuilder.apply(tf); - return new FunctionGraph(signature, graph, new Session(graph), Ownership.GRAPH); + return new ConcreteFunction(signature, graph, new Session(graph), Ownership.GRAPH_AND_SESSION); } catch (Exception e) { graph.close(); throw e; @@ -96,7 +96,7 @@ public static FunctionGraph create(Function functionBuilder) { * Add output = tf.math.add(input, tf.constant(2.0f)); * Signature signature = Signature.builder().input("x", input).output("y", output).build(); * - * try (FunctionGraph f = FunctionGraph.create(signature, g); + * try (ConcreteFunction f = ConcreteFunction.create(signature, g); * Tensor x = TFloat32.scalarOf(2.0f)) { * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); * } @@ -108,8 +108,8 @@ public static FunctionGraph create(Function functionBuilder) { * @param graph a valid and initialized graph * @return a new function */ - public static FunctionGraph create(Signature signature, Graph graph) { - return new FunctionGraph(signature, graph, new Session(graph), Ownership.SESSION); + public static ConcreteFunction create(Signature signature, Graph graph) { + return new ConcreteFunction(signature, graph, new Session(graph), Ownership.SESSION_ONLY); } /** @@ -128,7 +128,7 @@ public static FunctionGraph create(Signature signature, Graph graph) { * try (Session s = new Session(g)) { * // Auto-closing the function just as an example but this is not required since it has * // no effect - * try (FunctionGraph f = FunctionGraph.create(signature, s); + * try (ConcreteFunction f = ConcreteFunction.create(signature, s); * Tensor t = TFloat32.scalarOf(2.0f)) { * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); * } @@ -142,8 +142,8 @@ public static FunctionGraph create(Signature signature, Graph graph) { * @param graph a valid session to an initialized graph * @return a new function */ - public static FunctionGraph create(Signature signature, Session session) { - return new FunctionGraph(signature, session.graph(), session, Ownership.NONE); + public static ConcreteFunction create(Signature signature, Session session) { + return new ConcreteFunction(signature, session.graph(), session, Ownership.NONE); } /** @@ -226,6 +226,18 @@ public Tensor call(Tensor tensor) throws IllegalArgumentException { return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); } + /** + * Export this function as a saved model. + * + *

This method is convenient shortcut equivalent to + * {@code SavedModel.exporter(exportDir).withFunction(this).export()} + */ + public void save(String exportDir) throws IOException { + SavedModelBundle.exporter(exportDir) + .withFunction(this) + .export(); + } + /** * Returns the session used to execute the graph when calling this function * @@ -250,14 +262,14 @@ public Graph graph() { public void close() { if (ownership != Ownership.NONE) { session.close(); - if (ownership == Ownership.GRAPH) { + if (ownership == Ownership.GRAPH_AND_SESSION) { graph.close(); } } } private enum Ownership { - GRAPH, SESSION, NONE; + GRAPH_AND_SESSION, SESSION_ONLY, NONE; } private final Graph graph; @@ -265,7 +277,7 @@ private enum Ownership { private final Signature signature; private final Ownership ownership; - FunctionGraph(Signature signature, Graph graph, Session session, Ownership ownership) { + ConcreteFunction(Signature signature, Graph graph, Session session, Ownership ownership) { this.graph = graph; this.session = session; this.signature = signature; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 071a28bb6c9..f365956dff1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -301,17 +301,6 @@ public Output[] addGradients(Output y, Output[] x) { return addGradients(null, new Output[] {y}, x, null); } - public SaverDef saverDef() { - if (saverDef == null) { - synchronized (this) { - if (saverDef == null) { - saverDef = addVariableSaver(this); - } - } - } - return saverDef; - } - /** * Used to instantiate an abstract class which overrides the buildSubgraph method to build a * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to @@ -427,6 +416,23 @@ public Output[] whileLoop( } } + /** + * Return the {@link SaverDef} instance used to save the state of all variables present in + * this graph. + * + *

On the first call of this method, all nodes necessary to save and restore the state of the + * variables are added to the graph. Consequently, any variables that are added to the graph after + * this call could not be saved nor restored using this {@link SaverDef}. + * + * @return a {@link SaverDef} instance + */ + synchronized SaverDef saverDef() { + if (saverDef == null) { + saverDef = addVariableSaver(this); + } + return saverDef; + } + private final Object nativeHandleLock = new Object(); private TF_Graph nativeHandle; private int refcount = 0; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index e1b81006104..473e7a22e33 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -30,7 +30,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; @@ -43,10 +42,8 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.MetaGraphDef; import org.tensorflow.proto.framework.MetaGraphDef.MetaInfoDef; -import org.tensorflow.proto.framework.MetaGraphDefOrBuilder; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.framework.SavedModel; -import org.tensorflow.proto.framework.SignatureDef; /** * SavedModelBundle represents a model loaded from storage. @@ -130,24 +127,24 @@ public Exporter withTags(String... tags) { } /** - * Save a function with this model. + * Save a concrete function of this model. * - *

The function carries a signature (i.e. a list of user-friendly input and outputs names to - * a graph) and a valid session to a graph to be saved in the model. + *

The concrete function carries a signature (i.e. a list of user-friendly input and outputs + * names to a graph) and a valid session to a graph to be saved in the model. * - *

Note:Eventually, TensorFlow for Java will support the export of functions objects like the - * Python API does. But right now, only session-centric models are supported (i.e. models that - * has a single main graph and one or more signatures), like TensorFlow 1.x and estimators do. + *

Note:Eventually, TensorFlow for Java will support the export of functions objects like + * the Python API does but right now, only session-centric models are supported (i.e. models that + * has a single main graph and one or more signatures). These models are compatible with those + * exported by TensorFlow 1.x or by TensorFlow 2.x estimators. * - *

Still, the actual Java API is "function-ready", meaning that the signatures exposed by the - * main graph are provided as `FunctionGraph` objects. Only functions based on the same graph - * can then be saved within a single model, or an exception will be thrown. + *

Therefore, all functions exported in a model should share the same session at the moment + * or an exception will be thrown. * - * @param function a function carrying a signature and a valid session to graph to be saved + * @param function a function carrying a signature and a valid session to the graph to be saved * @return this object * @throws IllegalArgumentException if a function with the same name has already been added to the model */ - public Exporter function(FunctionGraph function) { + public Exporter withFunction(ConcreteFunction function) { Signature signature = function.signature(); if (functions.containsKey(signature.name())) { throw new IllegalArgumentException("Function \"" + signature.name() + "\" was already added to the model"); @@ -175,7 +172,7 @@ public void export() throws IOException { tags.add(DEFAULT_TAG); } // It is imperative to retrieve the graphDef after the saverDef, as the former might add - // new ops to the graph. + // new ops to the graph for saving and restoring the variables. Graph graph = session.graph(); MetaGraphDef.Builder metaGraphDef = metaGraphDefBuilder .setSaverDef(graph.saverDef()) @@ -187,10 +184,10 @@ public void export() throws IOException { Path variableDir = Paths.get(exportDir, "variables"); variableDir.toFile().mkdirs(); - // Save variables state, using the "variables-*" prefix + // Save the variables state session.save(variableDir.resolve("variables").toString()); - // Save graph + // Save the graph SavedModel savedModelDef = SavedModel.newBuilder().addMetaGraphs(metaGraphDef).build(); try (OutputStream file = new FileOutputStream(Paths.get(exportDir, "saved_model.pb").toString())) { @@ -205,7 +202,7 @@ public void export() throws IOException { private final String exportDir; private final List tags = new ArrayList<>(); private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); - private final Map functions = new HashMap<>(); + private final Map functions = new HashMap<>(); private Session session; } @@ -276,10 +273,10 @@ public Session session() { } /** - * Return a {@link FunctionGraph} corresponding to the function signature. + * Return a {@link ConcreteFunction} corresponding to the function signature. * *

{@code
-   * FunctionGraph myFunction = savedModelBundle.function("myFunctionSignatureName");
+   * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
    * Map> outputTensorMap = myFunction.call(session, inputTensorMap);
    * }
* @@ -288,8 +285,8 @@ public Session session() { * @throws IllegalArgumentException if {@code functionSignatureName} is not found in this * saved model. */ - public FunctionGraph function(String functionSignatureName) { - FunctionGraph function = functions.get(functionSignatureName); + public ConcreteFunction function(String functionSignatureName) { + ConcreteFunction function = functions.get(functionSignatureName); if (function == null) { throw new IllegalArgumentException( String.format("Function with signature [%s] not found", functionSignatureName)); @@ -314,7 +311,7 @@ public FunctionGraph function(String functionSignatureName) { * @throws IllegalArgumentException if no function can be selected by default */ public Map> call(Map> arguments) { - FunctionGraph function = null; + ConcreteFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); } else { @@ -339,9 +336,9 @@ public void close() { private final Graph graph; private final Session session; private final MetaGraphDef metaGraphDef; - private final Map functions; + private final Map functions; - private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) { + private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) { this.graph = graph; this.session = session; this.metaGraphDef = metaGraphDef; @@ -364,10 +361,10 @@ private static SavedModelBundle fromHandle( // Note that the saved model will remain the owner of the graph and the session, meaning // that the functions do not need to be closed by the user and if it does, it should have // no effect. - final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); + final Map functions = new HashMap<>(metaGraphDef.getSignatureDefCount()); metaGraphDef.getSignatureDefMap().forEach((signatureName, signatureDef) -> { Signature signature = new Signature(signatureName, signatureDef); - functions.put(signatureName, FunctionGraph.create(signature, session)); + functions.put(signatureName, ConcreteFunction.create(signature, session)); }); return new SavedModelBundle(graph, session, metaGraphDef, functions); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 21d338ca765..6f6ee4e136f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -446,6 +446,16 @@ public void run(Op op) { runner().addTarget(op.op()).run(); } + /** + * Saves the actual state of the variables of this session's graph. + * + *

{@code prefix} is a path where the files containing the variables state will be saved, + * followed by a prefix for naming these files. For example, if {@code prefix} is set to + * mymodel/myvariables/variables, then the generated files will be located under + * mymodel/myvariables and named variables.data-*-of-* + * + * @param prefix + */ public void save(String prefix) { SaverDef saverDef = graph.saverDef(); runner() diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index e7d0ef9d319..5ced9ae9c82 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -15,10 +15,6 @@ */ package org.tensorflow; -import java.util.HashMap; -import java.util.List; -import java.util.ListIterator; -import java.util.Map; import java.util.Set; import org.tensorflow.ndarray.Shape; import org.tensorflow.proto.framework.DataType; @@ -28,7 +24,7 @@ import org.tensorflow.proto.framework.TensorShapeProto.Dim; /** - * Describe the inputs and outputs of an executable entity, such as a {@link FunctionGraph}, among + * Describe the inputs and outputs of an executable entity, such as a {@link ConcreteFunction}, among * other useful metadata. */ public class Signature { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java new file mode 100644 index 00000000000..ec2ebbacaee --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -0,0 +1,122 @@ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TFloat32; + +public class ConcreteFunctionTest { + + private static Signature plusFive(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.DTYPE); + Add output = tf.math.add(input, tf.constant(5.0f)); + Init init = tf.init(); // for native resource management tests + return Signature.builder("plusFive").input("x", input).output("y", output).build(); + } + + private static Signature minusTwo(Ops tf) { + Placeholder input = tf.placeholder(TFloat32.DTYPE); + Sub output = tf.math.sub(input, tf.constant(2.0f)); + return Signature.builder("minusTwo").input("x", input).output("y", output).build(); + } + + @Test + public void createFunction() { + try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + + @Test + public void createFunctionFromGraph() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + try (ConcreteFunction f = ConcreteFunction.create(signature, g); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + } + + @Test + public void createFunctionFromSession() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + try (Session s = new Session(g)) { + try (ConcreteFunction f = ConcreteFunction.create(signature, s); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + } + } + } + } + + @Test + public void chainFunctions() { + try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive); + ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo); + Tensor x = TFloat32.scalarOf(3.0f)) { + assertEquals(6.0f, f2.call(f1.call(x)).expect(TFloat32.DTYPE).data().getFloat()); + } + } + + @Test + public void closingFunctionReleaseAllResourcesItOwns() { + Graph g; + Session s; + try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive)) { + g = f.graph(); + s = f.session(); + } + try { + s.run("Add"); + fail(); + } catch (IllegalStateException e) { + // as expected + } + try { + g.toGraphDef(); + fail(); + } catch (IllegalStateException e) { + // as expected + } + } + + @Test + public void closingFunctionCreatedFromGraphOnlyReleaseResourcesItOwns() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + Session s; + try (ConcreteFunction f = ConcreteFunction.create(signature, g)) { + s = f.session(); + } + try { + s.run(Init.DEFAULT_NAME); + fail(); + } catch (IllegalStateException e) { + // as expected + } + g.toGraphDef(); // check that graph is still valid + } + } + + @Test + public void closingFunctionCreatedFromSessionDoesNotReleaseResources() { + try (Graph g = new Graph()) { + Signature signature = plusFive(Ops.create(g)); + try (Session s = new Session(g)) { + try (ConcreteFunction f = ConcreteFunction.create(signature, s)) { + } + s.run(Init.DEFAULT_NAME); // check that session is still valid + } + g.toGraphDef(); // check that graph is still valid + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java deleted file mode 100644 index eb928b843d4..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/FunctionGraphTest.java +++ /dev/null @@ -1,52 +0,0 @@ -package org.tensorflow; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; - -import java.util.Collections; -import org.junit.jupiter.api.Test; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Placeholder; -import org.tensorflow.op.math.Add; -import org.tensorflow.types.TFloat32; - -public class FunctionGraphTest { - - private static Signature addTwo(Ops tf) { - Placeholder input = tf.placeholder(TFloat32.DTYPE); - Add output = tf.math.add(input, tf.constant(2.0f)); - return Signature.builder("addTwo").input("x", input).output("y", output).build(); - } - - @Test - public void createFunctionFromScratch() { - try (FunctionGraph f = FunctionGraph.create(FunctionGraphTest::addTwo); - Tensor x = TFloat32.scalarOf(2.0f)) { - assertEquals(4.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); - } - } - - @Test - public void createFunctionFromGraph() { - try (Graph g = new Graph()) { - Signature signature = addTwo(Ops.create(g)); - try (FunctionGraph f = FunctionGraph.create(signature, g); - Tensor x = TFloat32.scalarOf(2.0f)) { - assertEquals(4.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); - } - } - } - - @Test - public void createFunctionFromSession() { - try (Graph g = new Graph()) { - Signature signature = addTwo(Ops.create(g)); - try (Session s = new Session(g)) { - try (FunctionGraph f = FunctionGraph.create(signature, s); - Tensor x = TFloat32.scalarOf(2.0f)) { - assertEquals(4.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); - } - } - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index e7949f14d4f..9d2e380d0c5 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -25,7 +25,6 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import jdk.nashorn.internal.codegen.FunctionSignature; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.ndarray.FloatNdArray; @@ -36,7 +35,6 @@ import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.ReduceSum; import org.tensorflow.op.core.Variable; -import org.tensorflow.op.math.Sign; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.framework.SignatureDef; @@ -105,7 +103,7 @@ public void export() throws IOException { Init init = tf.init(); Signature signature = Signature.builder().input("input", x).output("reducedSum", z).build(); - try (FunctionGraph f = FunctionGraph.create(signature, g)) { + try (ConcreteFunction f = ConcreteFunction.create(signature, g)) { f.session().run(init); // Call the graph and remember the result of computation for later @@ -116,7 +114,7 @@ public void export() throws IOException { // Export the model SavedModelBundle.exporter(testFolder.toString()) .withTags("test") - .function(f) + .withFunction(f) .export(); } } @@ -133,7 +131,7 @@ public void export() throws IOException { assertEquals(Signature.DEFAULT_NAME, savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next()); - FunctionGraph function = savedModel.function(Signature.DEFAULT_NAME); + ConcreteFunction function = savedModel.function(Signature.DEFAULT_NAME); assertNotNull(function); Signature signature = function.signature(); From 1383a38006bf2ebbc08146bc76881db6f66d9860 Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Thu, 27 Aug 2020 21:57:44 -0400 Subject: [PATCH 6/6] Add more unit tests --- .../java/org/tensorflow/SavedModelBundle.java | 8 +- .../main/java/org/tensorflow/Signature.java | 32 ++-- .../org/tensorflow/ConcreteFunctionTest.java | 4 +- .../org/tensorflow/SavedModelBundleTest.java | 163 ++++++++++++++---- 4 files changed, 157 insertions(+), 50 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 473e7a22e33..dc4c58ed478 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -93,10 +93,14 @@ public Loader withConfigProto(ConfigProto configProto) { /** * Sets the set of tags that identify the specific graph in the saved model to load. * + *

Has no effect if {@code tags} is null or empty + * * @param tags the tags identifying the specific MetaGraphDef to load. */ public Loader withTags(String... tags) { - this.tags = tags; + if (tags != null && tags.length > 0) { + this.tags = tags; + } return this; } @@ -105,7 +109,7 @@ private Loader(String exportDir) { } private String exportDir = null; - private String[] tags = null; + private String[] tags = {DEFAULT_TAG}; private ConfigProto configProto = null; private RunOptions runOptions = null; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 5ced9ae9c82..479a7c168ac 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -37,6 +37,19 @@ public class Signature { */ public static class Builder { + /** + * Sets the name of this signature. + * + *

When not set explicitly, the default value is {@link #DEFAULT_NAME}. + * + * @param name signature name + * @return this builder + */ + public Builder name(String name) { + this.name = name; + return this; + } + /** * Register a tensor as an input of the function. * @@ -93,30 +106,15 @@ private static TensorInfo toTensorInfo(Output operand) { .build(); } - private final String name; + private String name = DEFAULT_NAME; private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder(); - - private Builder(String name) { - this.name = name; - } } /** * Returns a new builder for creating a signature - * - *

"serving_default" will be used as the default signature name. */ public static Builder builder() { - return new Builder(DEFAULT_NAME); - } - - /** - * Returns a new builder for creating a signature. - * - * @param name signature name - */ - public static Builder builder(String name) { - return new Builder(name); + return new Builder(); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index ec2ebbacaee..ec32c69c7f4 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -17,13 +17,13 @@ private static Signature plusFive(Ops tf) { Placeholder input = tf.placeholder(TFloat32.DTYPE); Add output = tf.math.add(input, tf.constant(5.0f)); Init init = tf.init(); // for native resource management tests - return Signature.builder("plusFive").input("x", input).output("y", output).build(); + return Signature.builder().name("plusFive").input("x", input).output("y", output).build(); } private static Signature minusTwo(Ops tf) { Placeholder input = tf.placeholder(TFloat32.DTYPE); Sub output = tf.math.sub(input, tf.constant(2.0f)); - return Signature.builder("minusTwo").input("x", input).output("y", output).build(); + return Signature.builder().name("minusTwo").input("x", input).output("y", output).build(); } @Test diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index 9d2e380d0c5..43e9888176f 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -25,18 +25,22 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Collections; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Identity; import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.ReduceSum; import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Sign; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunOptions; +import org.tensorflow.proto.framework.SavedModel; import org.tensorflow.proto.framework.SignatureDef; import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.types.TFloat32; @@ -89,34 +93,22 @@ public void loader() { } @Test - public void export() throws IOException { + public void exportFunctionWithVariables() throws IOException { Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); float reducedSum; FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); Shape xyShape = Shape.of(2, 3L); - try (Graph g = new Graph()) { - Ops tf = Ops.create(g); - Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xyShape)); - Variable y = tf - .variable(tf.random.randomUniform(tf.constant(xyShape), TFloat32.DTYPE)); - ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); - Init init = tf.init(); - Signature signature = Signature.builder().input("input", x).output("reducedSum", z).build(); - - try (ConcreteFunction f = ConcreteFunction.create(signature, g)) { - f.session().run(init); - - // Call the graph and remember the result of computation for later - try (Tensor xTensor = TFloat32.tensorOf(xValue); - Tensor zTensor = f.call(xTensor).expect(TFloat32.DTYPE)) { - reducedSum = zTensor.data().getFloat(); - } - // Export the model - SavedModelBundle.exporter(testFolder.toString()) - .withTags("test") - .withFunction(f) - .export(); + try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { + // Init variable state by running the Init operation directly + f.session().run(Init.DEFAULT_NAME); + + // Call the graph and remember the result of computation for later + try (Tensor xTensor = TFloat32.tensorOf(xValue); + Tensor zTensor = f.call(xTensor).expect(TFloat32.DTYPE)) { + reducedSum = zTensor.data().getFloat(); } + // Save/export the model (which is a single function in this case) + f.save(testFolder.toString()); } assertTrue(Files.exists(testFolder.resolve(Paths.get("variables", "variables.index")))); assertTrue(Files @@ -124,7 +116,8 @@ public void export() throws IOException { assertTrue(Files.exists(testFolder.resolve("saved_model.pb"))); // Reload the model just saved and validate its data - try (SavedModelBundle savedModel = SavedModelBundle.load(testFolder.toString(), "test")) { + try (SavedModelBundle savedModel = + SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) { assertNotNull(savedModel.metaGraphDef()); assertNotNull(savedModel.metaGraphDef().getSaverDef()); assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); @@ -156,21 +149,133 @@ public void export() throws IOException { assertNotNull(outputInfo); assertEquals(0, outputInfo.getTensorShape().getDimCount()); - // Call the saved model function and make sure it returns the same result as before - try (Tensor xTensor = TFloat32.tensorOf(xValue); - Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { - assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + try (Tensor xTensor = TFloat32.tensorOf(xValue)) { + // Call the saved model function and make sure it returns the same result as before + try (Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + } + // Now call the same function directly from the model + try (Tensor zTensor = + savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum").expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + } + } + } + } + + @Test + public void exportMultipleFunctions() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + float reducedSum; + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, "identity"); + try (Session s = new Session(g); + ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { + f1.session().run(Init.DEFAULT_NAME); + try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { + reducedSum = t.data().getFloat(); + } + SavedModelBundle.exporter(testFolder.toString()) + .withFunction(f1) + .withFunction(f2) + .export(); + } + } + try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { + ConcreteFunction f1 = model.function(Signature.DEFAULT_NAME); + assertNotNull(f1); + try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { + assertEquals(reducedSum, t.data().getFloat(), EPSILON); + } + ConcreteFunction f2 = model.function("identity"); + assertNotNull(f2); + try (Tensor x = TFloat32.scalarOf(10.0f); + Tensor t = f2.call(x).expect(TFloat32.DTYPE)) { + assertEquals(10.0f, t.data().getFloat(), 0.0f); + } + try { + model.function("NoSuchFunction"); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + } + + @Test + public void cannotExportMultipleFunctionsWithDifferentSessions() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, "identity"); + try (ConcreteFunction f1 = ConcreteFunction.create(f1Signature, g); + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, g)) { + f1.session().run(Init.DEFAULT_NAME); + try { + SavedModelBundle.exporter(testFolder.toString()) + .withFunction(f1) + .withFunction(f2) + .export(); + fail(); + } catch (UnsupportedOperationException e) { + // as expected + } + } + } + } + + @Test + public void cannotExportMultipleFunctionsWithSameSignatureName() throws IOException { + Path testFolder = Files.createTempDirectory("tf-saved-model-export-test"); + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Signature f1Signature = buildGraphWithVariables(tf, Shape.of(1, 1)); + Signature f2Signature = buildIdentityGraph(tf, Signature.DEFAULT_NAME); + try (Session s = new Session(g); + ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { + f1.session().run(Init.DEFAULT_NAME); + try { + SavedModelBundle.exporter(testFolder.toString()) + .withFunction(f1) + .withFunction(f2) + .export(); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } } } } + private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { + Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(xShape)); + Variable y = tf + .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.DTYPE)); + ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); + Init init = tf.init(); + return Signature.builder().input("input", x).output("reducedSum", z).build(); + } + + private static Signature buildIdentityGraph(Ops tf, String signatureName) { + Placeholder x = tf.placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + Identity xprime = tf.identity(x); + return Signature.builder().name(signatureName).input("x", x).output("x", xprime).build(); + } + private static RunOptions sillyRunOptions() { return RunOptions.newBuilder() .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) .build(); } - public static ConfigProto sillyConfigProto() { + private static ConfigProto sillyConfigProto() { return ConfigProto.newBuilder() .setInterOpParallelismThreads(1) .setIntraOpParallelismThreads(1)