Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -66,7 +66,7 @@
* Map<String, Tensor> outputTensorMap = myFunction.call(inputTensorMap);
* }</pre>
*/
public class ConcreteFunction implements AutoCloseable, TensorFunction {
public final class ConcreteFunction implements AutoCloseable, TensorFunction {

/**
* Creates a function by building a new graph.
Expand Down Expand Up @@ -220,11 +220,11 @@ public String toString() {
public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> arguments) {
List<Operand<?>> inputList = new ArrayList<>(signature.inputNames().size());

for (String inputName : signature().inputNames()) {
for (String inputName : signature.inputNames()) {
if (!arguments.containsKey(inputName)) {
throw new IllegalArgumentException(
"Function "
+ signature().methodName()
+ signature.methodName()
+ " has parameter \""
+ inputName
+ "\", but no argument was passed for it.");
Expand All @@ -241,30 +241,30 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen
}

List<Output<?>> outputList =
PartitionedCall.create(
scope,
inputList,
Arrays.stream(outputDtypes)
.map(x -> TensorTypeRegistry.find(x).type())
.collect(Collectors.toList()),
this)
.output();

Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature().outputNames().size());

List<String> outputNames = new ArrayList<>(signature().outputNames());
for (int i = 0; i < outputNames.size(); i++) {
String outputName = outputNames.get(i);

if (i > outputList.size()) {
throw new IllegalStateException(
"Somehow, not all required outputs were returned from the function");
}
PartitionedCall.create(scope, inputList, outputTypes, this).output();

if (signature.outputNames().size() == 0) {
return Collections.emptyMap();
}
if (signature.outputNames().size() == 1) {
return Collections.singletonMap(signature.outputNames().iterator().next(), outputList.get(0));
}
if (outputList.size() < signature.outputNames().size()) {
throw new IllegalStateException(
"Somehow, not all required outputs were returned from the function"
+ "(expected: "
+ signature.outputNames().size()
+ ", returned: "
+ outputList.size()
+ ")");
}
Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature.outputNames().size());
Iterator<String> outputNames = signature.outputNames().iterator();
for (int i = 0; outputNames.hasNext(); i++) {
String outputName = outputNames.next();
Operand<?> output = outputList.get(i);
namedOutputs.put(outputName, output);
}

return Collections.unmodifiableMap(namedOutputs);
}

Expand All @@ -291,10 +291,7 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
}
String outputName = signatureDef.getOutputsMap().keySet().iterator().next();

Map<String, Operand<?>> inputMap = new LinkedHashMap<>();
inputMap.put(inputName, argument);

return call(scope, inputMap).get(outputName);
return call(scope, Collections.singletonMap(inputName, argument)).get(outputName);
}

@Override
Expand Down Expand Up @@ -395,8 +392,7 @@ static ConcreteFunction fromNativeHandle(
private final NativeFunction nativeFunction;
private final PointerScope scope;
private final Set<TF_Function> dependencies;
private final DataType[] inputDtypes;
private final DataType[] outputDtypes;
private final List<Class<? extends TType>> outputTypes;

/** All native functions should have deallocators registered */
private ConcreteFunction(
Expand All @@ -405,7 +401,7 @@ private ConcreteFunction(
this.nativeFunction = nativeFunction;
this.dependencies = Collections.unmodifiableSet(dependencies);

if (this.signature.getInputs().size()
if (signature.getInputs().size()
!= nativeFunction.getFunctionDef().getSignature().getInputArgCount()) {
throw new IllegalArgumentException(
"Signature must have the same number of inputs as the native function. Expected "
Expand All @@ -414,7 +410,7 @@ private ConcreteFunction(
+ this.signature.getInputs().size());
}

if (this.signature.getOutputs().size()
if (signature.getOutputs().size()
!= nativeFunction.getFunctionDef().getSignature().getOutputArgCount()) {
throw new IllegalArgumentException(
"New signature must have the same number of outputs as the native function. Expected "
Expand All @@ -423,10 +419,8 @@ private ConcreteFunction(
+ this.signature.getOutputs().size());
}

inputDtypes =
this.signature.getInputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);

List<DataType> inputs = Arrays.asList(inputDtypes);
List<DataType> inputs =
signature.getInputs().values().stream().map(x -> x.dataType).collect(Collectors.toList());
List<DataType> nativeInputs =
nativeFunction.getFunctionDef().getSignature().getInputArgList().stream()
.map(ArgDef::getType)
Expand All @@ -440,10 +434,8 @@ private ConcreteFunction(
+ inputs);
}

outputDtypes =
signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);

List<DataType> outputs = Arrays.asList(outputDtypes);
List<DataType> outputs =
signature.getOutputs().values().stream().map(x -> x.dataType).collect(Collectors.toList());
List<DataType> nativeOutputs =
nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream()
.map(ArgDef::getType)
Expand All @@ -457,6 +449,9 @@ private ConcreteFunction(
+ outputs);
}

outputTypes =
outputs.stream().map(x -> TensorTypeRegistry.find(x).type()).collect(Collectors.toList());

try (PointerScope scope = new PointerScope()) {
this.scope = scope;
scope.extend();
Expand All @@ -469,6 +464,8 @@ private ConcreteFunction(
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
* how to enable XLA JIT is extremely non-obvious.
*
* <p>See https://github.com/tensorflow/java/issues/347
*
* <p>Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
* platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
/* Copyright 2020-2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
*/
package org.tensorflow;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Init;
Expand All @@ -27,6 +29,7 @@
import org.tensorflow.op.math.Sub;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;

public class ConcreteFunctionTest {

Expand Down Expand Up @@ -144,6 +147,28 @@ public void testNestedFunctionGraph() {
}
}

@Test
public void testFunctionWithTwoOutputs() {
ConcreteFunction cf =
ConcreteFunction.create(
tf -> {
Placeholder<TInt32> x = tf.placeholder(TInt32.class);
Operand<TInt32> dblX = tf.math.add(x, x);
Operand<TInt32> tripX = tf.math.add(x, dblX);
return Signature.builder()
.input("x", x)
.output("dbl", dblX)
.output("trpl", tripX)
.build();
});

Map<String, Tensor> inputs = new HashMap<>();
inputs.put("x", TInt32.scalarOf(2));
Map<String, Tensor> outputs = cf.call(inputs);
assertEquals(4, ((TInt32) outputs.get("dbl")).getInt());
assertEquals(6, ((TInt32) outputs.get("trpl")).getInt());
}

private static Signature square(Ops tf) {
Placeholder<TFloat32> input = tf.placeholder(TFloat32.class);
Operand<TFloat32> output = tf.math.square(input);
Expand Down