Skip to content

Commit 980e77b

Browse files
committed
address review
1 parent ae089a8 commit 980e77b

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

test/test_transforms_v2_refactored.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ def _make_input(
896896
return input
897897

898898
def _adapt_fill(self, value, *, dtype):
899-
"""Adapt ``fill`` values in the range ``[0.0, 1.0]`` to the value range of the dtype"""
899+
"""Adapt fill values in the range [0.0, 1.0] to the value range of the dtype"""
900900
if value is None:
901901
return value
902902

@@ -907,9 +907,7 @@ def _adapt_fill(self, value, *, dtype):
907907
elif isinstance(value, (list, tuple)):
908908
return type(value)(type(v)(v * max_value) for v in value)
909909
else:
910-
raise pytest.UsageError(
911-
f"`fill` should be an int or float, or a list or tuple of the former, but got {value}"
912-
)
910+
raise ValueError(f"fill should be an int or float, or a list or tuple of the former, but got '{value}'")
913911

914912
_EXHAUSTIVE_TYPE_AFFINE_KWARGS = dict(
915913
# float, int
@@ -926,7 +924,7 @@ def _adapt_fill(self, value, *, dtype):
926924
# two-list of float, two-list of int, two-tuple of float, two-tuple of int
927925
center=[None, [1.2, 4.9], [-3, 1], (2.5, -4.7), (3, 2)],
928926
)
929-
# The special case for `"shear"` makes sure we pick a value that is supported while JIT scripting
927+
# The special case for shear makes sure we pick a value that is supported while JIT scripting
930928
_MINIMAL_AFFINE_KWARGS = {
931929
k: vs[0] if k != "shear" else next(v for v in vs if isinstance(v, list))
932930
for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items()
@@ -1110,14 +1108,14 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent
11101108
)
11111109

11121110
mae = (actual.float() - expected.float()).abs().mean()
1113-
assert mae < 5
1111+
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
11141112

11151113
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
11161114
@pytest.mark.parametrize(
11171115
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
11181116
)
11191117
@pytest.mark.parametrize("fill", _CORRECTNESS_FILL)
1120-
@pytest.mark.parametrize("seed", list(range(10)))
1118+
@pytest.mark.parametrize("seed", list(range(5)))
11211119
def test_transform_image_correctness(self, center, interpolation, fill, seed):
11221120
image = self._make_input(torch.Tensor, dtype=torch.uint8, device="cpu")
11231121

@@ -1127,18 +1125,14 @@ def test_transform_image_correctness(self, center, interpolation, fill, seed):
11271125
**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center, interpolation=interpolation, fill=fill
11281126
)
11291127

1130-
torch.manual_seed(seed)
1131-
params = transform._get_params([image])
1132-
11331128
torch.manual_seed(seed)
11341129
actual = transform(image)
11351130

1136-
expected = F.to_image_tensor(
1137-
F.affine(F.to_image_pil(image), **params, center=center, interpolation=interpolation, fill=fill)
1138-
)
1131+
torch.manual_seed(seed)
1132+
expected = F.to_image_tensor(transform(F.to_image_pil(image)))
11391133

11401134
mae = (actual.float() - expected.float()).abs().mean()
1141-
assert mae < 7
1135+
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
11421136

11431137
def _compute_affine_matrix(self, *, angle, translate, scale, shear, center):
11441138
rot = math.radians(angle)
@@ -1210,7 +1204,7 @@ def test_functional_bounding_box_correctness(self, format, angle, translate, sca
12101204

12111205
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
12121206
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1213-
@pytest.mark.parametrize("seed", list(range(10)))
1207+
@pytest.mark.parametrize("seed", list(range(5)))
12141208
def test_transform_bounding_box_correctness(self, format, center, seed):
12151209
bounding_box = self._make_input(datapoints.BoundingBox, format=format)
12161210

0 commit comments

Comments
 (0)