Skip to content

Commit 424fe94

Browse files
rnettkarllessard
authored andcommitted
ConcreteFunction fix and performance improvements
Signed-off-by: Ryan Nett <[email protected]>
1 parent 9a105fa commit 424fe94

File tree

2 files changed

+58
-38
lines changed

2 files changed

+58
-38
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphToFunction;
2020

2121
import java.util.ArrayList;
22-
import java.util.Arrays;
2322
import java.util.Collection;
2423
import java.util.Collections;
2524
import java.util.HashSet;
25+
import java.util.Iterator;
2626
import java.util.LinkedHashMap;
2727
import java.util.List;
2828
import java.util.Map;
@@ -218,7 +218,7 @@ public String toString() {
218218
* @return the outputs of the function
219219
*/
220220
public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> arguments) {
221-
List<Operand<?>> inputList = new ArrayList<>(signature.inputNames().size());
221+
List<Operand<?>> inputList = new ArrayList<>(signature().inputNames().size());
222222

223223
for (String inputName : signature().inputNames()) {
224224
if (!arguments.containsKey(inputName)) {
@@ -241,31 +241,31 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen
241241
}
242242

243243
List<Output<?>> outputList =
244-
PartitionedCall.create(
245-
scope,
246-
inputList,
247-
Arrays.stream(outputDtypes)
248-
.map(x -> TensorTypeRegistry.find(x).type())
249-
.collect(Collectors.toList()),
250-
this)
251-
.output();
252-
253-
Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature().outputNames().size());
254-
255-
List<String> outputNames = new ArrayList<>(signature().outputNames());
256-
for (int i = 0; i < outputNames.size(); i++) {
257-
String outputName = outputNames.get(i);
258-
259-
if (i > outputList.size()) {
260-
throw new IllegalStateException(
261-
"Somehow, not all required outputs were returned from the function");
244+
PartitionedCall.create(scope, inputList, outputTypes, this).output();
245+
246+
if (signature().outputNames().size() == 0) {
247+
return Collections.emptyMap();
248+
} else if (signature().outputNames().size() == 1) {
249+
return Collections.singletonMap(
250+
signature().outputNames().iterator().next(), outputList.get(0));
251+
} else {
252+
Map<String, Operand<?>> namedOutputs = new LinkedHashMap<>(signature().outputNames().size());
253+
254+
Iterator<String> outputNames = signature().outputNames().iterator();
255+
for (int i = 0; outputNames.hasNext(); i++) {
256+
String outputName = outputNames.next();
257+
258+
if (i > outputList.size()) {
259+
throw new IllegalStateException(
260+
"Somehow, not all required outputs were returned from the function");
261+
}
262+
263+
Operand<?> output = outputList.get(i);
264+
namedOutputs.put(outputName, output);
262265
}
263266

264-
Operand<?> output = outputList.get(i);
265-
namedOutputs.put(outputName, output);
267+
return Collections.unmodifiableMap(namedOutputs);
266268
}
267-
268-
return Collections.unmodifiableMap(namedOutputs);
269269
}
270270

271271
/**
@@ -291,10 +291,7 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
291291
}
292292
String outputName = signatureDef.getOutputsMap().keySet().iterator().next();
293293

294-
Map<String, Operand<?>> inputMap = new LinkedHashMap<>();
295-
inputMap.put(inputName, argument);
296-
297-
return call(scope, inputMap).get(outputName);
294+
return call(scope, Collections.singletonMap(inputName, argument)).get(outputName);
298295
}
299296

300297
@Override
@@ -395,8 +392,7 @@ static ConcreteFunction fromNativeHandle(
395392
private final NativeFunction nativeFunction;
396393
private final PointerScope scope;
397394
private final Set<TF_Function> dependencies;
398-
private final DataType[] inputDtypes;
399-
private final DataType[] outputDtypes;
395+
private final List<Class<? extends TType>> outputTypes;
400396

401397
/** All native functions should have deallocators registered */
402398
private ConcreteFunction(
@@ -423,10 +419,10 @@ private ConcreteFunction(
423419
+ this.signature.getOutputs().size());
424420
}
425421

426-
inputDtypes =
427-
this.signature.getInputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);
428-
429-
List<DataType> inputs = Arrays.asList(inputDtypes);
422+
List<DataType> inputs =
423+
this.signature.getInputs().values().stream()
424+
.map(x -> x.dataType)
425+
.collect(Collectors.toList());
430426
List<DataType> nativeInputs =
431427
nativeFunction.getFunctionDef().getSignature().getInputArgList().stream()
432428
.map(ArgDef::getType)
@@ -440,10 +436,10 @@ private ConcreteFunction(
440436
+ inputs);
441437
}
442438

