19
19
import static org .tensorflow .internal .c_api .global .tensorflow .TF_GraphToFunction ;
20
20
21
21
import java .util .ArrayList ;
22
- import java .util .Arrays ;
23
22
import java .util .Collection ;
24
23
import java .util .Collections ;
25
24
import java .util .HashSet ;
25
+ import java .util .Iterator ;
26
26
import java .util .LinkedHashMap ;
27
27
import java .util .List ;
28
28
import java .util .Map ;
@@ -218,7 +218,7 @@ public String toString() {
218
218
* @return the outputs of the function
219
219
*/
220
220
public Map <String , Operand <?>> call (Scope scope , Map <String , Operand <?>> arguments ) {
221
- List <Operand <?>> inputList = new ArrayList <>(signature .inputNames ().size ());
221
+ List <Operand <?>> inputList = new ArrayList <>(signature () .inputNames ().size ());
222
222
223
223
for (String inputName : signature ().inputNames ()) {
224
224
if (!arguments .containsKey (inputName )) {
@@ -241,31 +241,31 @@ public Map<String, Operand<?>> call(Scope scope, Map<String, Operand<?>> argumen
241
241
}
242
242
243
243
List <Output <?>> outputList =
244
- PartitionedCall .create (
245
- scope ,
246
- inputList ,
247
- Arrays .stream (outputDtypes )
248
- .map (x -> TensorTypeRegistry .find (x ).type ())
249
- .collect (Collectors .toList ()),
250
- this )
251
- .output ();
252
-
253
- Map <String , Operand <?>> namedOutputs = new LinkedHashMap <>(signature ().outputNames ().size ());
254
-
255
- List <String > outputNames = new ArrayList <>(signature ().outputNames ());
256
- for (int i = 0 ; i < outputNames .size (); i ++) {
257
- String outputName = outputNames .get (i );
258
-
259
- if (i > outputList .size ()) {
260
- throw new IllegalStateException (
261
- "Somehow, not all required outputs were returned from the function" );
244
+ PartitionedCall .create (scope , inputList , outputTypes , this ).output ();
245
+
246
+ if (signature ().outputNames ().size () == 0 ) {
247
+ return Collections .emptyMap ();
248
+ } else if (signature ().outputNames ().size () == 1 ) {
249
+ return Collections .singletonMap (
250
+ signature ().outputNames ().iterator ().next (), outputList .get (0 ));
251
+ } else {
252
+ Map <String , Operand <?>> namedOutputs = new LinkedHashMap <>(signature ().outputNames ().size ());
253
+
254
+ Iterator <String > outputNames = signature ().outputNames ().iterator ();
255
+ for (int i = 0 ; outputNames .hasNext (); i ++) {
256
+ String outputName = outputNames .next ();
257
+
258
+ if (i > outputList .size ()) {
259
+ throw new IllegalStateException (
260
+ "Somehow, not all required outputs were returned from the function" );
261
+ }
262
+
263
+ Operand <?> output = outputList .get (i );
264
+ namedOutputs .put (outputName , output );
262
265
}
263
266
264
- Operand <?> output = outputList .get (i );
265
- namedOutputs .put (outputName , output );
267
+ return Collections .unmodifiableMap (namedOutputs );
266
268
}
267
-
268
- return Collections .unmodifiableMap (namedOutputs );
269
269
}
270
270
271
271
/**
@@ -291,10 +291,7 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
291
291
}
292
292
String outputName = signatureDef .getOutputsMap ().keySet ().iterator ().next ();
293
293
294
- Map <String , Operand <?>> inputMap = new LinkedHashMap <>();
295
- inputMap .put (inputName , argument );
296
-
297
- return call (scope , inputMap ).get (outputName );
294
+ return call (scope , Collections .singletonMap (inputName , argument )).get (outputName );
298
295
}
299
296
300
297
@ Override
@@ -395,8 +392,7 @@ static ConcreteFunction fromNativeHandle(
395
392
private final NativeFunction nativeFunction ;
396
393
private final PointerScope scope ;
397
394
private final Set <TF_Function > dependencies ;
398
- private final DataType [] inputDtypes ;
399
- private final DataType [] outputDtypes ;
395
+ private final List <Class <? extends TType >> outputTypes ;
400
396
401
397
/** All native functions should have deallocators registered */
402
398
private ConcreteFunction (
@@ -423,10 +419,10 @@ private ConcreteFunction(
423
419
+ this .signature .getOutputs ().size ());
424
420
}
425
421
426
- inputDtypes =
427
- this .signature .getInputs ().values ().stream (). map ( x -> x . dataType ). toArray ( DataType []:: new );
428
-
429
- List < DataType > inputs = Arrays . asList ( inputDtypes );
422
+ List < DataType > inputs =
423
+ this .signature .getInputs ().values ().stream ()
424
+ . map ( x -> x . dataType )
425
+ . collect ( Collectors . toList () );
430
426
List <DataType > nativeInputs =
431
427
nativeFunction .getFunctionDef ().getSignature ().getInputArgList ().stream ()
432
428
.map (ArgDef ::getType )
@@ -440,10 +436,10 @@ private ConcreteFunction(
440
436
+ inputs );
441
437
}
442
438
443
- outputDtypes =
444
- signature ().getOutputs ().values ().stream (). map ( x -> x . dataType ). toArray ( DataType []:: new );
445
-
446
- List < DataType > outputs = Arrays . asList ( outputDtypes );
439
+ List < DataType > outputs =
440
+ signature ().getOutputs ().values ().stream ()
441
+ . map ( x -> x . dataType )
442
+ . collect ( Collectors . toList () );
447
443
List <DataType > nativeOutputs =
448
444
nativeFunction .getFunctionDef ().getSignature ().getOutputArgList ().stream ()
449
445
.map (ArgDef ::getType )
@@ -457,6 +453,9 @@ private ConcreteFunction(
457
453
+ outputs );
458
454
}
459
455
456
+ outputTypes =
457
+ outputs .stream ().map (x -> TensorTypeRegistry .find (x ).type ()).collect (Collectors .toList ());
458
+
460
459
try (PointerScope scope = new PointerScope ()) {
461
460
this .scope = scope ;
462
461
scope .extend ();
@@ -469,6 +468,8 @@ private ConcreteFunction(
469
468
* FIXME: This causes native errors when I use it (Linux GPU, 6.1 CC), but I'm leaving it because
470
469
* how to enable XLA JIT is extremely non-obvious.
471
470
*
471
+ * <p>See https://github.com/tensorflow/java/issues/347
472
+ *
472
473
* <p>Causes {@code OP_REQUIRES failed at xla_ops.cc:363 : Not found: could not find registered
473
474
* platform with id: 0x7f75af03e6e8} (it's a warning, but the resulting TF_Status fails).
474
475
*/
0 commit comments