Skip to content

Commit eaf339e

Browse files
committed
ENH: use tf.control_dependencies
1 parent 9001774 commit eaf339e

File tree

2 files changed

+39
-31
lines changed

2 files changed

+39
-31
lines changed

tensorflow_addons/image/utils.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,26 @@ def to_4D_image(image):
3535
Returns:
3636
4D tensor with the same type.
3737
"""
38-
tf.debugging.assert_rank_in(image, [2, 3, 4])
39-
ndims = image.get_shape().ndims
40-
if ndims is None:
41-
return _dynamic_to_4D_image(image)
42-
elif ndims == 2:
43-
return image[None, :, :, None]
44-
elif ndims == 3:
45-
return image[None, :, :, :]
46-
else:
47-
return image
38+
with tf.control_dependencies([
39+
tf.debugging.assert_rank_in(
40+
image, [2, 3, 4], message='`image` must be 2/3/4D tensor')
41+
]):
42+
ndims = image.get_shape().ndims
43+
if ndims is None:
44+
return _dynamic_to_4D_image(image)
45+
elif ndims == 2:
46+
return image[None, :, :, None]
47+
elif ndims == 3:
48+
return image[None, :, :, :]
49+
else:
50+
return image
4851

4952

5053
def _dynamic_to_4D_image(image):
5154
shape = tf.shape(image)
5255
original_rank = tf.rank(image)
53-
# 4D image => [N, H, W, C]
54-
# 3D image => [1, H, W, C]
56+
# 4D image => [N, H, W, C] or [N, C, H, W]
57+
# 3D image => [1, H, W, C] or [1, C, H, W]
5558
# 2D image => [1, H, W, 1]
5659
left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
5760
right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
@@ -76,21 +79,24 @@ def from_4D_image(image, ndims):
7679
Returns:
7780
`ndims`-D tensor with the same type.
7881
"""
79-
tf.debugging.assert_rank(image, 4)
80-
if isinstance(ndims, tf.Tensor):
81-
return _dynamic_from_4D_image(image, ndims)
82-
elif ndims == 2:
83-
return tf.squeeze(image, [0, 3])
84-
elif ndims == 3:
85-
return tf.squeeze(image, [0])
86-
else:
87-
return image
82+
with tf.control_dependencies([
83+
tf.debugging.assert_rank(
84+
image, 4, message='`image` must be 4D tensor')
85+
]):
86+
if isinstance(ndims, tf.Tensor):
87+
return _dynamic_from_4D_image(image, ndims)
88+
elif ndims == 2:
89+
return tf.squeeze(image, [0, 3])
90+
elif ndims == 3:
91+
return tf.squeeze(image, [0])
92+
else:
93+
return image
8894

8995

9096
def _dynamic_from_4D_image(image, original_rank):
9197
shape = tf.shape(image)
92-
# 4D image <= [N, H, W, C]
93-
# 3D image <= [1, H, W, C]
98+
# 4D image <= [N, H, W, C] or [N, C, H, W]
99+
# 3D image <= [1, H, W, C] or [1, C, H, W]
94100
# 2D image <= [1, H, W, 1]
95101
begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
96102
end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)

tensorflow_addons/image/utils_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ def test_to_4D_image_with_unknown_shape(self):
4343
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))
4444

4545
def test_to_4D_image_with_invalid_shape(self):
46-
with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
46+
errors = (ValueError, tf.errors.InvalidArgumentError)
47+
with self.assertRaisesRegexp(errors, '`image` must be 2/3/4D tensor'):
4748
img_utils.to_4D_image(tf.ones(shape=(1,)))
4849

49-
with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
50+
with self.assertRaisesRegexp(errors, '`image` must be 2/3/4D tensor'):
5051
img_utils.to_4D_image(tf.ones(shape=(1, 2, 4, 3, 2)))
5152

5253
def test_from_4D_image(self):
@@ -77,18 +78,19 @@ def test_from_4D_image_with_invalid_data(self):
7778
tf.ones(shape=(2, 2, 4, 1)), tf.constant(2)))
7879

7980
def test_from_4D_image_with_invalid_shape(self):
81+
errors = (ValueError, tf.errors.InvalidArgumentError)
8082
for rank in 2, tf.constant(2):
8183
with self.subTest(rank=rank):
82-
with self.assertRaises((ValueError,
83-
tf.errors.InvalidArgumentError)):
84+
with self.assertRaisesRegexp(errors,
85+
'`image` must be 4D tensor'):
8486
img_utils.from_4D_image(tf.ones(shape=(2, 4)), rank)
8587

86-
with self.assertRaises((ValueError,
87-
tf.errors.InvalidArgumentError)):
88+
with self.assertRaisesRegexp(errors,
89+
'`image` must be 4D tensor'):
8890
img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), rank)
8991

90-
with self.assertRaises((ValueError,
91-
tf.errors.InvalidArgumentError)):
92+
with self.assertRaisesRegexp(errors,
93+
'`image` must be 4D tensor'):
9294
img_utils.from_4D_image(
9395
tf.ones(shape=(1, 2, 4, 1, 1)), rank)
9496

0 commit comments

Comments
 (0)