Skip to content
84 changes: 82 additions & 2 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
from typing import Callable, Dict, Type
from typing import Callable, Dict, Sequence, Type

import pytest
import torchvision.prototype.transforms.functional as F
from prototype_transforms_kernel_infos import KERNEL_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS, Skip
from torchvision.prototype import features

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
Expand All @@ -15,6 +15,11 @@
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)

def __post_init__(self):
self._skips_map = {skip.test_name: skip for skip in self.skips}

def sample_inputs(self, *types):
for type in types or self.kernels.keys():
Expand All @@ -23,6 +28,11 @@ def sample_inputs(self, *types):

yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()

def maybe_skip(self, *, test_name, args_kwargs, device):
skip = self._skips_map.get(test_name)
if skip and skip.condition(args_kwargs, device):
pytest.skip(skip.reason)


DISPATCHER_INFOS = [
DispatcherInfo(
Expand Down Expand Up @@ -97,6 +107,14 @@ def sample_inputs(self, *types):
features.Mask: F.perspective_mask,
},
),
DispatcherInfo(
F.elastic,
kernels={
features.Image: F.elastic_image_tensor,
features.BoundingBox: F.elastic_bounding_box,
features.Mask: F.elastic_mask,
},
),
DispatcherInfo(
F.center_crop,
kernels={
Expand Down Expand Up @@ -153,4 +171,66 @@ def sample_inputs(self, *types):
features.Image: F.erase_image_tensor,
},
),
DispatcherInfo(
F.adjust_brightness,
kernels={
features.Image: F.adjust_brightness_image_tensor,
},
),
DispatcherInfo(
F.adjust_contrast,
kernels={
features.Image: F.adjust_contrast_image_tensor,
},
),
DispatcherInfo(
F.adjust_gamma,
kernels={
features.Image: F.adjust_gamma_image_tensor,
},
),
DispatcherInfo(
F.adjust_hue,
kernels={
features.Image: F.adjust_hue_image_tensor,
},
),
DispatcherInfo(
F.adjust_saturation,
kernels={
features.Image: F.adjust_saturation_image_tensor,
},
),
DispatcherInfo(
F.five_crop,
kernels={
features.Image: F.five_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting five_crop_image_tensor.",
),
],
),
DispatcherInfo(
F.ten_crop,
kernels={
features.Image: F.ten_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
),
],
),
DispatcherInfo(
F.normalize,
kernels={
features.Image: F.normalize_image_tensor,
},
),
]
Loading