43
43
import org .tensorflow .types .TFloat32 ;
44
44
import org .tensorflow .types .TInt32 ;
45
45
46
- /**
47
- * Unit tests for {@link org.tensorflow.Session}.
48
- */
46
+ /** Unit tests for {@link org.tensorflow.Session}. */
49
47
public class SessionTest {
50
48
49
+ @ Test
50
+ public void runUsingFunction () {
51
+ try (Graph g = new Graph ();
52
+ Session s = new Session (g )) {
53
+ Ops tf = Ops .create (g );
54
+ transpose_A_times_X (tf , new int [][] {{2 }, {3 }});
55
+ Signature sig =
56
+ Signature .builder ("sess" ).input ("X" , g .output ("X" )).output ("Y" , g .output ("Y" )).build ();
57
+ SessionFunction func = s .function (sig );
58
+
59
+ try (TInt32 x = TInt32 .tensorOf (StdArrays .ndCopyOf (new int [][] {{5 }, {7 }}));
60
+ TInt32 y = (TInt32 ) func .call (x )) {
61
+ assertEquals (31 , y .getInt (0 , 0 ));
62
+ }
63
+ }
64
+ }
65
+
51
66
@ Test
52
67
public void runUsingOperationNames () {
53
68
try (Graph g = new Graph ();
54
69
Session s = new Session (g )) {
55
70
Ops tf = Ops .create (g );
56
- transpose_A_times_X (tf , new int [][]{{2 }, {3 }});
57
- try (TInt32 x = TInt32 .tensorOf (StdArrays .ndCopyOf (new int [][]{{5 }, {7 }}));
71
+ transpose_A_times_X (tf , new int [][] {{2 }, {3 }});
72
+ try (TInt32 x = TInt32 .tensorOf (StdArrays .ndCopyOf (new int [][] {{5 }, {7 }}));
58
73
AutoCloseableList <Tensor > outputs =
59
74
new AutoCloseableList <>(s .runner ().feed ("X" , x ).fetch ("Y" ).run ())) {
60
75
assertEquals (1 , outputs .size ());
@@ -68,10 +83,10 @@ public void runUsingOperationHandles() {
68
83
try (Graph g = new Graph ();
69
84
Session s = new Session (g )) {
70
85
Ops tf = Ops .create (g );
71
- transpose_A_times_X (tf , new int [][]{{2 }, {3 }});
86
+ transpose_A_times_X (tf , new int [][] {{2 }, {3 }});
72
87
Output <TInt32 > feed = g .operation ("X" ).output (0 );
73
88
Output <TInt32 > fetch = g .operation ("Y" ).output (0 );
74
- try (TInt32 x = TInt32 .tensorOf (StdArrays .ndCopyOf (new int [][]{{5 }, {7 }}));
89
+ try (TInt32 x = TInt32 .tensorOf (StdArrays .ndCopyOf (new int [][] {{5 }, {7 }}));
75
90
AutoCloseableList <Tensor > outputs =
76
91
new AutoCloseableList <>(s .runner ().feed (feed , x ).fetch (fetch ).run ())) {
77
92
assertEquals (1 , outputs .size ());
@@ -95,12 +110,9 @@ public void runUsingColonSeparatedNames() {
95
110
}
96
111
// Feed using colon separated names.
97
112
try (TInt32 fed = TInt32 .vectorOf (4 , 3 , 2 , 1 );
98
- TInt32 fetched = (TInt32 ) s .runner ()
99
- .feed ("Split:0" , fed )
100
- .feed ("Split:1" , fed )
101
- .fetch ("Add" )
102
- .run ()
103
- .get (0 )) {
113
+ TInt32 fetched =
114
+ (TInt32 )
115
+ s .runner ().feed ("Split:0" , fed ).feed ("Split:1" , fed ).fetch ("Add" ).run ().get (0 )) {
104
116
assertEquals (NdArrays .vectorOf (8 , 6 , 4 , 2 ), fetched );
105
117
}
106
118
}
@@ -111,13 +123,14 @@ public void runWithMetadata() {
111
123
try (Graph g = new Graph ();
112
124
Session s = new Session (g )) {
113
125
Ops tf = Ops .create (g );
114
- transpose_A_times_X (tf , new int [][]{{2 }, {3 }});
115
- try (TInt32 x = TInt32 .tensorOf (StdArrays .ndCopyOf (new int [][]{{5 }, {7 }}))) {
116
- Session .Run result = s .runner ()
117
- .feed ("X" , x )
118
- .fetch ("Y" )
119
- .setOptions (fullTraceRunOptions ())
120
- .runAndFetchMetadata ();
126
+ transpose_A_times_X (tf , new int [][] {{2 }, {3 }});
127
+ try (TInt32 x = TInt32 .tensorOf (StdArrays .ndCopyOf (new int [][] {{5 }, {7 }}))) {
128
+ Session .Run result =
129
+ s .runner ()
130
+ .feed ("X" , x )
131
+ .fetch ("Y" )
132
+ .setOptions (fullTraceRunOptions ())
133
+ .runAndFetchMetadata ();
121
134
// Sanity check on outputs.
122
135
AutoCloseableList <Tensor > outputs = new AutoCloseableList <>(result .outputs );
123
136
assertEquals (1 , outputs .size ());
@@ -163,8 +176,7 @@ public void failOnUseAfterClose() {
163
176
@ Test
164
177
public void createWithConfigProto () {
165
178
try (Graph g = new Graph ();
166
- Session s = new Session (g , singleThreadConfigProto ())) {
167
- }
179
+ Session s = new Session (g , singleThreadConfigProto ())) {}
168
180
}
169
181
170
182
@ Test
@@ -219,10 +231,12 @@ public void saveAndRestore() throws IOException {
219
231
Path testFolder = Files .createTempDirectory ("tf-session-save-restore-test" );
220
232
try (Graph g = new Graph ()) {
221
233
Ops tf = Ops .create (g );
222
- Variable <TFloat32 > x = tf .withName ("x" )
223
- .variable (tf .random .randomUniform (tf .constant (Shape .of (3 , 3L )), TFloat32 .class ));
224
- Variable <TFloat32 > y = tf .withName ("y" )
225
- .variable (tf .random .randomUniform (tf .constant (Shape .of (3 , 3L )), TFloat32 .class ));
234
+ Variable <TFloat32 > x =
235
+ tf .withName ("x" )
236
+ .variable (tf .random .randomUniform (tf .constant (Shape .of (3 , 3L )), TFloat32 .class ));
237
+ Variable <TFloat32 > y =
238
+ tf .withName ("y" )
239
+ .variable (tf .random .randomUniform (tf .constant (Shape .of (3 , 3L )), TFloat32 .class ));
226
240
Init init = tf .init ();
227
241
228
242
try (Session s = new Session (g )) {
@@ -234,9 +248,10 @@ public void saveAndRestore() throws IOException {
234
248
restoredGraph .importGraphDef (graphDef );
235
249
try (Session restoredSession = new Session (restoredGraph )) {
236
250
restoredSession .restore (testFolder .resolve ("checkpoint" ).toString ());
237
- try (AutoCloseableList <Tensor > oldList = new AutoCloseableList <>(s .runner ().fetch ("x" ).fetch ("y" ).run ());
238
- AutoCloseableList <Tensor > newList = new AutoCloseableList <>(
239
- restoredSession .runner ().fetch ("x" ).fetch ("y" ).run ())) {
251
+ try (AutoCloseableList <Tensor > oldList =
252
+ new AutoCloseableList <>(s .runner ().fetch ("x" ).fetch ("y" ).run ());
253
+ AutoCloseableList <Tensor > newList =
254
+ new AutoCloseableList <>(restoredSession .runner ().fetch ("x" ).fetch ("y" ).run ())) {
240
255
assertEquals (oldList .get (0 ), newList .get (0 ));
241
256
assertEquals (oldList .get (1 ), newList .get (1 ));
242
257
}
@@ -265,7 +280,6 @@ public static void testFetchVariable() {
265
280
try (TInt32 value = (TInt32 ) s .runner ().addTarget (assign ).fetch (variable ).run ().get (0 )) {
266
281
assertEquals (2 , value .getInt ());
267
282
}
268
-
269
283
}
270
284
}
271
285
@@ -295,14 +309,11 @@ public static void testFetchVariableReusingRead() {
295
309
}
296
310
297
311
assertEquals (0 , numOperations (g ) - ops );
298
-
299
312
}
300
313
}
301
314
302
315
private static RunOptions fullTraceRunOptions () {
303
- return RunOptions .newBuilder ()
304
- .setTraceLevel (RunOptions .TraceLevel .FULL_TRACE )
305
- .build ();
316
+ return RunOptions .newBuilder ().setTraceLevel (RunOptions .TraceLevel .FULL_TRACE ).build ();
306
317
}
307
318
308
319
private static ConfigProto singleThreadConfigProto () {
@@ -313,10 +324,11 @@ private static ConfigProto singleThreadConfigProto() {
313
324
}
314
325
315
326
private static void transpose_A_times_X (Ops tf , int [][] a ) {
316
- tf .withName ("Y" ).linalg .matMul (
317
- tf .withName ("A" ).constant (a ),
318
- tf .withName ("X" ).placeholder (TInt32 .class ),
319
- MatMul .transposeA (true ).transposeB (false )
320
- );
327
+ tf .withName ("Y" )
328
+ .linalg
329
+ .matMul (
330
+ tf .withName ("A" ).constant (a ),
331
+ tf .withName ("X" ).placeholder (TInt32 .class ),
332
+ MatMul .transposeA (true ).transposeB (false ));
321
333
}
322
334
}
0 commit comments