Skip to content

Commit 46eae18

Browse files
authored
use pytest markers instead of custom solution for prototype transforms functional tests (#6653)
* use pytest markers instead of custom solution for prototype transforms functional tests * cleanup * cleanup * trigger CI
1 parent a46c4f0 commit 46eae18

File tree

3 files changed

+201
-201
lines changed

3 files changed

+201
-201
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 69 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import dataclasses
33

44
from collections import defaultdict
5+
56
from typing import Callable, Dict, List, Optional, Sequence, Type
67

78
import pytest
89
import torchvision.prototype.transforms.functional as F
9-
from prototype_common_utils import BoundingBoxLoader
10-
from prototype_transforms_kernel_infos import KERNEL_INFOS, KernelInfo, Skip
10+
from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark
1111
from torchvision.prototype import features
1212

1313
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
@@ -24,35 +24,27 @@ def __post_init__(self):
2424
self.kernel_name = self.kernel_name or self.kernel.__name__
2525

2626

27-
def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
28-
return Skip(
29-
"test_scripted_smoke",
30-
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)),
31-
reason=reason,
32-
)
33-
34-
35-
def skip_integer_size_jit(name="size"):
36-
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.")
37-
38-
3927
@dataclasses.dataclass
4028
class DispatcherInfo:
4129
dispatcher: Callable
4230
kernels: Dict[Type, Callable]
43-
kernel_infos: Dict[Type, KernelInfo] = dataclasses.field(default=None)
4431
pil_kernel_info: Optional[PILKernelInfo] = None
4532
method_name: str = dataclasses.field(default=None)
46-
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
47-
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)
33+
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
34+
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)
4835

4936
def __post_init__(self):
5037
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
5138
self.method_name = self.method_name or self.dispatcher.__name__
52-
skips_map = defaultdict(list)
53-
for skip in self.skips:
54-
skips_map[skip.test_name].append(skip)
55-
self._skips_map = dict(skips_map)
39+
test_marks_map = defaultdict(list)
40+
for test_mark in self.test_marks:
41+
test_marks_map[test_mark.test_id].append(test_mark)
42+
self._test_marks_map = dict(test_marks_map)
43+
44+
def get_marks(self, test_id, args_kwargs):
45+
return [
46+
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
47+
]
5648

5749
def sample_inputs(self, *feature_types, filter_metadata=True):
5850
for feature_type in feature_types or self.kernels.keys():
@@ -70,17 +62,27 @@ def sample_inputs(self, *feature_types, filter_metadata=True):
7062

7163
yield args_kwargs
7264

73-
def maybe_skip(self, *, test_name, args_kwargs, device):
74-
skips = self._skips_map.get(test_name)
75-
if not skips:
76-
return
7765

78-
for skip in skips:
79-
if skip.condition(args_kwargs, device):
80-
pytest.skip(skip.reason)
66+
def xfail_python_scalar_arg_jit(name, *, reason=None):
67+
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
68+
return TestMark(
69+
("TestDispatchers", "test_scripted_smoke"),
70+
pytest.mark.xfail(reason=reason),
71+
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)),
72+
)
73+
8174

75+
def xfail_integer_size_jit(name="size"):
76+
return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.")
8277

83-
def fill_sequence_needs_broadcast(args_kwargs, device):
78+
79+
skip_dispatch_feature = TestMark(
80+
("TestDispatchers", "test_dispatch_feature"),
81+
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary feature dispatch."),
82+
)
83+
84+
85+
def fill_sequence_needs_broadcast(args_kwargs):
8486
(image_loader, *_), kwargs = args_kwargs
8587
try:
8688
fill = kwargs["fill"]
@@ -93,15 +95,12 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
9395
return image_loader.num_channels > 1
9496

9597

