1
1
import dataclasses
2
- from typing import Callable , Dict , Type
2
+ from typing import Callable , Dict , Sequence , Type
3
3
4
4
import pytest
5
5
import torchvision .prototype .transforms .functional as F
6
- from prototype_transforms_kernel_infos import KERNEL_INFOS
6
+ from prototype_transforms_kernel_infos import KERNEL_INFOS , Skip
7
7
from torchvision .prototype import features
8
8
9
9
__all__ = ["DispatcherInfo" , "DISPATCHER_INFOS" ]
15
15
class DispatcherInfo :
16
16
dispatcher : Callable
17
17
kernels : Dict [Type , Callable ]
18
+ skips : Sequence [Skip ] = dataclasses .field (default_factory = list )
19
+ _skips_map : Dict [str , Skip ] = dataclasses .field (default = None , init = False )
20
+
21
+ def __post_init__ (self ):
22
+ self ._skips_map = {skip .test_name : skip for skip in self .skips }
18
23
19
24
def sample_inputs (self , * types ):
20
25
for type in types or self .kernels .keys ():
@@ -23,6 +28,11 @@ def sample_inputs(self, *types):
23
28
24
29
yield from KERNEL_SAMPLE_INPUTS_FN_MAP [self .kernels [type ]]()
25
30
31
+ def maybe_skip (self , * , test_name , args_kwargs , device ):
32
+ skip = self ._skips_map .get (test_name )
33
+ if skip and skip .condition (args_kwargs , device ):
34
+ pytest .skip (skip .reason )
35
+
26
36
27
37
DISPATCHER_INFOS = [
28
38
DispatcherInfo (
@@ -97,6 +107,14 @@ def sample_inputs(self, *types):
97
107
features .Mask : F .perspective_mask ,
98
108
},
99
109
),
110
+ DispatcherInfo (
111
+ F .elastic ,
112
+ kernels = {
113
+ features .Image : F .elastic_image_tensor ,
114
+ features .BoundingBox : F .elastic_bounding_box ,
115
+ features .Mask : F .elastic_mask ,
116
+ },
117
+ ),
100
118
DispatcherInfo (
101
119
F .center_crop ,
102
120
kernels = {
@@ -153,4 +171,66 @@ def sample_inputs(self, *types):
153
171
features .Image : F .erase_image_tensor ,
154
172
},
155
173
),
174
+ DispatcherInfo (
175
+ F .adjust_brightness ,
176
+ kernels = {
177
+ features .Image : F .adjust_brightness_image_tensor ,
178
+ },
179
+ ),
180
+ DispatcherInfo (
181
+ F .adjust_contrast ,
182
+ kernels = {
183
+ features .Image : F .adjust_contrast_image_tensor ,
184
+ },
185
+ ),
186
+ DispatcherInfo (
187
+ F .adjust_gamma ,
188
+ kernels = {
189
+ features .Image : F .adjust_gamma_image_tensor ,
190
+ },
191
+ ),
192
+ DispatcherInfo (
193
+ F .adjust_hue ,
194
+ kernels = {
195
+ features .Image : F .adjust_hue_image_tensor ,
196
+ },
197
+ ),
198
+ DispatcherInfo (
199
+ F .adjust_saturation ,
200
+ kernels = {
201
+ features .Image : F .adjust_saturation_image_tensor ,
202
+ },
203
+ ),
204
+ DispatcherInfo (
205
+ F .five_crop ,
206
+ kernels = {
207
+ features .Image : F .five_crop_image_tensor ,
208
+ },
209
+ skips = [
210
+ Skip (
211
+ "test_scripted_smoke" ,
212
+ condition = lambda args_kwargs , device : isinstance (args_kwargs .kwargs ["size" ], int ),
213
+ reason = "Integer size is not supported when scripting five_crop_image_tensor." ,
214
+ ),
215
+ ],
216
+ ),
217
+ DispatcherInfo (
218
+ F .ten_crop ,
219
+ kernels = {
220
+ features .Image : F .ten_crop_image_tensor ,
221
+ },
222
+ skips = [
223
+ Skip (
224
+ "test_scripted_smoke" ,
225
+ condition = lambda args_kwargs , device : isinstance (args_kwargs .kwargs ["size" ], int ),
226
+ reason = "Integer size is not supported when scripting ten_crop_image_tensor." ,
227
+ ),
228
+ ],
229
+ ),
230
+ DispatcherInfo (
231
+ F .normalize ,
232
+ kernels = {
233
+ features .Image : F .normalize_image_tensor ,
234
+ },
235
+ ),
156
236
]
0 commit comments