@@ -35,23 +35,27 @@ def to_4D_image(image):
35
35
Returns:
36
36
4D tensor with the same type.
37
37
"""
38
- tf .debugging .assert_rank_in (image , [2 , 3 , 4 ])
39
- ndims = image .get_shape ().ndims
40
- if ndims is None :
41
- return _dynamic_to_4D_image (image )
42
- elif ndims == 2 :
43
- return image [None , :, :, None ]
44
- elif ndims == 3 :
45
- return image [None , :, :, :]
46
- else :
47
- return image
38
+ with tf .control_dependencies (
39
+ [tf .debugging .assert_rank_in (
40
+ image ,
41
+ [2 , 3 , 4 ],
42
+ message = '`image` must be 2/3/4D tensor' )]):
43
+ ndims = image .get_shape ().ndims
44
+ if ndims is None :
45
+ return _dynamic_to_4D_image (image )
46
+ elif ndims == 2 :
47
+ return image [None , :, :, None ]
48
+ elif ndims == 3 :
49
+ return image [None , :, :, :]
50
+ else :
51
+ return image
48
52
49
53
50
54
def _dynamic_to_4D_image (image ):
51
55
shape = tf .shape (image )
52
56
original_rank = tf .rank (image )
53
- # 4D image => [N, H, W, C]
54
- # 3D image => [1, H, W, C]
57
+ # 4D image => [N, H, W, C] or [N, C, H, W]
58
+ # 3D image => [1, H, W, C] or [1, C, H, W]
55
59
# 2D image => [1, H, W, 1]
56
60
left_pad = tf .cast (tf .less_equal (original_rank , 3 ), dtype = tf .int32 )
57
61
right_pad = tf .cast (tf .equal (original_rank , 2 ), dtype = tf .int32 )
@@ -76,21 +80,25 @@ def from_4D_image(image, ndims):
76
80
Returns:
77
81
`ndims`-D tensor with the same type.
78
82
"""
79
- tf .debugging .assert_rank (image , 4 )
80
- if isinstance (ndims , tf .Tensor ):
81
- return _dynamic_from_4D_image (image , ndims )
82
- elif ndims == 2 :
83
- return tf .squeeze (image , [0 , 3 ])
84
- elif ndims == 3 :
85
- return tf .squeeze (image , [0 ])
86
- else :
87
- return image
83
+ with tf .control_dependencies (
84
+ [tf .debugging .assert_rank (
85
+ image ,
86
+ 4 ,
87
+ message = '`image` must be 4D tensor' )]):
88
+ if isinstance (ndims , tf .Tensor ):
89
+ return _dynamic_from_4D_image (image , ndims )
90
+ elif ndims == 2 :
91
+ return tf .squeeze (image , [0 , 3 ])
92
+ elif ndims == 3 :
93
+ return tf .squeeze (image , [0 ])
94
+ else :
95
+ return image
88
96
89
97
90
98
def _dynamic_from_4D_image (image , original_rank ):
91
99
shape = tf .shape (image )
92
- # 4D image <= [N, H, W, C]
93
- # 3D image <= [1, H, W, C]
100
+ # 4D image <= [N, H, W, C] or [N, C, H, W]
101
+ # 3D image <= [1, H, W, C] or [1, C, H, W]
94
102
# 2D image <= [1, H, W, 1]
95
103
begin = tf .cast (tf .less_equal (original_rank , 3 ), dtype = tf .int32 )
96
104
end = 4 - tf .cast (tf .equal (original_rank , 2 ), dtype = tf .int32 )
0 commit comments