96-
skip_dispatch_pil_if_fill_sequence_needs_broadcast = Skip(
97-
"test_dispatch_pil",
98+
xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark(
99+
("TestDispatchers", "test_dispatch_pil"),
100+
pytest.mark.xfail(
101+
reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger."
102+
),
98103
condition=fill_sequence_needs_broadcast,
99-
reason="PIL kernel doesn't support sequences of length 1 if the number of channels is larger.",
100-
)
101-
102-
skip_dispatch_feature = Skip(
103-
"test_dispatch_feature",
104-
reason="Dispatcher doesn't support arbitrary feature dispatch.",
105104
)
106105

107106

@@ -123,8 +122,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
123122
features.Mask: F.resize_mask,
124123
},
125124
pil_kernel_info=PILKernelInfo(F.resize_image_pil),
126-
skips=[
127-
skip_integer_size_jit(),
125+
test_marks=[
126+
xfail_integer_size_jit(),
128127
],
129128
),
130129
DispatcherInfo(
@@ -135,9 +134,9 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
135134
features.Mask: F.affine_mask,
136135
},
137136
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
138-
skips=[
139-
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
140-
skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"),
137+
test_marks=[
138+
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
139+
xfail_python_scalar_arg_jit("shear"),
141140
],
142141
),
143142
DispatcherInfo(
@@ -166,16 +165,6 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
166165
features.Mask: F.crop_mask,
167166
},
168167
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
169-
skips=[
170-
Skip(
171-
"test_dispatch_feature",
172-
condition=lambda args_kwargs, device: isinstance(args_kwargs.args[0], BoundingBoxLoader),
173-
reason=(
174-
"F.crop expects 4 coordinates as input, but bounding box sample inputs only generate two "
175-
"since that is sufficient for the kernel."
176-
),
177-
)
178-
],
179168
),
180169
DispatcherInfo(
181170
F.resized_crop,
@@ -193,10 +182,20 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
193182
features.BoundingBox: F.pad_bounding_box,
194183
features.Mask: F.pad_mask,
195184
},
196-
skips=[
197-
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
198-
],
199185
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
186+
test_marks=[
187+
TestMark(
188+
("TestDispatchers", "test_dispatch_pil"),
189+
pytest.mark.xfail(
190+
reason=(
191+
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
192+
"`padding_mode='constant'`, if the number of color channels is larger."
193+
)
194+
),
195+
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
196+
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
197+
)
198+
],
200199
),
201200
DispatcherInfo(
202201
F.perspective,
@@ -205,10 +204,10 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
205204
features.BoundingBox: F.perspective_bounding_box,
206205
features.Mask: F.perspective_mask,
207206
},
208-
skips=[
209-
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
210-
],
211207
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
208+
test_marks=[
209+
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
210+
],
212211
),
213212
DispatcherInfo(
214213
F.elastic,
@@ -227,8 +226,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
227226
features.Mask: F.center_crop_mask,
228227
},
229228
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
230-
skips=[
231-
skip_integer_size_jit("output_size"),
229+
test_marks=[
230+
xfail_integer_size_jit("output_size"),
232231
],
233232
),
234233
DispatcherInfo(
@@ -237,9 +236,9 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
237236
features.Image: F.gaussian_blur_image_tensor,
238237
},
239238
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
240-
skips=[
241-
skip_python_scalar_arg_jit("kernel_size"),
242-
skip_python_scalar_arg_jit("sigma"),
239+
test_marks=[
240+
xfail_python_scalar_arg_jit("kernel_size"),
241+
xfail_python_scalar_arg_jit("sigma"),
243242
],
244243
),
245244
DispatcherInfo(
@@ -290,7 +289,7 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
290289
features.Image: F.erase_image_tensor,
291290
},
292291
pil_kernel_info=PILKernelInfo(F.erase_image_pil),
293-
skips=[
292+
test_marks=[
294293
skip_dispatch_feature,
295294
],
296295
),
@@ -335,8 +334,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
335334
features.Image: F.five_crop_image_tensor,
336335
},
337336
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
338-
skips=[
339-
skip_integer_size_jit(),
337+
test_marks=[
338+
xfail_integer_size_jit(),
340339
skip_dispatch_feature,
341340
],
342341
),
@@ -345,18 +344,18 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
345344
kernels={
346345
features.Image: F.ten_crop_image_tensor,
347346
},
348-
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
349-
skips=[
350-
skip_integer_size_jit(),
347+
test_marks=[
348+
xfail_integer_size_jit(),
351349
skip_dispatch_feature,
352350
],
351+
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
353352
),
354353
DispatcherInfo(
355354
F.normalize,
356355
kernels={
357356
features.Image: F.normalize_image_tensor,
358357
},
359-
skips=[
358+
test_marks=[
360359
skip_dispatch_feature,
361360
],
362361
),

0 commit comments

Comments
 (0)