@@ -35,23 +35,26 @@ def to_4D_image(image):
35
35
Returns:
36
36
4D tensor with the same type.
37
37
"""
38
- tf .debugging .assert_rank_in (image , [2 , 3 , 4 ])
39
- ndims = image .get_shape ().ndims
40
- if ndims is None :
41
- return _dynamic_to_4D_image (image )
42
- elif ndims == 2 :
43
- return image [None , :, :, None ]
44
- elif ndims == 3 :
45
- return image [None , :, :, :]
46
- else :
47
- return image
38
+ with tf .control_dependencies ([
39
+ tf .debugging .assert_rank_in (
40
+ image , [2 , 3 , 4 ], message = '`image` must be 2/3/4D tensor' )
41
+ ]):
42
+ ndims = image .get_shape ().ndims
43
+ if ndims is None :
44
+ return _dynamic_to_4D_image (image )
45
+ elif ndims == 2 :
46
+ return image [None , :, :, None ]
47
+ elif ndims == 3 :
48
+ return image [None , :, :, :]
49
+ else :
50
+ return image
48
51
49
52
50
53
def _dynamic_to_4D_image (image ):
51
54
shape = tf .shape (image )
52
55
original_rank = tf .rank (image )
53
- # 4D image => [N, H, W, C]
54
- # 3D image => [1, H, W, C]
56
+ # 4D image => [N, H, W, C] or [N, C, H, W]
57
+ # 3D image => [1, H, W, C] or [1, C, H, W]
55
58
# 2D image => [1, H, W, 1]
56
59
left_pad = tf .cast (tf .less_equal (original_rank , 3 ), dtype = tf .int32 )
57
60
right_pad = tf .cast (tf .equal (original_rank , 2 ), dtype = tf .int32 )
@@ -76,21 +79,24 @@ def from_4D_image(image, ndims):
76
79
Returns:
77
80
`ndims`-D tensor with the same type.
78
81
"""
79
- tf .debugging .assert_rank (image , 4 )
80
- if isinstance (ndims , tf .Tensor ):
81
- return _dynamic_from_4D_image (image , ndims )
82
- elif ndims == 2 :
83
- return tf .squeeze (image , [0 , 3 ])
84
- elif ndims == 3 :
85
- return tf .squeeze (image , [0 ])
86
- else :
87
- return image
82
+ with tf .control_dependencies ([
83
+ tf .debugging .assert_rank (
84
+ image , 4 , message = '`image` must be 4D tensor' )
85
+ ]):
86
+ if isinstance (ndims , tf .Tensor ):
87
+ return _dynamic_from_4D_image (image , ndims )
88
+ elif ndims == 2 :
89
+ return tf .squeeze (image , [0 , 3 ])
90
+ elif ndims == 3 :
91
+ return tf .squeeze (image , [0 ])
92
+ else :
93
+ return image
88
94
89
95
90
96
def _dynamic_from_4D_image (image , original_rank ):
91
97
shape = tf .shape (image )
92
- # 4D image <= [N, H, W, C]
93
- # 3D image <= [1, H, W, C]
98
+ # 4D image <= [N, H, W, C] or [N, C, H, W]
99
+ # 3D image <= [1, H, W, C] or [1, C, H, W]
94
100
# 2D image <= [1, H, W, 1]
95
101
begin = tf .cast (tf .less_equal (original_rank , 3 ), dtype = tf .int32 )
96
102
end = 4 - tf .cast (tf .equal (original_rank , 2 ), dtype = tf .int32 )
0 commit comments