diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index f5dcd7c2ce3..b441f421efa 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -19,10 +19,10 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -218,7 +218,7 @@ public String toString() { * @return the outputs of the function */ public Map> call(Scope scope, Map> arguments) { - List> inputList = new ArrayList<>(signature.inputNames().size()); + List> inputList = new ArrayList<>(signature().inputNames().size()); for (String inputName : signature().inputNames()) { if (!arguments.containsKey(inputName)) { @@ -241,31 +241,31 @@ public Map> call(Scope scope, Map> argumen } List> outputList = - PartitionedCall.create( - scope, - inputList, - Arrays.stream(inputDtypes) - .map(x -> TensorTypeRegistry.find(x).type()) - .collect(Collectors.toList()), - this) - .output(); - - Map> namedOutputs = new LinkedHashMap<>(signature().outputNames().size()); - - List outputNames = new ArrayList<>(signature().outputNames()); - for (int i = 0; i < outputNames.size(); i++) { - String outputName = outputNames.get(i); - - if (i > outputList.size()) { - throw new IllegalStateException( - "Somehow, not all required outputs were returned from the function"); + PartitionedCall.create(scope, inputList, outputTypes, this).output(); + + if (signature().outputNames().size() == 0) { + return Collections.emptyMap(); + } else if (signature().outputNames().size() == 1) { + return Collections.singletonMap( + signature().outputNames().iterator().next(), outputList.get(0)); + } else { + Map> namedOutputs = new LinkedHashMap<>(signature().outputNames().size()); + + Iterator outputNames = signature().outputNames().iterator(); + for (int i = 0; outputNames.hasNext(); i++) { + String outputName = outputNames.next(); + + if (i > outputList.size()) { + throw new IllegalStateException( + "Somehow, not all required outputs were returned from the function"); + } + + Operand output = outputList.get(i); + namedOutputs.put(outputName, output); } - Operand output = outputList.get(i); - namedOutputs.put(outputName, output); + return Collections.unmodifiableMap(namedOutputs); } - - return Collections.unmodifiableMap(namedOutputs); } /** @@ -291,10 +291,7 @@ public Operand call(Scope scope, Operand argument) { } String outputName = signatureDef.getOutputsMap().keySet().iterator().next(); - Map> inputMap = new LinkedHashMap<>(); - inputMap.put(inputName, argument); - - return call(scope, inputMap).get(outputName); + return call(scope, Collections.singletonMap(inputName, argument)).get(outputName); } @Override @@ -395,8 +392,7 @@ static ConcreteFunction fromNativeHandle( private final NativeFunction nativeFunction; private final PointerScope scope; private final Set dependencies; - private final DataType[] inputDtypes; - private final DataType[] outputDtypes; + private final List> outputTypes; /** All native functions should have deallocators registered */ private ConcreteFunction( @@ -423,10 +419,10 @@ private ConcreteFunction( + this.signature.getOutputs().size()); } - inputDtypes = - this.signature.getInputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new); - - List inputs = Arrays.asList(inputDtypes); + List inputs = + this.signature.getInputs().values().stream() + .map(x -> x.dataType) + .collect(Collectors.toList()); List nativeInputs = nativeFunction.getFunctionDef().getSignature().getInputArgList().stream() .map(ArgDef::getType) @@ -440,10 +436,10 @@ private ConcreteFunction( + inputs); } - outputDtypes = - signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new); - - List outputs = Arrays.asList(outputDtypes); + List outputs = + signature().getOutputs().values().stream() + .map(x -> x.dataType) + .collect(Collectors.toList()); List nativeOutputs = nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream() .map(ArgDef::getType) @@ -457,6 +453,9 @@ private ConcreteFunction( + outputs); } + outputTypes = + outputs.stream().map(x -> TensorTypeRegistry.find(x).type()).collect(Collectors.toList()); + try (PointerScope scope = new PointerScope()) { this.scope = scope; scope.extend(); @@ -469,6 +468,8 @@ private ConcreteFunction( * FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because * how to enable XLA JIT is extremely non-obvious. * + *

See https://github.com/tensorflow/java/issues/347 + * *

Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered * platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails). */