From 9c6643955847126c150d649e075f0edf1f308a09 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 14:26:18 +0200 Subject: [PATCH 01/11] add KernelInfo for adjust_brightness --- test/prototype_transforms_kernel_infos.py | 31 +++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index b41e8409d54..eba6890ec9e 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1223,3 +1223,34 @@ def sample_inputs_erase_image_tensor(): sample_inputs_fn=sample_inputs_erase_image_tensor, ) ) + +_ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5] + + +def sample_inputs_adjust_brightness_image_tensor(): + for image_loader in make_image_loaders( + sizes=["random"], + color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), + # FIXME: kernel should support arbitrary batch sizes + ): + yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) + + +def reference_inputs_adjust_brightness_image_tensor(): + for image_loader, brightness_factor in itertools.product( + make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), + _ADJUST_BRIGHTNESS_FACTORS, + ): + yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) + + +KERNEL_INFOS.append( + KernelInfo( + F.adjust_brightness_image_tensor, + kernel_name="adjust_brightness_image_tensor", + sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil), + reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ) +) From fc21ef8eb63152e5cfc553ca2559619ef4f596b8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 14:28:51 +0200 Subject: [PATCH 02/11] add KernelInfo for adjust_contrast --- test/prototype_transforms_kernel_infos.py | 34 +++++++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index eba6890ec9e..5034d2ddeb3 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1229,9 +1229,7 @@ def sample_inputs_erase_image_tensor(): def sample_inputs_adjust_brightness_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), - # FIXME: kernel should support arbitrary batch sizes + sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) ): yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) @@ -1254,3 +1252,33 @@ def reference_inputs_adjust_brightness_image_tensor(): closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ) ) + + +_ADJUST_CONTRAST_FACTORS = [0.1, 0.5] + + +def sample_inputs_adjust_contrast_image_tensor(): + for image_loader in make_image_loaders( + sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + ): + yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) + + +def reference_inputs_adjust_contrast_image_tensor(): + for image_loader, contrast_factor in itertools.product( + make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), + _ADJUST_CONTRAST_FACTORS, + ): + yield ArgsKwargs(image_loader, contrast_factor=contrast_factor) + + +KERNEL_INFOS.append( + KernelInfo( + F.adjust_contrast_image_tensor, + kernel_name="adjust_contrast_image_tensor", + sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), + reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ) +) From d89a68753acea7266776cad3af6ef440978a1f06 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 14:31:16 +0200 Subject: [PATCH 03/11] add KernelInfo for adjust_hue --- test/prototype_transforms_kernel_infos.py | 30 +++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 5034d2ddeb3..701bef1f993 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1282,3 +1282,33 @@ def reference_inputs_adjust_contrast_image_tensor(): closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ) ) + + +_ADJUST_HUE_FACTORS = [-0.1, 0.5] + + +def sample_inputs_adjust_hue_image_tensor(): + for image_loader in make_image_loaders( + sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + ): + yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) + + +def reference_inputs_adjust_hue_image_tensor(): + for image_loader, hue_factor in itertools.product( + make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), + _ADJUST_HUE_FACTORS, + ): + yield ArgsKwargs(image_loader, hue_factor=hue_factor) + + +KERNEL_INFOS.append( + KernelInfo( + F.adjust_hue_image_tensor, + kernel_name="adjust_hue_image_tensor", + sample_inputs_fn=sample_inputs_adjust_hue_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), + reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ) +) From b26ee039f00de0fef3b3f790e6796e145b23925a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 14:33:06 +0200 Subject: [PATCH 04/11] add KernelInfo for adjust_saturation --- test/prototype_transforms_kernel_infos.py | 29 +++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 701bef1f993..ea810b00fe4 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1312,3 +1312,32 @@ def reference_inputs_adjust_hue_image_tensor(): closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ) ) + +_ADJUST_SATURATION_FACTORS = [0.1, 0.5] + + +def sample_inputs_adjust_saturation_image_tensor(): + for image_loader in make_image_loaders( + sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + ): + yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) + + +def reference_inputs_adjust_saturation_image_tensor(): + for image_loader, saturation_factor in itertools.product( + make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), + _ADJUST_SATURATION_FACTORS, + ): + yield ArgsKwargs(image_loader, saturation_factor=saturation_factor) + + +KERNEL_INFOS.append( + KernelInfo( + F.adjust_saturation_image_tensor, + kernel_name="adjust_saturation_image_tensor", + sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), + reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ) +) From dfb790dae427cb819d050606bcf8c7afdb40c1a3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 15:56:09 +0200 Subject: [PATCH 05/11] add KernelInfo for clamp_bounding_box --- test/prototype_transforms_kernel_infos.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index ea810b00fe4..1d0af64f661 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1341,3 +1341,18 @@ def reference_inputs_adjust_saturation_image_tensor(): closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, ) ) + + +def sample_inputs_clamp_bounding_box(): + for bounding_box_loader in make_bounding_box_loaders(): + yield ArgsKwargs( + bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size + ) + + +KERNEL_INFOS.append( + KernelInfo( + F.clamp_bounding_box, + sample_inputs_fn=sample_inputs_clamp_bounding_box, + ) +) From 1b5d991c6765f736b5e9a7e94e83d8dc309433d4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 15:56:43 +0200 Subject: [PATCH 06/11] add KernelInfo for {five, ten}_crop_image_tensor as well as skip functionality --- test/prototype_transforms_kernel_infos.py | 93 ++++++++++++++++++++++- test/test_prototype_transforms_kernels.py | 24 +++++- 2 files changed, 115 insertions(+), 2 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 1d0af64f661..752510e254d 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -2,7 +2,7 @@ import functools import itertools import math -from typing import Any, Callable, Dict, Iterable, Optional +from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple import numpy as np import pytest @@ -17,6 +17,13 @@ __all__ = ["KernelInfo", "KERNEL_INFOS"] +@dataclasses.dataclass +class Skip: + test_name: str + reason: str + condition: Callable[[Tuple[ArgsKwargs, str]], bool] = lambda args_kwargs, device: True + + @dataclasses.dataclass class KernelInfo: kernel: Callable @@ -36,10 +43,18 @@ class KernelInfo: reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + skips: Sequence[Skip] = dataclasses.field(default_factory=list) + _skip_map: Dict[Tuple[Optional[str], str], Skip] = dataclasses.field(default=None, init=False) def __post_init__(self): self.kernel_name = self.kernel_name or self.kernel.__name__ self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn + self._skips_map = {skip.test_name: skip for skip in self.skips} + + def maybe_skip(self, *, test_name, args_kwargs, device): + skip = self._skips_map.get(test_name) + if skip and skip.condition(args_kwargs, device): + pytest.skip(skip.reason) DEFAULT_IMAGE_CLOSENESS_KWARGS = dict( @@ -1356,3 +1371,79 @@ def sample_inputs_clamp_bounding_box(): sample_inputs_fn=sample_inputs_clamp_bounding_box, ) ) + +_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]] + + +def _get_five_ten_crop_image_size(size): + if isinstance(size, int): + crop_height = crop_width = size + elif len(size) == 1: + crop_height = crop_width = size[0] + else: + crop_height, crop_width = size + return 2 * crop_height, 2 * crop_width + + +def sample_inputs_five_crop_image_tensor(): + for size in _FIVE_TEN_CROP_SIZES: + for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)]): + yield ArgsKwargs(image_loader, size=size) + + +def reference_inputs_five_crop_image_tensor(): + for size in _FIVE_TEN_CROP_SIZES: + for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]): + yield ArgsKwargs(image_loader, size=size) + + +def sample_inputs_ten_crop_image_tensor(): + for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): + for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)]): + yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) + + +def reference_inputs_ten_crop_image_tensor(): + for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): + for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]): + yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) + + +KERNEL_INFOS.extend( + [ + KernelInfo( + F.five_crop_image_tensor, + sample_inputs_fn=sample_inputs_five_crop_image_tensor, + reference_fn=pil_reference_wrapper(F.five_crop_image_pil), + reference_inputs_fn=reference_inputs_five_crop_image_tensor, + skips=[ + Skip( + "test_scripted_vs_eager", + condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int), + reason="Integer size is not supported when scripting five_crop_image_tensor.", + ), + Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."), + Skip("test_no_inplace", reason="Output of five_crop_image_tensor is not a tensor."), + Skip("test_dtype_and_device_consistency", reason="Output of five_crop_image_tensor is not a tensor."), + ], + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + KernelInfo( + F.ten_crop_image_tensor, + sample_inputs_fn=sample_inputs_ten_crop_image_tensor, + reference_fn=pil_reference_wrapper(F.ten_crop_image_pil), + reference_inputs_fn=reference_inputs_ten_crop_image_tensor, + skips=[ + Skip( + "test_scripted_vs_eager", + condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int), + reason="Integer size is not supported when scripting ten_crop_image_tensor.", + ), + Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."), + Skip("test_no_inplace", reason="Output of ten_crop_image_tensor is not a tensor."), + Skip("test_dtype_and_device_consistency", reason="Output of ten_crop_image_tensor is not a tensor."), + ], + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ), + ] +) diff --git a/test/test_prototype_transforms_kernels.py b/test/test_prototype_transforms_kernels.py index b4d048c34a1..82002f86551 100644 --- a/test/test_prototype_transforms_kernels.py +++ b/test/test_prototype_transforms_kernels.py @@ -9,8 +9,11 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F +KERNEL_INFOS = KERNEL_INFOS[-3:] -def test_coverage(): + +def test_coverage(maybe_skip): + maybe_skip(None) tested = {info.kernel_name for info in KERNEL_INFOS} exposed = { name @@ -58,6 +61,25 @@ def test_coverage(): ) +@pytest.fixture(autouse=True) +def maybe_skip(request): + # In case the test uses no parametrization or fixtures, the `callspec` attribute does not exist + try: + callspec = request.node.callspec + except AttributeError: + return + + try: + info = callspec.params["info"] + args_kwargs = callspec.params["args_kwargs"] + except KeyError: + return + + info.maybe_skip( + test_name=request.node.originalname, args_kwargs=args_kwargs, device=callspec.params.get("device", "cpu") + ) + + class TestCommon: sample_inputs = pytest.mark.parametrize( ("info", "args_kwargs"), From 9df4206e34fc3c50dccb6e1bbd5f96c15771f47c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 16:02:29 +0200 Subject: [PATCH 07/11] add KernelInfo for normalize --- test/prototype_transforms_kernel_infos.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 752510e254d..dd40a01925b 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1447,3 +1447,25 @@ def reference_inputs_ten_crop_image_tensor(): ), ] ) + +_NORMALIZE_MEANS_STDS = [ + ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]), +] + + +def sample_inputs_normalize_image_tensor(): + for image_loader, (mean, std) in itertools.product( + make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]), + _NORMALIZE_MEANS_STDS, + ): + yield ArgsKwargs(image_loader, mean=mean, std=std) + + +KERNEL_INFOS.append( + KernelInfo( + F.normalize, + kernel_name="normalize_image_tensor", + sample_inputs_fn=sample_inputs_normalize_image_tensor, + ) +) From b7fa1c65b9dc0e74ae9a2ba68d67f1ae2157ab38 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 16:09:36 +0200 Subject: [PATCH 08/11] add KernelInfo for adjust_gamma --- test/prototype_transforms_kernel_infos.py | 33 +++++++++++++++++++++++ test/test_prototype_transforms_kernels.py | 19 +------------ 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index dd40a01925b..255fa3dfe32 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1298,6 +1298,39 @@ def reference_inputs_adjust_contrast_image_tensor(): ) ) +_ADJUST_GAMMA_GAMMAS_GAINS = [ + (0.5, 2.0), + (0.0, 1.0), +] + + +def sample_inputs_adjust_gamma_image_tensor(): + gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] + for image_loader in make_image_loaders( + sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) + ): + yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) + + +def reference_inputs_adjust_gamma_image_tensor(): + for image_loader, (gamma, gain) in itertools.product( + make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), + _ADJUST_GAMMA_GAMMAS_GAINS, + ): + yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) + + +KERNEL_INFOS.append( + KernelInfo( + F.adjust_gamma_image_tensor, + kernel_name="adjust_gamma_image_tensor", + sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor, + reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil), + reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, + closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, + ) +) + _ADJUST_HUE_FACTORS = [-0.1, 0.5] diff --git a/test/test_prototype_transforms_kernels.py b/test/test_prototype_transforms_kernels.py index 82002f86551..a1cca79f66a 100644 --- a/test/test_prototype_transforms_kernels.py +++ b/test/test_prototype_transforms_kernels.py @@ -9,11 +9,8 @@ from torchvision.prototype import features from torchvision.prototype.transforms import functional as F -KERNEL_INFOS = KERNEL_INFOS[-3:] - -def test_coverage(maybe_skip): - maybe_skip(None) +def test_coverage(): tested = {info.kernel_name for info in KERNEL_INFOS} exposed = { name @@ -29,20 +26,6 @@ def test_coverage(maybe_skip): } ) and name not in {"to_image_tensor"} - # TODO: The list below should be quickly reduced in the transition period. There is nothing that prevents us - # from adding `KernelInfo`'s for these kernels other than time. - and name - not in { - "adjust_brightness_image_tensor", - "adjust_contrast_image_tensor", - "adjust_gamma_image_tensor", - "adjust_hue_image_tensor", - "adjust_saturation_image_tensor", - "clamp_bounding_box", - "five_crop_image_tensor", - "normalize_image_tensor", - "ten_crop_image_tensor", - } } needlessly_ignored = tested - exposed From 7c0027057fa8cd2b1cb039a3334a2dff71667ef6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 16:10:44 +0200 Subject: [PATCH 09/11] cleanup --- test/test_prototype_transforms_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_kernels.py b/test/test_prototype_transforms_kernels.py index a1cca79f66a..8a213b38874 100644 --- a/test/test_prototype_transforms_kernels.py +++ b/test/test_prototype_transforms_kernels.py @@ -25,7 +25,7 @@ def test_coverage(): "mask", } ) - and name not in {"to_image_tensor"} + and name != "to_image_tensor" } needlessly_ignored = tested - exposed From deb0d05e37b56bc32b0bdbe51522d00f7f1ac40e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 16:26:58 +0200 Subject: [PATCH 10/11] add DispatcherInfo's for previously add KernelInfo's --- test/prototype_transforms_dispatcher_infos.py | 100 +++++++++++++----- test/prototype_transforms_kernel_infos.py | 8 +- test/test_prototype_transforms_dispatchers.py | 19 ++++ 3 files changed, 97 insertions(+), 30 deletions(-) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index 61f158f0a55..da89bdbaf1a 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -1,13 +1,9 @@ import dataclasses -import functools -from typing import Callable, Dict, Type +from typing import Callable, Dict, Sequence, Type import pytest -import torch import torchvision.prototype.transforms.functional as F -from prototype_common_utils import ArgsKwargs -from prototype_transforms_kernel_infos import KERNEL_INFOS -from test_prototype_transforms_functional import FUNCTIONAL_INFOS +from prototype_transforms_kernel_infos import KERNEL_INFOS, Skip from torchvision.prototype import features __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] @@ -15,30 +11,15 @@ KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS} -# Helper class to use the infos from the old framework for now tests -class PreloadedArgsKwargs(ArgsKwargs): - def load(self, device="cpu"): - args = tuple(arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in self.args) - kwargs = { - keyword: arg.to(device) if isinstance(arg, torch.Tensor) else arg for keyword, arg in self.kwargs.items() - } - return args, kwargs - - -def preloaded_sample_inputs(args_kwargs): - for args, kwargs in args_kwargs: - yield PreloadedArgsKwargs(*args, **kwargs) - - -KERNEL_SAMPLE_INPUTS_FN_MAP.update( - {info.functional: functools.partial(preloaded_sample_inputs, info.sample_inputs()) for info in FUNCTIONAL_INFOS} -) - - @dataclasses.dataclass class DispatcherInfo: dispatcher: Callable kernels: Dict[Type, Callable] + skips: Sequence[Skip] = dataclasses.field(default_factory=list) + _skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False) + + def __post_init__(self): + self._skips_map = {skip.test_name: skip for skip in self.skips} def sample_inputs(self, *types): for type in types or self.kernels.keys(): @@ -47,6 +28,11 @@ def sample_inputs(self, *types): yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]() + def maybe_skip(self, *, test_name, args_kwargs, device): + skip = self._skips_map.get(test_name) + if skip and skip.condition(args_kwargs, device): + pytest.skip(skip.reason) + DISPATCHER_INFOS = [ DispatcherInfo( @@ -177,4 +163,66 @@ def sample_inputs(self, *types): features.Image: F.erase_image_tensor, }, ), + DispatcherInfo( + F.adjust_brightness, + kernels={ + features.Image: F.adjust_brightness_image_tensor, + }, + ), + DispatcherInfo( + F.adjust_contrast, + kernels={ + features.Image: F.adjust_contrast_image_tensor, + }, + ), + DispatcherInfo( + F.adjust_gamma, + kernels={ + features.Image: F.adjust_gamma_image_tensor, + }, + ), + DispatcherInfo( + F.adjust_hue, + kernels={ + features.Image: F.adjust_hue_image_tensor, + }, + ), + DispatcherInfo( + F.adjust_saturation, + kernels={ + features.Image: F.adjust_saturation_image_tensor, + }, + ), + DispatcherInfo( + F.five_crop, + kernels={ + features.Image: F.five_crop_image_tensor, + }, + skips=[ + Skip( + "test_scripted_smoke", + condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int), + reason="Integer size is not supported when scripting five_crop_image_tensor.", + ), + ], + ), + DispatcherInfo( + F.ten_crop, + kernels={ + features.Image: F.ten_crop_image_tensor, + }, + skips=[ + Skip( + "test_scripted_smoke", + condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int), + reason="Integer size is not supported when scripting ten_crop_image_tensor.", + ), + ], + ), + DispatcherInfo( + F.normalize, + kernels={ + features.Image: F.normalize_image_tensor, + }, + ), ] diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 255fa3dfe32..430018ca4c2 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -2,7 +2,7 @@ import functools import itertools import math -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, Iterable, Optional, Sequence import numpy as np import pytest @@ -21,7 +21,7 @@ class Skip: test_name: str reason: str - condition: Callable[[Tuple[ArgsKwargs, str]], bool] = lambda args_kwargs, device: True + condition: Callable[[ArgsKwargs, str], bool] = lambda args_kwargs, device: True @dataclasses.dataclass @@ -44,7 +44,7 @@ class KernelInfo: # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`. closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) skips: Sequence[Skip] = dataclasses.field(default_factory=list) - _skip_map: Dict[Tuple[Optional[str], str], Skip] = dataclasses.field(default=None, init=False) + _skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False) def __post_init__(self): self.kernel_name = self.kernel_name or self.kernel.__name__ @@ -1497,7 +1497,7 @@ def sample_inputs_normalize_image_tensor(): KERNEL_INFOS.append( KernelInfo( - F.normalize, + F.normalize_image_tensor, kernel_name="normalize_image_tensor", sample_inputs_fn=sample_inputs_normalize_image_tensor, ) diff --git a/test/test_prototype_transforms_dispatchers.py b/test/test_prototype_transforms_dispatchers.py index 6cbcdedf702..5ae99cb53e9 100644 --- a/test/test_prototype_transforms_dispatchers.py +++ b/test/test_prototype_transforms_dispatchers.py @@ -8,6 +8,25 @@ from torchvision.prototype import features +@pytest.fixture(autouse=True) +def maybe_skip(request): + # In case the test uses no parametrization or fixtures, the `callspec` attribute does not exist + try: + callspec = request.node.callspec + except AttributeError: + return + + try: + info = callspec.params["info"] + args_kwargs = callspec.params["args_kwargs"] + except KeyError: + return + + info.maybe_skip( + test_name=request.node.originalname, args_kwargs=args_kwargs, device=callspec.params.get("device", "cpu") + ) + + class TestCommon: @pytest.mark.parametrize( ("info", "args_kwargs"), From ba9d4fd426f9160e8de83a6dc5bc308f65423af5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 22 Sep 2022 11:17:50 +0200 Subject: [PATCH 11/11] add dispatcher info for elastic --- test/prototype_transforms_dispatcher_infos.py | 8 ++++++++ test/test_prototype_transforms_functional.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index da89bdbaf1a..3cd9b0e8e00 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -107,6 +107,14 @@ def maybe_skip(self, *, test_name, args_kwargs, device): features.Mask: F.perspective_mask, }, ), + DispatcherInfo( + F.elastic, + kernels={ + features.Image: F.elastic_image_tensor, + features.BoundingBox: F.elastic_bounding_box, + features.Mask: F.elastic_mask, + }, + ), DispatcherInfo( F.center_crop, kernels={ diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index e1f5bcf0184..d29407820d7 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -192,7 +192,6 @@ def test_scripted_smoke(self, info, args_kwargs, device): [ F.convert_color_space, F.convert_image_dtype, - F.elastic_transform, F.get_dimensions, F.get_image_num_channels, F.get_image_size, @@ -214,6 +213,7 @@ def test_scriptable(self, dispatcher): (F.vflip, F.vertical_flip), (F.get_image_num_channels, F.get_num_channels), (F.to_pil_image, F.to_image_pil), + (F.elastic_transform, F.elastic), ] ], )