Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -218,7 +218,7 @@ public String toString() {
* @return the outputs of the function
*/
public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> arguments) {
List<Operand<?>> inputList = new ArrayList<>(signature.inputNames().size());
List<Operand<?>> inputList = new ArrayList<>(signature().inputNames().size());

for (String inputName : signature().inputNames()) {
if (!arguments.containsKey(inputName)) {
Expand All @@ -241,31 +241,31 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen
}

List<Output<?>> outputList =
PartitionedCall.create(
scope,
inputList,
Arrays.stream(inputDtypes)
.map(x -> TensorTypeRegistry.find(x).type())
.collect(Collectors.toList()),
this)
.output();

Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature().outputNames().size());

List<String> outputNames = new ArrayList<>(signature().outputNames());
for (int i = 0; i < outputNames.size(); i++) {
String outputName = outputNames.get(i);

if (i > outputList.size()) {
throw new IllegalStateException(
"Somehow, not all required outputs were returned from the function");
PartitionedCall.create(scope, inputList, outputTypes, this).output();

if (signature().outputNames().size() == 0) {
return Collections.emptyMap();
} else if (signature().outputNames().size() == 1) {
return Collections.singletonMap(
signature().outputNames().iterator().next(), outputList.get(0));
} else {
Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature().outputNames().size());

Iterator<String> outputNames = signature().outputNames().iterator();
for (int i = 0; outputNames.hasNext(); i++) {
String outputName = outputNames.next();

if (i > outputList.size()) {
throw new IllegalStateException(
"Somehow, not all required outputs were returned from the function");
}

Operand<?> output = outputList.get(i);
namedOutputs.put(outputName, output);
}

Operand<?> output = outputList.get(i);
namedOutputs.put(outputName, output);
return Collections.unmodifiableMap(namedOutputs);
}

return Collections.unmodifiableMap(namedOutputs);
}

/**
Expand All @@ -291,10 +291,7 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
}
String outputName = signatureDef.getOutputsMap().keySet().iterator().next();

Map<String, Operand<?>> inputMap = new LinkedHashMap<>();
inputMap.put(inputName, argument);

return call(scope, inputMap).get(outputName);
return call(scope, Collections.singletonMap(inputName, argument)).get(outputName);
}

@Override
Expand Down Expand Up @@ -395,8 +392,7 @@ static ConcreteFunction fromNativeHandle(
private final NativeFunction nativeFunction;
private final PointerScope scope;
private final Set<TF_Function> dependencies;
private final DataType[] inputDtypes;
private final DataType[] outputDtypes;
private final List<Class<? extends TType>> outputTypes;

/** All native functions should have deallocators registered */
private ConcreteFunction(
Expand All @@ -423,10 +419,10 @@ private ConcreteFunction(
+ this.signature.getOutputs().size());
}

inputDtypes =
this.signature.getInputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);

List<DataType> inputs = Arrays.asList(inputDtypes);
List<DataType> inputs =
this.signature.getInputs().values().stream()
.map(x -> x.dataType)
.collect(Collectors.toList());
List<DataType> nativeInputs =
nativeFunction.getFunctionDef().getSignature().getInputArgList().stream()
.map(ArgDef::getType)
Expand All @@ -440,10 +436,10 @@ private ConcreteFunction(
+ inputs);
}

outputDtypes =
signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);

List<DataType> outputs = Arrays.asList(outputDtypes);
List<DataType> outputs =
signature().getOutputs().values().stream()
.map(x -> x.dataType)
.collect(Collectors.toList());
List<DataType> nativeOutputs =
nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream()
.map(ArgDef::getType)
Expand All @@ -457,6 +453,9 @@ private ConcreteFunction(
+ outputs);
}

outputTypes =
outputs.stream().map(x -> TensorTypeRegistry.find(x).type()).collect(Collectors.toList());

try (PointerScope scope = new PointerScope()) {
this.scope = scope;
scope.extend();
Expand All @@ -469,6 +468,8 @@ private ConcreteFunction(
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
* how to enable XLA JIT is extremely non-obvious.
*
* <p>See https://github.com/tensorflow/java/issues/347
*
* <p>Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
* platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
*/
Expand Down