diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 9aa0688e7a0..190867523eb 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,9 +1,8 @@ import itertools -import PIL.Image import pytest import torch -from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels +from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels from torchvision.prototype import transforms, features from torchvision.transforms.functional import to_pil_image @@ -25,15 +24,6 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs): yield bounding_box.data -INPUT_CREATIONS_FNS = { - features.Image: make_images, - features.BoundingBox: make_bounding_boxes, - features.OneHotLabel: make_one_hot_labels, - torch.Tensor: make_vanilla_tensor_images, - PIL.Image.Image: make_pil_images, -} - - def parametrize(transforms_with_inputs): return pytest.mark.parametrize( ("transform", "input"), @@ -52,15 +42,21 @@ def parametrize(transforms_with_inputs): def parametrize_from_transforms(*transforms): transforms_with_inputs = [] for transform in transforms: - dispatcher = transform._DISPATCHER - if dispatcher is None: - continue - - for type_ in dispatcher._kernels: + for creation_fn in [ + make_images, + make_bounding_boxes, + make_one_hot_labels, + make_vanilla_tensor_images, + make_pil_images, + ]: + inputs = list(creation_fn()) try: - inputs = INPUT_CREATIONS_FNS[type_]() - except KeyError: + output = transform(inputs[0]) + except Exception: continue + else: + if output is inputs[0]: + continue transforms_with_inputs.append((transform, inputs)) @@ -69,7 +65,7 @@ def parametrize_from_transforms(*transforms): class TestSmoke: @parametrize_from_transforms( - transforms.RandomErasing(), + transforms.RandomErasing(p=1.0), transforms.HorizontalFlip(), transforms.Resize([16, 16]), transforms.CenterCrop([16, 16]), @@ -141,35 +137,6 @@ def test_auto_augment(self, transform, input): def test_normalize(self, transform, input): transform(input) - @parametrize( - [ - ( - transforms.ConvertColorSpace("grayscale"), - itertools.chain( - make_images(), - make_vanilla_tensor_images(color_spaces=["rgb"]), - make_pil_images(color_spaces=["rgb"]), - ), - ) - ] - ) - def test_convert_bounding_color_space(self, transform, input): - transform(input) - - @parametrize( - [ - ( - transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"), - itertools.chain( - make_bounding_boxes(), - make_vanilla_tensor_bounding_boxes(formats=["xywh"]), - ), - ) - ] - ) - def test_convert_bounding_box_format(self, transform, input): - transform(input) - @parametrize( [ ( diff --git a/test/test_prototype_transforms_kernels.py b/test/test_prototype_transforms_functional.py similarity index 84% rename from test/test_prototype_transforms_kernels.py rename to test/test_prototype_transforms_functional.py index fb436a6a830..4bfca28ae37 100644 --- a/test/test_prototype_transforms_kernels.py +++ b/test/test_prototype_transforms_functional.py @@ -3,7 +3,7 @@ import pytest import torch.testing -import torchvision.prototype.transforms.kernels as K +import torchvision.prototype.transforms.functional as F from torch import jit from torch.nn.functional import one_hot from torchvision.prototype import features @@ -134,10 +134,10 @@ def __init__(self, *args, **kwargs): self.kwargs = kwargs -class KernelInfo: +class FunctionalInfo: def __init__(self, name, *, sample_inputs_fn): self.name = name - self.kernel = getattr(K, name) + self.functional = getattr(F, name) self._sample_inputs_fn = sample_inputs_fn def sample_inputs(self): @@ -146,21 +146,21 @@ def sample_inputs(self): def __call__(self, *args, **kwargs): if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput): sample_input = args[0] - return self.kernel(*sample_input.args, **sample_input.kwargs) + return self.functional(*sample_input.args, **sample_input.kwargs) - return self.kernel(*args, **kwargs) + return self.functional(*args, **kwargs) -KERNEL_INFOS = [] +FUNCTIONAL_INFOS = [] def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): - KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) + FUNCTIONAL_INFOS.append(FunctionalInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) return sample_inputs_fn @register_kernel_info_from_sample_inputs_fn -def horizontal_flip_image(): +def horizontal_flip_image_tensor(): for image in make_images(): yield SampleInput(image) @@ -172,12 +172,12 @@ def horizontal_flip_bounding_box(): @register_kernel_info_from_sample_inputs_fn -def resize_image(): +def resize_image_tensor(): for image, interpolation in itertools.product( make_images(), [ - K.InterpolationMode.BILINEAR, - K.InterpolationMode.NEAREST, + F.InterpolationMode.BILINEAR, + F.InterpolationMode.NEAREST, ], ): height, width = image.shape[-2:] @@ -200,20 +200,20 @@ def resize_bounding_box(): class TestKernelsCommon: - @pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name) - def test_scriptable(self, kernel_info): - jit.script(kernel_info.kernel) + @pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name) + def test_scriptable(self, functional_info): + jit.script(functional_info.functional) @pytest.mark.parametrize( - ("kernel_info", "sample_input"), + ("functional_info", "sample_input"), [ - pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}") - for kernel_info in KERNEL_INFOS - for idx, sample_input in enumerate(kernel_info.sample_inputs()) + pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}") + for functional_info in FUNCTIONAL_INFOS + for idx, sample_input in enumerate(functional_info.sample_inputs()) ], ) - def test_eager_vs_scripted(self, kernel_info, sample_input): - eager = kernel_info(sample_input) - scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs) + def test_eager_vs_scripted(self, functional_info, sample_input): + eager = functional_info(sample_input) + scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs) torch.testing.assert_close(eager, scripted) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index fbe19549dca..5b60d7ee55c 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -41,7 +41,7 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: # promote this out of the prototype state # import at runtime to avoid cyclic imports - from torchvision.prototype.transforms.kernels import convert_bounding_box_format + from torchvision.prototype.transforms.functional import convert_bounding_box_format if isinstance(format, str): format = BoundingBoxFormat[format] diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index ab6b821d673..276aeec2529 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -43,7 +43,7 @@ def decode(self) -> Image: # promote this out of the prototype state # import at runtime to avoid cyclic imports - from torchvision.prototype.transforms.kernels import decode_image_with_pil + from torchvision.prototype.transforms.functional import decode_image_with_pil return Image(decode_image_with_pil(self)) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 420db8a4b4f..73235720d58 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,13 +1,14 @@ -from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip -from . import kernels # usort: skip +from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip + from . import functional # usort: skip + from ._transform import Transform # usort: skip from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop -from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace +from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 9cd389bced0..ce198d39b33 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -3,7 +3,6 @@ import warnings from typing import Any, Dict, Tuple -import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F @@ -12,9 +11,6 @@ class RandomErasing(Transform): - _DISPATCHER = F.erase - _FAIL_TYPES = {PIL.Image.Image, features.BoundingBox, features.SegmentationMask} - def __init__( self, p: float = 0.5, @@ -45,8 +41,8 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - img_h, img_w = F.get_image_size(image) img_c = F.get_image_num_channels(image) + img_w, img_h = F.get_image_size(image) if isinstance(self.value, (int, float)): value = [self.value] @@ -93,16 +89,24 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(zip("ijhwv", (i, j, h, w, v))) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if torch.rand(1) >= self.p: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.erase_image_tensor(input, **params) + return features.Image.new_like(input, output) + elif isinstance(input, torch.Tensor): + return F.erase_image_tensor(input, **params) + else: return input - return super()._transform(input, params) + def forward(self, *inputs: Any) -> Any: + if torch.rand(1) >= self.p: + return inputs if len(inputs) > 1 else inputs[0] + + return super().forward(*inputs) class RandomMixup(Transform): - _DISPATCHER = F.mixup - _FAIL_TYPES = {features.BoundingBox, features.SegmentationMask} - def __init__(self, *, alpha: float) -> None: super().__init__() self.alpha = alpha @@ -111,11 +115,20 @@ def __init__(self, *, alpha: float) -> None: def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.mixup_image_tensor(input, **params) + return features.Image.new_like(input, output) + elif isinstance(input, features.OneHotLabel): + output = F.mixup_one_hot_label(input, **params) + return features.OneHotLabel.new_like(input, output) + else: + return input -class RandomCutmix(Transform): - _DISPATCHER = F.cutmix - _FAIL_TYPES = {features.BoundingBox, features.SegmentationMask} +class RandomCutmix(Transform): def __init__(self, *, alpha: float) -> None: super().__init__() self.alpha = alpha @@ -125,7 +138,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) image = query_image(sample) - H, W = F.get_image_size(image) + W, H = F.get_image_size(image) r_x = torch.randint(W, ()) r_y = torch.randint(H, ()) @@ -143,3 +156,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) return dict(box=box, lam_adjusted=lam_adjusted) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.cutmix_image_tensor(input, box=params["box"]) + return features.Image.new_like(input, output) + elif isinstance(input, features.OneHotLabel): + output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"]) + return features.OneHotLabel.new_like(input, output) + else: + return input diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 57336d8e517..7eae25a681e 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -21,65 +21,31 @@ def __init__( self.interpolation = interpolation self.fill = fill - _DISPATCHER_MAP: Dict[str, Callable[[Any, float, InterpolationMode, Optional[List[float]]], Any]] = { - "Identity": lambda input, magnitude, interpolation, fill: input, - "ShearX": lambda input, magnitude, interpolation, fill: F.affine( - input, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[math.degrees(magnitude), 0.0], - interpolation=interpolation, - fill=fill, - ), - "ShearY": lambda input, magnitude, interpolation, fill: F.affine( - input, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[0.0, math.degrees(magnitude)], - interpolation=interpolation, - fill=fill, - ), - "TranslateX": lambda input, magnitude, interpolation, fill: F.affine( - input, - angle=0.0, - translate=[int(magnitude), 0], - scale=1.0, - shear=[0.0, 0.0], - interpolation=interpolation, - fill=fill, - ), - "TranslateY": lambda input, magnitude, interpolation, fill: F.affine( - input, - angle=0.0, - translate=[0, int(magnitude)], - scale=1.0, - shear=[0.0, 0.0], - interpolation=interpolation, - fill=fill, - ), - "Rotate": lambda input, magnitude, interpolation, fill: F.rotate(input, angle=magnitude), - "Brightness": lambda input, magnitude, interpolation, fill: F.adjust_brightness( - input, brightness_factor=1.0 + magnitude - ), - "Color": lambda input, magnitude, interpolation, fill: F.adjust_saturation( - input, saturation_factor=1.0 + magnitude - ), - "Contrast": lambda input, magnitude, interpolation, fill: F.adjust_contrast( - input, contrast_factor=1.0 + magnitude - ), - "Sharpness": lambda input, magnitude, interpolation, fill: F.adjust_sharpness( - input, sharpness_factor=1.0 + magnitude - ), - "Posterize": lambda input, magnitude, interpolation, fill: F.posterize(input, bits=int(magnitude)), - "Solarize": lambda input, magnitude, interpolation, fill: F.solarize(input, threshold=magnitude), - "AutoContrast": lambda input, magnitude, interpolation, fill: F.autocontrast(input), - "Equalize": lambda input, magnitude, interpolation, fill: F.equalize(input), - "Invert": lambda input, magnitude, interpolation, fill: F.invert(input), - } + def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: + keys = tuple(dct.keys()) + key = keys[int(torch.randint(len(keys), ()))] + return key, dct[key] + + def _apply_transform(self, sample: Any, transform_id: str, magnitude: float) -> Any: + def dispatch( + image_tensor_kernel: Callable, + image_pil_kernel: Callable, + input: Any, + *args: Any, + **kwargs: Any, + ) -> Any: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = image_tensor_kernel(input, *args, **kwargs) + return features.Image.new_like(input, output) + elif isinstance(input, torch.Tensor): + return image_tensor_kernel(input, *args, **kwargs) + elif isinstance(input, PIL.Image.Image): + return image_pil_kernel(input, *args, **kwargs) + else: + return input - def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) num_channels = F.get_image_num_channels(image) @@ -89,24 +55,104 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: elif fill is not None: fill = [float(f) for f in fill] - return dict(interpolation=self.interpolation, fill=fill) - - def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: - keys = tuple(dct.keys()) - key = keys[int(torch.randint(len(keys), ()))] - return key, dct[key] - - def _apply_transform(self, sample: Any, params: Dict[str, Any], transform_id: str, magnitude: float) -> Any: - dispatcher = self._DISPATCHER_MAP[transform_id] + interpolation = self.interpolation def transform(input: Any) -> Any: - if type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image): - return dispatcher(input, magnitude, params["interpolation"], params["fill"]) - elif type(input) in {features.BoundingBox, features.SegmentationMask}: + if type(input) in {features.BoundingBox, features.SegmentationMask}: raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") - else: + elif not (type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image)): return input + if transform_id == "Identity": + return input + elif transform_id == "ShearX": + return dispatch( + F.affine_image_tensor, + F.affine_image_pil, + input, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "ShearY": + return dispatch( + F.affine_image_tensor, + F.affine_image_pil, + input, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "TranslateX": + return dispatch( + F.affine_image_tensor, + F.affine_image_pil, + input, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "TranslateY": + return dispatch( + F.affine_image_tensor, + F.affine_image_pil, + input, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "Rotate": + return dispatch(F.rotate_image_tensor, F.rotate_image_pil, input, angle=magnitude) + elif transform_id == "Brightness": + return dispatch( + F.adjust_brightness_image_tensor, + F.adjust_brightness_image_pil, + input, + brightness_factor=1.0 + magnitude, + ) + elif transform_id == "Color": + return dispatch( + F.adjust_saturation_image_tensor, + F.adjust_saturation_image_pil, + input, + saturation_factor=1.0 + magnitude, + ) + elif transform_id == "Contrast": + return dispatch( + F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, input, contrast_factor=1.0 + magnitude + ) + elif transform_id == "Sharpness": + return dispatch( + F.adjust_sharpness_image_tensor, + F.adjust_sharpness_image_pil, + input, + sharpness_factor=1.0 + magnitude, + ) + elif transform_id == "Posterize": + return dispatch(F.posterize_image_tensor, F.posterize_image_pil, input, bits=int(magnitude)) + elif transform_id == "Solarize": + return dispatch(F.solarize_image_tensor, F.solarize_image_pil, input, threshold=magnitude) + elif transform_id == "AutoContrast": + return dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, input) + elif transform_id == "Equalize": + return dispatch(F.equalize_image_tensor, F.equalize_image_pil, input) + elif transform_id == "Invert": + return dispatch(F.invert_image_tensor, F.invert_image_pil, input) + else: + raise ValueError(f"No transform available for {transform_id}") + return apply_recursively(transform, sample) @@ -114,7 +160,7 @@ class AutoAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), @@ -228,9 +274,8 @@ def _get_policies( else: raise ValueError(f"The provided policy {policy} is not recognized.") - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - params = params or self._get_params(sample) image = query_image(sample) image_size = F.get_image_size(image) @@ -251,7 +296,7 @@ def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: else: magnitude = 0.0 - sample = self._apply_transform(sample, params, transform_id, magnitude) + sample = self._apply_transform(sample, transform_id, magnitude) return sample @@ -261,7 +306,7 @@ class RandAugment(_AutoAugmentBase): "Identity": (lambda num_bins, image_size: None, False), "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), @@ -285,9 +330,8 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - params = params or self._get_params(sample) image = query_image(sample) image_size = F.get_image_size(image) @@ -303,7 +347,7 @@ def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: else: magnitude = 0.0 - sample = self._apply_transform(sample, params, transform_id, magnitude) + sample = self._apply_transform(sample, transform_id, magnitude) return sample @@ -335,9 +379,8 @@ def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): super().__init__(**kwargs) self.num_magnitude_bins = num_magnitude_bins - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - params = params or self._get_params(sample) image = query_image(sample) image_size = F.get_image_size(image) @@ -352,4 +395,4 @@ def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: else: magnitude = 0.0 - return self._apply_transform(sample, params, transform_id, magnitude) + return self._apply_transform(sample, transform_id, magnitude) diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py index d2a0d642626..bd20d0c701a 100644 --- a/torchvision/prototype/transforms/_container.py +++ b/torchvision/prototype/transforms/_container.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Dict +from typing import Any import torch @@ -6,13 +6,13 @@ class Compose(Transform): - def __init__(self, *transforms: Transform): + def __init__(self, *transforms: Transform) -> None: super().__init__() self.transforms = transforms for idx, transform in enumerate(transforms): self.add_module(str(idx), transform) - def forward(self, *inputs: Any) -> Any: # type: ignore[override] + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] for transform in self.transforms: sample = transform(sample) @@ -25,38 +25,38 @@ def __init__(self, transform: Transform, *, p: float = 0.5) -> None: self.transform = transform self.p = p - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if float(torch.rand(())) < self.p: return sample - return self.transform(sample, params=params) + return self.transform(sample) def extra_repr(self) -> str: return f"p={self.p}" class RandomChoice(Transform): - def __init__(self, *transforms: Transform): + def __init__(self, *transforms: Transform) -> None: super().__init__() self.transforms = transforms for idx, transform in enumerate(transforms): self.add_module(str(idx), transform) - def forward(self, *inputs: Any) -> Any: # type: ignore[override] + def forward(self, *inputs: Any) -> Any: idx = int(torch.randint(len(self.transforms), size=())) transform = self.transforms[idx] return transform(*inputs) class RandomOrder(Transform): - def __init__(self, *transforms: Transform): + def __init__(self, *transforms: Transform) -> None: super().__init__() self.transforms = transforms for idx, transform in enumerate(transforms): self.add_module(str(idx), transform) - def forward(self, *inputs: Any) -> Any: # type: ignore[override] + def forward(self, *inputs: Any) -> Any: for idx in torch.randperm(len(self.transforms)): transform = self.transforms[idx] inputs = transform(*inputs) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index cc7838df83b..4c9d9192ac8 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -2,6 +2,7 @@ import warnings from typing import Any, Dict, List, Union, Sequence, Tuple, cast +import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F @@ -11,41 +12,69 @@ class HorizontalFlip(Transform): - _DISPATCHER = F.horizontal_flip + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, features.Image): + output = F.horizontal_flip_image_tensor(input) + return features.Image.new_like(input, output) + elif isinstance(input, features.BoundingBox): + output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) + return features.BoundingBox.new_like(input, output) + elif isinstance(input, PIL.Image.Image): + return F.horizontal_flip_image_pil(input) + elif isinstance(input, torch.Tensor): + return F.horizontal_flip_image_tensor(input) + else: + return input class Resize(Transform): - _DISPATCHER = F.resize - def __init__( self, size: Union[int, Sequence[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() - self.size = size + self.size = [size] if isinstance(size, int) else list(size) self.interpolation = interpolation - def _get_params(self, sample: Any) -> Dict[str, Any]: - return dict(size=self.size, interpolation=self.interpolation) + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, features.Image): + output = F.resize_image_tensor(input, self.size, interpolation=self.interpolation) + return features.Image.new_like(input, output) + elif isinstance(input, features.SegmentationMask): + output = F.resize_segmentation_mask(input, self.size) + return features.SegmentationMask.new_like(input, output) + elif isinstance(input, features.BoundingBox): + output = F.resize_bounding_box(input, self.size, image_size=input.image_size) + return features.BoundingBox.new_like(input, output, image_size=self.size) + elif isinstance(input, PIL.Image.Image): + return F.resize_image_pil(input, self.size, interpolation=self.interpolation) + elif isinstance(input, torch.Tensor): + return F.resize_image_tensor(input, self.size, interpolation=self.interpolation) + else: + return input class CenterCrop(Transform): - _DISPATCHER = F.center_crop - _FAIL_TYPES = {features.BoundingBox, features.SegmentationMask} - def __init__(self, output_size: List[int]): super().__init__() self.output_size = output_size - def _get_params(self, sample: Any) -> Dict[str, Any]: - return dict(output_size=self.output_size) + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.center_crop_image_tensor(input, self.output_size) + return features.Image.new_like(input, output) + elif isinstance(input, torch.Tensor): + return F.center_crop_image_tensor(input, self.output_size) + elif isinstance(input, PIL.Image.Image): + return F.center_crop_image_pil(input, self.output_size) + else: + return input class RandomResizedCrop(Transform): - _DISPATCHER = F.resized_crop - _FAIL_TYPES = {features.BoundingBox, features.SegmentationMask} - def __init__( self, size: Union[int, Sequence[int]], @@ -80,7 +109,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - height, width = F.get_image_size(image) + width, height = F.get_image_size(image) area = height * width log_ratio = torch.log(torch.tensor(self.ratio)) @@ -115,4 +144,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: i = (height - h) // 2 j = (width - w) // 2 - return dict(top=i, left=j, height=h, width=w, size=self.size) + return dict(top=i, left=j, height=h, width=w) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.resized_crop_image_tensor( + input, **params, size=list(self.size), interpolation=self.interpolation + ) + return features.Image.new_like(input, output) + elif isinstance(input, torch.Tensor): + return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation) + elif isinstance(input, PIL.Image.Image): + return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation) + else: + return input diff --git a/torchvision/prototype/transforms/_meta_conversion.py b/torchvision/prototype/transforms/_meta_conversion.py index d7a1e6a76fa..3675e1d8ada 100644 --- a/torchvision/prototype/transforms/_meta_conversion.py +++ b/torchvision/prototype/transforms/_meta_conversion.py @@ -1,5 +1,6 @@ from typing import Union, Any, Dict, Optional +import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F @@ -7,24 +8,18 @@ class ConvertBoundingBoxFormat(Transform): - _DISPATCHER = F.convert_format - - def __init__( - self, - format: Union[str, features.BoundingBoxFormat], - old_format: Optional[Union[str, features.BoundingBoxFormat]] = None, - ) -> None: + def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: super().__init__() if isinstance(format, str): format = features.BoundingBoxFormat[format] self.format = format - if isinstance(old_format, str): - old_format = features.BoundingBoxFormat[old_format] - self.old_format = old_format - - def _get_params(self, sample: Any) -> Dict[str, Any]: - return dict(format=self.format, old_format=self.old_format) + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.BoundingBox: + output = F.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"]) + return features.BoundingBox.new_like(input, output, format=params["format"]) + else: + return input class ConvertImageDtype(Transform): @@ -33,21 +28,50 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: self.dtype = dtype def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if not isinstance(input, features.Image): + if type(input) is features.Image: + output = convert_image_dtype(input, dtype=self.dtype) + return features.Image.new_like(input, output, dtype=self.dtype) + else: return input - output = convert_image_dtype(input, dtype=self.dtype) - return features.Image.new_like(input, output, dtype=self.dtype) - - -class ConvertColorSpace(Transform): - _DISPATCHER = F.convert_color_space - def __init__(self, color_space: Union[str, features.ColorSpace]) -> None: +class ConvertImageColorSpace(Transform): + def __init__( + self, + color_space: Union[str, features.ColorSpace], + old_color_space: Optional[Union[str, features.ColorSpace]] = None, + ) -> None: super().__init__() + if isinstance(color_space, str): color_space = features.ColorSpace[color_space] self.color_space = color_space - def _get_params(self, sample: Any) -> Dict[str, Any]: - return dict(color_space=self.color_space) + if isinstance(old_color_space, str): + old_color_space = features.ColorSpace[old_color_space] + self.old_color_space = old_color_space + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, features.Image): + output = F.convert_image_color_space_tensor( + input, old_color_space=input.color_space, new_color_space=self.color_space + ) + return features.Image.new_like(input, output, color_space=self.color_space) + elif isinstance(input, torch.Tensor): + if self.old_color_space is None: + raise RuntimeError("") + + return F.convert_image_color_space_tensor( + input, old_color_space=self.old_color_space, new_color_space=self.color_space + ) + elif isinstance(input, PIL.Image.Image): + old_color_space = { + "L": features.ColorSpace.GRAYSCALE, + "RGB": features.ColorSpace.RGB, + }.get(input.mode, features.ColorSpace.OTHER) + + return F.convert_image_color_space_pil( + input, old_color_space=old_color_space, new_color_space=self.color_space + ) + else: + return input diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 42502e74874..54440ee05a5 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -17,11 +17,11 @@ def __init__(self, fn: Callable[[Any], Any], *types: Type): self.types = types def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if not isinstance(input, self.types): + if type(input) in self.types: + return self.fn(input) + else: return input - return self.fn(input) - def extra_repr(self) -> str: extras = [] name = getattr(self.fn, "__name__", None) @@ -32,15 +32,18 @@ def extra_repr(self) -> str: class Normalize(Transform): - _DISPATCHER = F.normalize - def __init__(self, mean: List[float], std: List[float]): super().__init__() self.mean = mean self.std = std - def _get_params(self, sample: Any) -> Dict[str, Any]: - return dict(mean=self.mean, std=self.std) + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, torch.Tensor): + # We don't need to differentiate between vanilla tensors and features.Image's here, since the result of the + # normalization transform is no longer a features.Image + return F.normalize_image_tensor(input, mean=self.mean, std=self.std) + else: + return input class ToDtype(Lambda): diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index e164d14cb00..923b90c6777 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -1,18 +1,13 @@ import enum import functools -from typing import Any, Dict, Optional, Set, Type +from typing import Any, Dict from torch import nn from torchvision.prototype.utils._internal import apply_recursively from torchvision.utils import _log_api_usage_once -from .functional._utils import Dispatcher - class Transform(nn.Module): - _DISPATCHER: Optional[Dispatcher] = None - _FAIL_TYPES: Set[Type] = set() - def __init__(self) -> None: super().__init__() _log_api_usage_once(self) @@ -21,19 +16,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict() def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if not self._DISPATCHER: - raise NotImplementedError() - - if input in self._DISPATCHER: - return self._DISPATCHER(input, **params) - elif type(input) in self._FAIL_TYPES: - raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") - else: - return input + raise NotImplementedError - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - return apply_recursively(functools.partial(self._transform, params=params or self._get_params(sample)), sample) + return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample) def extra_repr(self) -> str: extra = [] diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 3f4afc571d9..fa49c35265e 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -1,16 +1,17 @@ from typing import Any, Dict from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, kernels as K +from torchvision.prototype.transforms import Transform, functional as F class DecodeImage(Transform): def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if not isinstance(input, features.EncodedImage): + if type(input) is features.EncodedImage: + output = F.decode_image_with_pil(input) + return features.Image(output) + else: return input - return features.Image(K.decode_image_with_pil(input)) - class LabelToOneHot(Transform): def __init__(self, num_categories: int = -1): @@ -18,16 +19,15 @@ def __init__(self, num_categories: int = -1): self.num_categories = num_categories def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if not isinstance(input, features.Label): + if type(input) is features.Label: + num_categories = self.num_categories + if num_categories == -1 and input.categories is not None: + num_categories = len(input.categories) + output = F.label_to_one_hot(input, num_categories=num_categories) + return features.OneHotLabel(output, categories=input.categories) + else: return input - num_categories = self.num_categories - if num_categories == -1 and input.categories is not None: - num_categories = len(input.categories) - return features.OneHotLabel( - K.label_to_one_hot(input, num_categories=num_categories), categories=input.categories - ) - def extra_repr(self) -> str: if self.num_categories == -1: return "" diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 7f29d817499..24d794a2cb4 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Union, Optional +from typing import Any, Optional, Union import PIL.Image import torch diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 9ebe46989d3..c487aba7fa2 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,15 +1,66 @@ -from ._augment import erase, mixup, cutmix +from torchvision.transforms import InterpolationMode # usort: skip +from ._utils import get_image_size, get_image_num_channels # usort: skip +from ._meta_conversion import ( + convert_bounding_box_format, + convert_image_color_space_tensor, + convert_image_color_space_pil, +) # usort: skip + +from ._augment import ( + erase_image_tensor, + mixup_image_tensor, + mixup_one_hot_label, + cutmix_image_tensor, + cutmix_one_hot_label, +) from ._color import ( - adjust_brightness, - adjust_contrast, - adjust_saturation, - adjust_sharpness, - posterize, - solarize, - autocontrast, - equalize, - invert, + adjust_brightness_image_tensor, + adjust_brightness_image_pil, + adjust_contrast_image_tensor, + adjust_contrast_image_pil, + adjust_saturation_image_tensor, + adjust_saturation_image_pil, + adjust_sharpness_image_tensor, + adjust_sharpness_image_pil, + posterize_image_tensor, + posterize_image_pil, + solarize_image_tensor, + solarize_image_pil, + autocontrast_image_tensor, + autocontrast_image_pil, + equalize_image_tensor, + equalize_image_pil, + invert_image_tensor, + invert_image_pil, + adjust_hue_image_tensor, + adjust_hue_image_pil, + adjust_gamma_image_tensor, + adjust_gamma_image_pil, +) +from ._geometry import ( + horizontal_flip_bounding_box, + horizontal_flip_image_tensor, + horizontal_flip_image_pil, + resize_bounding_box, + resize_image_tensor, + resize_image_pil, + resize_segmentation_mask, + center_crop_image_tensor, + center_crop_image_pil, + resized_crop_image_tensor, + resized_crop_image_pil, + affine_image_tensor, + affine_image_pil, + rotate_image_tensor, + rotate_image_pil, + pad_image_tensor, + pad_image_pil, + crop_image_tensor, + crop_image_pil, + perspective_image_tensor, + perspective_image_pil, + vertical_flip_image_tensor, + vertical_flip_image_pil, ) -from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate -from ._meta_conversion import convert_color_space, convert_format -from ._misc import normalize, get_image_size, get_image_num_channels +from ._misc import normalize_image_tensor, gaussian_blur_image_tensor +from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 4f9835bfe01..5004ac550dd 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,64 +1,45 @@ -from typing import Any +from typing import Tuple import torch -from torchvision.prototype import features -from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F - -from ._utils import dispatch - - -@dispatch( - { - torch.Tensor: _F.erase, - features.Image: K.erase_image, - } -) -def erase(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - features.Image: K.mixup_image, - features.OneHotLabel: K.mixup_one_hot_label, - } -) -def mixup(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - features.Image: None, - features.OneHotLabel: None, - } -) -def cutmix(input: Any, *args: Any, **kwargs: Any) -> Any: - """Perform the CutMix operation as introduced in the paper - `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. - - Dispatch to the corresponding kernels happens according to this table: - - .. table:: - :widths: 30 70 - - ==================================================== ================================================================ - :class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image` - :class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label` - ==================================================== ================================================================ - - Please refer to the kernel documentations for a detailed explanation of the functionality and parameters. - """ - if isinstance(input, features.Image): - kwargs.pop("lam_adjusted", None) - output = K.cutmix_image(input, **kwargs) - return features.Image.new_like(input, output) - elif isinstance(input, features.OneHotLabel): - kwargs.pop("box", None) - output = K.cutmix_one_hot_label(input, **kwargs) - return features.OneHotLabel.new_like(input, output) - - raise RuntimeError +from torchvision.transforms import functional_tensor as _FT + + +erase_image_tensor = _FT.erase + + +def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: + input = input.clone() + return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) + + +def mixup_image_tensor(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + + return _mixup_tensor(image_batch, -4, lam) + + +def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor: + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + + return _mixup_tensor(one_hot_label_batch, -2, lam) + + +def cutmix_image_tensor(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + + x1, y1, x2, y2 = box + image_rolled = image_batch.roll(1, -4) + + image_batch = image_batch.clone() + image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return image_batch + + +def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor: + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + + return _mixup_tensor(one_hot_label_batch, -2, lam_adjusted) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 6c4f9d33a28..fa632d7df58 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,141 +1,34 @@ -from typing import Any +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP -import PIL.Image -import torch -from torchvision.prototype import features -from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F +adjust_brightness_image_tensor = _FT.adjust_brightness +adjust_brightness_image_pil = _FP.adjust_brightness -from ._utils import dispatch +adjust_saturation_image_tensor = _FT.adjust_saturation +adjust_saturation_image_pil = _FP.adjust_saturation +adjust_contrast_image_tensor = _FT.adjust_contrast +adjust_contrast_image_pil = _FP.adjust_contrast -@dispatch( - { - torch.Tensor: _F.adjust_brightness, - PIL.Image.Image: _F.adjust_brightness, - features.Image: K.adjust_brightness_image, - } -) -def adjust_brightness(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... +adjust_sharpness_image_tensor = _FT.adjust_sharpness +adjust_sharpness_image_pil = _FP.adjust_sharpness +posterize_image_tensor = _FT.posterize +posterize_image_pil = _FP.posterize -@dispatch( - { - torch.Tensor: _F.adjust_saturation, - PIL.Image.Image: _F.adjust_saturation, - features.Image: K.adjust_saturation_image, - } -) -def adjust_saturation(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... +solarize_image_tensor = _FT.solarize +solarize_image_pil = _FP.solarize +autocontrast_image_tensor = _FT.autocontrast +autocontrast_image_pil = _FP.autocontrast -@dispatch( - { - torch.Tensor: _F.adjust_contrast, - PIL.Image.Image: _F.adjust_contrast, - features.Image: K.adjust_contrast_image, - } -) -def adjust_contrast(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... +equalize_image_tensor = _FT.equalize +equalize_image_pil = _FP.equalize +invert_image_tensor = _FT.invert +invert_image_pil = _FP.invert -@dispatch( - { - torch.Tensor: _F.adjust_sharpness, - PIL.Image.Image: _F.adjust_sharpness, - features.Image: K.adjust_sharpness_image, - } -) -def adjust_sharpness(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... +adjust_hue_image_tensor = _FT.adjust_hue +adjust_hue_image_pil = _FP.adjust_hue - -@dispatch( - { - torch.Tensor: _F.posterize, - PIL.Image.Image: _F.posterize, - features.Image: K.posterize_image, - } -) -def posterize(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.solarize, - PIL.Image.Image: _F.solarize, - features.Image: K.solarize_image, - } -) -def solarize(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.autocontrast, - PIL.Image.Image: _F.autocontrast, - features.Image: K.autocontrast_image, - } -) -def autocontrast(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.equalize, - PIL.Image.Image: _F.equalize, - features.Image: K.equalize_image, - } -) -def equalize(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.invert, - PIL.Image.Image: _F.invert, - features.Image: K.invert_image, - } -) -def invert(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.adjust_hue, - PIL.Image.Image: _F.adjust_hue, - features.Image: K.adjust_hue_image, - } -) -def adjust_hue(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.adjust_gamma, - PIL.Image.Image: _F.adjust_gamma, - features.Image: K.adjust_gamma_image, - } -) -def adjust_gamma(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... +adjust_gamma_image_tensor = _FT.adjust_gamma +adjust_gamma_image_pil = _FP.adjust_gamma diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 06ecd38dac0..d4214f791b3 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,165 +1,318 @@ -from typing import Any +import numbers +from typing import Tuple, List, Optional, Sequence, Union import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F - -from ._utils import dispatch - - -@dispatch( - { - torch.Tensor: _F.hflip, - PIL.Image.Image: _F.hflip, - features.Image: K.horizontal_flip_image, - features.BoundingBox: None, - }, -) -def horizontal_flip(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - if isinstance(input, features.BoundingBox): - output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) - return features.BoundingBox.new_like(input, output) - - raise RuntimeError - - -@dispatch( - { - torch.Tensor: _F.resize, - PIL.Image.Image: _F.resize, - features.Image: K.resize_image, - features.SegmentationMask: K.resize_segmentation_mask, - features.BoundingBox: None, - } -) -def resize(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - if isinstance(input, features.BoundingBox): - size = kwargs.pop("size") - output = K.resize_bounding_box(input, size=size, image_size=input.image_size) - return features.BoundingBox.new_like(input, output, image_size=size) - - raise RuntimeError - - -@dispatch( - { - torch.Tensor: _F.center_crop, - PIL.Image.Image: _F.center_crop, - features.Image: K.center_crop_image, - } -) -def center_crop(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.resized_crop, - PIL.Image.Image: _F.resized_crop, - features.Image: K.resized_crop_image, - } -) -def resized_crop(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.affine, - PIL.Image.Image: _F.affine, - features.Image: K.affine_image, - } -) -def affine(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.rotate, - PIL.Image.Image: _F.rotate, - features.Image: K.rotate_image, - } -) -def rotate(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.pad, - PIL.Image.Image: _F.pad, - features.Image: K.pad_image, - } -) -def pad(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.crop, - PIL.Image.Image: _F.crop, - features.Image: K.crop_image, - } -) -def crop(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.perspective, - PIL.Image.Image: _F.perspective, - features.Image: K.perspective_image, - } -) -def perspective(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.vflip, - PIL.Image.Image: _F.vflip, - features.Image: K.vertical_flip_image, - } -) -def vertical_flip(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.five_crop, - PIL.Image.Image: _F.five_crop, - features.Image: K.five_crop_image, - } -) -def five_crop(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... - - -@dispatch( - { - torch.Tensor: _F.ten_crop, - PIL.Image.Image: _F.ten_crop, - features.Image: K.ten_crop_image, - } -) -def ten_crop(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... +from torchvision.prototype.transforms import InterpolationMode +from torchvision.prototype.transforms.functional import get_image_size +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix + +from ._meta_conversion import convert_bounding_box_format + + +horizontal_flip_image_tensor = _FT.hflip +horizontal_flip_image_pil = _FP.hflip + + +def horizontal_flip_bounding_box( + bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_box.shape + + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]] + + return convert_bounding_box_format( + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format + ).view(shape) + + +def resize_image_tensor( + image: torch.Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + new_height, new_width = size + old_width, old_height = _FT.get_image_size(image) + num_channels = _FT.get_image_num_channels(image) + batch_shape = image.shape[:-3] + return _FT.resize( + image.reshape((-1, num_channels, old_height, old_width)), + size=size, + interpolation=interpolation.value, + max_size=max_size, + antialias=antialias, + ).reshape(batch_shape + (num_channels, new_height, new_width)) + + +def resize_image_pil( + img: PIL.Image.Image, + size: Union[Sequence[int], int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, +) -> PIL.Image.Image: + return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation], max_size=max_size) + + +def resize_segmentation_mask( + segmentation_mask: torch.Tensor, size: List[int], max_size: Optional[int] = None +) -> torch.Tensor: + return resize_image_tensor(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) + + +# TODO: handle max_size +def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor: + old_height, old_width = image_size + new_height, new_width = size + ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) + return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) + + +vertical_flip_image_tensor = _FT.vflip +vertical_flip_image_pil = _FP.vflip + + +def _affine_parse_args( + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + center: Optional[List[float]] = None, +) -> Tuple[float, List[float], List[float], Optional[List[float]]]: + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if not isinstance(translate, (list, tuple)): + raise TypeError("Argument translate should be a sequence") + + if len(translate) != 2: + raise ValueError("Argument translate should be a sequence of length 2") + + if scale <= 0.0: + raise ValueError("Argument scale should be positive") + + if not isinstance(shear, (numbers.Number, (list, tuple))): + raise TypeError("Shear should be either a single value or a sequence of two values") + + if not isinstance(interpolation, InterpolationMode): + raise TypeError("Argument interpolation should be a InterpolationMode") + + if isinstance(angle, int): + angle = float(angle) + + if isinstance(translate, tuple): + translate = list(translate) + + if isinstance(shear, numbers.Number): + shear = [shear, 0.0] + + if isinstance(shear, tuple): + shear = list(shear) + + if len(shear) == 1: + shear = [shear[0], shear[0]] + + if len(shear) != 2: + raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") + + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + + return angle, translate, shear, center + + +def affine_image_tensor( + img: torch.Tensor, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) + + center_f = [0.0, 0.0] + if center is not None: + width, height = get_image_size(img) + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + + translate_f = [1.0 * t for t in translate] + matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) + + return _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill) + + +def affine_image_pil( + img: PIL.Image.Image, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> PIL.Image.Image: + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) + + # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) + # it is visually better to estimate the center without 0.5 offset + # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine + if center is None: + width, height = get_image_size(img) + center = [width * 0.5, height * 0.5] + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + + return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill, center=center) + + +def rotate_image_tensor( + img: torch.Tensor, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> torch.Tensor: + center_f = [0.0, 0.0] + if center is not None: + width, height = get_image_size(img) + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + + # due to current incoherence of rotation angle direction between affine and rotate implementations + # we need to set -angle. + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) + return _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill) + + +def rotate_image_pil( + img: PIL.Image.Image, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> PIL.Image.Image: + return _FP.rotate( + img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center + ) + + +pad_image_tensor = _FT.pad +pad_image_pil = _FP.pad + +crop_image_tensor = _FT.crop +crop_image_pil = _FP.crop + + +def perspective_image_tensor( + img: torch.Tensor, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, +) -> torch.Tensor: + return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) + + +def perspective_image_pil( + img: PIL.Image.Image, + perspective_coeffs: float, + interpolation: InterpolationMode = InterpolationMode.BICUBIC, + fill: Optional[List[float]] = None, +) -> PIL.Image.Image: + return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) + + +def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: + if isinstance(output_size, numbers.Number): + return [int(output_size), int(output_size)] + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + return [output_size[0], output_size[0]] + else: + return list(output_size) + + +def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]: + return [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + + +def _center_crop_compute_crop_anchor( + crop_height: int, crop_width: int, image_height: int, image_width: int +) -> Tuple[int, int]: + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return crop_top, crop_left + + +def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + image_width, image_height = get_image_size(img) + + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + img = pad_image_tensor(img, padding_ltrb, fill=0) + + image_width, image_height = get_image_size(img) + if crop_width == image_width and crop_height == image_height: + return img + + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return crop_image_tensor(img, crop_top, crop_left, crop_height, crop_width) + + +def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + image_width, image_height = get_image_size(img) + + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + img = pad_image_pil(img, padding_ltrb, fill=0) + + image_width, image_height = get_image_size(img) + if crop_width == image_width and crop_height == image_height: + return img + + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width) + + +def resized_crop_image_tensor( + img: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, +) -> torch.Tensor: + img = crop_image_tensor(img, top, left, height, width) + return resize_image_tensor(img, size, interpolation=interpolation) + + +def resized_crop_image_pil( + img: PIL.Image.Image, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, +) -> PIL.Image.Image: + img = crop_image_pil(img, top, left, height, width) + return resize_image_pil(img, size, interpolation=interpolation) diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/functional/_meta_conversion.py index bbda3ea939a..b260beaa361 100644 --- a/torchvision/prototype/transforms/functional/_meta_conversion.py +++ b/torchvision/prototype/transforms/functional/_meta_conversion.py @@ -1,50 +1,91 @@ -from typing import Any - import PIL.Image import torch -from torchvision.ops import box_convert -from torchvision.prototype import features -from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F - -from ._utils import dispatch - - -@dispatch( - { - torch.Tensor: None, - features.BoundingBox: None, - } -) -def convert_format(input: Any, *args: Any, **kwargs: Any) -> Any: - format = kwargs["format"] - if type(input) is torch.Tensor: - old_format = kwargs.get("old_format") - if old_format is None: - raise TypeError("For vanilla tensors the `old_format` needs to be provided.") - return box_convert(input, in_fmt=kwargs["old_format"].name.lower(), out_fmt=format.name.lower()) - elif isinstance(input, features.BoundingBox): - output = K.convert_bounding_box_format(input, old_format=input.format, new_format=kwargs["format"]) - return features.BoundingBox.new_like(input, output, format=format) - - raise RuntimeError - - -@dispatch( - { - torch.Tensor: None, - PIL.Image.Image: None, - features.Image: None, - } -) -def convert_color_space(input: Any, *args: Any, **kwargs: Any) -> Any: - color_space = kwargs["color_space"] - if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): - if color_space != features.ColorSpace.GRAYSCALE: - raise ValueError("For vanilla tensors and PIL images only RGB to grayscale is supported") - return _F.rgb_to_grayscale(input) - elif isinstance(input, features.Image): - output = K.convert_color_space(input, old_color_space=input.color_space, new_color_space=color_space) - return features.Image.new_like(input, output, color_space=color_space) - - raise RuntimeError +from torchvision.prototype.features import BoundingBoxFormat, ColorSpace +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP + + +def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: + xyxy = xywh.clone() + xyxy[..., 2:] += xyxy[..., :2] + return xyxy + + +def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: + xywh = xyxy.clone() + xywh[..., 2:] -= xywh[..., :2] + return xywh + + +def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor: + cx, cy, w, h = torch.unbind(cxcywh, dim=-1) + x1 = cx - 0.5 * w + y1 = cy - 0.5 * h + x2 = cx + 0.5 * w + y2 = cy + 0.5 * h + return torch.stack((x1, y1, x2, y2), dim=-1) + + +def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: + x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1) + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + return torch.stack((cx, cy, w, h), dim=-1) + + +def convert_bounding_box_format( + bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat +) -> torch.Tensor: + if new_format == old_format: + return bounding_box.clone() + + if old_format == BoundingBoxFormat.XYWH: + bounding_box = _xywh_to_xyxy(bounding_box) + elif old_format == BoundingBoxFormat.CXCYWH: + bounding_box = _cxcywh_to_xyxy(bounding_box) + + if new_format == BoundingBoxFormat.XYWH: + bounding_box = _xyxy_to_xywh(bounding_box) + elif new_format == BoundingBoxFormat.CXCYWH: + bounding_box = _xyxy_to_cxcywh(bounding_box) + + return bounding_box + + +def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor: + return grayscale.expand(3, 1, 1) + + +def convert_image_color_space_tensor( + image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace +) -> torch.Tensor: + if new_color_space == old_color_space: + return image.clone() + + if old_color_space == ColorSpace.GRAYSCALE: + image = _grayscale_to_rgb_tensor(image) + + if new_color_space == ColorSpace.GRAYSCALE: + image = _FT.rgb_to_grayscale(image) + + return image + + +def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image: + return grayscale.convert("RGB") + + +def convert_image_color_space_pil( + image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace +) -> PIL.Image.Image: + if new_color_space == old_color_space: + return image.copy() + + if old_color_space == ColorSpace.GRAYSCALE: + image = _grayscale_to_rgb_pil(image) + + if new_color_space == ColorSpace.GRAYSCALE: + image = _FP.to_grayscale(image) + + return image diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 212492230ea..fd0507cca4d 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,69 +1,42 @@ -from typing import Any +from typing import Optional, List import PIL.Image import torch -from torchvision.prototype import features -from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F -from torchvision.transforms.functional_pil import ( - get_image_size as _get_image_size_pil, - get_image_num_channels as _get_image_num_channels_pil, -) -from torchvision.transforms.functional_tensor import ( - get_image_size as _get_image_size_tensor, - get_image_num_channels as _get_image_num_channels_tensor, -) +from torchvision.transforms import functional_tensor as _FT +from torchvision.transforms.functional import to_tensor, to_pil_image -from ._utils import dispatch +normalize_image_tensor = _FT.normalize -@dispatch( - { - torch.Tensor: _F.normalize, - features.Image: K.normalize_image, - } -) -def normalize(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... +def gaussian_blur_image_tensor( + img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> torch.Tensor: + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + if len(kernel_size) != 2: + raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") + for ksize in kernel_size: + if ksize % 2 == 0 or ksize < 0: + raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}") -@dispatch( - { - torch.Tensor: _F.gaussian_blur, - PIL.Image.Image: _F.gaussian_blur, - features.Image: K.gaussian_blur_image, - } -) -def gaussian_blur(input: Any, *args: Any, **kwargs: Any) -> Any: - """TODO: add docstring""" - ... + if sigma is None: + sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] + if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): + raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") + if isinstance(sigma, (int, float)): + sigma = [float(sigma), float(sigma)] + if isinstance(sigma, (list, tuple)) and len(sigma) == 1: + sigma = [sigma[0], sigma[0]] + if len(sigma) != 2: + raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}") + for s in sigma: + if s <= 0.0: + raise ValueError(f"sigma should have positive values. Got {sigma}") -@dispatch( - { - torch.Tensor: _get_image_size_tensor, - PIL.Image.Image: _get_image_size_pil, - features.Image: None, - features.BoundingBox: None, - } -) -def get_image_size(input: Any, *args: Any, **kwargs: Any) -> Any: - if isinstance(input, (features.Image, features.BoundingBox)): - return list(input.image_size) + return _FT.gaussian_blur(img, kernel_size, sigma) - raise RuntimeError - -@dispatch( - { - torch.Tensor: _get_image_num_channels_tensor, - PIL.Image.Image: _get_image_num_channels_pil, - features.Image: None, - } -) -def get_image_num_channels(input: Any, *args: Any, **kwargs: Any) -> Any: - if isinstance(input, features.Image): - return input.num_channels - - raise RuntimeError +def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optional[List[float]] = None) -> PIL.Image: + return to_pil_image(gaussian_blur_image_tensor(to_tensor(img), kernel_size=kernel_size, sigma=sigma)) diff --git a/torchvision/prototype/transforms/kernels/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py similarity index 100% rename from torchvision/prototype/transforms/kernels/_type_conversion.py rename to torchvision/prototype/transforms/functional/_type_conversion.py diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index abdee565bc4..07235d63716 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -1,102 +1,29 @@ -import inspect -from typing import Any, Optional, Callable, TypeVar, Mapping, Type +from typing import Tuple, Union, cast +import PIL.Image import torch -import torch.overrides from torchvision.prototype import features - -F = TypeVar("F", bound=features._Feature) - - -class Dispatcher: - """Wrap a function to automatically dispatch to registered kernels based on the call arguments. - - The wrapped function should have this signature - - .. code:: python - - @dispatch( - ... - ) - def dispatch_fn(input, *args, **kwargs): - ... - - where ``input`` is used to determine which kernel to dispatch to. - - Args: - kernels: Dictionary with types as keys that maps to a kernel to call. The resolution order is checking for - exact type matches first and if none is found falls back to checking for subclasses. If a value is - ``None``, the decorated function is called. - - Raises: - TypeError: If any value in ``kernels`` is not callable with ``kernel(input, *args, **kwargs)``. - TypeError: If the decorated function is called with an input that cannot be dispatched. - """ - - def __init__(self, fn: Callable, kernels: Mapping[Type, Optional[Callable]]): - self._fn = fn - - for feature_type, kernel in kernels.items(): - if not self._check_kernel(kernel): - raise TypeError( - f"Kernel for feature type {feature_type.__name__} is not callable with " - f"kernel(input, *args, **kwargs)." - ) - - self._kernels = kernels - - def _check_kernel(self, kernel: Optional[Callable]) -> bool: - if kernel is None: - return True - - if not callable(kernel): - return False - - params = list(inspect.signature(kernel).parameters.values()) - if not params: - return False - - return params[0].kind != inspect.Parameter.KEYWORD_ONLY - - def _resolve(self, feature_type: Type) -> Optional[Callable]: - try: - return self._kernels[feature_type] - except KeyError: - try: - return next( - kernel - for registered_feature_type, kernel in self._kernels.items() - if issubclass(feature_type, registered_feature_type) - ) - except StopIteration: - raise TypeError(f"No support for feature type {feature_type.__name__}") from None - - def __contains__(self, obj: Any) -> bool: - try: - self._resolve(type(obj)) - return True - except TypeError: - return False - - def __call__(self, input: Any, *args: Any, **kwargs: Any) -> Any: - kernel = self._resolve(type(input)) - - if kernel is None: - output = self._fn(input, *args, **kwargs) - if output is None: - raise RuntimeError( - f"{self._fn.__name__}() did not handle inputs of type {type(input).__name__} " - f"although it was configured to do so." - ) - else: - output = kernel(input, *args, **kwargs) - - if isinstance(input, features._Feature) and type(output) is torch.Tensor: - output = type(input).new_like(input, output) - - return output - - -def dispatch(kernels: Mapping[Type, Optional[Callable]]) -> Callable[[Callable], Dispatcher]: - """Decorates a function and turns it into a :class:`Dispatcher`.""" - return lambda fn: Dispatcher(fn, kernels) +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP + + +def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: + if isinstance(image, features.Image): + height, width = image.image_size + return width, height + elif isinstance(image, torch.Tensor): + return cast(Tuple[int, int], tuple(_FT.get_image_size(image))) + if isinstance(image, PIL.Image.Image): + return cast(Tuple[int, int], tuple(_FP.get_image_size(image))) + else: + raise TypeError(f"unable to get image size from object of type {type(image).__name__}") + + +def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: + if isinstance(image, features.Image): + return image.num_channels + elif isinstance(image, torch.Tensor): + return _FT.get_image_num_channels(image) + if isinstance(image, PIL.Image.Image): + return cast(int, _FP.get_image_num_channels(image)) + else: + raise TypeError(f"unable to get num channels from object of type {type(image).__name__}") diff --git a/torchvision/prototype/transforms/kernels/__init__.py b/torchvision/prototype/transforms/kernels/__init__.py deleted file mode 100644 index 1cac91d29c1..00000000000 --- a/torchvision/prototype/transforms/kernels/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -from torchvision.transforms import InterpolationMode # usort: skip -from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip - -from ._augment import ( - erase_image, - mixup_image, - mixup_one_hot_label, - cutmix_image, - cutmix_one_hot_label, -) -from ._color import ( - adjust_brightness_image, - adjust_contrast_image, - adjust_saturation_image, - adjust_sharpness_image, - posterize_image, - solarize_image, - autocontrast_image, - equalize_image, - invert_image, - adjust_hue_image, - adjust_gamma_image, -) -from ._geometry import ( - horizontal_flip_bounding_box, - horizontal_flip_image, - resize_bounding_box, - resize_image, - resize_segmentation_mask, - center_crop_image, - resized_crop_image, - affine_image, - rotate_image, - pad_image, - crop_image, - perspective_image, - vertical_flip_image, - five_crop_image, - ten_crop_image, -) -from ._misc import normalize_image, gaussian_blur_image -from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/kernels/_augment.py b/torchvision/prototype/transforms/kernels/_augment.py deleted file mode 100644 index 526ed85ffd8..00000000000 --- a/torchvision/prototype/transforms/kernels/_augment.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Tuple - -import torch -from torchvision.transforms import functional as _F - - -erase_image = _F.erase - - -def _mixup(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: - input = input.clone() - return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) - - -def mixup_image(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - return _mixup(image_batch, -4, lam) - - -def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") - - return _mixup(one_hot_label_batch, -2, lam) - - -def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - x1, y1, x2, y2 = box - image_rolled = image_batch.roll(1, -4) - - image_batch = image_batch.clone() - image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] - return image_batch - - -def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") - - return _mixup(one_hot_label_batch, -2, lam_adjusted) diff --git a/torchvision/prototype/transforms/kernels/_color.py b/torchvision/prototype/transforms/kernels/_color.py deleted file mode 100644 index 00ed5cfbfc7..00000000000 --- a/torchvision/prototype/transforms/kernels/_color.py +++ /dev/null @@ -1,14 +0,0 @@ -from torchvision.transforms import functional as _F - - -adjust_brightness_image = _F.adjust_brightness -adjust_saturation_image = _F.adjust_saturation -adjust_contrast_image = _F.adjust_contrast -adjust_sharpness_image = _F.adjust_sharpness -posterize_image = _F.posterize -solarize_image = _F.solarize -autocontrast_image = _F.autocontrast -equalize_image = _F.equalize -invert_image = _F.invert -adjust_hue_image = _F.adjust_hue -adjust_gamma_image = _F.adjust_gamma diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/kernels/_geometry.py deleted file mode 100644 index 72afc2e62a3..00000000000 --- a/torchvision/prototype/transforms/kernels/_geometry.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Tuple, List, Optional, TypeVar - -import torch -from torchvision.prototype import features -from torchvision.transforms import functional as _F, InterpolationMode - -from ._meta_conversion import convert_bounding_box_format - - -T = TypeVar("T", bound=features._Feature) - - -horizontal_flip_image = _F.hflip - - -def horizontal_flip_bounding_box( - bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int] -) -> torch.Tensor: - shape = bounding_box.shape - - bounding_box = convert_bounding_box_format( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY - ).view(-1, 4) - - bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]] - - return convert_bounding_box_format( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format - ).view(shape) - - -def resize_image( - image: torch.Tensor, - size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[bool] = None, -) -> torch.Tensor: - new_height, new_width = size - num_channels, old_height, old_width = image.shape[-3:] - batch_shape = image.shape[:-3] - return _F.resize( - image.reshape((-1, num_channels, old_height, old_width)), - size=size, - interpolation=interpolation, - max_size=max_size, - antialias=antialias, - ).reshape(batch_shape + (num_channels, new_height, new_width)) - - -def resize_segmentation_mask( - segmentation_mask: torch.Tensor, - size: List[int], - max_size: Optional[int] = None, -) -> torch.Tensor: - return resize_image(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) - - -# TODO: handle max_size -def resize_bounding_box(bounding_box: torch.Tensor, *, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor: - old_height, old_width = image_size - new_height, new_width = size - ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) - return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) - - -center_crop_image = _F.center_crop -resized_crop_image = _F.resized_crop -affine_image = _F.affine -rotate_image = _F.rotate -pad_image = _F.pad -crop_image = _F.crop -perspective_image = _F.perspective -vertical_flip_image = _F.vflip -five_crop_image = _F.five_crop -ten_crop_image = _F.ten_crop diff --git a/torchvision/prototype/transforms/kernels/_meta_conversion.py b/torchvision/prototype/transforms/kernels/_meta_conversion.py deleted file mode 100644 index 4acaf9fe9e4..00000000000 --- a/torchvision/prototype/transforms/kernels/_meta_conversion.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -from torchvision.prototype.features import BoundingBoxFormat, ColorSpace -from torchvision.transforms.functional_tensor import rgb_to_grayscale as _rgb_to_grayscale - - -def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: - xyxy = xywh.clone() - xyxy[..., 2:] += xyxy[..., :2] - return xyxy - - -def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: - xywh = xyxy.clone() - xywh[..., 2:] -= xywh[..., :2] - return xywh - - -def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor: - cx, cy, w, h = torch.unbind(cxcywh, dim=-1) - x1 = cx - 0.5 * w - y1 = cy - 0.5 * h - x2 = cx + 0.5 * w - y2 = cy + 0.5 * h - return torch.stack((x1, y1, x2, y2), dim=-1) - - -def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: - x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1) - cx = (x1 + x2) / 2 - cy = (y1 + y2) / 2 - w = x2 - x1 - h = y2 - y1 - return torch.stack((cx, cy, w, h), dim=-1) - - -def convert_bounding_box_format( - bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat -) -> torch.Tensor: - if new_format == old_format: - return bounding_box.clone() - - if old_format == BoundingBoxFormat.XYWH: - bounding_box = _xywh_to_xyxy(bounding_box) - elif old_format == BoundingBoxFormat.CXCYWH: - bounding_box = _cxcywh_to_xyxy(bounding_box) - - if new_format == BoundingBoxFormat.XYWH: - bounding_box = _xyxy_to_xywh(bounding_box) - elif new_format == BoundingBoxFormat.CXCYWH: - bounding_box = _xyxy_to_cxcywh(bounding_box) - - return bounding_box - - -def _grayscale_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: - return grayscale.expand(3, 1, 1) - - -def convert_color_space(image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace) -> torch.Tensor: - if new_color_space == old_color_space: - return image.clone() - - if old_color_space == ColorSpace.GRAYSCALE: - image = _grayscale_to_rgb(image) - - if new_color_space == ColorSpace.GRAYSCALE: - image = _rgb_to_grayscale(image) - - return image diff --git a/torchvision/prototype/transforms/kernels/_misc.py b/torchvision/prototype/transforms/kernels/_misc.py deleted file mode 100644 index f4e2c69c7ee..00000000000 --- a/torchvision/prototype/transforms/kernels/_misc.py +++ /dev/null @@ -1,5 +0,0 @@ -from torchvision.transforms import functional as _F - - -normalize_image = _F.normalize -gaussian_blur_image = _F.gaussian_blur