@@ -90,66 +90,9 @@ def test_unknown_shape(self):
90
90
91
91
@test_utils .run_all_in_graph_and_eager_modes
92
92
class DenseImageWarpTest (tf .test .TestCase ):
93
- def _get_random_image_and_flows (self , shape , image_type , flow_type ):
94
- batch_size , height , width , num_channels = shape
95
- image_shape = [batch_size , height , width , num_channels ]
96
- image = np .random .normal (size = image_shape )
97
- flow_shape = [batch_size , height , width , 2 ]
98
- flows = np .random .normal (size = flow_shape ) * 3
99
- return image .astype (image_type ), flows .astype (flow_type )
100
-
101
- def _assert_correct_interpolation_value (
102
- self ,
103
- image ,
104
- flows ,
105
- pred_interpolation ,
106
- batch_index ,
107
- y_index ,
108
- x_index ,
109
- low_precision = False ,
110
- ):
111
- """Assert that the tf interpolation matches hand-computed value."""
112
- height = image .shape [1 ]
113
- width = image .shape [2 ]
114
- displacement = flows [batch_index , y_index , x_index , :]
115
- float_y = y_index - displacement [0 ]
116
- float_x = x_index - displacement [1 ]
117
- floor_y = max (min (height - 2 , math .floor (float_y )), 0 )
118
- floor_x = max (min (width - 2 , math .floor (float_x )), 0 )
119
- ceil_y = floor_y + 1
120
- ceil_x = floor_x + 1
121
-
122
- alpha_y = min (max (0.0 , float_y - floor_y ), 1.0 )
123
- alpha_x = min (max (0.0 , float_x - floor_x ), 1.0 )
124
-
125
- floor_y = int (floor_y )
126
- floor_x = int (floor_x )
127
- ceil_y = int (ceil_y )
128
- ceil_x = int (ceil_x )
129
-
130
- top_left = image [batch_index , floor_y , floor_x , :]
131
- top_right = image [batch_index , floor_y , ceil_x , :]
132
- bottom_left = image [batch_index , ceil_y , floor_x , :]
133
- bottom_right = image [batch_index , ceil_y , ceil_x , :]
134
-
135
- interp_top = alpha_x * (top_right - top_left ) + top_left
136
- interp_bottom = alpha_x * (bottom_right - bottom_left ) + bottom_left
137
- interp = alpha_y * (interp_bottom - interp_top ) + interp_top
138
- atol = 1e-6
139
- rtol = 1e-6
140
- if low_precision :
141
- atol = 1e-2
142
- rtol = 1e-3
143
- self .assertAllClose (
144
- interp ,
145
- pred_interpolation [batch_index , y_index , x_index , :],
146
- atol = atol ,
147
- rtol = rtol ,
148
- )
149
-
150
93
def _check_zero_flow_correctness (self , shape , image_type , flow_type ):
151
94
"""Assert using zero flows doesn't change the input image."""
152
- rand_image , rand_flows = self . _get_random_image_and_flows (
95
+ rand_image , rand_flows = _get_random_image_and_flows (
153
96
shape , image_type , flow_type
154
97
)
155
98
rand_flows *= 0
@@ -169,62 +112,6 @@ def test_zero_flows(self):
169
112
shape , image_type = "float32" , flow_type = "float32"
170
113
)
171
114
172
- def _check_interpolation_correctness (
173
- self , shape , image_type , flow_type , call_with_unknown_shapes = False , num_probes = 5
174
- ):
175
- """Interpolate, and then assert correctness for a few query
176
- locations."""
177
- low_precision = image_type == "float16" or flow_type == "float16"
178
- rand_image , rand_flows = self ._get_random_image_and_flows (
179
- shape , image_type , flow_type
180
- )
181
-
182
- if call_with_unknown_shapes :
183
- fn = dense_image_warp .get_concrete_function (
184
- tf .TensorSpec (shape = None , dtype = image_type ),
185
- tf .TensorSpec (shape = None , dtype = flow_type ),
186
- )
187
- interp = fn (
188
- image = tf .convert_to_tensor (rand_image ),
189
- flow = tf .convert_to_tensor (rand_flows ),
190
- )
191
- else :
192
- interp = dense_image_warp (
193
- image = tf .convert_to_tensor (rand_image ),
194
- flow = tf .convert_to_tensor (rand_flows ),
195
- )
196
-
197
- for _ in range (num_probes ):
198
- batch_index = np .random .randint (0 , shape [0 ])
199
- y_index = np .random .randint (0 , shape [1 ])
200
- x_index = np .random .randint (0 , shape [2 ])
201
-
202
- self ._assert_correct_interpolation_value (
203
- rand_image ,
204
- rand_flows ,
205
- interp ,
206
- batch_index ,
207
- y_index ,
208
- x_index ,
209
- low_precision = low_precision ,
210
- )
211
-
212
- def test_interpolation (self ):
213
- """Apply _check_interpolation_correctness() for a few sizes and
214
- types."""
215
- shapes_to_try = [[3 , 4 , 5 , 6 ], [1 , 2 , 2 , 1 ]]
216
- for im_type in ["float32" , "float64" , "float16" ]:
217
- for flow_type in ["float32" , "float64" , "float16" ]:
218
- for shape in shapes_to_try :
219
- self ._check_interpolation_correctness (shape , im_type , flow_type )
220
-
221
- def test_unknown_shapes (self ):
222
- """Apply _check_interpolation_correctness() for a few sizes and check
223
- for tf.Dataset compatibility."""
224
- shapes_to_try = [[3 , 4 , 5 , 6 ], [1 , 2 , 2 , 1 ]]
225
- for shape in shapes_to_try :
226
- self ._check_interpolation_correctness (shape , "float32" , "float32" , True )
227
-
228
115
def test_gradients_exist (self ):
229
116
"""Check that backprop can run.
230
117
@@ -253,12 +140,132 @@ def loss():
253
140
for _ in range (10 ):
254
141
self .evaluate (minimize_op )
255
142
256
- def test_size_exception (self ):
257
- """Make sure it throws an exception for images that are too small."""
258
- shape = [1 , 2 , 1 , 1 ]
259
- errors = (ValueError , tf .errors .InvalidArgumentError )
260
- with self .assertRaisesRegexp (errors , "Grid width must be at least 2." ):
261
- self ._check_interpolation_correctness (shape , "float32" , "float32" )
143
+
144
+ def _assert_correct_interpolation_value (
145
+ image ,
146
+ flows ,
147
+ pred_interpolation ,
148
+ batch_index ,
149
+ y_index ,
150
+ x_index ,
151
+ low_precision = False ,
152
+ ):
153
+ """Assert that the tf interpolation matches hand-computed value."""
154
+ height = image .shape [1 ]
155
+ width = image .shape [2 ]
156
+ displacement = flows [batch_index , y_index , x_index , :]
157
+ float_y = y_index - displacement [0 ]
158
+ float_x = x_index - displacement [1 ]
159
+ floor_y = max (min (height - 2 , math .floor (float_y )), 0 )
160
+ floor_x = max (min (width - 2 , math .floor (float_x )), 0 )
161
+ ceil_y = floor_y + 1
162
+ ceil_x = floor_x + 1
163
+
164
+ alpha_y = min (max (0.0 , float_y - floor_y ), 1.0 )
165
+ alpha_x = min (max (0.0 , float_x - floor_x ), 1.0 )
166
+
167
+ floor_y = int (floor_y )
168
+ floor_x = int (floor_x )
169
+ ceil_y = int (ceil_y )
170
+ ceil_x = int (ceil_x )
171
+
172
+ top_left = image [batch_index , floor_y , floor_x , :]
173
+ top_right = image [batch_index , floor_y , ceil_x , :]
174
+ bottom_left = image [batch_index , ceil_y , floor_x , :]
175
+ bottom_right = image [batch_index , ceil_y , ceil_x , :]
176
+
177
+ interp_top = alpha_x * (top_right - top_left ) + top_left
178
+ interp_bottom = alpha_x * (bottom_right - bottom_left ) + bottom_left
179
+ interp = alpha_y * (interp_bottom - interp_top ) + interp_top
180
+ atol = 1e-6
181
+ rtol = 1e-6
182
+ if low_precision :
183
+ atol = 1e-2
184
+ rtol = 1e-3
185
+ np .testing .assert_allclose (
186
+ interp ,
187
+ pred_interpolation [batch_index , y_index , x_index , :],
188
+ atol = atol ,
189
+ rtol = rtol ,
190
+ )
191
+
192
+
193
+ def _get_random_image_and_flows (shape , image_type , flow_type ):
194
+ batch_size , height , width , num_channels = shape
195
+ image_shape = [batch_size , height , width , num_channels ]
196
+ image = np .random .normal (size = image_shape )
197
+ flow_shape = [batch_size , height , width , 2 ]
198
+ flows = np .random .normal (size = flow_shape ) * 3
199
+ return image .astype (image_type ), flows .astype (flow_type )
200
+
201
+
202
+ def _check_interpolation_correctness (
203
+ shape , image_type , flow_type , call_with_unknown_shapes = False , num_probes = 5
204
+ ):
205
+ """Interpolate, and then assert correctness for a few query
206
+ locations."""
207
+ low_precision = image_type == "float16" or flow_type == "float16"
208
+ rand_image , rand_flows = _get_random_image_and_flows (shape , image_type , flow_type )
209
+
210
+ if call_with_unknown_shapes :
211
+ fn = dense_image_warp .get_concrete_function (
212
+ tf .TensorSpec (shape = None , dtype = image_type ),
213
+ tf .TensorSpec (shape = None , dtype = flow_type ),
214
+ )
215
+ interp = fn (
216
+ image = tf .convert_to_tensor (rand_image ),
217
+ flow = tf .convert_to_tensor (rand_flows ),
218
+ )
219
+ else :
220
+ interp = dense_image_warp (
221
+ image = tf .convert_to_tensor (rand_image ),
222
+ flow = tf .convert_to_tensor (rand_flows ),
223
+ )
224
+
225
+ for _ in range (num_probes ):
226
+ batch_index = np .random .randint (0 , shape [0 ])
227
+ y_index = np .random .randint (0 , shape [1 ])
228
+ x_index = np .random .randint (0 , shape [2 ])
229
+
230
+ _assert_correct_interpolation_value (
231
+ rand_image ,
232
+ rand_flows ,
233
+ interp ,
234
+ batch_index ,
235
+ y_index ,
236
+ x_index ,
237
+ low_precision = low_precision ,
238
+ )
239
+
240
+
241
+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
242
+ def test_interpolation ():
243
+ """Apply _check_interpolation_correctness() for a few sizes and
244
+ types."""
245
+ shapes_to_try = [[3 , 4 , 5 , 6 ], [1 , 2 , 2 , 1 ]]
246
+ for im_type in ["float32" , "float64" , "float16" ]:
247
+ for flow_type in ["float32" , "float64" , "float16" ]:
248
+ for shape in shapes_to_try :
249
+ _check_interpolation_correctness (shape , im_type , flow_type )
250
+
251
+
252
+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
253
+ def test_size_exception ():
254
+ """Make sure it throws an exception for images that are too small."""
255
+ shape = [1 , 2 , 1 , 1 ]
256
+ errors = (ValueError , tf .errors .InvalidArgumentError )
257
+ with pytest .raises (errors ) as exception_raised :
258
+ _check_interpolation_correctness (shape , "float32" , "float32" )
259
+ assert "Grid width must be at least 2." in str (exception_raised .value )
260
+
261
+
262
+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
263
+ def test_unknown_shapes ():
264
+ """Apply _check_interpolation_correctness() for a few sizes and check
265
+ for tf.Dataset compatibility."""
266
+ shapes_to_try = [[3 , 4 , 5 , 6 ], [1 , 2 , 2 , 1 ]]
267
+ for shape in shapes_to_try :
268
+ _check_interpolation_correctness (shape , "float32" , "float32" , True )
262
269
263
270
264
271
if __name__ == "__main__" :
0 commit comments