Skip to content

Commit ff37a27

Browse files
committed
Change completeSubgraph signature, leave filtering to users
Signed-off-by: Ryan Nett <[email protected]>
1 parent f5c52dd commit ff37a27

File tree

2 files changed

+7
-99
lines changed

2 files changed

+7
-99
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

+5-53
Original file line numberDiff line numberDiff line change
@@ -220,52 +220,16 @@ private GraphOperation graphOp(Operand<?> operand) {
220220
}
221221

222222
/**
223-
* Finds the operations used to produce {@code outputs} from {@code inputs}, or throws if that is not possible.
224-
* Includes control dependencies.
225-
*
226-
* @param inputs the inputs of the subgraph. Must be from single output ops. May not be null.
227-
* @param outputs the outputs of the subgraph. May not be null.
228-
* @param allowNoInputBodyOps whether to allow 0-input ops in the body. For more specificy use {@link
229-
* #completeSubgraph(Set, Set, Set, Set)}.
230-
* @return the set of operations needed to calculate outputs from inputs, including outputs and inputs
231-
* @throws IllegalStateException if outputs depends on ops outside of the subgraph (i.e. is not calculable based
232-
* solely on inputs)
233-
* @see #completeSubgraph(Set, Set, Set, Set)
234-
*/
235-
public Set<GraphOperation> completeSubgraph(Set<Operand<?>> inputs, Set<Operand<?>> outputs,
236-
boolean allowNoInputBodyOps) {
237-
return completeSubgraph(inputs, outputs, null, allowNoInputBodyOps ? Collections.emptySet() : null);
238-
}
239-
240-
/**
241-
* Finds the operations used to produce {@code outputs} from {@code inputs}, or throws if that is not possible.
242-
* Includes control dependencies.
243-
*
244-
* If both {@code allowedNoInputBodyOps} and {@code forbiddenNoInputBodyOps} are {@code null}, forbids 0-input ops in
245-
* the body. To allow all ops in the body, use {@code null} for {@code allowedNoInputBodyOps} and the empty set for
246-
* {@code forbiddenNoInputBodyOps}.
223+
* Finds the operations used to produce {@code outputs}, assuming {@code inputs} are provided. Includes control dependencies.
224+
* <p>
225+
* Note that this function can easily return ops upstream of inputs as part of the body. Depending on your use, the
226+
* returned body should probably be filtered for {@code Placeholder}s, at least.
247227
*
248228
* @param inputs the inputs of the subgraph. Must be from single output ops. May not be null.
249229
* @param outputs the outputs of the subgraph. May not be null.
250-
* @param allowedNoInputBodyOps types of ops to allow as 0-input ops in the body. Allows all (except {@code
251-
* forbiddenNoInputBodyOps}) if null.
252-
* @param forbiddenNoInputBodyOps types of ops to never allow as 0-input ops in the body. Forbids all (except {@code
253-
* allowedNoInputBodyOps}) if null.
254230
* @return the set of operations needed to calculate outputs from inputs, including outputs and inputs
255-
* @throws IllegalStateException if outputs depends on ops outside of the subgraph (i.e. is not calculable based
256-
* solely on inputs)
257-
* @see #completeSubgraph(Set, Set, boolean)
258231
*/
259-
public synchronized Set<GraphOperation> completeSubgraph(Set<Operand<?>> inputs, Set<Operand<?>> outputs,
260-
Set<String> allowedNoInputBodyOps, Set<String> forbiddenNoInputBodyOps) {
261-
262-
if (forbiddenNoInputBodyOps != null && allowedNoInputBodyOps != null) {
263-
for (String t : forbiddenNoInputBodyOps) {
264-
if (allowedNoInputBodyOps.contains(t)) {
265-
throw new IllegalArgumentException("Can't allow and forbid op type " + t + ".");
266-
}
267-
}
268-
}
232+
public synchronized Set<GraphOperation> completeSubgraph(Set<Operand<?>> inputs, Set<Operand<?>> outputs) {
269233

270234
if (inputs == null) {
271235
throw new IllegalArgumentException("Inputs can't be null.");
@@ -301,18 +265,6 @@ public synchronized Set<GraphOperation> completeSubgraph(Set<Operand<?>> inputs,
301265
continue;
302266
}
303267

304-
if (op.numControlInputs() + op.numInputs() == 0) {
305-
// inverted: (nothing is forbidden || not forbidden) and (everything is allowed || allowed) and (not both null)
306-
if ((forbiddenNoInputBodyOps != null && forbiddenNoInputBodyOps.contains(op.type()))
307-
|| (allowedNoInputBodyOps != null && !allowedNoInputBodyOps.contains(op.type()))
308-
|| (forbiddenNoInputBodyOps == null && allowedNoInputBodyOps == null)) {
309-
throw new IllegalStateException("Operation " + op
310-
+ " of type " + op.type() +
311-
" has no inputs and is not an allowed 0-input op type, but is not set as an input. "
312-
+ "It is impossible to calculate the specified outputs with the given inputs.");
313-
}
314-
}
315-
316268
for (GraphOperation control : op.controlInputs()) {
317269
if (!inputOps.contains(control)) {
318270
currents.add(control);

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java

+2-46
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ public void completeSubgraph() {
131131
Operand<TInt32> output = tf.math.mul(d, c);
132132

133133
Set<GraphOperation> subgraph = g
134-
.completeSubgraph(new LinkedHashSet<>(Arrays.asList(control, a, b, c)), Collections.singleton(output), null,
135-
null);
134+
.completeSubgraph(new LinkedHashSet<>(Arrays.asList(control, a, b, c)), Collections.singleton(output));
136135

137136
assertEquals(new LinkedHashSet<>(Arrays.asList(control.op(), a.op(), b.op(), c.op(), d.op(), output.op())),
138137
subgraph);
@@ -152,56 +151,13 @@ public void completeSubgraphWithConstants() {
152151
Operand<TInt32> output = tf.math.mul(d, c);
153152

154153
Set<GraphOperation> subgraph = g
155-
.completeSubgraph(Collections.emptySet(), Collections.singleton(output), Collections.singleton(
156-
Constant.OP_NAME), null);
154+
.completeSubgraph(Collections.emptySet(), Collections.singleton(output));
157155

158156
assertEquals(new LinkedHashSet<>(Arrays.asList(control.op(), a.op(), b.op(), c.op(), d.op(), output.op())),
159157
subgraph);
160158
}
161159
}
162160

163-
@Test
164-
public void completeSubgraphMissingInput() {
165-
try (Graph g = new Graph()) {
166-
Ops tf = Ops.create(g);
167-
Operand<TInt32> control = tf.constant(0);
168-
Operand<TInt32> a = tf.withControlDependencies(Collections.singletonList(control)).constant(1);
169-
Operand<TInt32> b = tf.constant(2);
170-
Operand<TInt32> c = tf.constant(3);
171-
172-
Operand<TInt32> d = tf.math.add(a, b);
173-
Operand<TInt32> output = tf.math.mul(d, c);
174-
175-
try {
176-
g.completeSubgraph(new LinkedHashSet<>(Arrays.asList(control, b)), Collections.singleton(output), null, null);
177-
fail();
178-
} catch (IllegalStateException e) {
179-
assertTrue(e.getMessage().contains("is not set as an input"));
180-
}
181-
}
182-
}
183-
184-
@Test
185-
public void completeSubgraphMissingControlInput() {
186-
try (Graph g = new Graph()) {
187-
Ops tf = Ops.create(g);
188-
Operand<TInt32> control = tf.constant(0);
189-
Operand<TInt32> a = tf.withControlDependencies(Collections.singletonList(control)).constant(1);
190-
Operand<TInt32> b = tf.constant(2);
191-
Operand<TInt32> c = tf.constant(3);
192-
193-
Operand<TInt32> d = tf.math.add(a, b);
194-
Operand<TInt32> output = tf.math.mul(d, c);
195-
196-
try {
197-
g.completeSubgraph(new LinkedHashSet<>(Arrays.asList(a, b)), Collections.singleton(output), null, null);
198-
fail();
199-
} catch (IllegalStateException e) {
200-
assertTrue(e.getMessage().contains("is not set as an input"));
201-
}
202-
}
203-
}
204-
205161
@Test
206162
public void failImportOnInvalidGraphDefs() {
207163
try (Graph g = new Graph()) {

0 commit comments

Comments
 (0)