Skip to content

Commit 61463a0

Browse files
committed
ENH: use tf.control_dependencies
1 parent 9001774 commit 61463a0

File tree

2 files changed

+38
-31
lines changed

2 files changed

+38
-31
lines changed

tensorflow_addons/image/utils.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,27 @@ def to_4D_image(image):
3535
Returns:
3636
4D tensor with the same type.
3737
"""
38-
tf.debugging.assert_rank_in(image, [2, 3, 4])
39-
ndims = image.get_shape().ndims
40-
if ndims is None:
41-
return _dynamic_to_4D_image(image)
42-
elif ndims == 2:
43-
return image[None, :, :, None]
44-
elif ndims == 3:
45-
return image[None, :, :, :]
46-
else:
47-
return image
38+
with tf.control_dependencies(
39+
[tf.debugging.assert_rank_in(
40+
image,
41+
[2, 3, 4],
42+
message='`image` must be 2/3/4D tensor')]):
43+
ndims = image.get_shape().ndims
44+
if ndims is None:
45+
return _dynamic_to_4D_image(image)
46+
elif ndims == 2:
47+
return image[None, :, :, None]
48+
elif ndims == 3:
49+
return image[None, :, :, :]
50+
else:
51+
return image
4852

4953

5054
def _dynamic_to_4D_image(image):
5155
shape = tf.shape(image)
5256
original_rank = tf.rank(image)
53-
# 4D image => [N, H, W, C]
54-
# 3D image => [1, H, W, C]
57+
# 4D image => [N, H, W, C] or [N, C, H, W]
58+
# 3D image => [1, H, W, C] or [1, C, H, W]
5559
# 2D image => [1, H, W, 1]
5660
left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
5761
right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
@@ -76,21 +80,25 @@ def from_4D_image(image, ndims):
7680
Returns:
7781
`ndims`-D tensor with the same type.
7882
"""
79-
tf.debugging.assert_rank(image, 4)
80-
if isinstance(ndims, tf.Tensor):
81-
return _dynamic_from_4D_image(image, ndims)
82-
elif ndims == 2:
83-
return tf.squeeze(image, [0, 3])
84-
elif ndims == 3:
85-
return tf.squeeze(image, [0])
86-
else:
87-
return image
83+
with tf.control_dependencies(
84+
[tf.debugging.assert_rank(
85+
image,
86+
4,
87+
message='`image` must be 4D tensor')]):
88+
if isinstance(ndims, tf.Tensor):
89+
return _dynamic_from_4D_image(image, ndims)
90+
elif ndims == 2:
91+
return tf.squeeze(image, [0, 3])
92+
elif ndims == 3:
93+
return tf.squeeze(image, [0])
94+
else:
95+
return image
8896

8997

9098
def _dynamic_from_4D_image(image, original_rank):
9199
shape = tf.shape(image)
92-
# 4D image <= [N, H, W, C]
93-
# 3D image <= [1, H, W, C]
100+
# 4D image <= [N, H, W, C] or [N, C, H, W]
101+
# 3D image <= [1, H, W, C] or [1, C, H, W]
94102
# 2D image <= [1, H, W, 1]
95103
begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
96104
end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)

tensorflow_addons/image/utils_test.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ def test_to_4D_image_with_unknown_shape(self):
4343
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))
4444

4545
def test_to_4D_image_with_invalid_shape(self):
46-
with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
46+
errors = (ValueError, tf.errors.InvalidArgumentError)
47+
with self.assertRaisesRegexp(errors, '`image` must be 2/3/4D tensor'):
4748
img_utils.to_4D_image(tf.ones(shape=(1,)))
4849

49-
with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
50+
with self.assertRaisesRegexp(errors, '`image` must be 2/3/4D tensor'):
5051
img_utils.to_4D_image(tf.ones(shape=(1, 2, 4, 3, 2)))
5152

5253
def test_from_4D_image(self):
@@ -77,18 +78,16 @@ def test_from_4D_image_with_invalid_data(self):
7778
tf.ones(shape=(2, 2, 4, 1)), tf.constant(2)))
7879

7980
def test_from_4D_image_with_invalid_shape(self):
81+
errors = (ValueError, tf.errors.InvalidArgumentError)
8082
for rank in 2, tf.constant(2):
8183
with self.subTest(rank=rank):
82-
with self.assertRaises((ValueError,
83-
tf.errors.InvalidArgumentError)):
84+
with self.assertRaisesRegexp(errors, '`image` must be 4D tensor'):
8485
img_utils.from_4D_image(tf.ones(shape=(2, 4)), rank)
8586

86-
with self.assertRaises((ValueError,
87-
tf.errors.InvalidArgumentError)):
87+
with self.assertRaisesRegexp(errors, '`image` must be 4D tensor'):
8888
img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), rank)
8989

90-
with self.assertRaises((ValueError,
91-
tf.errors.InvalidArgumentError)):
90+
with self.assertRaisesRegexp(errors, '`image` must be 4D tensor'):
9291
img_utils.from_4D_image(
9392
tf.ones(shape=(1, 2, 4, 1, 1)), rank)
9493

0 commit comments

Comments
 (0)