Skip to content

Commit b90db9d

Browse files
committed
Add ignored gradient test
Signed-off-by: Ryan Nett <[email protected]>
1 parent d5bd64e commit b90db9d

File tree

3 files changed

+56
-13
lines changed

3 files changed

+56
-13
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

-9
Original file line numberDiff line numberDiff line change
@@ -406,15 +406,6 @@ TF_Function nativeHandle() {
406406
return nativeFunction.getNativeHandle();
407407
}
408408

409-
/**
410-
* Get the native handle of the function's gradient, so that it can be attached to a Graph. Not implemented yet.
411-
*
412-
* TODO implement
413-
*/
414-
TF_Function gradNativeHandle() {
415-
return null;
416-
}
417-
418409
private final Signature signature;
419410
private final NativeFunction nativeFunction;
420411
private final PointerScope scope;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,14 @@ public GraphOperationBuilder opBuilder(String type, String name) {
331331
public void attachFunction(ConcreteFunction function) {
332332
try (Reference ref = ref();
333333
PointerScope scope = new PointerScope()) {
334-
attachNativeFunction(ref.nativeHandle(), function.nativeHandle(), function.gradNativeHandle());
335-
function.getDependencies().forEach(x -> attachNativeFunction(ref.nativeHandle(), x, null));
334+
attachNativeFunction(ref.nativeHandle(), function.nativeHandle());
335+
function.getDependencies().forEach(x -> attachNativeFunction(ref.nativeHandle(), x));
336336
}
337337
}
338338

339-
private void attachNativeFunction(TF_Graph graph, TF_Function fn, TF_Function grad) {
339+
private void attachNativeFunction(TF_Graph graph, TF_Function fn) {
340340
TF_Status status = TF_Status.newStatus();
341-
TF_GraphCopyFunction(graph, fn, grad, status);
341+
TF_GraphCopyFunction(graph, fn, null, status);
342342
status.throwExceptionIfNotOK();
343343
}
344344

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java

+52
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
import static org.junit.jupiter.api.Assertions.assertEquals;
1818
import static org.junit.jupiter.api.Assertions.assertNotNull;
1919

20+
import java.util.Arrays;
2021
import org.junit.jupiter.api.Test;
2122
import org.tensorflow.op.Ops;
2223
import org.tensorflow.op.core.Init;
2324
import org.tensorflow.op.core.Placeholder;
2425
import org.tensorflow.op.math.Add;
2526
import org.tensorflow.op.math.Sub;
27+
import org.tensorflow.proto.framework.DataType;
2628
import org.tensorflow.types.TFloat32;
2729

2830
public class ConcreteFunctionTest {
@@ -134,4 +136,54 @@ public void testNestedFunctionGraph() {
134136
}
135137
}
136138
}
139+
140+
private static Signature square(Ops tf) {
141+
Placeholder<TFloat32> input = tf.placeholder(TFloat32.class);
142+
Operand<TFloat32> output = tf.math.square(input);
143+
return Signature.builder().methodName("square").key("square").input("x", input).output("y", output).build();
144+
}
145+
146+
// call op gradients are not defined in c++
147+
// @Test
148+
public void testGradientsGraph() {
149+
try (Graph g = new Graph();
150+
ConcreteFunction square = ConcreteFunction.create(ConcreteFunctionTest::square);
151+
Session s = new Session(g)) {
152+
Ops tf = Ops.create(g);
153+
154+
Output<TFloat32> x1 = tf.placeholder(TFloat32.class).output();
155+
Output<TFloat32> x2 = tf.placeholder(TFloat32.class).output();
156+
Output<TFloat32> y0 = (Output<TFloat32>) square.call(tf, x1);
157+
Output<TFloat32> y1 = (Output<TFloat32>) square.call(tf, y0);
158+
Output<TFloat32> y2 = tf.math.addN(Arrays.asList(y0, x2)).sum();
159+
160+
Output<?>[] grads0 = g.addGradients(y1, new Output[]{x1});
161+
assertNotNull(grads0);
162+
assertEquals(1, grads0.length);
163+
assertEquals(DataType.DT_FLOAT, grads0[0].dataType());
164+
165+
Output<?>[] grads1 = g.addGradients(y2, new Output[]{x1, x2});
166+
assertNotNull(grads1);
167+
assertEquals(2, grads1.length);
168+
assertEquals(DataType.DT_FLOAT, grads1[0].dataType());
169+
assertEquals(DataType.DT_FLOAT, grads1[1].dataType());
170+
171+
try (TFloat32 c1 = TFloat32.scalarOf(3.0f);
172+
TFloat32 c2 = TFloat32.scalarOf(2.0f);
173+
AutoCloseableList<Tensor> outputs = new AutoCloseableList<>(
174+
s.runner()
175+
.feed(x1, c1)
176+
.feed(x2, c2)
177+
.fetch(grads0[0])
178+
.fetch(grads1[0])
179+
.fetch(grads1[1])
180+
.run())) {
181+
182+
assertEquals(3, outputs.size());
183+
assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
184+
assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f);
185+
assertEquals(1.0f, ((TFloat32) outputs.get(2)).getFloat(), 0.0f);
186+
}
187+
}
188+
}
137189
}

0 commit comments

Comments
 (0)