diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 44b32df0894..4d07b678811 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -19,10 +19,10 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -66,7 +66,7 @@ * Map outputTensorMap = myFunction.call(inputTensorMap); * } */ -public class ConcreteFunction implements AutoCloseable, TensorFunction { +public final class ConcreteFunction implements AutoCloseable, TensorFunction { /** * Creates a function by building a new graph. @@ -220,11 +220,11 @@ public String toString() { public Map> call(Scope scope, Map> arguments) { List> inputList = new ArrayList<>(signature.inputNames().size()); - for (String inputName : signature().inputNames()) { + for (String inputName : signature.inputNames()) { if (!arguments.containsKey(inputName)) { throw new IllegalArgumentException( "Function " - + signature().methodName() + + signature.methodName() + " has parameter \"" + inputName + "\", but no argument was passed for it."); @@ -241,30 +241,30 @@ public Map> call(Scope scope, Map> argumen } List> outputList = - PartitionedCall.create( - scope, - inputList, - Arrays.stream(outputDtypes) - .map(x -> TensorTypeRegistry.find(x).type()) - .collect(Collectors.toList()), - this) - .output(); - - Map> namedOutputs = new LinkedHashMap<>(signature().outputNames().size()); - - List outputNames = new ArrayList<>(signature().outputNames()); - for (int i = 0; i < outputNames.size(); i++) { - String outputName = outputNames.get(i); - - if (i > outputList.size()) { - throw new IllegalStateException( - "Somehow, not all required outputs were returned from the function"); - } + PartitionedCall.create(scope, inputList, outputTypes, this).output(); + if (signature.outputNames().size() == 0) { + return Collections.emptyMap(); + } + if (signature.outputNames().size() == 1) { + return Collections.singletonMap(signature.outputNames().iterator().next(), outputList.get(0)); + } + if (outputList.size() < signature.outputNames().size()) { + throw new IllegalStateException( + "Somehow, not all required outputs were returned from the function" + + "(expected: " + + signature.outputNames().size() + + ", returned: " + + outputList.size() + + ")"); + } + Map> namedOutputs = new LinkedHashMap<>(signature.outputNames().size()); + Iterator outputNames = signature.outputNames().iterator(); + for (int i = 0; outputNames.hasNext(); i++) { + String outputName = outputNames.next(); Operand output = outputList.get(i); namedOutputs.put(outputName, output); } - return Collections.unmodifiableMap(namedOutputs); } @@ -291,10 +291,7 @@ public Operand call(Scope scope, Operand argument) { } String outputName = signatureDef.getOutputsMap().keySet().iterator().next(); - Map> inputMap = new LinkedHashMap<>(); - inputMap.put(inputName, argument); - - return call(scope, inputMap).get(outputName); + return call(scope, Collections.singletonMap(inputName, argument)).get(outputName); } @Override @@ -395,8 +392,7 @@ static ConcreteFunction fromNativeHandle( private final NativeFunction nativeFunction; private final PointerScope scope; private final Set dependencies; - private final DataType[] inputDtypes; - private final DataType[] outputDtypes; + private final List> outputTypes; /** All native functions should have deallocators registered */ private ConcreteFunction( @@ -405,7 +401,7 @@ private ConcreteFunction( this.nativeFunction = nativeFunction; this.dependencies = Collections.unmodifiableSet(dependencies); - if (this.signature.getInputs().size() + if (signature.getInputs().size() != nativeFunction.getFunctionDef().getSignature().getInputArgCount()) { throw new IllegalArgumentException( "Signature must have the same number of inputs as the native function. Expected " @@ -414,7 +410,7 @@ private ConcreteFunction( + this.signature.getInputs().size()); } - if (this.signature.getOutputs().size() + if (signature.getOutputs().size() != nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) { throw new IllegalArgumentException( "New signature must have the same number of outputs as the native function. Expected " @@ -423,10 +419,8 @@ private ConcreteFunction( + this.signature.getOutputs().size()); } - inputDtypes = - this.signature.getInputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new); - - List inputs = Arrays.asList(inputDtypes); + List inputs = + signature.getInputs().values().stream().map(x -> x.dataType).collect(Collectors.toList()); List nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList().stream() .map(ArgDef::getType) @@ -440,10 +434,8 @@ private ConcreteFunction( + inputs); } - outputDtypes = - signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new); - - List outputs = Arrays.asList(outputDtypes); + List outputs = + signature.getOutputs().values().stream().map(x -> x.dataType).collect(Collectors.toList()); List nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream() .map(ArgDef::getType) @@ -457,6 +449,9 @@ private ConcreteFunction( + outputs); } + outputTypes = + outputs.stream().map(x -> TensorTypeRegistry.find(x).type()).collect(Collectors.toList()); + try (PointerScope scope = new PointerScope()) { this.scope = scope; scope.extend(); @@ -469,6 +464,8 @@ private ConcreteFunction( * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because * how to enable XLA JIT is extremely non-obvious. * + *

See https://github.com/tensorflow/java/issues/347 + * *

Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered * platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). */ diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 64c33f451fb..eb4348a3fd6 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -1,24 +1,26 @@ /* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ package org.tensorflow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; @@ -27,6 +29,7 @@ import org.tensorflow.op.math.Sub; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; public class ConcreteFunctionTest { @@ -144,6 +147,28 @@ public void testNestedFunctionGraph() { } } + @Test + public void testFunctionWithTwoOutputs() { + ConcreteFunction cf = + ConcreteFunction.create( + tf -> { + Placeholder x = tf.placeholder(TInt32.class); + Operand dblX = tf.math.add(x, x); + Operand tripX = tf.math.add(x, dblX); + return Signature.builder() + .input("x", x) + .output("dbl", dblX) + .output("trpl", tripX) + .build(); + }); + + Map inputs = new HashMap<>(); + inputs.put("x", TInt32.scalarOf(2)); + Map outputs = cf.call(inputs); + assertEquals(4, ((TInt32) outputs.get("dbl")).getInt()); + assertEquals(6, ((TInt32) outputs.get("trpl")).getInt()); + } + private static Signature square(Ops tf) { Placeholder input = tf.placeholder(TFloat32.class); Operand output = tf.math.square(input);