2
2
import dataclasses
3
3
4
4
from collections import defaultdict
5
+
5
6
from typing import Callable , Dict , List , Optional , Sequence , Type
6
7
7
8
import pytest
8
9
import torchvision .prototype .transforms .functional as F
9
- from prototype_common_utils import BoundingBoxLoader
10
- from prototype_transforms_kernel_infos import KERNEL_INFOS , KernelInfo , Skip
10
+ from prototype_transforms_kernel_infos import KERNEL_INFOS , TestMark
11
11
from torchvision .prototype import features
12
12
13
13
__all__ = ["DispatcherInfo" , "DISPATCHER_INFOS" ]
@@ -24,35 +24,27 @@ def __post_init__(self):
24
24
self .kernel_name = self .kernel_name or self .kernel .__name__
25
25
26
26
27
- def skip_python_scalar_arg_jit (name , * , reason = "Python scalar int or float is not supported when scripting" ):
28
- return Skip (
29
- "test_scripted_smoke" ,
30
- condition = lambda args_kwargs , device : isinstance (args_kwargs .kwargs [name ], (int , float )),
31
- reason = reason ,
32
- )
33
-
34
-
35
- def skip_integer_size_jit (name = "size" ):
36
- return skip_python_scalar_arg_jit (name , reason = "Integer size is not supported when scripting." )
37
-
38
-
39
27
@dataclasses .dataclass
40
28
class DispatcherInfo :
41
29
dispatcher : Callable
42
30
kernels : Dict [Type , Callable ]
43
- kernel_infos : Dict [Type , KernelInfo ] = dataclasses .field (default = None )
44
31
pil_kernel_info : Optional [PILKernelInfo ] = None
45
32
method_name : str = dataclasses .field (default = None )
46
- skips : Sequence [Skip ] = dataclasses .field (default_factory = list )
47
- _skips_map : Dict [str , List [Skip ]] = dataclasses .field (default = None , init = False )
33
+ test_marks : Sequence [TestMark ] = dataclasses .field (default_factory = list )
34
+ _test_marks_map : Dict [str , List [TestMark ]] = dataclasses .field (default = None , init = False )
48
35
49
36
def __post_init__ (self ):
50
37
self .kernel_infos = {feature_type : KERNEL_INFO_MAP [kernel ] for feature_type , kernel in self .kernels .items ()}
51
38
self .method_name = self .method_name or self .dispatcher .__name__
52
- skips_map = defaultdict (list )
53
- for skip in self .skips :
54
- skips_map [skip .test_name ].append (skip )
55
- self ._skips_map = dict (skips_map )
39
+ test_marks_map = defaultdict (list )
40
+ for test_mark in self .test_marks :
41
+ test_marks_map [test_mark .test_id ].append (test_mark )
42
+ self ._test_marks_map = dict (test_marks_map )
43
+
44
+ def get_marks (self , test_id , args_kwargs ):
45
+ return [
46
+ test_mark .mark for test_mark in self ._test_marks_map .get (test_id , []) if test_mark .condition (args_kwargs )
47
+ ]
56
48
57
49
def sample_inputs (self , * feature_types , filter_metadata = True ):
58
50
for feature_type in feature_types or self .kernels .keys ():
@@ -70,17 +62,27 @@ def sample_inputs(self, *feature_types, filter_metadata=True):
70
62
71
63
yield args_kwargs
72
64
73
- def maybe_skip (self , * , test_name , args_kwargs , device ):
74
- skips = self ._skips_map .get (test_name )
75
- if not skips :
76
- return
77
65
78
- for skip in skips :
79
- if skip .condition (args_kwargs , device ):
80
- pytest .skip (skip .reason )
66
+ def xfail_python_scalar_arg_jit (name , * , reason = None ):
67
+ reason = reason or f"Python scalar int or float for `{ name } ` is not supported when scripting"
68
+ return TestMark (
69
+ ("TestDispatchers" , "test_scripted_smoke" ),
70
+ pytest .mark .xfail (reason = reason ),
71
+ condition = lambda args_kwargs : isinstance (args_kwargs .kwargs [name ], (int , float )),
72
+ )
73
+
81
74
75
+ def xfail_integer_size_jit (name = "size" ):
76
+ return xfail_python_scalar_arg_jit (name , reason = f"Integer `{ name } ` is not supported when scripting." )
82
77
83
- def fill_sequence_needs_broadcast (args_kwargs , device ):
78
+
79
+ skip_dispatch_feature = TestMark (
80
+ ("TestDispatchers" , "test_dispatch_feature" ),
81
+ pytest .mark .skip (reason = "Dispatcher doesn't support arbitrary feature dispatch." ),
82
+ )
83
+
84
+
85
+ def fill_sequence_needs_broadcast (args_kwargs ):
84
86
(image_loader , * _ ), kwargs = args_kwargs
85
87
try :
86
88
fill = kwargs ["fill" ]
@@ -93,15 +95,12 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
93
95
return image_loader .num_channels > 1
94
96
95
97
96
- skip_dispatch_pil_if_fill_sequence_needs_broadcast = Skip (
97
- "test_dispatch_pil" ,
98
+ xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark (
99
+ ("TestDispatchers" , "test_dispatch_pil" ),
100
+ pytest .mark .xfail (
101
+ reason = "PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger."
102
+ ),
98
103
condition = fill_sequence_needs_broadcast ,
99
- reason = "PIL kernel doesn't support sequences of length 1 if the number of channels is larger." ,
100
- )
101
-
102
- skip_dispatch_feature = Skip (
103
- "test_dispatch_feature" ,
104
- reason = "Dispatcher doesn't support arbitrary feature dispatch." ,
105
104
)
106
105
107
106
@@ -123,8 +122,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
123
122
features .Mask : F .resize_mask ,
124
123
},
125
124
pil_kernel_info = PILKernelInfo (F .resize_image_pil ),
126
- skips = [
127
- skip_integer_size_jit (),
125
+ test_marks = [
126
+ xfail_integer_size_jit (),
128
127
],
129
128
),
130
129
DispatcherInfo (
@@ -135,9 +134,9 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
135
134
features .Mask : F .affine_mask ,
136
135
},
137
136
pil_kernel_info = PILKernelInfo (F .affine_image_pil ),
138
- skips = [
139
- skip_dispatch_pil_if_fill_sequence_needs_broadcast ,
140
- skip_python_scalar_arg_jit ("shear" , reason = "Scalar shear is not supported by JIT " ),
137
+ test_marks = [
138
+ xfail_dispatch_pil_if_fill_sequence_needs_broadcast ,
139
+ xfail_python_scalar_arg_jit ("shear" ),
141
140
],
142
141
),
143
142
DispatcherInfo (
@@ -166,16 +165,6 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
166
165
features .Mask : F .crop_mask ,
167
166
},
168
167
pil_kernel_info = PILKernelInfo (F .crop_image_pil , kernel_name = "crop_image_pil" ),
169
- skips = [
170
- Skip (
171
- "test_dispatch_feature" ,
172
- condition = lambda args_kwargs , device : isinstance (args_kwargs .args [0 ], BoundingBoxLoader ),
173
- reason = (
174
- "F.crop expects 4 coordinates as input, but bounding box sample inputs only generate two "
175
- "since that is sufficient for the kernel."
176
- ),
177
- )
178
- ],
179
168
),
180
169
DispatcherInfo (
181
170
F .resized_crop ,
@@ -193,10 +182,20 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
193
182
features .BoundingBox : F .pad_bounding_box ,
194
183
features .Mask : F .pad_mask ,
195
184
},
196
- skips = [
197
- skip_dispatch_pil_if_fill_sequence_needs_broadcast ,
198
- ],
199
185
pil_kernel_info = PILKernelInfo (F .pad_image_pil , kernel_name = "pad_image_pil" ),
186
+ test_marks = [
187
+ TestMark (
188
+ ("TestDispatchers" , "test_dispatch_pil" ),
189
+ pytest .mark .xfail (
190
+ reason = (
191
+ "PIL kernel doesn't support sequences of length 1 for argument `fill` and "
192
+ "`padding_mode='constant'`, if the number of color channels is larger."
193
+ )
194
+ ),
195
+ condition = lambda args_kwargs : fill_sequence_needs_broadcast (args_kwargs )
196
+ and args_kwargs .kwargs .get ("padding_mode" , "constant" ) == "constant" ,
197
+ )
198
+ ],
200
199
),
201
200
DispatcherInfo (
202
201
F .perspective ,
@@ -205,10 +204,10 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
205
204
features .BoundingBox : F .perspective_bounding_box ,
206
205
features .Mask : F .perspective_mask ,
207
206
},
208
- skips = [
209
- skip_dispatch_pil_if_fill_sequence_needs_broadcast ,
210
- ],
211
207
pil_kernel_info = PILKernelInfo (F .perspective_image_pil ),
208
+ test_marks = [
209
+ xfail_dispatch_pil_if_fill_sequence_needs_broadcast ,
210
+ ],
212
211
),
213
212
DispatcherInfo (
214
213
F .elastic ,
@@ -227,8 +226,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
227
226
features .Mask : F .center_crop_mask ,
228
227
},
229
228
pil_kernel_info = PILKernelInfo (F .center_crop_image_pil ),
230
- skips = [
231
- skip_integer_size_jit ("output_size" ),
229
+ test_marks = [
230
+ xfail_integer_size_jit ("output_size" ),
232
231
],
233
232
),
234
233
DispatcherInfo (
@@ -237,9 +236,9 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
237
236
features .Image : F .gaussian_blur_image_tensor ,
238
237
},
239
238
pil_kernel_info = PILKernelInfo (F .gaussian_blur_image_pil ),
240
- skips = [
241
- skip_python_scalar_arg_jit ("kernel_size" ),
242
- skip_python_scalar_arg_jit ("sigma" ),
239
+ test_marks = [
240
+ xfail_python_scalar_arg_jit ("kernel_size" ),
241
+ xfail_python_scalar_arg_jit ("sigma" ),
243
242
],
244
243
),
245
244
DispatcherInfo (
@@ -290,7 +289,7 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
290
289
features .Image : F .erase_image_tensor ,
291
290
},
292
291
pil_kernel_info = PILKernelInfo (F .erase_image_pil ),
293
- skips = [
292
+ test_marks = [
294
293
skip_dispatch_feature ,
295
294
],
296
295
),
@@ -335,8 +334,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
335
334
features .Image : F .five_crop_image_tensor ,
336
335
},
337
336
pil_kernel_info = PILKernelInfo (F .five_crop_image_pil ),
338
- skips = [
339
- skip_integer_size_jit (),
337
+ test_marks = [
338
+ xfail_integer_size_jit (),
340
339
skip_dispatch_feature ,
341
340
],
342
341
),
@@ -345,18 +344,18 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
345
344
kernels = {
346
345
features .Image : F .ten_crop_image_tensor ,
347
346
},
348
- pil_kernel_info = PILKernelInfo (F .ten_crop_image_pil ),
349
- skips = [
350
- skip_integer_size_jit (),
347
+ test_marks = [
348
+ xfail_integer_size_jit (),
351
349
skip_dispatch_feature ,
352
350
],
351
+ pil_kernel_info = PILKernelInfo (F .ten_crop_image_pil ),
353
352
),
354
353
DispatcherInfo (
355
354
F .normalize ,
356
355
kernels = {
357
356
features .Image : F .normalize_image_tensor ,
358
357
},
359
- skips = [
358
+ test_marks = [
360
359
skip_dispatch_feature ,
361
360
],
362
361
),
0 commit comments