443-
outputDtypes =
444-
signature().getOutputs().values().stream().map(x -> x.dataType).toArray(DataType[]::new);
445-
446-
List<DataType> outputs = Arrays.asList(outputDtypes);
439+
List<DataType> outputs =
440+
signature().getOutputs().values().stream()
441+
.map(x -> x.dataType)
442+
.collect(Collectors.toList());
447443
List<DataType> nativeOutputs =
448444
nativeFunction.getFunctionDef().getSignature().getOutputArgList().stream()
449445
.map(ArgDef::getType)
@@ -457,6 +453,9 @@ private ConcreteFunction(
457453
+ outputs);
458454
}
459455

456+
outputTypes =
457+
outputs.stream().map(x -> TensorTypeRegistry.find(x).type()).collect(Collectors.toList());
458+
460459
try (PointerScope scope = new PointerScope()) {
461460
this.scope = scope;
462461
scope.extend();
@@ -469,6 +468,8 @@ private ConcreteFunction(
469468
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
470469
* how to enable XLA JIT is extremely non-obvious.
471470
*
471+
* <p>See https://github.com/tensorflow/java/issues/347
472+
*
472473
* <p>Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
473474
* platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
474475
*/

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import static org.junit.jupiter.api.Assertions.assertNotNull;
2020

2121
import java.util.Arrays;
22+
import java.util.HashMap;
23+
import java.util.Map;
2224
import org.junit.jupiter.api.Test;
2325
import org.tensorflow.op.Ops;
2426
import org.tensorflow.op.core.Init;
@@ -27,6 +29,7 @@
2729
import org.tensorflow.op.math.Sub;
2830
import org.tensorflow.proto.framework.DataType;
2931
import org.tensorflow.types.TFloat32;
32+
import org.tensorflow.types.TInt32;
3033

3134
public class ConcreteFunctionTest {
3235

@@ -144,6 +147,22 @@ public void testNestedFunctionGraph() {
144147
}
145148
}
146149

150+
@Test
151+
public void testFunctionWithTwoOutputs() {
152+
ConcreteFunction cf = ConcreteFunction.create(tf -> {
153+
Placeholder<TInt32> x = tf.placeholder(TInt32.class);
154+
Operand<TInt32> dblX = tf.math.add(x, x);
155+
Operand<TInt32> tripX = tf.math.add(x, dblX);
156+
return Signature.builder().input("x", x).output("dbl", dblX).output("trpl", tripX).build();
157+
});
158+
159+
Map<String, Tensor> inputs = new HashMap<>();
160+
inputs.put("x", TInt32.scalarOf(2));
161+
Map<String, Tensor> outputs = cf.call(inputs);
162+
assertEquals(4, ((TInt32)outputs.get("dbl")).getInt());
163+
assertEquals(6, ((TInt32)outputs.get("trpl")).getInt());
164+
}
165+
147166
private static Signature square(Ops tf) {
148167
Placeholder<TFloat32> input = tf.placeholder(TFloat32.class);
149168
Operand<TFloat32> output = tf.math.square(input);

0 commit comments

Comments
 (0)