Skip to content

Commit fa957d7

Browse files
Moved test out of run_in_graph_and_eager_mode in dense_image_warp. (#1419)
See #1328
1 parent 369b8f2 commit fa957d7

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

tensorflow_addons/image/dense_image_warp_test.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,20 @@ def test_interpolate_small_grid_batched(self):
7474

7575
self.assertAllClose(expected_results, interp)
7676

77-
def test_unknown_shape(self):
78-
query_points = tf.constant(
79-
[[0.0, 0.0], [0.0, 1.0], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2]
80-
)
81-
fn = interpolate_bilinear.get_concrete_function(
82-
tf.TensorSpec(shape=None, dtype=tf.float32),
83-
tf.TensorSpec(shape=None, dtype=tf.float32),
84-
)
85-
for shape in (2, 4, 3, 6), (6, 2, 4, 3), (1, 2, 4, 3):
86-
image = tf.ones(shape=shape)
87-
res = fn(image, query_points)
88-
self.assertAllEqual(res.shape, (shape[0], 4, shape[3]))
77+
78+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
79+
def test_unknown_shape():
80+
query_points = tf.constant(
81+
[[0.0, 0.0], [0.0, 1.0], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2]
82+
)
83+
fn = interpolate_bilinear.get_concrete_function(
84+
tf.TensorSpec(shape=None, dtype=tf.float32),
85+
tf.TensorSpec(shape=None, dtype=tf.float32),
86+
)
87+
for shape in (2, 4, 3, 6), (6, 2, 4, 3), (1, 2, 4, 3):
88+
image = tf.ones(shape=shape)
89+
res = fn(image, query_points)
90+
assert res.shape == (shape[0], 4, shape[3])
8991

9092

9193
@test_utils.run_all_in_graph_and_eager_modes

0 commit comments

Comments
 (0)