diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java index 0ffd6c2205e..3d390d33406 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java @@ -39,6 +39,17 @@ public Output[] outputList(int idx, int length) { @Override public Output output(int idx) { + if (getUnsafeNativeHandle(idx) != null && !getUnsafeNativeHandle(idx).isNull()) { + int numOutputs = this.numOutputs(); + if (idx >= numOutputs) { + throw new IndexOutOfBoundsException( + "Can't get output with index " + idx + ", this op only has " + numOutputs + " outputs."); + } + + if (idx < 0) { + throw new IndexOutOfBoundsException("Can't get output with index < 0."); + } + } return new Output<>(this, idx); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 0f7a291466c..ff805c73b53 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -27,11 +27,16 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewWhile; import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Iterator; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerScope; @@ -69,12 +74,16 @@ */ public final class Graph implements ExecutionEnvironment, AutoCloseable { - /** Create an empty Graph. */ + /** + * Create an empty Graph. + */ public Graph() { nativeHandle = allocate(); } - /** Create a Graph from an existing handle (takes ownership). */ + /** + * Create a Graph from an existing handle (takes ownership). + */ Graph(TF_Graph nativeHandle) { this.nativeHandle = nativeHandle; } @@ -128,6 +137,73 @@ public GraphOperation operation(String name) { } } + /** + * Returns the operation (node in the Graph) with the provided name, or throws {@link IllegalArgumentException} if + * there isn't one. + * + * @param name name of the operation to look for + * @return operation in the graph with this name + * @throws IllegalArgumentException if no such operation exists in the Graph + */ + public GraphOperation operationOrThrow(String name) { + GraphOperation op = operation(name); + if (op == null) { + throw new IllegalArgumentException("No Operation named [" + name + "] in the Graph"); + } + return op; + } + + /** + * Returns the output with the provided name, or {@code null} if there is no such output. + *

Names should be of the + * format {@code /scope/op}, with an optional index: {@code /scope/op:1}. {@code 0} is used if the index is not + * specified. + * + * @param output the output to get + * @return the output with this name, or null if there isn't one + */ + @SuppressWarnings("rawtypes") + public Output output(String output) { + int colon = output.lastIndexOf(':'); + if (colon == -1 || colon == output.length() - 1) { + GraphOperation op = operation(output); + if (op == null) { + return null; + } + return new Output(op, 0); + } + try { + String op = output.substring(0, colon); + int index = Integer.parseInt(output.substring(colon + 1)); + GraphOperation operation = operation(op); + if (operation == null) { + return null; + } + return new Output(operation, index); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Could not get output for badly formatted output name: \"" + output + "\"", e); + } + } + + /** + * Returns the output with the provided name, or throws {@link IllegalArgumentException} if there isn't one. + *

Names should be of the + * format {@code /scope/op}, with an optional index: {@code /scope/op:1}. {@code 0} is used if the index is not + * specified. + * + * @param output the output to get + * @return the output with this name + * @throws IllegalArgumentException if no such output exists in the Graph + * @see #output(String) + */ + public Output outputOrThrow(String output) { + Output op = output(output); + if (op == null) { + throw new IllegalArgumentException("No Operation named [" + output + "] in the Graph"); + } + return op; + } + /** * Iterator over all the {@link Operation}s in the graph. * @@ -138,14 +214,157 @@ public Iterator operations() { return new OperationIterator(this); } + private GraphOperation graphOp(Operand operand) { + checkInput(operand); + return (GraphOperation) operand.op(); + } + + /** + * Finds the operations used to produce {@code outputs}, assuming {@code inputs} are provided. Includes control dependencies. + *

+ * Note that this function can easily return ops upstream of inputs as part of the body. Depending on your use, the + * returned body should probably be filtered for {@code Placeholder}s, at least. + * + * @param inputs the inputs of the subgraph. Must be from single output ops. May not be null. + * @param outputs the outputs of the subgraph. May not be null. + * @return the set of operations needed to calculate outputs from inputs, including outputs and inputs + */ + public synchronized Set completeSubgraph(Set> inputs, Set> outputs) { + + if (inputs == null) { + throw new IllegalArgumentException("Inputs can't be null."); + } + + if (outputs == null) { + throw new IllegalArgumentException("Outputs can't be null."); + } + + Queue currents = new ArrayDeque<>(outputs.size()); + Set seen = new LinkedHashSet<>(inputs.size()); + Set inputOps = new LinkedHashSet<>(inputs.size()); + + for (Operand input : inputs) { + if (input.op().numOutputs() > 1) { + throw new IllegalStateException("Only ops with one output are supported as subgraph inputs"); + } + GraphOperation op = graphOp(input); + inputOps.add(op); + seen.add(op); + } + + for (Operand operand : outputs) { + GraphOperation op = graphOp(operand); + currents.add(op); + } + + while (!currents.isEmpty()) { + GraphOperation op = currents.poll(); + + // skip if already present + if (!seen.add(op)) { + continue; + } + + for (GraphOperation control : op.controlInputs()) { + if (!inputOps.contains(control)) { + currents.add(control); + } + } + + for (Operand input : op.inputs()) { + GraphOperation inputOp = graphOp(input); + if (!inputOps.contains(inputOp)) { + currents.add(inputOp); + } + } + + } + + return seen; + } + + /** + * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code outputs}), including + * control dependencies. + * + * @param outputs the starting points of the traversal. + * @return the ops needed to calculate {@code outputs}, not including {@code outputs} + */ + public Set subgraphToOps(Set outputs) { + Set seen = new LinkedHashSet<>(outputs.size()); + Queue todo = new ArrayDeque<>(outputs); + while (!todo.isEmpty()) { + GraphOperation current = todo.poll(); + + if (seen.add(current)) { + todo.addAll(current.inputs().stream().map(this::graphOp).collect(Collectors.toSet())); + todo.addAll(current.controlInputs()); + } + } + seen.removeAll(outputs); + return seen; + } + + /** + * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code inputs}), including control + * dependencies. + * + * @param inputs the starting points of the traversal. + * @return the ops that depend on {@code inputs}, not including {@code inputs} + */ + public synchronized Set subgraphFromOps(Set inputs) { + Set seen = new LinkedHashSet<>(inputs.size()); + Queue todo = new ArrayDeque<>(inputs); + while (!todo.isEmpty()) { + GraphOperation current = todo.poll(); + + if (seen.add(current)) { + todo.addAll(current.consumers()); + todo.addAll(current.controlConsumers()); + } + } + seen.removeAll(inputs); + return seen; + } + + /** + * Get all ops directly or indirectly required to calculate {@code outputs} (not including {@code outputs}), including + * control dependencies. + * + * @param outputs the starting points of the traversal. + * @return the ops needed to calculate {@code outputs}, not including {@code outputs} + */ + public Set subgraphTo(Set> outputs) { + return subgraphToOps(outputs.stream().map(this::graphOp).collect(Collectors.toSet())); + } + + /** + * Get all ops that use one of {@code inputs} directly or indirectly (not including {@code inputs}), including control + * dependencies. + * + * @param inputs the starting points of the traversal. + * @return the ops that depend on {@code inputs}, not including {@code inputs} + */ + public synchronized Set subgraphFrom(Set> inputs) { + Set ops = new LinkedHashSet<>(); + for (Operand input : inputs) { + GraphOperation op = graphOp(input); + ops.addAll(op.consumers(input.asOutput().index())); + ops.addAll(op.controlConsumers()); + } + Set downstream = subgraphFromOps(ops); + downstream.addAll(ops); + return downstream; + } + /** * Returns a builder to add {@link Operation}s to the Graph. * * @param type of the Operation (i.e., identifies the computation to be performed) * @param name to refer to the created Operation in the graph. * @return an {@link OperationBuilder}, which will add the Operation to the graph when {@link - * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, - * then some resources may leak. + * OperationBuilder#build()} is invoked. If {@link OperationBuilder#build()} is not invoked, then some resources may + * leak. */ @Override public GraphOperationBuilder opBuilder(String type, String name) { @@ -216,6 +435,7 @@ public GraphDef toGraphDef() { /** * Adds an initializer to the graph initializer list. + * * @param initializer An initializer to add to the list. */ public synchronized void addInitializer(Op initializer) { @@ -230,12 +450,11 @@ public List initializers() { } /** - * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., - * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., {@code d(y_1 + y_2 + * + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} * *

{@code dx} are used as initial gradients (which represent the symbolic partial derivatives - * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of - * {@code y}. + * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. * *

If {@code dx} is null, the implementation will use dx of {@link * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}. @@ -245,8 +464,8 @@ public List initializers() { * *

If {@code prefix} is null, then one will be chosen automatically. * - * @param prefix unique string prefix applied before the names of nodes added to the graph to - * compute gradients. If null, a default one will be chosen. + * @param prefix unique string prefix applied before the names of nodes added to the graph to compute gradients. If + * null, a default one will be chosen. * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} @@ -263,11 +482,11 @@ public Output[] addGradients(String prefix, Output[] y, Output[] x, Out try (Reference ref = ref()) { for (int i = 0; i < y.length; ++i) { - yHandles[i] = (TF_Operation)y[i].getUnsafeNativeHandle(); + yHandles[i] = (TF_Operation) y[i].getUnsafeNativeHandle(); yIndices[i] = y[i].index(); } for (int i = 0; i < x.length; ++i) { - xHandles[i] = (TF_Operation)x[i].getUnsafeNativeHandle(); + xHandles[i] = (TF_Operation) x[i].getUnsafeNativeHandle(); xIndices[i] = x[i].index(); } if (dx != null && dx.length > 0) { @@ -275,7 +494,7 @@ public Output[] addGradients(String prefix, Output[] y, Output[] x, Out dxIndices = new int[dx.length]; for (int i = 0; i < dx.length; ++i) { - dxHandles[i] = (TF_Operation)dx[i].getUnsafeNativeHandle(); + dxHandles[i] = (TF_Operation) dx[i].getUnsafeNativeHandle(); dxIndices[i] = dx[i].index(); } } @@ -300,7 +519,7 @@ public Output[] addGradients(String prefix, Output[] y, Output[] x, Out + " were expected"); } for (int i = 0, j = ndy; i < ndy; ++i, ++j) { - GraphOperation op = new GraphOperation(this, (TF_Operation)dyHandlesAndIndices[i]); + GraphOperation op = new GraphOperation(this, (TF_Operation) dyHandlesAndIndices[i]); dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); } } @@ -308,24 +527,23 @@ public Output[] addGradients(String prefix, Output[] y, Output[] x, Out } /** - * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, - * i.e., {@code dy/dx_1, dy/dx_2...} + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., {@code dy/dx_1, + * dy/dx_2...} *

- * This is a simplified version of {@link #addGradients(String, Output[], Output[], Output[])} - * where {@code y} is a single output, {@code dx} is null and {@code prefix} is null. + * This is a simplified version of {@link #addGradients(String, Output[], Output[], Output[])} where {@code y} is a + * single output, {@code dx} is null and {@code prefix} is null. * * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @return the partial derivatives {@code dy} with the size of {@code x} */ public Output[] addGradients(Output y, Output[] x) { - return addGradients(null, new Output[] {y}, x, null); + return addGradients(null, new Output[]{y}, x, null); } /** - * Used to instantiate an abstract class which overrides the buildSubgraph method to build a - * conditional or body subgraph for a while loop. After Java 8, this can alternatively be used to - * create a lambda for the same purpose. + * Used to instantiate an abstract class which overrides the buildSubgraph method to build a conditional or body + * subgraph for a while loop. After Java 8, this can alternatively be used to create a lambda for the same purpose. * *

To be used when calling {@link #whileLoop(Output[], * org.tensorflow.Graph.WhileSubgraphBuilder, org.tensorflow.Graph.WhileSubgraphBuilder, String)} @@ -348,6 +566,7 @@ public Output[] addGradients(Output y, Output[] x) { * */ public interface WhileSubgraphBuilder { + /** * To be overridden by user with code to build conditional or body subgraph for a while loop * @@ -421,7 +640,7 @@ public Output[] whileLoop( try (Reference ref = ref()) { for (int i = 0; i < ninputs; i++) { - inputHandles[i] = (TF_Operation)inputs[i].getUnsafeNativeHandle(); + inputHandles[i] = (TF_Operation) inputs[i].getUnsafeNativeHandle(); inputIndices[i] = inputs[i].index(); } @@ -429,7 +648,7 @@ public Output[] whileLoop( whileLoop(nativeHandle, inputHandles, inputIndices, name, cgBuilder, bgBuilder); for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { - Operation op = new GraphOperation(this, (TF_Operation)outputHandlesAndIndices[i]); + Operation op = new GraphOperation(this, (TF_Operation) outputHandlesAndIndices[i]); outputs[i] = op.output((int) outputHandlesAndIndices[j]); } } @@ -438,15 +657,13 @@ public Output[] whileLoop( } /** - * Return the {@link SaverDef} instance used to save the state of all variables present in - * this graph. + * Return the {@link SaverDef} instance used to save the state of all variables present in this graph. * - *

The first time this method is called it builds the {@link SaverDef}. If this graph already - * contains a "save/restore_all" operation then it is assumed to contain all necessary saving and - * restoring operations. If that operation does not exist then the graph is mutated to add all - * the nodes necessary to save and restore the state of the graph. Consequently, any variables - * that are added to the graph after this call will not be saved nor restored using this - * {@link SaverDef}. + *

The first time this method is called it builds the {@link SaverDef}. If this graph already contains a + * "save/restore_all" operation then it is assumed to contain all necessary saving and restoring operations. If that + * operation does not exist then the graph is mutated to add all the nodes necessary to save and restore the state of + * the graph. Consequently, any variables that are added to the graph after this call will not be saved nor restored + * using this {@link SaverDef}. * * @return a {@link SaverDef} instance */ @@ -462,10 +679,10 @@ synchronized SaverDef saverDef() { // the python implementation for compatibility. // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py saverDef = SaverDef.newBuilder() - .setFilenameTensorName("save/filename") - .setSaveTensorName("save/control_dependency") - .setRestoreOpName("save/restore_all") - .build(); + .setFilenameTensorName("save/filename") + .setSaveTensorName("save/control_dependency") + .setRestoreOpName("save/restore_all") + .build(); } } return saverDef; @@ -485,6 +702,7 @@ synchronized SaverDef saverDef() { // Instances of the Reference class should be used to ensure the Graph has not been closed // while dependent handles are in use. class Reference implements AutoCloseable { + private Reference() { synchronized (Graph.this.nativeHandleLock) { active = Graph.this.nativeHandle != null && !Graph.this.nativeHandle.isNull(); @@ -539,9 +757,9 @@ private final void advance() { try { Object[] nativeReturn = nextOperation(reference.nativeHandle(), this.position); - if (nativeReturn != null && nativeReturn[0] != null && !((TF_Operation)nativeReturn[0]).isNull()) { - this.operation = new GraphOperation(this.graph, (TF_Operation)nativeReturn[0]); - this.position = (Integer)nativeReturn[1]; + if (nativeReturn != null && nativeReturn[0] != null && !((TF_Operation) nativeReturn[0]).isNull()) { + this.operation = new GraphOperation(this.graph, (TF_Operation) nativeReturn[0]); + this.position = (Integer) nativeReturn[1]; } } finally { reference.close(); @@ -571,11 +789,13 @@ public void remove() { } private static TF_Graph allocate() { - return TF_NewGraph(); + return TF_NewGraph(); } private static void delete(TF_Graph handle) { - if (handle == null || handle.isNull()) return; + if (handle == null || handle.isNull()) { + return; + } TF_DeleteGraph(handle); } @@ -598,11 +818,13 @@ private static Object[] nextOperation(TF_Graph handle, int position) { try (PointerScope scope = new PointerScope()) { SizeTPointer pos = new SizeTPointer(1).put(position); TF_Operation operation = TF_GraphNextOperation(handle, pos); - if (operation == null || operation.isNull()) return null; + if (operation == null || operation.isNull()) { + return null; + } Object[] handleAndPosition = new Object[2]; handleAndPosition[0] = operation; - handleAndPosition[1] = (int)pos.get(); + handleAndPosition[1] = (int) pos.get(); return handleAndPosition; } } @@ -642,12 +864,13 @@ private static GraphDef toGraphDef(TF_Graph handle) { } static void resolveOutputs(String type, TF_Operation[] srcOps, - int[] srcIndices, TF_Output dst, int n) { + int[] srcIndices, TF_Output dst, int n) { if (srcOps.length != n) { throw new IllegalArgumentException("expected " + n + ", got " + srcOps.length + " " + type + " Operations"); } if (srcIndices.length != n) { - throw new IllegalArgumentException("expected " + n + ", got " + srcIndices.length + " " + type + " Operation output indices"); + throw new IllegalArgumentException( + "expected " + n + ", got " + srcIndices.length + " " + type + " Operation output indices"); } for (int i = 0; i < n; ++i) { if (srcOps[i] == null || srcOps[i].isNull()) { @@ -731,16 +954,16 @@ private static Object[] whileLoop( TF_Operation[] condOutputHandles = new TF_Operation[1]; int[] condOutputIndices = new int[1]; for (int i = 0; i < ninputs; i++) { - condInputHandles[i] = condInputsOutput.position(i).oper(); - condInputIndices[i] = condInputsOutput.position(i).index(); + condInputHandles[i] = condInputsOutput.position(i).oper(); + condInputIndices[i] = condInputsOutput.position(i).index(); } condOutputHandles[0] = condOutputOutput.oper(); condOutputIndices[0] = condOutputOutput.index(); Object[] condOutputHandlesAndIndices = buildSubgraph(condGraphBuilder, params.cond_graph(), - condInputHandles, condInputIndices, - condOutputHandles, condOutputIndices); + condInputHandles, condInputIndices, + condOutputHandles, condOutputIndices); // build body subgraph TF_Output bodyInputsOutput = params.body_inputs(); @@ -750,29 +973,30 @@ private static Object[] whileLoop( TF_Operation[] bodyOutputHandles = new TF_Operation[ninputs]; int[] bodyOutputIndices = new int[ninputs]; for (int i = 0; i < ninputs; i++) { - bodyInputHandles[i] = bodyInputsOutput.position(i).oper(); - bodyInputIndices[i] = bodyInputsOutput.position(i).index(); - bodyOutputHandles[i] = bodyOutputsOutput.position(i).oper(); - bodyOutputIndices[i] = bodyOutputsOutput.position(i).index(); + bodyInputHandles[i] = bodyInputsOutput.position(i).oper(); + bodyInputIndices[i] = bodyInputsOutput.position(i).index(); + bodyOutputHandles[i] = bodyOutputsOutput.position(i).oper(); + bodyOutputIndices[i] = bodyOutputsOutput.position(i).index(); } Object[] bodyOutputHandlesAndIndices = buildSubgraph(bodyGraphBuilder, params.body_graph(), - bodyInputHandles, bodyInputIndices, - bodyOutputHandles, bodyOutputIndices); + bodyInputHandles, bodyInputIndices, + bodyOutputHandles, bodyOutputIndices); if (condOutputHandlesAndIndices == null || - bodyOutputHandlesAndIndices == null) + bodyOutputHandlesAndIndices == null) { return null; + } // set cond_output param to output of the conditional subgraph - condOutputOutput.oper((TF_Operation)condOutputHandlesAndIndices[0]) - .index((Integer)condOutputHandlesAndIndices[1]); + condOutputOutput.oper((TF_Operation) condOutputHandlesAndIndices[0]) + .index((Integer) condOutputHandlesAndIndices[1]); // set body_outputs param to outputs of the body subgraph for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) { - bodyOutputsOutput.position(i).oper((TF_Operation)bodyOutputHandlesAndIndices[i]) - .index((Integer)bodyOutputHandlesAndIndices[j]); + bodyOutputsOutput.position(i).oper((TF_Operation) bodyOutputHandlesAndIndices[i]) + .index((Integer) bodyOutputHandlesAndIndices[j]); } // set loop name param @@ -803,7 +1027,7 @@ private static SaverDef addVariableSaver(Graph graph) { List> varOutputs = new ArrayList<>(); List> varTypes = new ArrayList<>(); - for (Iterator iter = graph.operations(); iter.hasNext();) { + for (Iterator iter = graph.operations(); iter.hasNext(); ) { Operation op = iter.next(); if (op.type().equals("VariableV2")) { varNames.add(op.name()); @@ -824,8 +1048,8 @@ private static SaverDef addVariableSaver(Graph graph) { varSlices, varOutputs ); - Identity id = tf.withControlDependencies(Arrays.asList(saveFilename,saveVariables)) - .withName("control_dependency").identity(saveFilename); + Identity id = tf.withControlDependencies(Arrays.asList(saveFilename, saveVariables)) + .withName("control_dependency").identity(saveFilename); Restore restoreVariables = tf.train.restore( saveFilename, varNamesTensor, diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java index fbad92160a2..b97ef09a9e4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java @@ -17,16 +17,30 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphGetTensorNumDims; import static org.tensorflow.internal.c_api.global.tensorflow.TF_GraphGetTensorShape; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationAllInputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetControlInputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetControlOutputs; import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationInputListLength; import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationName; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationNumControlInputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationNumControlOutputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationNumInputs; import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationNumOutputs; import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationOpType; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationOutputConsumers; import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationOutputListLength; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationOutputNumConsumers; import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationOutputType; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; import org.tensorflow.internal.c_api.TF_Graph; +import org.tensorflow.internal.c_api.TF_Input; import org.tensorflow.internal.c_api.TF_Operation; import org.tensorflow.internal.c_api.TF_Output; import org.tensorflow.internal.c_api.TF_Status; @@ -170,6 +184,153 @@ Tensor tensor(int outputIdx) { throw new IllegalStateException("Graph tensors must be fetched by running a session"); } + /** + * Get the number of inputs to the op, not including control inputs. + */ + public int numInputs() { + return TF_OperationNumInputs(getUnsafeNativeHandle()); + } + + /** + * Get the op's inputs, not including control inputs. + */ + public List> inputs() { + try (PointerScope scope = new PointerScope()) { + int numInputs = numInputs(); + TF_Output handles = new TF_Output(numInputs); + + TF_OperationAllInputs(getUnsafeNativeHandle(), handles, numInputs); + + List> operands = new ArrayList<>(numInputs); + for (int i = 0; i < numInputs; i++) { + TF_Output atPos = handles.position(i); + TF_Operation op = atPos.oper(); + int index = atPos.index(); + String opName = TF_OperationName(op).getString(); + operands.add(graph.operation(opName).output(index)); + } + return operands; + } + } + + /** + * Get the number of ops that use this op's designated output as an input, not including control dependencies. + * + * @param index the output to look for usages of + */ + public int numConsumers(int index) { + try (PointerScope scope = new PointerScope()) { + TF_Output output = new TF_Output().oper(getUnsafeNativeHandle()).index(index); + return TF_OperationOutputNumConsumers(output); + } + } + + /** + * Get the ops that use this op's designated output as an input, not including control dependencies. + * + * @param index the output to look for usages of + */ + public Set consumers(int index) { + try (PointerScope scope = new PointerScope()) { + TF_Output output = new TF_Output().oper(getUnsafeNativeHandle()).index(index); + int numConsumers = numConsumers(index); + TF_Input handles = new TF_Input(numConsumers); + + TF_OperationOutputConsumers(output, handles, numConsumers); + + Set operands = new LinkedHashSet<>(numConsumers); + for (int i = 0; i < numConsumers; i++) { + TF_Input atPos = handles.position(i); + TF_Operation op = atPos.oper(); + String opName = TF_OperationName(op).getString(); + operands.add(graph.operation(opName)); + } + return operands; + } + } + + /** + * Get the number of ops that use any of this op's outputs as an input, not including control dependencies. + */ + public int numConsumers() { + int all = 0; + for (int i = 0; i < numOutputs(); i++) { + all += numConsumers(i); + } + return all; + } + + + /** + * Get the ops that use any of this op's outputs as an input, not including control dependencies. + */ + public Set consumers() { + Set all = new LinkedHashSet<>(); + for (int i = 0; i < numOutputs(); i++) { + all.addAll(consumers(i)); + } + return all; + } + + /** + * Get the number of control inputs for this op. + */ + public int numControlInputs() { + try (PointerScope scope = new PointerScope()) { + return TF_OperationNumControlInputs(getUnsafeNativeHandle()); + } + } + + /** + * Get the control inputs of this op. + */ + public Set controlInputs() { + try (PointerScope scope = new PointerScope()) { + int numInputs = numControlInputs(); + PointerPointer handles = new PointerPointer<>(numInputs); + + TF_OperationGetControlInputs(getUnsafeNativeHandle(), handles, numInputs); + + Set operands = new LinkedHashSet<>(numInputs); + for (int i = 0; i < numInputs; i++) { + TF_Operation op = handles.get(TF_Operation.class, i); + String opName = TF_OperationName(op).getString(); + operands.add(graph.operation(opName)); + } + return operands; + } + } + + /** + * Get the number of ops with this op as a control dependency. + */ + public int numControlConsumers() { + try (PointerScope scope = new PointerScope()) { + return TF_OperationNumControlOutputs(getUnsafeNativeHandle()); + } + } + + /** + * Get the ops with this op as a control dependency. + */ + public Set controlConsumers() { + try (PointerScope scope = new PointerScope()) { + int numConsumers = numControlConsumers(); + PointerPointer handles = new PointerPointer<>(numConsumers); + + TF_OperationGetControlOutputs(getUnsafeNativeHandle(), handles, numConsumers); + + Set operands = new LinkedHashSet<>(numConsumers); + for (int i = 0; i < numConsumers; i++) { + TF_Operation op = handles.get(TF_Operation.class, i); + String opName = TF_OperationName(op).getString(); + operands.add(graph.operation(opName)); + } + return operands; + } + } + + TF_Operation getUnsafeNativeHandle() { return unsafeNativeHandle; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index b67f4a611e6..8c6a28ba4c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -15,7 +15,16 @@ package org.tensorflow; +import static org.tensorflow.Graph.resolveOutputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; + import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; +import java.util.List; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -33,15 +42,9 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; - -import java.util.ArrayList; -import java.util.List; import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; -import static org.tensorflow.Graph.resolveOutputs; -import static org.tensorflow.internal.c_api.global.tensorflow.*; - /** * Driver for {@link Graph} execution. * @@ -84,11 +87,9 @@ public Session(Graph g) { * Construct a new session with the associated {@link Graph} and configuration options. * * @param g The {@link Graph} the created Session will operate on. - * @param config Configuration parameters for the session specified as a ConfigProto - * protocol buffer. - * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto - * protocol buffer. + * @param config Configuration parameters for the session specified as a ConfigProto + * protocol buffer. + * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto protocol buffer. */ public Session(Graph g, ConfigProto config) { graph = g; @@ -101,7 +102,9 @@ public Session(Graph g, ConfigProto config) { } } - /** Wrap an existing session with the associated {@link Graph}. */ + /** + * Wrap an existing session with the associated {@link Graph}. + */ Session(Graph g, TF_Session nativeHandle) { graph = g; this.nativeHandle = nativeHandle; @@ -139,32 +142,31 @@ public void close() { * Run {@link Operation}s and evaluate {@link Tensor Tensors}. * *

A Runner runs the necessary graph fragments to execute every {@link Operation} required to - * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String,int,Tensor)} call allows - * callers to override the value of {@link Tensor Tensors} in the graph by substituting the - * provided {@link Tensor Tensors} for the outputs of the operations provided to {@link - * #feed(String,int,Tensor)}. + * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String, int, Tensor)} call allows callers to + * override the value of {@link Tensor Tensors} in the graph by substituting the provided {@link Tensor Tensors} for + * the outputs of the operations provided to {@link #feed(String, int, Tensor)}. */ public final class Runner { /** * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces. * - * @param operation Is either the string name of the operation, in which case this method is a - * shorthand for {@code feed(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * feed(operation_name, output_index)}. These colon-separated names are commonly used in the - * {@code SignatureDef} protocol buffer messages that are included in {@link - * SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code + * feed(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * feed(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} + * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. * @param t the tensor substituting the operation * @return this session runner + * @throws IllegalArgumentException if no output exists with the provided name */ public Runner feed(String operation, Tensor t) { - return feed(parseOutput(operation), t); + return feed(graph.outputOrThrow(operation), t); } /** - * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} - * for the value it produces. + * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} for the value it + * produces. * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one {@code t} is being provided for. @@ -172,19 +174,18 @@ public Runner feed(String operation, Tensor t) { * @param operation the string name of the operation * @param t the tensor substituting the operation * @return this session runner + * @throws IllegalArgumentException if no operation exists with the provided name + * @throws IndexOutOfBoundsException if the operation has no output with the given index */ public Runner feed(String operation, int index, Tensor t) { - Operation op = operationByName(operation); - if (op != null) { - inputs.add(op.output(index)); - inputTensors.add(t); - } + Operation op = graph.operationOrThrow(operation); + inputs.add(op.output(index)); + inputTensors.add(t); return this; } /** - * Use {@code t} instead of the Tensor referred to by executing the operation referred to by - * {@code operand}. + * Use {@code t} instead of the Tensor referred to by executing the operation referred to by {@code operand}. * * @param operand the node in the graph representing the operation to substitute * @param t the tensor substituting the operation @@ -199,16 +200,16 @@ public Runner feed(Operand operand, Tensor t) { /** * Make {@link #run()} return the output of {@code operation}. * - * @param operation Is either the string name of the operation, in which case this method is a - * shorthand for {@code fetch(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * fetch(operation_name, output_index)}. These colon-separated names are commonly used in - * the {@code SignatureDef} protocol buffer messages that are included in {@link - * SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code + * fetch(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * fetch(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} + * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. * @return this session runner + * @throws IllegalArgumentException if no output exists with the provided name */ public Runner fetch(String operation) { - return fetch(parseOutput(operation)); + return fetch(graph.outputOrThrow(operation)); } /** @@ -219,12 +220,12 @@ public Runner fetch(String operation) { * * @param operation the string name of the operation * @return this session runner + * @throws IllegalArgumentException if no operation exists with the provided name + * @throws IndexOutOfBoundsException if the operation has no output with the given index */ public Runner fetch(String operation, int index) { - Operation op = operationByName(operation); - if (op != null) { - outputs.add(op.output(index)); - } + Operation op = graph.operationOrThrow(operation); + outputs.add(op.output(index)); return this; } @@ -250,23 +251,20 @@ public Runner fetch(Operand operand) { } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor - * Tensors}. + * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. * * @param operation the string name of the operation to execute * @return this session runner + * @throws IllegalArgumentException if no operation exists with the provided name */ public Runner addTarget(String operation) { - GraphOperation op = operationByName(operation); - if (op != null) { - targets.add(op); - } + GraphOperation op = graph.operationOrThrow(operation); + targets.add(op); return this; } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor - * Tensors}. + * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. * * @param operation the operation to execute * @return this session runner @@ -297,8 +295,7 @@ public Runner addTarget(Op op) { * Set options (typically for debugging) for this run. * *

The options are presented as a RunOptions - * protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions protocol buffer. * * @param options a {@code RunOptions} proto * @return this session runner @@ -312,13 +309,11 @@ public Runner setOptions(RunOptions options) { * Execute the graph fragments necessary to compute all requested fetches. * *

WARNING: The caller assumes ownership of all returned {@link Tensor Tensors}, i.e., - * the caller must call {@link Tensor#close} on all elements of the returned list to free up - * resources. + * the caller must call {@link Tensor#close} on all elements of the returned list to free up resources. * *

TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it - * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in - * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a - * {@code Map}? + * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in SessionTest.java), and + * (b) Evaluate whether the return value should be a list, or maybe a {@code Map}? * *

TODO(andrewmyers): It would also be good if whatever is returned here made it easier to * extract output tensors in a type-safe way. @@ -333,8 +328,7 @@ public List run() { * Execute graph fragments to compute requested fetches and return metadata about the run. * *

This is exactly like {@link #run()}, but in addition to the requested Tensors, also - * returns metadata about the graph execution in the form of a RunMetadata + * returns metadata about the graph execution in the form of a RunMetadata * protocol buffer. * * @return list of resulting tensors fetched by this session runner, with execution metadata @@ -405,6 +399,7 @@ private Run runHelper(boolean wantMetadata) { } private class Reference implements AutoCloseable { + public Reference() { synchronized (nativeHandleLock) { if (nativeHandle == null || nativeHandle.isNull()) { @@ -427,29 +422,6 @@ public void close() { } } - private GraphOperation operationByName(String opName) { - GraphOperation op = graph.operation(opName); - if (op == null) { - throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph"); - } - return op; - } - - @SuppressWarnings("rawtypes") - private Output parseOutput(String opName) { - int colon = opName.lastIndexOf(':'); - if (colon == -1 || colon == opName.length() - 1) { - return new Output(operationByName(opName), 0); - } - try { - String op = opName.substring(0, colon); - int index = Integer.parseInt(opName.substring(colon + 1)); - return new Output(operationByName(op), index); - } catch (NumberFormatException e) { - return new Output(operationByName(opName), 0); - } - } - private final ArrayList> inputs = new ArrayList<>(); private final ArrayList inputTensors = new ArrayList<>(); private final ArrayList> outputs = new ArrayList<>(); @@ -457,7 +429,9 @@ private Output parseOutput(String opName) { private RunOptions runOptions = null; } - /** Create a Runner to execute graph operations and evaluate Tensors. */ + /** + * Create a Runner to execute graph operations and evaluate Tensors. + */ public Runner runner() { return new Runner(); } @@ -471,12 +445,7 @@ public Runner runner() { * @throws IllegalArgumentException if no operation of that name can be found in the graph */ public void run(String opName) { - Operation operation = graph.operation(opName); - if (operation == null) { - throw new IllegalArgumentException( - "Operation named '" + opName + "' cannot be found in the graph"); - } - runner().addTarget(operation).run(); + runner().addTarget(opName).run(); } /** @@ -495,9 +464,8 @@ public void run(Op op) { * Execute the graph's initializers. * *

This method is equivalent to {@code session.run(Ops.create(session.graph).init())}. - * */ - public void runInit(){ + public void runInit() { Runner runner = runner(); graph.initializers().forEach(runner::addTarget); runner.run(); @@ -519,8 +487,8 @@ public void runInit(){ public void save(String prefix) { SaverDef saverDef = graph.saverDef(); runner().addTarget(saverDef.getSaveTensorName()) - .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) - .run(); + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) + .run(); } /** @@ -539,8 +507,8 @@ public void save(String prefix) { public void restore(String prefix) { SaverDef saverDef = graph.saverDef(); runner().addTarget(saverDef.getRestoreOpName()) - .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) - .run(); + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) + .run(); } /** @@ -549,15 +517,17 @@ public void restore(String prefix) { *

See {@link Runner#runAndFetchMetadata()} */ public static final class Run { - /** Tensors from requested fetches. */ + + /** + * Tensors from requested fetches. + */ public List outputs; /** * Metadata about the run. * *

A RunMetadata - * protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata protocol buffer. */ public RunMetadata metadata; } @@ -633,20 +603,19 @@ private static void delete(TF_Session handle) { * @param runOptions A RunOptions protocol buffer, or null * @param inputOpHandles (see inputOpIndices) * @param inputOpIndices (see inputTensorHandles) - * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values - * that are being "fed" (do not need to be computed) during graph execution. - * inputTensorHandles[i] (which corresponds to a Tensor.nativeHandle) is considered to be the - * inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that - * inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. + * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed" + * (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a + * Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, + * it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. * @param outputOpHandles (see outputOpIndices) - * @param outputOpIndices together with outputOpHandles identifies the set of values that should - * be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is - * required that outputOpHandles.length == outputOpIndices.length. - * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose - * output will not be returned + * @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The + * outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length == + * outputOpIndices.length. + * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose output will not be + * returned * @param wantRunMetadata indicates whether metadata about this execution should be returned. - * @param outputTensors will be filled in with tensors to the outputs requested. It is required - * that outputs.length == outputOpHandles.length. + * @param outputTensors will be filled in with tensors to the outputs requested. It is required that outputs.length == + * outputOpHandles.length. * @return if wantRunMetadata is true, a RunMetadata protocol buffer, false otherwise. */ private static RunMetadata run( diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 497ee5f2d46..0e3d30acc54 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -18,6 +18,7 @@ import java.nio.charset.Charset; import org.tensorflow.Operand; import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; import org.tensorflow.Output; import org.tensorflow.Tensor; import org.tensorflow.ndarray.BooleanNdArray; @@ -1355,13 +1356,15 @@ public static Constant tensorOfSameType(Scope scope, Oper */ @Endpoint(name = "constantOf") public static Constant create(Scope scope, T tensor) { - return new Constant<>( - scope - .env() - .opBuilder("Const", scope.makeOpName("Const")) - .setAttr("value", tensor) - .setAttr("dtype", tensor.dataType()) - .build()); + OperationBuilder builder = scope + .env() + .opBuilder(OP_NAME, scope.makeOpName(OP_NAME)) + .setAttr("value", tensor) + .setAttr("dtype", tensor.dataType()); + + scope.apply(builder); + + return new Constant<>(builder.build()); } @Override @@ -1375,4 +1378,6 @@ private Constant(Operation operation) { } private final Output output; + + public static final String OP_NAME = "Const"; } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java index b164c129745..ac9694adabd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java @@ -22,11 +22,14 @@ import static org.junit.jupiter.api.Assertions.fail; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Set; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; /** Unit tests for {@link org.tensorflow.GraphOperation}. */ @@ -189,4 +192,68 @@ public void outputTensorNotSupported() { } } } + + @Test + public void inputs() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + Operand a = tf.constant(1f); + Operand b = tf.constant(2f); + Operand c = tf.math.add(a, b); + + GraphOperation op = (GraphOperation) c.op(); + + assertEquals(2, op.numInputs()); + assertEquals(Arrays.asList(a.asOutput(), b.asOutput()), op.inputs()); + } + } + + @Test + public void consumers() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + Operand a = tf.constant(1f); + Operand b = tf.constant(2f); + Operand c = tf.math.add(a, b); + + GraphOperation op = (GraphOperation) a.op(); + + assertEquals(1, op.numConsumers()); + assertEquals(new LinkedHashSet<>(Collections.singletonList(c.op())), op.consumers()); + } + } + + @Test + public void controlInputs() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + Operand a = tf.constant(1f); + Operand b = tf.constant(2f); + Operand c = tf.withControlDependencies(Arrays.asList(a, b)).constant(3f); + + GraphOperation op = (GraphOperation) c.op(); + + assertEquals(2, op.numControlInputs()); + assertEquals(new LinkedHashSet<>(Arrays.asList(a.op(), b.op())), op.controlInputs()); + } + } + + @Test + public void controlConsumers() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + Operand a = tf.constant(1f); + Operand b = tf.constant(2f); + Operand c = tf.withControlDependencies(Arrays.asList(a, b)).constant(3f); + + GraphOperation op = (GraphOperation) a.op(); + + assertEquals(1, op.numControlConsumers()); + assertEquals(new LinkedHashSet<>(Collections.singletonList(c.op())), op.controlConsumers()); + } + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index d8ffc1a475b..dddd5867d33 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -19,13 +19,18 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.Set; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Constant; import org.tensorflow.op.linalg.MatMul; import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.GraphDef; @@ -113,6 +118,46 @@ public void iterateOverOperations() { } } + @Test + public void completeSubgraph() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Operand control = tf.constant(0); + Operand a = tf.withControlDependencies(Collections.singletonList(control)).constant(1); + Operand b = tf.constant(2); + Operand c = tf.constant(3); + + Operand d = tf.math.add(a, b); + Operand output = tf.math.mul(d, c); + + Set subgraph = g + .completeSubgraph(new LinkedHashSet<>(Arrays.asList(control, a, b, c)), Collections.singleton(output)); + + assertEquals(new LinkedHashSet<>(Arrays.asList(control.op(), a.op(), b.op(), c.op(), d.op(), output.op())), + subgraph); + } + } + + @Test + public void completeSubgraphWithConstants() { + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + Operand control = tf.constant(0); + Operand a = tf.withControlDependencies(Collections.singletonList(control)).constant(1); + Operand b = tf.constant(2); + Operand c = tf.constant(3); + + Operand d = tf.math.add(a, b); + Operand output = tf.math.mul(d, c); + + Set subgraph = g + .completeSubgraph(Collections.emptySet(), Collections.singleton(output)); + + assertEquals(new LinkedHashSet<>(Arrays.asList(control.op(), a.op(), b.op(), c.op(), d.op(), output.op())), + subgraph); + } + } + @Test public void failImportOnInvalidGraphDefs() { try (Graph g = new Graph()) {