Skip to content

Commit 2b7cae6

Browse files
author
Federico Pozzi
committed
refactor: improve tests
1 parent c9d5ca9 commit 2b7cae6

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

test/test_prototype_transforms.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -191,54 +191,56 @@ def test_convert_image_color_space(self, transform, input):
191191
transform(input)
192192

193193

194+
@pytest.mark.parametrize("p", [0.0, 1.0])
194195
class TestRandomHorizontalFlip:
195-
def input_tensor(self, dtype=torch.float32):
196-
return torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
196+
def input_expected_image_tensor(self, p, dtype=torch.float32):
197+
input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype)
198+
expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype)
197199

198-
def expected_tensor(self, dtype=torch.float32):
199-
return torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype)
200+
if p == 1.0:
201+
return input, expected
202+
return input, input
200203

201-
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"])
202204
def test_simple_tensor(self, p):
203-
input = self.input_tensor()
205+
input, expected = self.input_expected_image_tensor(p)
206+
transform = transforms.RandomHorizontalFlip(p=p)
204207

205-
actual = transforms.RandomHorizontalFlip(p=p)(input)
208+
actual = transform(input)
206209

207-
expected = self.expected_tensor() if p == 1.0 else input
208210
assert_equal(expected, actual)
209211

210-
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"])
211212
def test_pil_image(self, p):
212-
input = self.input_tensor(dtype=torch.uint8)
213+
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
214+
transform = transforms.RandomHorizontalFlip(p=p)
213215

214-
actual = transforms.RandomHorizontalFlip(p=p)(to_pil_image(input))
216+
actual = transform(to_pil_image(input))
215217

216-
expected = self.expected_tensor(dtype=torch.uint8) if p == 1.0 else input
217218
assert_equal(expected, pil_to_tensor(actual))
218219

219-
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"])
220220
def test_features_image(self, p):
221-
input = self.input_tensor()
221+
input, expected = self.input_expected_image_tensor(p)
222+
transform = transforms.RandomHorizontalFlip(p=p)
222223

223-
actual = transforms.RandomHorizontalFlip(p=p)(features.Image(input))
224+
actual = transform(features.Image(input))
224225

225-
expected = self.expected_tensor() if p == 1.0 else input
226226
assert_equal(features.Image(expected), actual)
227227

228-
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"])
229228
def test_features_segmentation_mask(self, p):
230-
input = features.SegmentationMask(self.input_tensor())
229+
input, expected = self.input_expected_image_tensor(p)
230+
transform = transforms.RandomHorizontalFlip(p=p)
231231

232-
actual = transforms.RandomHorizontalFlip(p=p)(input)
232+
actual = transform(features.SegmentationMask(input))
233233

234-
expected = self.expected_tensor() if p == 1.0 else input
235234
assert_equal(features.SegmentationMask(expected), actual)
236235

237-
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"])
238236
def test_features_bounding_box(self, p):
239237
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
238+
transform = transforms.RandomHorizontalFlip(p=p)
240239

241-
actual = transforms.RandomHorizontalFlip(p=p)(input)
240+
actual = transform(input)
242241

243-
expected = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
244-
assert_equal(features.BoundingBox.new_like(input, expected), actual)
242+
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
243+
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
244+
assert_equal(expected, actual)
245+
assert actual.format == expected.format
246+
assert actual.image_size == expected.image_size

0 commit comments

Comments
 (0)