Skip to content

Commit 9001774

Browse files
committed
TST: static shape check
1 parent 60c3d34 commit 9001774

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

tensorflow_addons/image/utils_test.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,19 @@
2828
class UtilsOpsTest(tf.test.TestCase):
2929
def test_to_4D_image(self):
3030
for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
31-
self.assertAllEqual(
32-
self.evaluate(tf.ones(shape=(1, 2, 4, 1))),
33-
self.evaluate(img_utils.to_4D_image(tf.ones(shape=shape))))
31+
exp = tf.ones(shape=(1, 2, 4, 1))
32+
res = img_utils.to_4D_image(tf.ones(shape=shape))
33+
# static shape:
34+
self.assertAllEqual(exp.get_shape(), res.get_shape())
35+
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))
3436

3537
def test_to_4D_image_with_unknown_shape(self):
3638
fn = img_utils.to_4D_image.get_concrete_function(
3739
tf.TensorSpec(shape=None, dtype=tf.float32))
3840
for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
39-
image = tf.ones(shape=shape)
40-
self.assertAllEqual(
41-
self.evaluate(tf.ones(shape=(1, 2, 4, 1))),
42-
self.evaluate(fn(image)))
41+
exp = tf.ones(shape=(1, 2, 4, 1))
42+
res = fn(tf.ones(shape=shape))
43+
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))
4344

4445
def test_to_4D_image_with_invalid_shape(self):
4546
with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
@@ -50,19 +51,20 @@ def test_to_4D_image_with_invalid_shape(self):
5051

5152
def test_from_4D_image(self):
5253
for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
53-
self.assertAllEqual(
54-
self.evaluate(tf.ones(shape=shape)),
55-
self.evaluate(
56-
img_utils.from_4D_image(
57-
tf.ones(shape=(1, 2, 4, 1)), len(shape))))
54+
exp = tf.ones(shape=shape)
55+
res = img_utils.from_4D_image(
56+
tf.ones(shape=(1, 2, 4, 1)), len(shape))
57+
# static shape:
58+
self.assertAllEqual(exp.get_shape(), res.get_shape())
59+
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))
5860

5961
def test_from_4D_image_with_unknown_shape(self):
6062
for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
63+
exp = tf.ones(shape=shape)
6164
fn = img_utils.from_4D_image.get_concrete_function(
6265
tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape))
63-
self.assertAllEqual(
64-
self.evaluate(tf.ones(shape=shape)),
65-
self.evaluate(fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape))))
66+
res = fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape))
67+
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))
6668

6769
def test_from_4D_image_with_invalid_data(self):
6870
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)