Skip to content

Commit 39b91c8

Browse files
committed
break down massive parametrizations
1 parent 262e9c2 commit 39b91c8

File tree

1 file changed

+100
-22
lines changed

1 file changed

+100
-22
lines changed

test/test_transforms_v2_refactored.py

Lines changed: 100 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -963,57 +963,135 @@ def _adapt_fill(self, value, *, dtype):
963963
k: next(v for v in vs if v is not None) for k, vs in _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES.items()
964964
}
965965

966+
def _check_kernel(self, kernel, input, *args, **kwargs):
967+
kwargs_ = self._MINIMAL_AFFINE_KWARGS.copy()
968+
kwargs_.update(kwargs)
969+
check_kernel(kernel, input, *args, **kwargs_)
970+
966971
@pytest.mark.parametrize("angle", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"])
972+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
973+
@pytest.mark.parametrize("device", cpu_and_cuda())
974+
def test_kernel_image_tensor_angle(self, angle, dtype, device):
975+
self._check_kernel(
976+
F.affine_image_tensor,
977+
self._make_input(torch.Tensor, dtype=dtype, device=device),
978+
angle=angle,
979+
)
980+
967981
@pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"])
968-
@pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["scale"])
982+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
983+
@pytest.mark.parametrize("device", cpu_and_cuda())
984+
def test_kernel_image_tensor_translate(self, translate, dtype, device):
985+
self._check_kernel(
986+
F.affine_image_tensor,
987+
self._make_input(torch.Tensor, dtype=dtype, device=device),
988+
translate=translate,
989+
)
990+
969991
@pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"])
992+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
993+
@pytest.mark.parametrize("device", cpu_and_cuda())
994+
def test_kernel_image_tensor_shear(self, shear, dtype, device):
995+
self._check_kernel(
996+
F.affine_image_tensor,
997+
self._make_input(torch.Tensor, dtype=dtype, device=device),
998+
shear=shear,
999+
check_scripted_vs_eager=not isinstance(shear, (int, float)),
1000+
)
1001+
9701002
@pytest.mark.parametrize("center", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"])
1003+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
1004+
@pytest.mark.parametrize("device", cpu_and_cuda())
1005+
def test_kernel_image_tensor_center(self, center, dtype, device):
1006+
self._check_kernel(
1007+
F.affine_image_tensor,
1008+
self._make_input(torch.Tensor, dtype=dtype, device=device),
1009+
center=center,
1010+
)
1011+
9711012
@pytest.mark.parametrize(
9721013
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
9731014
)
974-
@pytest.mark.parametrize("fill", _EXHAUSTIVE_TYPE_FILLS)
9751015
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
9761016
@pytest.mark.parametrize("device", cpu_and_cuda())
977-
def test_kernel_image_tensor(self, angle, translate, scale, shear, center, interpolation, fill, dtype, device):
978-
check_kernel(
1017+
def test_kernel_image_tensor_interpolation(self, interpolation, dtype, device):
1018+
self._check_kernel(
9791019
F.affine_image_tensor,
9801020
self._make_input(torch.Tensor, dtype=dtype, device=device),
981-
angle=angle,
982-
translate=translate,
983-
scale=scale,
984-
shear=shear,
985-
center=center,
9861021
interpolation=interpolation,
987-
fill=self._adapt_fill(fill, dtype=dtype),
988-
check_scripted_vs_eager=not (isinstance(shear, (int, float)) or isinstance(fill, (int, float))),
9891022
check_cuda_vs_cpu=dict(atol=1, rtol=0)
9901023
if dtype is torch.uint8 and interpolation is transforms.InterpolationMode.BILINEAR
9911024
else True,
9921025
)
9931026

994-
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
1027+
@pytest.mark.parametrize("fill", _EXHAUSTIVE_TYPE_FILLS)
1028+
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
1029+
@pytest.mark.parametrize("device", cpu_and_cuda())
1030+
def test_kernel_image_tensor_fill(self, fill, dtype, device):
1031+
self._check_kernel(
1032+
F.affine_image_tensor,
1033+
self._make_input(torch.Tensor, dtype=dtype, device=device),
1034+
fill=self._adapt_fill(fill, dtype=dtype),
1035+
check_scripted_vs_eager=not isinstance(fill, (int, float)),
1036+
)
1037+
9951038
@pytest.mark.parametrize("angle", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"])
996-
@pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"])
997-
@pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["scale"])
998-
@pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"])
999-
@pytest.mark.parametrize("center", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"])
1039+
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
10001040
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
10011041
@pytest.mark.parametrize("device", cpu_and_cuda())
1002-
def test_kernel_bounding_box(self, format, angle, translate, scale, shear, center, dtype, device):
1003-
bounding_box = self._make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format)
1004-
check_kernel(
1042+
def test_kernel_bounding_box_angle(self, angle, format, dtype, device):
1043+
bounding_box = self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device)
1044+
self._check_kernel(
10051045
F.affine_bounding_box,
1006-
bounding_box,
1046+
self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device),
10071047
format=format,
10081048
spatial_size=bounding_box.spatial_size,
10091049
angle=angle,
1050+
)
1051+
1052+
@pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"])
1053+
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
1054+
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
1055+
@pytest.mark.parametrize("device", cpu_and_cuda())
1056+
def test_kernel_bounding_box_translate(self, translate, format, dtype, device):
1057+
bounding_box = self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device)
1058+
self._check_kernel(
1059+
F.affine_bounding_box,
1060+
self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device),
1061+
format=format,
1062+
spatial_size=bounding_box.spatial_size,
10101063
translate=translate,
1011-
scale=scale,
1064+
)
1065+
1066+
@pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"])
1067+
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
1068+
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
1069+
@pytest.mark.parametrize("device", cpu_and_cuda())
1070+
def test_kernel_bounding_box_shear(self, shear, format, dtype, device):
1071+
bounding_box = self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device)
1072+
self._check_kernel(
1073+
F.affine_bounding_box,
1074+
self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device),
1075+
format=format,
1076+
spatial_size=bounding_box.spatial_size,
10121077
shear=shear,
1013-
center=center,
10141078
check_scripted_vs_eager=not isinstance(shear, (int, float)),
10151079
)
10161080

1081+
@pytest.mark.parametrize("center", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"])
1082+
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
1083+
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
1084+
@pytest.mark.parametrize("device", cpu_and_cuda())
1085+
def test_kernel_bounding_box_center(self, center, format, dtype, device):
1086+
bounding_box = self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device)
1087+
self._check_kernel(
1088+
F.affine_bounding_box,
1089+
self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device),
1090+
format=format,
1091+
spatial_size=bounding_box.spatial_size,
1092+
center=center,
1093+
)
1094+
10171095
@pytest.mark.parametrize("mask_type", ["segmentation", "detection"])
10181096
def test_kernel_mask(self, mask_type):
10191097
check_kernel(

0 commit comments

Comments
 (0)