Skip to content

Commit 7f5fcf1

Browse files
committed
Small fixes
Signed-off-by: Ryan Nett <[email protected]>
1 parent 978848d commit 7f5fcf1

File tree

2 files changed

+8
-15
lines changed

2 files changed

+8
-15
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

+7-12
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import java.util.Collections;
3434
import java.util.Iterator;
3535
import java.util.LinkedHashSet;
36-
import java.util.LinkedList;
3736
import java.util.List;
3837
import java.util.Queue;
3938
import java.util.Set;
@@ -182,11 +181,7 @@ public Output<?> output(String output) {
182181
}
183182
return new Output(operation, index);
184183
} catch (NumberFormatException e) {
185-
GraphOperation op = operation(output);
186-
if (op == null) {
187-
return null;
188-
}
189-
return new Output(op, 0);
184+
throw new IllegalArgumentException("Could not get output for badly formatted output name: \"" + output + "\"", e);
190185
}
191186
}
192187

@@ -343,7 +338,7 @@ public synchronized Set<GraphOperation> completeSubgraph(Set<Operand<?>> inputs,
343338
* @param outputs the starting points of the traversal.
344339
* @return the ops needed to calculate {@code outputs}, not including {@code outputs}
345340
*/
346-
public Set<GraphOperation> upstreamOps(Set<GraphOperation> outputs) {
341+
public Set<GraphOperation> subgraphToOps(Set<GraphOperation> outputs) {
347342
Set<GraphOperation> seen = new LinkedHashSet<>(outputs.size());
348343
Queue<GraphOperation> todo = new ArrayDeque<>(outputs);
349344
while (!todo.isEmpty()) {
@@ -365,7 +360,7 @@ public Set<GraphOperation> upstreamOps(Set<GraphOperation> outputs) {
365360
* @param inputs the starting points of the traversal.
366361
* @return the ops that depend on {@code inputs}, not including {@code inputs}
367362
*/
368-
public synchronized Set<GraphOperation> downstreamOps(Set<GraphOperation> inputs) {
363+
public synchronized Set<GraphOperation> subgraphFromOps(Set<GraphOperation> inputs) {
369364
Set<GraphOperation> seen = new LinkedHashSet<>(inputs.size());
370365
Queue<GraphOperation> todo = new ArrayDeque<>(inputs);
371366
while (!todo.isEmpty()) {
@@ -387,8 +382,8 @@ public synchronized Set<GraphOperation> downstreamOps(Set<GraphOperation> inputs
387382
* @param outputs the starting points of the traversal.
388383
* @return the ops needed to calculate {@code outputs}, not including {@code outputs}
389384
*/
390-
public Set<GraphOperation> upstream(Set<Operand<?>> outputs) {
391-
return upstreamOps(outputs.stream().map(this::graphOp).collect(Collectors.toSet()));
385+
public Set<GraphOperation> subgraphTo(Set<Operand<?>> outputs) {
386+
return subgraphToOps(outputs.stream().map(this::graphOp).collect(Collectors.toSet()));
392387
}
393388

394389
/**
@@ -398,14 +393,14 @@ public Set<GraphOperation> upstream(Set<Operand<?>> outputs) {
398393
* @param inputs the starting points of the traversal.
399394
* @return the ops that depend on {@code inputs}, not including {@code inputs}
400395
*/
401-
public synchronized Set<GraphOperation> downstream(Set<Operand<?>> inputs) {
396+
public synchronized Set<GraphOperation> subgraphFrom(Set<Operand<?>> inputs) {
402397
Set<GraphOperation> ops = new LinkedHashSet<>();
403398
for (Operand<?> input : inputs) {
404399
GraphOperation op = graphOp(input);
405400
ops.addAll(op.consumers(input.asOutput().index()));
406401
ops.addAll(op.controlConsumers());
407402
}
408-
Set<GraphOperation> downstream = downstreamOps(ops);
403+
Set<GraphOperation> downstream = subgraphFromOps(ops);
409404
downstream.addAll(ops);
410405
return downstream;
411406
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,7 @@ Tensor tensor(int outputIdx) {
188188
* Get the number of inputs to the op, not including control inputs.
189189
*/
190190
public int numInputs() {
191-
try (PointerScope scope = new PointerScope()) {
192-
return TF_OperationNumInputs(getUnsafeNativeHandle());
193-
}
191+
return TF_OperationNumInputs(getUnsafeNativeHandle());
194192
}
195193

196194
/**

0 commit comments

Comments
 (0)