28
28
class UtilsOpsTest (tf .test .TestCase ):
29
29
def test_to_4D_image (self ):
30
30
for shape in (2 , 4 ), (2 , 4 , 1 ), (1 , 2 , 4 , 1 ):
31
- self .assertAllEqual (
32
- self .evaluate (tf .ones (shape = (1 , 2 , 4 , 1 ))),
33
- self .evaluate (img_utils .to_4D_image (tf .ones (shape = shape ))))
31
+ exp = tf .ones (shape = (1 , 2 , 4 , 1 ))
32
+ res = img_utils .to_4D_image (tf .ones (shape = shape ))
33
+ # static shape:
34
+ self .assertAllEqual (exp .get_shape (), res .get_shape ())
35
+ self .assertAllEqual (self .evaluate (exp ), self .evaluate (res ))
34
36
35
37
def test_to_4D_image_with_unknown_shape (self ):
36
38
fn = img_utils .to_4D_image .get_concrete_function (
37
39
tf .TensorSpec (shape = None , dtype = tf .float32 ))
38
40
for shape in (2 , 4 ), (2 , 4 , 1 ), (1 , 2 , 4 , 1 ):
39
- image = tf .ones (shape = shape )
40
- self .assertAllEqual (
41
- self .evaluate (tf .ones (shape = (1 , 2 , 4 , 1 ))),
42
- self .evaluate (fn (image )))
41
+ exp = tf .ones (shape = (1 , 2 , 4 , 1 ))
42
+ res = fn (tf .ones (shape = shape ))
43
+ self .assertAllEqual (self .evaluate (exp ), self .evaluate (res ))
43
44
44
45
def test_to_4D_image_with_invalid_shape (self ):
45
46
with self .assertRaises ((ValueError , tf .errors .InvalidArgumentError )):
@@ -50,19 +51,20 @@ def test_to_4D_image_with_invalid_shape(self):
50
51
51
52
def test_from_4D_image (self ):
52
53
for shape in (2 , 4 ), (2 , 4 , 1 ), (1 , 2 , 4 , 1 ):
53
- self .assertAllEqual (
54
- self .evaluate (tf .ones (shape = shape )),
55
- self .evaluate (
56
- img_utils .from_4D_image (
57
- tf .ones (shape = (1 , 2 , 4 , 1 )), len (shape ))))
54
+ exp = tf .ones (shape = shape )
55
+ res = img_utils .from_4D_image (
56
+ tf .ones (shape = (1 , 2 , 4 , 1 )), len (shape ))
57
+ # static shape:
58
+ self .assertAllEqual (exp .get_shape (), res .get_shape ())
59
+ self .assertAllEqual (self .evaluate (exp ), self .evaluate (res ))
58
60
59
61
def test_from_4D_image_with_unknown_shape (self ):
60
62
for shape in (2 , 4 ), (2 , 4 , 1 ), (1 , 2 , 4 , 1 ):
63
+ exp = tf .ones (shape = shape )
61
64
fn = img_utils .from_4D_image .get_concrete_function (
62
65
tf .TensorSpec (shape = None , dtype = tf .float32 ), tf .size (shape ))
63
- self .assertAllEqual (
64
- self .evaluate (tf .ones (shape = shape )),
65
- self .evaluate (fn (tf .ones (shape = (1 , 2 , 4 , 1 )), tf .size (shape ))))
66
+ res = fn (tf .ones (shape = (1 , 2 , 4 , 1 )), tf .size (shape ))
67
+ self .assertAllEqual (self .evaluate (exp ), self .evaluate (res ))
66
68
67
69
def test_from_4D_image_with_invalid_data (self ):
68
70
with self .assertRaises (ValueError ):
0 commit comments