Skip to content

Commit ef8e08a

Browse files
Used fixture instead. (#1331)
1 parent abe0200 commit ef8e08a

File tree

1 file changed

+127
-120
lines changed

1 file changed

+127
-120
lines changed

tensorflow_addons/image/dense_image_warp_test.py

+127-120
Original file line numberDiff line numberDiff line change
@@ -90,66 +90,9 @@ def test_unknown_shape(self):
9090

9191
@test_utils.run_all_in_graph_and_eager_modes
9292
class DenseImageWarpTest(tf.test.TestCase):
93-
def _get_random_image_and_flows(self, shape, image_type, flow_type):
94-
batch_size, height, width, num_channels = shape
95-
image_shape = [batch_size, height, width, num_channels]
96-
image = np.random.normal(size=image_shape)
97-
flow_shape = [batch_size, height, width, 2]
98-
flows = np.random.normal(size=flow_shape) * 3
99-
return image.astype(image_type), flows.astype(flow_type)
100-
101-
def _assert_correct_interpolation_value(
102-
self,
103-
image,
104-
flows,
105-
pred_interpolation,
106-
batch_index,
107-
y_index,
108-
x_index,
109-
low_precision=False,
110-
):
111-
"""Assert that the tf interpolation matches hand-computed value."""
112-
height = image.shape[1]
113-
width = image.shape[2]
114-
displacement = flows[batch_index, y_index, x_index, :]
115-
float_y = y_index - displacement[0]
116-
float_x = x_index - displacement[1]
117-
floor_y = max(min(height - 2, math.floor(float_y)), 0)
118-
floor_x = max(min(width - 2, math.floor(float_x)), 0)
119-
ceil_y = floor_y + 1
120-
ceil_x = floor_x + 1
121-
122-
alpha_y = min(max(0.0, float_y - floor_y), 1.0)
123-
alpha_x = min(max(0.0, float_x - floor_x), 1.0)
124-
125-
floor_y = int(floor_y)
126-
floor_x = int(floor_x)
127-
ceil_y = int(ceil_y)
128-
ceil_x = int(ceil_x)
129-
130-
top_left = image[batch_index, floor_y, floor_x, :]
131-
top_right = image[batch_index, floor_y, ceil_x, :]
132-
bottom_left = image[batch_index, ceil_y, floor_x, :]
133-
bottom_right = image[batch_index, ceil_y, ceil_x, :]
134-
135-
interp_top = alpha_x * (top_right - top_left) + top_left
136-
interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left
137-
interp = alpha_y * (interp_bottom - interp_top) + interp_top
138-
atol = 1e-6
139-
rtol = 1e-6
140-
if low_precision:
141-
atol = 1e-2
142-
rtol = 1e-3
143-
self.assertAllClose(
144-
interp,
145-
pred_interpolation[batch_index, y_index, x_index, :],
146-
atol=atol,
147-
rtol=rtol,
148-
)
149-
15093
def _check_zero_flow_correctness(self, shape, image_type, flow_type):
15194
"""Assert using zero flows doesn't change the input image."""
152-
rand_image, rand_flows = self._get_random_image_and_flows(
95+
rand_image, rand_flows = _get_random_image_and_flows(
15396
shape, image_type, flow_type
15497
)
15598
rand_flows *= 0
@@ -169,62 +112,6 @@ def test_zero_flows(self):
169112
shape, image_type="float32", flow_type="float32"
170113
)
171114

172-
def _check_interpolation_correctness(
173-
self, shape, image_type, flow_type, call_with_unknown_shapes=False, num_probes=5
174-
):
175-
"""Interpolate, and then assert correctness for a few query
176-
locations."""
177-
low_precision = image_type == "float16" or flow_type == "float16"
178-
rand_image, rand_flows = self._get_random_image_and_flows(
179-
shape, image_type, flow_type
180-
)
181-
182-
if call_with_unknown_shapes:
183-
fn = dense_image_warp.get_concrete_function(
184-
tf.TensorSpec(shape=None, dtype=image_type),
185-
tf.TensorSpec(shape=None, dtype=flow_type),
186-
)
187-
interp = fn(
188-
image=tf.convert_to_tensor(rand_image),
189-
flow=tf.convert_to_tensor(rand_flows),
190-
)
191-
else:
192-
interp = dense_image_warp(
193-
image=tf.convert_to_tensor(rand_image),
194-
flow=tf.convert_to_tensor(rand_flows),
195-
)
196-
197-
for _ in range(num_probes):
198-
batch_index = np.random.randint(0, shape[0])
199-
y_index = np.random.randint(0, shape[1])
200-
x_index = np.random.randint(0, shape[2])
201-
202-
self._assert_correct_interpolation_value(
203-
rand_image,
204-
rand_flows,
205-
interp,
206-
batch_index,
207-
y_index,
208-
x_index,
209-
low_precision=low_precision,
210-
)
211-
212-
def test_interpolation(self):
213-
"""Apply _check_interpolation_correctness() for a few sizes and
214-
types."""
215-
shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]
216-
for im_type in ["float32", "float64", "float16"]:
217-
for flow_type in ["float32", "float64", "float16"]:
218-
for shape in shapes_to_try:
219-
self._check_interpolation_correctness(shape, im_type, flow_type)
220-
221-
def test_unknown_shapes(self):
222-
"""Apply _check_interpolation_correctness() for a few sizes and check
223-
for tf.Dataset compatibility."""
224-
shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]
225-
for shape in shapes_to_try:
226-
self._check_interpolation_correctness(shape, "float32", "float32", True)
227-
228115
def test_gradients_exist(self):
229116
"""Check that backprop can run.
230117
@@ -253,12 +140,132 @@ def loss():
253140
for _ in range(10):
254141
self.evaluate(minimize_op)
255142

256-
def test_size_exception(self):
257-
"""Make sure it throws an exception for images that are too small."""
258-
shape = [1, 2, 1, 1]
259-
errors = (ValueError, tf.errors.InvalidArgumentError)
260-
with self.assertRaisesRegexp(errors, "Grid width must be at least 2."):
261-
self._check_interpolation_correctness(shape, "float32", "float32")
143+
144+
def _assert_correct_interpolation_value(
145+
image,
146+
flows,
147+
pred_interpolation,
148+
batch_index,
149+
y_index,
150+
x_index,
151+
low_precision=False,
152+
):
153+
"""Assert that the tf interpolation matches hand-computed value."""
154+
height = image.shape[1]
155+
width = image.shape[2]
156+
displacement = flows[batch_index, y_index, x_index, :]
157+
float_y = y_index - displacement[0]
158+
float_x = x_index - displacement[1]
159+
floor_y = max(min(height - 2, math.floor(float_y)), 0)
160+
floor_x = max(min(width - 2, math.floor(float_x)), 0)
161+
ceil_y = floor_y + 1
162+
ceil_x = floor_x + 1
163+
164+
alpha_y = min(max(0.0, float_y - floor_y), 1.0)
165+
alpha_x = min(max(0.0, float_x - floor_x), 1.0)
166+
167+
floor_y = int(floor_y)
168+
floor_x = int(floor_x)
169+
ceil_y = int(ceil_y)
170+
ceil_x = int(ceil_x)
171+
172+
top_left = image[batch_index, floor_y, floor_x, :]
173+
top_right = image[batch_index, floor_y, ceil_x, :]
174+
bottom_left = image[batch_index, ceil_y, floor_x, :]
175+
bottom_right = image[batch_index, ceil_y, ceil_x, :]
176+
177+
interp_top = alpha_x * (top_right - top_left) + top_left
178+
interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left
179+
interp = alpha_y * (interp_bottom - interp_top) + interp_top
180+
atol = 1e-6
181+
rtol = 1e-6
182+
if low_precision:
183+
atol = 1e-2
184+
rtol = 1e-3
185+
np.testing.assert_allclose(
186+
interp,
187+
pred_interpolation[batch_index, y_index, x_index, :],
188+
atol=atol,
189+
rtol=rtol,
190+
)
191+
192+
193+
def _get_random_image_and_flows(shape, image_type, flow_type):
194+
batch_size, height, width, num_channels = shape
195+
image_shape = [batch_size, height, width, num_channels]
196+
image = np.random.normal(size=image_shape)
197+
flow_shape = [batch_size, height, width, 2]
198+
flows = np.random.normal(size=flow_shape) * 3
199+
return image.astype(image_type), flows.astype(flow_type)
200+
201+
202+
def _check_interpolation_correctness(
203+
shape, image_type, flow_type, call_with_unknown_shapes=False, num_probes=5
204+
):
205+
"""Interpolate, and then assert correctness for a few query
206+
locations."""
207+
low_precision = image_type == "float16" or flow_type == "float16"
208+
rand_image, rand_flows = _get_random_image_and_flows(shape, image_type, flow_type)
209+
210+
if call_with_unknown_shapes:
211+
fn = dense_image_warp.get_concrete_function(
212+
tf.TensorSpec(shape=None, dtype=image_type),
213+
tf.TensorSpec(shape=None, dtype=flow_type),
214+
)
215+
interp = fn(
216+
image=tf.convert_to_tensor(rand_image),
217+
flow=tf.convert_to_tensor(rand_flows),
218+
)
219+
else:
220+
interp = dense_image_warp(
221+
image=tf.convert_to_tensor(rand_image),
222+
flow=tf.convert_to_tensor(rand_flows),
223+
)
224+
225+
for _ in range(num_probes):
226+
batch_index = np.random.randint(0, shape[0])
227+
y_index = np.random.randint(0, shape[1])
228+
x_index = np.random.randint(0, shape[2])
229+
230+
_assert_correct_interpolation_value(
231+
rand_image,
232+
rand_flows,
233+
interp,
234+
batch_index,
235+
y_index,
236+
x_index,
237+
low_precision=low_precision,
238+
)
239+
240+
241+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
242+
def test_interpolation():
243+
"""Apply _check_interpolation_correctness() for a few sizes and
244+
types."""
245+
shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]
246+
for im_type in ["float32", "float64", "float16"]:
247+
for flow_type in ["float32", "float64", "float16"]:
248+
for shape in shapes_to_try:
249+
_check_interpolation_correctness(shape, im_type, flow_type)
250+
251+
252+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
253+
def test_size_exception():
254+
"""Make sure it throws an exception for images that are too small."""
255+
shape = [1, 2, 1, 1]
256+
errors = (ValueError, tf.errors.InvalidArgumentError)
257+
with pytest.raises(errors) as exception_raised:
258+
_check_interpolation_correctness(shape, "float32", "float32")
259+
assert "Grid width must be at least 2." in str(exception_raised.value)
260+
261+
262+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
263+
def test_unknown_shapes():
264+
"""Apply _check_interpolation_correctness() for a few sizes and check
265+
for tf.Dataset compatibility."""
266+
shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]
267+
for shape in shapes_to_try:
268+
_check_interpolation_correctness(shape, "float32", "float32", True)
262269

263270

264271
if __name__ == "__main__":

0 commit comments

Comments
 (0)