diff --git a/tensorflow_addons/image/utils_test.py b/tensorflow_addons/image/utils_test.py index c1ebdeb6dc..1f90fd7f53 100644 --- a/tensorflow_addons/image/utils_test.py +++ b/tensorflow_addons/image/utils_test.py @@ -76,18 +76,18 @@ def test_from_4D_image_with_invalid_data(self): img_utils.from_4D_image(tf.ones(shape=(2, 2, 4, 1)), tf.constant(2)) ) - def test_from_4D_image_with_invalid_shape(self): - errors = (ValueError, tf.errors.InvalidArgumentError) - for rank in 2, tf.constant(2): - with self.subTest(rank=rank): - with self.assertRaisesRegexp(errors, "`image` must be 4D tensor"): - img_utils.from_4D_image(tf.ones(shape=(2, 4)), rank) - with self.assertRaisesRegexp(errors, "`image` must be 4D tensor"): - img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), rank) +@pytest.mark.parametrize("rank", [2, tf.constant(2)]) +def test_from_4d_image_with_invalid_shape(rank): + errors = (ValueError, tf.errors.InvalidArgumentError) + with pytest.raises(errors, match="`image` must be 4D tensor"): + img_utils.from_4D_image(tf.ones(shape=(2, 4)), rank) + + with pytest.raises(errors, match="`image` must be 4D tensor"): + img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), rank) - with self.assertRaisesRegexp(errors, "`image` must be 4D tensor"): - img_utils.from_4D_image(tf.ones(shape=(1, 2, 4, 1, 1)), rank) + with pytest.raises(errors, match="`image` must be 4D tensor"): + img_utils.from_4D_image(tf.ones(shape=(1, 2, 4, 1, 1)), rank) if __name__ == "__main__":