Skip to content

Commit 2247ea7

Browse files
committed
Add session function test, Signature.builder with name
Signed-off-by: Ryan Nett <[email protected]>
1 parent 3e08365 commit 2247ea7

File tree

3 files changed

+62
-43
lines changed

3 files changed

+62
-43
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,8 +1120,8 @@ public Bucketize bucketize(Operand<? extends TNumber> input, List<Float> boundar
11201120
}
11211121

11221122
/**
1123-
* Calls the function in an execution environment, adding its graph as a function if it isn't already present. Only
1124-
* works for functions with a single input and output.
1123+
* Calls the function in an execution environment, adding its graph as a function if it isn't
1124+
* already present. Only works for functions with a single input and output.
11251125
*
11261126
* @param argument the argument to the call
11271127
* @return the output of the function
@@ -1132,8 +1132,8 @@ public Operand<?> call(ConcreteFunction function, Operand<?> argument) {
11321132
}
11331133

11341134
/**
1135-
* Calls the function in an execution environment, adding its graph as a function if it isn't already present. The
1136-
* inputs and outputs are keyed by the names set in the {@code Signature}.
1135+
* Calls the function in an execution environment, adding its graph as a function if it isn't
1136+
* already present. The inputs and outputs are keyed by the names set in the {@code Signature}.
11371137
*
11381138
* @param arguments the arguments to the call
11391139
* @return the outputs of the function

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,13 @@ public static Builder builder() {
179179
return new Builder();
180180
}
181181

182+
/**
183+
* Returns a new builder for creating a signature, with the methodName and key set to {@code name}
184+
*/
185+
public static Builder builder(String name) {
186+
return new Builder().methodName(name).key(name);
187+
}
188+
182189
/** Return the key of this signature */
183190
public String key() {
184191
return key;

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,33 @@
4343
import org.tensorflow.types.TFloat32;
4444
import org.tensorflow.types.TInt32;
4545

46-
/**
47-
* Unit tests for {@link org.tensorflow.Session}.
48-
*/
46+
/** Unit tests for {@link org.tensorflow.Session}. */
4947
public class SessionTest {
5048

49+
@Test
50+
public void runUsingFunction() {
51+
try (Graph g = new Graph();
52+
Session s = new Session(g)) {
53+
Ops tf = Ops.create(g);
54+
transpose_A_times_X(tf, new int[][] {{2}, {3}});
55+
Signature sig =
56+
Signature.builder("sess").input("X", g.output("X")).output("Y", g.output("Y")).build();
57+
SessionFunction func = s.function(sig);
58+
59+
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}));
60+
TInt32 y = (TInt32) func.call(x)) {
61+
assertEquals(31, y.getInt(0, 0));
62+
}
63+
}
64+
}
65+
5166
@Test
5267
public void runUsingOperationNames() {
5368
try (Graph g = new Graph();
5469
Session s = new Session(g)) {
5570
Ops tf = Ops.create(g);
56-
transpose_A_times_X(tf, new int[][]{{2}, {3}});
57-
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}));
71+
transpose_A_times_X(tf, new int[][] {{2}, {3}});
72+
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}));
5873
AutoCloseableList<Tensor> outputs =
5974
new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) {
6075
assertEquals(1, outputs.size());
@@ -68,10 +83,10 @@ public void runUsingOperationHandles() {
6883
try (Graph g = new Graph();
6984
Session s = new Session(g)) {
7085
Ops tf = Ops.create(g);
71-
transpose_A_times_X(tf, new int[][]{{2}, {3}});
86+
transpose_A_times_X(tf, new int[][] {{2}, {3}});
7287
Output<TInt32> feed = g.operation("X").output(0);
7388
Output<TInt32> fetch = g.operation("Y").output(0);
74-
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}));
89+
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}));
7590
AutoCloseableList<Tensor> outputs =
7691
new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) {
7792
assertEquals(1, outputs.size());
@@ -95,12 +110,9 @@ public void runUsingColonSeparatedNames() {
95110
}
96111
// Feed using colon separated names.
97112
try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1);
98-
TInt32 fetched = (TInt32) s.runner()
99-
.feed("Split:0", fed)
100-
.feed("Split:1", fed)
101-
.fetch("Add")
102-
.run()
103-
.get(0)) {
113+
TInt32 fetched =
114+
(TInt32)
115+
s.runner().feed("Split:0", fed).feed("Split:1", fed).fetch("Add").run().get(0)) {
104116
assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched);
105117
}
106118
}
@@ -111,13 +123,14 @@ public void runWithMetadata() {
111123
try (Graph g = new Graph();
112124
Session s = new Session(g)) {
113125
Ops tf = Ops.create(g);
114-
transpose_A_times_X(tf, new int[][]{{2}, {3}});
115-
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}))) {
116-
Session.Run result = s.runner()
117-
.feed("X", x)
118-
.fetch("Y")
119-
.setOptions(fullTraceRunOptions())
120-
.runAndFetchMetadata();
126+
transpose_A_times_X(tf, new int[][] {{2}, {3}});
127+
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) {
128+
Session.Run result =
129+
s.runner()
130+
.feed("X", x)
131+
.fetch("Y")
132+
.setOptions(fullTraceRunOptions())
133+
.runAndFetchMetadata();
121134
// Sanity check on outputs.
122135
AutoCloseableList<Tensor> outputs = new AutoCloseableList<>(result.outputs);
123136
assertEquals(1, outputs.size());
@@ -163,8 +176,7 @@ public void failOnUseAfterClose() {
163176
@Test
164177
public void createWithConfigProto() {
165178
try (Graph g = new Graph();
166-
Session s = new Session(g, singleThreadConfigProto())) {
167-
}
179+
Session s = new Session(g, singleThreadConfigProto())) {}
168180
}
169181

170182
@Test
@@ -219,10 +231,12 @@ public void saveAndRestore() throws IOException {
219231
Path testFolder = Files.createTempDirectory("tf-session-save-restore-test");
220232
try (Graph g = new Graph()) {
221233
Ops tf = Ops.create(g);
222-
Variable<TFloat32> x = tf.withName("x")
223-
.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
224-
Variable<TFloat32> y = tf.withName("y")
225-
.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
234+
Variable<TFloat32> x =
235+
tf.withName("x")
236+
.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
237+
Variable<TFloat32> y =
238+
tf.withName("y")
239+
.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
226240
Init init = tf.init();
227241

228242
try (Session s = new Session(g)) {
@@ -234,9 +248,10 @@ public void saveAndRestore() throws IOException {
234248
restoredGraph.importGraphDef(graphDef);
235249
try (Session restoredSession = new Session(restoredGraph)) {
236250
restoredSession.restore(testFolder.resolve("checkpoint").toString());
237-
try (AutoCloseableList<Tensor> oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run());
238-
AutoCloseableList<Tensor> newList = new AutoCloseableList<>(
239-
restoredSession.runner().fetch("x").fetch("y").run())) {
251+
try (AutoCloseableList<Tensor> oldList =
252+
new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run());
253+
AutoCloseableList<Tensor> newList =
254+
new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())) {
240255
assertEquals(oldList.get(0), newList.get(0));
241256
assertEquals(oldList.get(1), newList.get(1));
242257
}
@@ -265,7 +280,6 @@ public static void testFetchVariable() {
265280
try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) {
266281
assertEquals(2, value.getInt());
267282
}
268-
269283
}
270284
}
271285

@@ -295,14 +309,11 @@ public static void testFetchVariableReusingRead() {
295309
}
296310

297311
assertEquals(0, numOperations(g) - ops);
298-
299312
}
300313
}
301314

302315
private static RunOptions fullTraceRunOptions() {
303-
return RunOptions.newBuilder()
304-
.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
305-
.build();
316+
return RunOptions.newBuilder().setTraceLevel(RunOptions.TraceLevel.FULL_TRACE).build();
306317
}
307318

308319
private static ConfigProto singleThreadConfigProto() {
@@ -313,10 +324,11 @@ private static ConfigProto singleThreadConfigProto() {
313324
}
314325

315326
private static void transpose_A_times_X(Ops tf, int[][] a) {
316-
tf.withName("Y").linalg.matMul(
317-
tf.withName("A").constant(a),
318-
tf.withName("X").placeholder(TInt32.class),
319-
MatMul.transposeA(true).transposeB(false)
320-
);
327+
tf.withName("Y")
328+
.linalg
329+
.matMul(
330+
tf.withName("A").constant(a),
331+
tf.withName("X").placeholder(TInt32.class),
332+
MatMul.transposeA(true).transposeB(false));
321333
}
322334
}

0 commit comments

Comments
 (0)