@@ -74,18 +74,20 @@ def test_interpolate_small_grid_batched(self):
74
74
75
75
self .assertAllClose (expected_results , interp )
76
76
77
- def test_unknown_shape (self ):
78
- query_points = tf .constant (
79
- [[0.0 , 0.0 ], [0.0 , 1.0 ], [0.5 , 2.0 ], [1.5 , 1.5 ]], shape = [1 , 4 , 2 ]
80
- )
81
- fn = interpolate_bilinear .get_concrete_function (
82
- tf .TensorSpec (shape = None , dtype = tf .float32 ),
83
- tf .TensorSpec (shape = None , dtype = tf .float32 ),
84
- )
85
- for shape in (2 , 4 , 3 , 6 ), (6 , 2 , 4 , 3 ), (1 , 2 , 4 , 3 ):
86
- image = tf .ones (shape = shape )
87
- res = fn (image , query_points )
88
- self .assertAllEqual (res .shape , (shape [0 ], 4 , shape [3 ]))
77
+
78
+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
79
+ def test_unknown_shape ():
80
+ query_points = tf .constant (
81
+ [[0.0 , 0.0 ], [0.0 , 1.0 ], [0.5 , 2.0 ], [1.5 , 1.5 ]], shape = [1 , 4 , 2 ]
82
+ )
83
+ fn = interpolate_bilinear .get_concrete_function (
84
+ tf .TensorSpec (shape = None , dtype = tf .float32 ),
85
+ tf .TensorSpec (shape = None , dtype = tf .float32 ),
86
+ )
87
+ for shape in (2 , 4 , 3 , 6 ), (6 , 2 , 4 , 3 ), (1 , 2 , 4 , 3 ):
88
+ image = tf .ones (shape = shape )
89
+ res = fn (image , query_points )
90
+ assert res .shape == (shape [0 ], 4 , shape [3 ])
89
91
90
92
91
93
@test_utils .run_all_in_graph_and_eager_modes
0 commit comments