Skip to content

Commit 4adbe3c

Browse files
jdsgomesfacebook-github-bot
authored andcommitted
[fbsync] add more KernelInfo's and DispatcherInfo's (#6626)
Summary: * add KernelInfo for adjust_brightness * add KernelInfo for adjust_contrast * add KernelInfo for adjust_hue * add KernelInfo for adjust_saturation * add KernelInfo for clamp_bounding_box * add KernelInfo for {five, ten}_crop_image_tensor as well as skip functionality * add KernelInfo for normalize * add KernelInfo for adjust_gamma * cleanup * add DispatcherInfo's for previously add KernelInfo's * add dispatcher info for elastic Reviewed By: NicolasHug Differential Revision: D39765294 fbshipit-source-id: 42bdb40ead8807f60fc75fb9259dd34e2d4f0724
1 parent 4407eb6 commit 4adbe3c

File tree

3 files changed

+382
-31
lines changed

3 files changed

+382
-31
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import dataclasses
2-
from typing import Callable, Dict, Type
2+
from typing import Callable, Dict, Sequence, Type
33

44
import pytest
55
import torchvision.prototype.transforms.functional as F
6-
from prototype_transforms_kernel_infos import KERNEL_INFOS
6+
from prototype_transforms_kernel_infos import KERNEL_INFOS, Skip
77
from torchvision.prototype import features
88

99
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
@@ -15,6 +15,11 @@
1515
class DispatcherInfo:
1616
dispatcher: Callable
1717
kernels: Dict[Type, Callable]
18+
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
19+
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)
20+
21+
def __post_init__(self):
22+
self._skips_map = {skip.test_name: skip for skip in self.skips}
1823

1924
def sample_inputs(self, *types):
2025
for type in types or self.kernels.keys():
@@ -23,6 +28,11 @@ def sample_inputs(self, *types):
2328

2429
yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()
2530

31+
def maybe_skip(self, *, test_name, args_kwargs, device):
32+
skip = self._skips_map.get(test_name)
33+
if skip and skip.condition(args_kwargs, device):
34+
pytest.skip(skip.reason)
35+
2636

2737
DISPATCHER_INFOS = [
2838
DispatcherInfo(
@@ -97,6 +107,14 @@ def sample_inputs(self, *types):
97107
features.Mask: F.perspective_mask,
98108
},
99109
),
110+
DispatcherInfo(
111+
F.elastic,
112+
kernels={
113+
features.Image: F.elastic_image_tensor,
114+
features.BoundingBox: F.elastic_bounding_box,
115+
features.Mask: F.elastic_mask,
116+
},
117+
),
100118
DispatcherInfo(
101119
F.center_crop,
102120
kernels={
@@ -153,4 +171,66 @@ def sample_inputs(self, *types):
153171
features.Image: F.erase_image_tensor,
154172
},
155173
),
174+
DispatcherInfo(
175+
F.adjust_brightness,
176+
kernels={
177+
features.Image: F.adjust_brightness_image_tensor,
178+
},
179+
),
180+
DispatcherInfo(
181+
F.adjust_contrast,
182+
kernels={
183+
features.Image: F.adjust_contrast_image_tensor,
184+
},
185+
),
186+
DispatcherInfo(
187+
F.adjust_gamma,
188+
kernels={
189+
features.Image: F.adjust_gamma_image_tensor,
190+
},
191+
),
192+
DispatcherInfo(
193+
F.adjust_hue,
194+
kernels={
195+
features.Image: F.adjust_hue_image_tensor,
196+
},
197+
),
198+
DispatcherInfo(
199+
F.adjust_saturation,
200+
kernels={
201+
features.Image: F.adjust_saturation_image_tensor,
202+
},
203+
),
204+
DispatcherInfo(
205+
F.five_crop,
206+
kernels={
207+
features.Image: F.five_crop_image_tensor,
208+
},
209+
skips=[
210+
Skip(
211+
"test_scripted_smoke",
212+
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
213+
reason="Integer size is not supported when scripting five_crop_image_tensor.",
214+
),
215+
],
216+
),
217+
DispatcherInfo(
218+
F.ten_crop,
219+
kernels={
220+
features.Image: F.ten_crop_image_tensor,
221+
},
222+
skips=[
223+
Skip(
224+
"test_scripted_smoke",
225+
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
226+
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
227+
),
228+
],
229+
),
230+
DispatcherInfo(
231+
F.normalize,
232+
kernels={
233+
features.Image: F.normalize_image_tensor,
234+
},
235+
),
156236
]

0 commit comments

Comments
 (0)