From 7d0057eb3fa63135ce1cf81d32dd4e6b7561f6a0 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 29 Jun 2022 16:07:49 +0000 Subject: [PATCH 1/7] Added mid-level ops and feature-based ops --- test/test_prototype_transforms_functional.py | 45 ++++ .../prototype/features/_bounding_box.py | 122 +++++++++- torchvision/prototype/features/_feature.py | 169 ++++++++++++- torchvision/prototype/features/_image.py | 159 +++++++++++- torchvision/prototype/features/_label.py | 13 +- .../prototype/features/_segmentation_mask.py | 99 +++++++- torchvision/prototype/transforms/_augment.py | 87 +++---- .../transforms/functional/__init__.py | 33 ++- .../transforms/functional/_augment.py | 44 +--- .../prototype/transforms/functional/_color.py | 145 ++++++++++- .../transforms/functional/_geometry.py | 227 ++++++++++++++++-- torchvision/transforms/functional_pil.py | 4 +- 12 files changed, 1005 insertions(+), 142 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 30d9b833ec8..4d48bd7cb7f 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -365,6 +365,18 @@ def rotate_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def crop_image_tensor(): + for image, top, left, height, width in itertools.product(make_images(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]): + yield SampleInput( + image, + top=top, + left=left, + height=height, + width=width, + ) + + @register_kernel_info_from_sample_inputs_fn def crop_bounding_box(): for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]): @@ -499,6 +511,39 @@ def test_scriptable(kernel): jit.script(kernel) +# Test below is intended to test mid-level op vs low-level ops it calls +# For example, resize -> resize_image_tensor, resize_bounding_boxes etc +# TODO: Rewrite this tests as sample args may include more or less params +# than needed by functions +@pytest.mark.parametrize( + "func", + [ + pytest.param(func, id=name) + for name, func in F.__dict__.items() + if not name.startswith("_") + and callable(func) + and all( + feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"} + ) + and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate"} + # We skip 'crop' due to missing 'height' and 'width' + # We skip 'rotate' due to non implemented yet expand=True case for bboxes + ], +) +def test_functional_mid_level(func): + finfos = [finfo for finfo in FUNCTIONAL_INFOS if f"{func.__name__}_" in finfo.name] + for finfo in finfos: + for sample_input in finfo.sample_inputs(): + expected = finfo(sample_input) + kwargs = dict(sample_input.kwargs) + for key in ["format", "image_size"]: + if key in kwargs: + del kwargs[key] + output = func(*sample_input.args, **kwargs) + torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}") + break + + @pytest.mark.parametrize( ("functional_info", "sample_input"), [ diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index cd5cdc69836..dd97e741704 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Any, Tuple, Union, Optional +from typing import Any, List, Tuple, Union, Optional import torch from torchvision._utils import StrEnum +from torchvision.transforms import InterpolationMode from ._feature import _Feature @@ -69,3 +70,122 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: return BoundingBox.new_like( self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format ) + + def horizontal_flip(self) -> BoundingBox: + output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size) # type: ignore[attr-defined] + return BoundingBox.new_like(self, output) + + def vertical_flip(self) -> BoundingBox: + output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size) # type: ignore[attr-defined] + return BoundingBox.new_like(self, output) + + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: bool = False, + ) -> BoundingBox: + # interpolation, antialias # unused + output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) # type: ignore[attr-defined] + image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1]) + return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype) + + def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: + output = self._F.crop_bounding_box(self, self.format, top, left) # type: ignore[attr-defined] + return BoundingBox.new_like(self, output, image_size=(height, width)) + + def center_crop(self, output_size: List[int]) -> BoundingBox: + output = self._F.center_crop_bounding_box( # type: ignore[attr-defined] + self, format=self.format, output_size=output_size, image_size=self.image_size + ) + image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1]) + return BoundingBox.new_like(self, output, image_size=image_size) + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = False, + ) -> BoundingBox: + output = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) # type: ignore[attr-defined] + image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1]) + return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype) + + def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> BoundingBox: + # fill # unused + if padding_mode not in ["constant"]: + raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") + + output = self._F.pad_bounding_box(self, padding, format=self.format) # type: ignore[attr-defined] + + # Update output image size: + # TODO: remove the import below and make _parse_pad_padding available + from torchvision.transforms.functional_tensor import _parse_pad_padding + + left, top, right, bottom = _parse_pad_padding(padding) + height, width = self.image_size + height += top + bottom + width += left + right + + return BoundingBox.new_like(self, output, image_size=(height, width)) + + def rotate( + self, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, + ) -> BoundingBox: + output = self._F.rotate_bounding_box( # type: ignore[attr-defined] + self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center + ) + # TODO: update output image size if expand is True + if expand: + raise RuntimeError("Not yet implemented") + return BoundingBox.new_like(self, output, dtype=output.dtype) + + def affine( + self, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, + ) -> BoundingBox: + output = self._F.affine_bounding_box( # type: ignore[attr-defined] + self, + self.format, + self.image_size, + angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return BoundingBox.new_like(self, output, dtype=output.dtype) + + def perspective( + self, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, + ) -> BoundingBox: + output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) # type: ignore[attr-defined] + return BoundingBox.new_like(self, output, dtype=output.dtype) + + def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> BoundingBox: + raise TypeError("Erase transformation does not support bounding boxes") + + def mixup(self, lam: float) -> BoundingBox: + raise TypeError("Mixup transformation does not support bounding boxes") + + def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> BoundingBox: + raise TypeError("Cutmix transformation does not support bounding boxes") diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index f8026b4d34d..6a645b8d289 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,8 +1,8 @@ -from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping +from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, List, Tuple, Sequence, Mapping import torch from torch._C import _TensorBase, DisableTorchFunction - +from torchvision.transforms import InterpolationMode F = TypeVar("F", bound="_Feature") @@ -16,7 +16,7 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> F: - return cast( + feature = cast( F, torch.Tensor._make_subclass( cast(_TensorBase, cls), @@ -25,6 +25,13 @@ def __new__( ), ) + # To avoid circular dependency between features and transforms + from ..transforms import functional + + feature._F = functional # type: ignore[attr-defined] + + return feature + @classmethod def new_like( cls: Type[F], @@ -83,3 +90,159 @@ def __torch_function__( return cls.new_like(args[0], output, dtype=output.dtype, device=output.device) else: return output + + def horizontal_flip(self) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def vertical_flip(self) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: bool = False, + ) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def crop(self, top: int, left: int, height: int, width: int) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def center_crop(self, output_size: List[int]) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = False, + ) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def rotate( + self, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, + ) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def affine( + self, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, + ) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def perspective( + self, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, + ) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_brightness(self, brightness_factor: float) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_saturation(self, saturation_factor: float) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_contrast(self, contrast_factor: float) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_sharpness(self, sharpness_factor: float) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_hue(self, hue_factor: float) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_gamma(self, gamma: float, gain: float = 1) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def posterize(self, bits: int) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def solarize(self, threshold: float) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def autocontrast(self) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def equalize(self) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def invert(self) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def mixup(self, lam: float) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Any: + # Just output itself + # How dangerous to do this instead of raising an error ? + return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 9206a844b6d..5b9f4a1ebd3 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -1,11 +1,11 @@ from __future__ import annotations import warnings -from typing import Any, Optional, Union, Tuple, cast +from typing import Any, List, Optional, Union, Tuple, cast import torch from torchvision._utils import StrEnum -from torchvision.transforms.functional import to_pil_image +from torchvision.transforms.functional import to_pil_image, InterpolationMode from torchvision.utils import draw_bounding_boxes from torchvision.utils import make_grid @@ -109,3 +109,158 @@ def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # promote this out of the prototype state return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) + + def horizontal_flip(self) -> Image: + output = self._F.horizontal_flip_image_tensor(self) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def vertical_flip(self) -> Image: + output = self._F.vertical_flip_image_tensor(self) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: bool = False, + ) -> Image: + output = self._F.resize_image_tensor( # type: ignore[attr-defined] + self, size, interpolation=interpolation, max_size=max_size, antialias=antialias + ) + return Image.new_like(self, output) + + def crop(self, top: int, left: int, height: int, width: int) -> Image: + output = self._F.crop_image_tensor(self, top, left, height, width) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def center_crop(self, output_size: List[int]) -> Image: + output = self._F.center_crop_image_tensor(self, output_size=output_size) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = False, + ) -> Image: + output = self._F.resized_crop_image_tensor( # type: ignore[attr-defined] + self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias + ) + return Image.new_like(self, output) + + def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Image: + output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def rotate( + self, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, + ) -> Image: + output = self._F.rotate_image_tensor( # type: ignore[attr-defined] + self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center + ) + return Image.new_like(self, output) + + def affine( + self, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, + ) -> Image: + output = self._F.affine_image_tensor( # type: ignore[attr-defined] + self, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + return Image.new_like(self, output) + + def perspective( + self, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, + ) -> Image: + output = self._F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def adjust_brightness(self, brightness_factor: float) -> Image: + output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def adjust_saturation(self, saturation_factor: float) -> Image: + output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def adjust_contrast(self, contrast_factor: float) -> Image: + output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def adjust_sharpness(self, sharpness_factor: float) -> Image: + output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def adjust_hue(self, hue_factor: float) -> Image: + output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def adjust_gamma(self, gamma: float, gain: float = 1) -> Image: + output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def posterize(self, bits: int) -> Image: + output = self._F.posterize_image_tensor(self, bits=bits) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def solarize(self, threshold: float) -> Image: + output = self._F.solarize_image_tensor(self, threshold=threshold) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def autocontrast(self) -> Image: + output = self._F.autocontrast_image_tensor(self) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def equalize(self) -> Image: + output = self._F.equalize_image_tensor(self) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def invert(self) -> Image: + output = self._F.invert_image_tensor(self) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Image: + output = self._F.erase_image_tensor(self, i, j, h, w, v) # type: ignore[attr-defined] + return Image.new_like(self, output) + + def mixup(self, lam: float) -> Image: + if self.ndim < 4: + raise ValueError("Need a batch of images") + output = self.clone() + output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) + return Image.new_like(self, output) + + def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Image: + if self.ndim < 4: + raise ValueError("Need a batch of images") + x1, y1, x2, y2 = box + image_rolled = self.roll(1, -4) + output = self.clone() + output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return Image.new_like(self, output) diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index e3433b7bb08..94e22f76f19 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional, Sequence, cast, Union +from typing import Any, Optional, Sequence, cast, Union, Tuple import torch from torchvision.prototype.utils._internal import apply_recursively @@ -77,3 +77,14 @@ def new_like( return super().new_like( other, data, categories=categories if categories is not None else other.categories, **kwargs ) + + def mixup(self, lam: float) -> OneHotLabel: + if self.ndim < 2: + raise ValueError("Need a batch of one hot labels") + output = self.clone() + output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam)) + return OneHotLabel.new_like(self, output) + + def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> OneHotLabel: + box # unused + return self.mixup(lam_adjusted) diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index dc41697ae9b..1277177fdc8 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -1,5 +1,102 @@ +from __future__ import annotations + +from typing import Tuple, List, Optional + +import torch +from torchvision.transforms import InterpolationMode + from ._feature import _Feature class SegmentationMask(_Feature): - pass + def horizontal_flip(self) -> SegmentationMask: + output = self._F.horizontal_flip_segmentation_mask(self) # type: ignore[attr-defined] + return SegmentationMask.new_like(self, output) + + def vertical_flip(self) -> SegmentationMask: + output = self._F.vertical_flip_segmentation_mask(self) # type: ignore[attr-defined] + return SegmentationMask.new_like(self, output) + + def resize( # type: ignore[override] + self, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: bool = False, + ) -> SegmentationMask: + output = self._F.resize_segmentation_mask(self, size, max_size=max_size) # type: ignore[attr-defined] + return SegmentationMask.new_like(self, output) + + def crop(self, top: int, left: int, height: int, width: int) -> SegmentationMask: + output = self._F.center_crop_segmentation_mask(self, top, left, height, width) # type: ignore[attr-defined] + return SegmentationMask.new_like(self, output) + + def center_crop(self, output_size: List[int]) -> SegmentationMask: + output = self._F.center_crop_segmentation_mask(self, output_size=output_size) # type: ignore[attr-defined] + return SegmentationMask.new_like(self, output) + + def resized_crop( + self, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = False, + ) -> SegmentationMask: + output = self._F.resized_crop_segmentation_mask(self, top, left, height, width, size=size) # type: ignore[attr-defined] + return SegmentationMask.new_like(self, output) + + def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> SegmentationMask: + output = self._F.pad_segmentation_mask(self, padding, padding_mode=padding_mode) # type: ignore[attr-defined] + return SegmentationMask.new_like(self, output) + + def rotate( + self, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, + ) -> SegmentationMask: + output = self._F.rotate_segmentation_mask(self, angle, expand=expand, center=center) # type: ignore[attr-defined] + return SegmentationMask.new_like(self, output) + + def affine( + self, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, + ) -> SegmentationMask: + output = self._F.affine_segmentation_mask( # type: ignore[attr-defined] + self, + angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return SegmentationMask.new_like(self, output) + + def perspective( + self, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, + ) -> SegmentationMask: + output = self._F.perspective_segmentation_mask(self, perspective_coeffs) # type: ignore[attr-defined] + return SegmentationMask.new_like(self, output) + + def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> SegmentationMask: + raise TypeError("Erase transformation does not support segmentation masks") + + def mixup(self, lam: float) -> SegmentationMask: + raise TypeError("Mixup transformation does not support segmentation masks") + + def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> SegmentationMask: + raise TypeError("Cutmix transformation does not support segmentation masks") diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 82c5f52f1dc..4ad9c7302b7 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -3,12 +3,13 @@ import warnings from typing import Any, Dict, Tuple +import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor +from ._utils import query_image, get_image_dimensions, has_all class RandomErasing(_RandomApplyTransform): @@ -51,7 +52,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: if value is not None and not (len(value) in (1, img_c)): raise ValueError( - f"If value is a sequence, it should have either a single value or {img_c} (number of input channels)" + f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" ) area = img_h * img_w @@ -82,59 +83,45 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: else: i, j, h, w, v = 0, 0, img_h, img_w, image - return dict(zip("ijhwv", (i, j, h, w, v))) + return dict(i=i, j=j, h=h, w=w, v=v) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.erase_image_tensor(input, **params) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return F.erase_image_tensor(input, **params) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.erase(**params) + elif isinstance(inpt, PIL.Image.Image): + # Shouldn't we implement a fallback to tensor ? + raise RuntimeError("Not implemented") + elif isinstance(inpt, torch.Tensor): + return F.erase_image_tensor(inpt, **params) else: - return input + return inpt - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - return super().forward(sample) - - -class RandomMixup(Transform): +class _BaseMixupCutmix(Transform): def __init__(self, *, alpha: float) -> None: super().__init__() self.alpha = alpha self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] + if not has_all(sample, features.Image, features.OneHotLabel): + raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + return super().forward(sample) + + +class RandomMixup(_BaseMixupCutmix): def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.mixup_image_tensor(input, **params) - return features.Image.new_like(input, output) - elif isinstance(input, features.OneHotLabel): - output = F.mixup_one_hot_label(input, **params) - return features.OneHotLabel.new_like(input, output) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.mixup(**params) else: - return input - - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - elif not has_all(sample, features.Image, features.OneHotLabel): - raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") - return super().forward(sample) + return inpt -class RandomCutmix(Transform): - def __init__(self, *, alpha: float) -> None: - super().__init__() - self.alpha = alpha - self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) - +class RandomCutmix(_BaseMixupCutmix): def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) @@ -158,20 +145,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.cutmix_image_tensor(input, box=params["box"]) - return features.Image.new_like(input, output) - elif isinstance(input, features.OneHotLabel): - output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"]) - return features.OneHotLabel.new_like(input, output) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.cutmix(**params) else: - return input - - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - elif not has_all(sample, features.Image, features.OneHotLabel): - raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") - return super().forward(sample) + return inpt diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 2a6c7dce516..a8c17577a56 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -7,72 +7,89 @@ from ._augment import ( erase_image_tensor, - mixup_image_tensor, - mixup_one_hot_label, - cutmix_image_tensor, - cutmix_one_hot_label, ) from ._color import ( + adjust_brightness, adjust_brightness_image_tensor, adjust_brightness_image_pil, + adjust_contrast, adjust_contrast_image_tensor, adjust_contrast_image_pil, + adjust_saturation, adjust_saturation_image_tensor, adjust_saturation_image_pil, + adjust_sharpness, adjust_sharpness_image_tensor, adjust_sharpness_image_pil, + adjust_hue, + adjust_hue_image_tensor, + adjust_hue_image_pil, + adjust_gamma, + adjust_gamma_image_tensor, + adjust_gamma_image_pil, + posterize, posterize_image_tensor, posterize_image_pil, + solarize, solarize_image_tensor, solarize_image_pil, + autocontrast, autocontrast_image_tensor, autocontrast_image_pil, + equalize, equalize_image_tensor, equalize_image_pil, + invert, invert_image_tensor, invert_image_pil, - adjust_hue_image_tensor, - adjust_hue_image_pil, - adjust_gamma_image_tensor, - adjust_gamma_image_pil, ) from ._geometry import ( + horizontal_flip, horizontal_flip_bounding_box, horizontal_flip_image_tensor, horizontal_flip_image_pil, horizontal_flip_segmentation_mask, + resize, resize_bounding_box, resize_image_tensor, resize_image_pil, resize_segmentation_mask, + center_crop, center_crop_bounding_box, center_crop_segmentation_mask, center_crop_image_tensor, center_crop_image_pil, + resized_crop, resized_crop_bounding_box, resized_crop_image_tensor, resized_crop_image_pil, resized_crop_segmentation_mask, + affine, affine_bounding_box, affine_image_tensor, affine_image_pil, affine_segmentation_mask, + rotate, rotate_bounding_box, rotate_image_tensor, rotate_image_pil, rotate_segmentation_mask, + pad, pad_bounding_box, pad_image_tensor, pad_image_pil, pad_segmentation_mask, + crop, crop_bounding_box, crop_image_tensor, crop_image_pil, crop_segmentation_mask, + perspective, perspective_bounding_box, perspective_image_tensor, perspective_image_pil, perspective_segmentation_mask, + vertical_flip, vertical_flip_image_tensor, vertical_flip_image_pil, vertical_flip_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 5004ac550dd..3920d1b3065 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,45 +1,13 @@ -from typing import Tuple - -import torch from torchvision.transforms import functional_tensor as _FT erase_image_tensor = _FT.erase -def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: - input = input.clone() - return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) - - -def mixup_image_tensor(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - return _mixup_tensor(image_batch, -4, lam) - - -def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") - - return _mixup_tensor(one_hot_label_batch, -2, lam) - - -def cutmix_image_tensor(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - x1, y1, x2, y2 = box - image_rolled = image_batch.roll(1, -4) - - image_batch = image_batch.clone() - image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] - return image_batch - - -def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") +# TODO: Don't forget to clean up from the primitives kernels those that shouldn't be kernels. +# Like the mixup and cutmix stuff - return _mixup_tensor(one_hot_label_batch, -2, lam_adjusted) +# This function is copy-pasted to Image and OneHotLabel and may be refactored +# def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: +# input = input.clone() +# return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index fa632d7df58..f8016b43a36 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,34 +1,171 @@ +from typing import Any + +import PIL.Image +import torch +from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP + adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness + +def adjust_brightness(inpt: Any, brightness_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_brightness(brightness_factor=brightness_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) + else: + return inpt + + adjust_saturation_image_tensor = _FT.adjust_saturation adjust_saturation_image_pil = _FP.adjust_saturation + +def adjust_saturation(inpt: Any, saturation_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_saturation(saturation_factor=saturation_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) + else: + return inpt + + adjust_contrast_image_tensor = _FT.adjust_contrast adjust_contrast_image_pil = _FP.adjust_contrast + +def adjust_contrast(inpt: Any, contrast_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_contrast(contrast_factor=contrast_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) + else: + return inpt + + adjust_sharpness_image_tensor = _FT.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness + +def adjust_sharpness(inpt: Any, sharpness_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) + else: + return inpt + + +adjust_hue_image_tensor = _FT.adjust_hue +adjust_hue_image_pil = _FP.adjust_hue + + +def adjust_hue(inpt: Any, hue_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_hue(hue_factor=hue_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_hue_image_pil(inpt, hue_factor=hue_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) + else: + return inpt + + +adjust_gamma_image_tensor = _FT.adjust_gamma +adjust_gamma_image_pil = _FP.adjust_gamma + + +def adjust_gamma(inpt: Any, gamma: float, gain: float = 1) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_gamma(gamma=gamma, gain=gain) + elif isinstance(inpt, PIL.Image.Image): + return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) + elif isinstance(inpt, torch.Tensor): + return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) + else: + return inpt + + posterize_image_tensor = _FT.posterize posterize_image_pil = _FP.posterize + +def posterize(inpt: Any, bits: int) -> Any: + if isinstance(inpt, features._Feature): + return inpt.posterize(bits=bits) + elif isinstance(inpt, PIL.Image.Image): + return posterize_image_pil(inpt, bits=bits) + elif isinstance(inpt, torch.Tensor): + return posterize_image_tensor(inpt, bits=bits) + else: + return inpt + + solarize_image_tensor = _FT.solarize solarize_image_pil = _FP.solarize + +def solarize(inpt: Any, threshold: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.solarize(threshold=threshold) + elif isinstance(inpt, PIL.Image.Image): + return solarize_image_pil(inpt, threshold=threshold) + elif isinstance(inpt, torch.Tensor): + return solarize_image_tensor(inpt, threshold=threshold) + else: + return inpt + + autocontrast_image_tensor = _FT.autocontrast autocontrast_image_pil = _FP.autocontrast + +def autocontrast(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.autocontrast() + elif isinstance(inpt, PIL.Image.Image): + return autocontrast_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return autocontrast_image_tensor(inpt) + else: + return inpt + + equalize_image_tensor = _FT.equalize equalize_image_pil = _FP.equalize + +def equalize(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.equalize() + elif isinstance(inpt, PIL.Image.Image): + return equalize_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return equalize_image_tensor(inpt) + else: + return inpt + + invert_image_tensor = _FT.invert invert_image_pil = _FP.invert -adjust_hue_image_tensor = _FT.adjust_hue -adjust_hue_image_pil = _FP.adjust_hue -adjust_gamma_image_tensor = _FT.adjust_gamma -adjust_gamma_image_pil = _FP.adjust_gamma +def invert(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.invert() + elif isinstance(inpt, PIL.Image.Image): + return invert_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return invert_image_tensor(inpt) + else: + return inpt diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 95e094ad798..6a7f62f4198 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,6 +1,6 @@ import numbers import warnings -from typing import Tuple, List, Optional, Sequence, Union +from typing import Any, Tuple, List, Optional, Sequence, Union import PIL.Image import torch @@ -40,12 +40,58 @@ def horizontal_flip_bounding_box( ).view(shape) +def horizontal_flip(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.horizontal_flip() + elif isinstance(inpt, PIL.Image.Image): + return horizontal_flip_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return horizontal_flip_image_tensor(inpt) + else: + return inpt + + +vertical_flip_image_tensor = _FT.vflip +vertical_flip_image_pil = _FP.vflip + + +def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: + return vertical_flip_image_tensor(segmentation_mask) + + +def vertical_flip_bounding_box( + bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_box.shape + + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]] + + return convert_bounding_box_format( + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(shape) + + +def vertical_flip(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.vertical_flip() + elif isinstance(inpt, PIL.Image.Image): + return vertical_flip_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return vertical_flip_image_tensor(inpt) + else: + return inpt + + def resize_image_tensor( image: torch.Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: bool = False, ) -> torch.Tensor: num_channels, old_height, old_width = get_dimensions_image_tensor(image) new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size) @@ -87,28 +133,25 @@ def resize_bounding_box( return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) -vertical_flip_image_tensor = _FT.vflip -vertical_flip_image_pil = _FP.vflip - - -def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: - return vertical_flip_image_tensor(segmentation_mask) - - -def vertical_flip_bounding_box( - bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] -) -> torch.Tensor: - shape = bounding_box.shape - - bounding_box = convert_bounding_box_format( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY - ).view(-1, 4) - - bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]] - - return convert_bounding_box_format( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False - ).view(shape) +def resize( + inpt: Any, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> Any: + if isinstance(inpt, features._Feature): + antialias = False if antialias is None else antialias + return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) + elif isinstance(inpt, PIL.Image.Image): + if antialias is not None and not antialias: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) + elif isinstance(inpt, torch.Tensor): + antialias = False if antialias is None else antialias + return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) + else: + return inpt def _affine_parse_args( @@ -323,6 +366,46 @@ def affine_segmentation_mask( ) +def affine( + inpt: Any, + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> Any: + if isinstance(inpt, features._Feature): + return inpt.affine( + angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center + ) + elif isinstance(inpt, PIL.Image.Image): + return affine_image_pil( + inpt, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + elif isinstance(inpt, torch.Tensor): + return affine_image_tensor( + inpt, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + else: + return inpt + + def rotate_image_tensor( img: torch.Tensor, angle: float, @@ -402,6 +485,24 @@ def rotate_segmentation_mask( ) +def rotate( + inpt: Any, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> Any: + if isinstance(inpt, features._Feature): + return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + elif isinstance(inpt, PIL.Image.Image): + return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + elif isinstance(inpt, torch.Tensor): + return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + else: + return inpt + + pad_image_tensor = _FT.pad pad_image_pil = _FP.pad @@ -436,6 +537,17 @@ def pad_bounding_box( return bounding_box +def pad(inpt: Any, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Any: + if isinstance(inpt, features._Feature): + return inpt.pad(padding, fill=fill, padding_mode=padding_mode) + elif isinstance(inpt, PIL.Image.Image): + return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) + elif isinstance(inpt, torch.Tensor): + return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) + else: + return inpt + + crop_image_tensor = _FT.crop crop_image_pil = _FP.crop @@ -463,6 +575,17 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, return crop_image_tensor(img, top, left, height, width) +def crop(inpt: Any, top: int, left: int, height: int, width: int) -> Any: + if isinstance(inpt, features._Feature): + return inpt.crop(top, left, height, width) + elif isinstance(inpt, PIL.Image.Image): + return crop_image_pil(inpt, top, left, height, width) + elif isinstance(inpt, torch.Tensor): + return crop_image_tensor(inpt, top, left, height, width) + else: + return inpt + + def perspective_image_tensor( img: torch.Tensor, perspective_coeffs: List[float], @@ -474,7 +597,7 @@ def perspective_image_tensor( def perspective_image_pil( img: PIL.Image.Image, - perspective_coeffs: float, + perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BICUBIC, fill: Optional[List[float]] = None, ) -> PIL.Image.Image: @@ -570,6 +693,22 @@ def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[fl return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) +def perspective( + inpt: Any, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, +) -> Any: + if isinstance(inpt, features._Feature): + return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill) + elif isinstance(inpt, PIL.Image.Image): + return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) + elif isinstance(inpt, torch.Tensor): + return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) + else: + return inpt + + def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): return [int(output_size), int(output_size)] @@ -643,6 +782,17 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: return center_crop_image_tensor(img=segmentation_mask, output_size=output_size) +def center_crop(inpt: Any, output_size: List[int]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.center_crop(output_size) + elif isinstance(inpt, PIL.Image.Image): + return center_crop_image_pil(inpt, output_size) + elif isinstance(inpt, torch.Tensor): + return center_crop_image_tensor(inpt, output_size) + else: + return inpt + + def resized_crop_image_tensor( img: torch.Tensor, top: int, @@ -651,9 +801,10 @@ def resized_crop_image_tensor( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = False, ) -> torch.Tensor: img = crop_image_tensor(img, top, left, height, width) - return resize_image_tensor(img, size, interpolation=interpolation) + return resize_image_tensor(img, size, interpolation=interpolation, antialias=antialias) def resized_crop_image_pil( @@ -694,6 +845,30 @@ def resized_crop_segmentation_mask( return resize_segmentation_mask(mask, size) +def resized_crop( + inpt: Any, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[bool] = None, +) -> Any: + if isinstance(inpt, features._Feature): + antialias = False if antialias is None else antialias + return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) + elif isinstance(inpt, PIL.Image.Image): + return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) + elif isinstance(inpt, torch.Tensor): + antialias = False if antialias is None else antialias + return resized_crop_image_tensor( + inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation + ) + else: + return inpt + + def _parse_five_crop_size(size: List[int]) -> List[int]: if isinstance(size, numbers.Number): size = [int(size), int(size)] diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 3710dc21e8c..54bf926762a 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -317,9 +317,9 @@ def rotate( @torch.jit.unused def perspective( img: Image.Image, - perspective_coeffs: float, + perspective_coeffs: List[float], interpolation: int = _pil_constants.BICUBIC, - fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0, + fill: Optional[Union[float, List[float], Tuple[float, ...]]] = None, ) -> Image.Image: if not _is_pil_image(img): From 2b3e9168ffadf0a3bb14c8d4adcf84e91cc7f59c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 30 Jun 2022 09:05:43 +0000 Subject: [PATCH 2/7] Fixing deadlock in dataloader with circular imports --- .github/workflows/prototype-tests.yml | 2 +- .../prototype/features/_bounding_box.py | 42 ++++++--- torchvision/prototype/features/_feature.py | 6 +- torchvision/prototype/features/_image.py | 90 ++++++++++++++----- .../prototype/features/_segmentation_mask.py | 40 ++++++--- 5 files changed, 130 insertions(+), 50 deletions(-) diff --git a/.github/workflows/prototype-tests.yml b/.github/workflows/prototype-tests.yml index 518fa1cedc2..ff29168d9a7 100644 --- a/.github/workflows/prototype-tests.yml +++ b/.github/workflows/prototype-tests.yml @@ -41,4 +41,4 @@ jobs: - name: Run prototype tests shell: bash - run: pytest --durations=20 test/test_prototype_*.py + run: pytest -vvv --durations=20 test/test_prototype_*.py diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index dd97e741704..0e2dc913ddc 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -72,11 +72,15 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: ) def horizontal_flip(self) -> BoundingBox: - output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size) return BoundingBox.new_like(self, output) def vertical_flip(self) -> BoundingBox: - output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size) return BoundingBox.new_like(self, output) def resize( # type: ignore[override] @@ -86,17 +90,22 @@ def resize( # type: ignore[override] max_size: Optional[int] = None, antialias: bool = False, ) -> BoundingBox: - # interpolation, antialias # unused - output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1]) return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype) def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: - output = self._F.crop_bounding_box(self, self.format, top, left) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.crop_bounding_box(self, self.format, top, left) return BoundingBox.new_like(self, output, image_size=(height, width)) def center_crop(self, output_size: List[int]) -> BoundingBox: - output = self._F.center_crop_bounding_box( # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.center_crop_bounding_box( self, format=self.format, output_size=output_size, image_size=self.image_size ) image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1]) @@ -112,16 +121,19 @@ def resized_crop( interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: bool = False, ) -> BoundingBox: - output = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1]) return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype) def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> BoundingBox: - # fill # unused + from torchvision.prototype.transforms import functional as _F + if padding_mode not in ["constant"]: raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") - output = self._F.pad_bounding_box(self, padding, format=self.format) # type: ignore[attr-defined] + output = _F.pad_bounding_box(self, padding, format=self.format) # Update output image size: # TODO: remove the import below and make _parse_pad_padding available @@ -142,7 +154,9 @@ def rotate( fill: Optional[List[float]] = None, center: Optional[List[float]] = None, ) -> BoundingBox: - output = self._F.rotate_bounding_box( # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.rotate_bounding_box( self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center ) # TODO: update output image size if expand is True @@ -160,7 +174,9 @@ def affine( fill: Optional[List[float]] = None, center: Optional[List[float]] = None, ) -> BoundingBox: - output = self._F.affine_bounding_box( # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.affine_bounding_box( self, self.format, self.image_size, @@ -178,7 +194,9 @@ def perspective( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[List[float]] = None, ) -> BoundingBox: - output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.perspective_bounding_box(self, self.format, perspective_coeffs) return BoundingBox.new_like(self, output, dtype=output.dtype) def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> BoundingBox: diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 6a645b8d289..3493e0c98c4 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -26,9 +26,9 @@ def __new__( ) # To avoid circular dependency between features and transforms - from ..transforms import functional - - feature._F = functional # type: ignore[attr-defined] + # This does not work with dataloader multi-worker setup + # from torchvision.prototype.transforms import functional + # setattr(feature, "_F", functional) # type: ignore[attr-defined] return feature diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 5b9f4a1ebd3..5ec502cc3b7 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -111,11 +111,15 @@ def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) def horizontal_flip(self) -> Image: - output = self._F.horizontal_flip_image_tensor(self) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.horizontal_flip_image_tensor(self) return Image.new_like(self, output) def vertical_flip(self) -> Image: - output = self._F.vertical_flip_image_tensor(self) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.vertical_flip_image_tensor(self) return Image.new_like(self, output) def resize( # type: ignore[override] @@ -125,17 +129,21 @@ def resize( # type: ignore[override] max_size: Optional[int] = None, antialias: bool = False, ) -> Image: - output = self._F.resize_image_tensor( # type: ignore[attr-defined] - self, size, interpolation=interpolation, max_size=max_size, antialias=antialias - ) + from torchvision.prototype.transforms import functional as _F + + output = _F.resize_image_tensor(self, size, interpolation=interpolation, max_size=max_size, antialias=antialias) return Image.new_like(self, output) def crop(self, top: int, left: int, height: int, width: int) -> Image: - output = self._F.crop_image_tensor(self, top, left, height, width) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.crop_image_tensor(self, top, left, height, width) return Image.new_like(self, output) def center_crop(self, output_size: List[int]) -> Image: - output = self._F.center_crop_image_tensor(self, output_size=output_size) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.center_crop_image_tensor(self, output_size=output_size) return Image.new_like(self, output) def resized_crop( @@ -148,13 +156,17 @@ def resized_crop( interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: bool = False, ) -> Image: - output = self._F.resized_crop_image_tensor( # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.resized_crop_image_tensor( self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias ) return Image.new_like(self, output) def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Image: - output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) return Image.new_like(self, output) def rotate( @@ -165,7 +177,9 @@ def rotate( fill: Optional[List[float]] = None, center: Optional[List[float]] = None, ) -> Image: - output = self._F.rotate_image_tensor( # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.rotate_image_tensor( self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center ) return Image.new_like(self, output) @@ -180,7 +194,9 @@ def affine( fill: Optional[List[float]] = None, center: Optional[List[float]] = None, ) -> Image: - output = self._F.affine_image_tensor( # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.affine_image_tensor( self, angle, translate=translate, @@ -198,55 +214,81 @@ def perspective( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[List[float]] = None, ) -> Image: - output = self._F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) return Image.new_like(self, output) def adjust_brightness(self, brightness_factor: float) -> Image: - output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor) return Image.new_like(self, output) def adjust_saturation(self, saturation_factor: float) -> Image: - output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor) return Image.new_like(self, output) def adjust_contrast(self, contrast_factor: float) -> Image: - output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor) return Image.new_like(self, output) def adjust_sharpness(self, sharpness_factor: float) -> Image: - output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor) return Image.new_like(self, output) def adjust_hue(self, hue_factor: float) -> Image: - output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.adjust_hue_image_tensor(self, hue_factor=hue_factor) return Image.new_like(self, output) def adjust_gamma(self, gamma: float, gain: float = 1) -> Image: - output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain) return Image.new_like(self, output) def posterize(self, bits: int) -> Image: - output = self._F.posterize_image_tensor(self, bits=bits) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.posterize_image_tensor(self, bits=bits) return Image.new_like(self, output) def solarize(self, threshold: float) -> Image: - output = self._F.solarize_image_tensor(self, threshold=threshold) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.solarize_image_tensor(self, threshold=threshold) return Image.new_like(self, output) def autocontrast(self) -> Image: - output = self._F.autocontrast_image_tensor(self) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.autocontrast_image_tensor(self) return Image.new_like(self, output) def equalize(self) -> Image: - output = self._F.equalize_image_tensor(self) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.equalize_image_tensor(self) return Image.new_like(self, output) def invert(self) -> Image: - output = self._F.invert_image_tensor(self) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.invert_image_tensor(self) return Image.new_like(self, output) def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Image: - output = self._F.erase_image_tensor(self, i, j, h, w, v) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.erase_image_tensor(self, i, j, h, w, v) return Image.new_like(self, output) def mixup(self, lam: float) -> Image: diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index 1277177fdc8..3cc5336f684 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -10,11 +10,15 @@ class SegmentationMask(_Feature): def horizontal_flip(self) -> SegmentationMask: - output = self._F.horizontal_flip_segmentation_mask(self) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.horizontal_flip_segmentation_mask(self) return SegmentationMask.new_like(self, output) def vertical_flip(self) -> SegmentationMask: - output = self._F.vertical_flip_segmentation_mask(self) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.vertical_flip_segmentation_mask(self) return SegmentationMask.new_like(self, output) def resize( # type: ignore[override] @@ -24,15 +28,21 @@ def resize( # type: ignore[override] max_size: Optional[int] = None, antialias: bool = False, ) -> SegmentationMask: - output = self._F.resize_segmentation_mask(self, size, max_size=max_size) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.resize_segmentation_mask(self, size, max_size=max_size) return SegmentationMask.new_like(self, output) def crop(self, top: int, left: int, height: int, width: int) -> SegmentationMask: - output = self._F.center_crop_segmentation_mask(self, top, left, height, width) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.crop_segmentation_mask(self, top, left, height, width) return SegmentationMask.new_like(self, output) def center_crop(self, output_size: List[int]) -> SegmentationMask: - output = self._F.center_crop_segmentation_mask(self, output_size=output_size) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.center_crop_segmentation_mask(self, output_size=output_size) return SegmentationMask.new_like(self, output) def resized_crop( @@ -45,11 +55,15 @@ def resized_crop( interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: bool = False, ) -> SegmentationMask: - output = self._F.resized_crop_segmentation_mask(self, top, left, height, width, size=size) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.resized_crop_segmentation_mask(self, top, left, height, width, size=size) return SegmentationMask.new_like(self, output) def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> SegmentationMask: - output = self._F.pad_segmentation_mask(self, padding, padding_mode=padding_mode) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.pad_segmentation_mask(self, padding, padding_mode=padding_mode) return SegmentationMask.new_like(self, output) def rotate( @@ -60,7 +74,9 @@ def rotate( fill: Optional[List[float]] = None, center: Optional[List[float]] = None, ) -> SegmentationMask: - output = self._F.rotate_segmentation_mask(self, angle, expand=expand, center=center) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.rotate_segmentation_mask(self, angle, expand=expand, center=center) return SegmentationMask.new_like(self, output) def affine( @@ -73,7 +89,9 @@ def affine( fill: Optional[List[float]] = None, center: Optional[List[float]] = None, ) -> SegmentationMask: - output = self._F.affine_segmentation_mask( # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.affine_segmentation_mask( self, angle, translate=translate, @@ -89,7 +107,9 @@ def perspective( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[List[float]] = None, ) -> SegmentationMask: - output = self._F.perspective_segmentation_mask(self, perspective_coeffs) # type: ignore[attr-defined] + from torchvision.prototype.transforms import functional as _F + + output = _F.perspective_segmentation_mask(self, perspective_coeffs) return SegmentationMask.new_like(self, output) def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> SegmentationMask: From d483b1680bd2ae6dbcdb6777dc7abd601a6d76de Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 30 Jun 2022 11:28:52 +0000 Subject: [PATCH 3/7] Added non-scalar fill support workaround for pad --- test/test_prototype_transforms_functional.py | 11 +++++ torchvision/prototype/features/_feature.py | 2 +- torchvision/prototype/features/_image.py | 11 ++++- .../transforms/functional/_geometry.py | 47 +++++++++++++++++-- 4 files changed, 65 insertions(+), 6 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 4d48bd7cb7f..616bee5c26d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -426,6 +426,17 @@ def resized_crop_segmentation_mask(): yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size) +@register_kernel_info_from_sample_inputs_fn +def pad_image_tensor(): + for image, padding, fill, padding_mode in itertools.product( + make_images(), + [[1], [1, 1], [1, 1, 2, 2]], # padding + [12], # fill + ["constant", "symmetric", "edge", "reflect"], # padding mode, + ): + yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode) + + @register_kernel_info_from_sample_inputs_fn def pad_segmentation_mask(): for mask, padding, padding_mode in itertools.product( diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 3493e0c98c4..76afa334344 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -136,7 +136,7 @@ def resized_crop( # How dangerous to do this instead of raising an error ? return self - def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Any: + def pad(self, padding: List[int], fill: Union[float, Sequence[float]] = 0, padding_mode: str = "constant") -> Any: # Just output itself # How dangerous to do this instead of raising an error ? return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 5ec502cc3b7..88bb4225d69 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -163,10 +163,17 @@ def resized_crop( ) return Image.new_like(self, output) - def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Image: + def pad(self, padding: List[int], fill: Union[float, List[float]] = 0.0, padding_mode: str = "constant") -> Image: from torchvision.prototype.transforms import functional as _F - output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) + # PyTorch's pad supports only scalars on fill. So we need to overwrite the colour + if isinstance(fill, (int, float)): + output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) + else: + from torchvision.prototype.transforms.functional._geometry import _pad_with_vector_fill + + output = _pad_with_vector_fill(self, padding, fill=fill, padding_mode=padding_mode) + return Image.new_like(self, output) def rotate( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6a7f62f4198..ef962767ff1 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -503,10 +503,45 @@ def rotate( return inpt -pad_image_tensor = _FT.pad pad_image_pil = _FP.pad +def pad_image_tensor( + img: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant" +) -> torch.Tensor: + num_masks, height, width = img.shape[-3:] + extra_dims = img.shape[:-3] + + padded_image = _FT.pad( + img=img.view(-1, num_masks, height, width), padding=padding, fill=fill, padding_mode=padding_mode + ) + + new_height, new_width = padded_image.shape[-2:] + return padded_image.view(extra_dims + (num_masks, new_height, new_width)) + + +# TODO: This should be removed once pytorch pad supports non-scalar padding values +def _pad_with_vector_fill( + img: torch.Tensor, padding: List[int], fill: Union[float, List[float]] = 0.0, padding_mode: str = "constant" +): + if padding_mode != "constant": + raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") + + output = pad_image_tensor(img, padding, fill=0, padding_mode="constant") + left, top, right, bottom = padding + fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1) + + if top > 0: + output[..., :top, :] = fill + if left > 0: + output[..., :, :left] = fill + if bottom > 0: + output[..., -bottom:, :] = fill + if right > 0: + output[..., :, -right:] = fill + return output + + def pad_segmentation_mask( segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant" ) -> torch.Tensor: @@ -537,13 +572,19 @@ def pad_bounding_box( return bounding_box -def pad(inpt: Any, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Any: +def pad( + inpt: Any, padding: List[int], fill: Union[float, Sequence[float]] = 0.0, padding_mode: str = "constant" +) -> Any: if isinstance(inpt, features._Feature): return inpt.pad(padding, fill=fill, padding_mode=padding_mode) elif isinstance(inpt, PIL.Image.Image): return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) elif isinstance(inpt, torch.Tensor): - return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) + # PyTorch's pad supports only scalars on fill. So we need to overwrite the colour + if isinstance(fill, (int, float)): + return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) + else: + return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode) else: return inpt From 8ef7b3c7605495df38667da4bb1ee5e6c38cceb6 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 30 Jun 2022 13:22:55 +0000 Subject: [PATCH 4/7] Removed comments --- torchvision/prototype/features/_feature.py | 59 ++-------------------- 1 file changed, 3 insertions(+), 56 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 76afa334344..6cde729c9ca 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -16,7 +16,7 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> F: - feature = cast( + return cast( F, torch.Tensor._make_subclass( cast(_TensorBase, cls), @@ -25,13 +25,6 @@ def __new__( ), ) - # To avoid circular dependency between features and transforms - # This does not work with dataloader multi-worker setup - # from torchvision.prototype.transforms import functional - # setattr(feature, "_F", functional) # type: ignore[attr-defined] - - return feature - @classmethod def new_like( cls: Type[F], @@ -92,15 +85,13 @@ def __torch_function__( return output def horizontal_flip(self) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def vertical_flip(self) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self + # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize + # https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593 def resize( # type: ignore[override] self, size: List[int], @@ -108,18 +99,12 @@ def resize( # type: ignore[override] max_size: Optional[int] = None, antialias: bool = False, ) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def crop(self, top: int, left: int, height: int, width: int) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def center_crop(self, output_size: List[int]) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def resized_crop( @@ -132,13 +117,9 @@ def resized_crop( interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: bool = False, ) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def pad(self, padding: List[int], fill: Union[float, Sequence[float]] = 0, padding_mode: str = "constant") -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def rotate( @@ -149,8 +130,6 @@ def rotate( fill: Optional[List[float]] = None, center: Optional[List[float]] = None, ) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def affine( @@ -163,8 +142,6 @@ def affine( fill: Optional[List[float]] = None, center: Optional[List[float]] = None, ) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def perspective( @@ -173,76 +150,46 @@ def perspective( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[List[float]] = None, ) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def adjust_brightness(self, brightness_factor: float) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def adjust_saturation(self, saturation_factor: float) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def adjust_contrast(self, contrast_factor: float) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def adjust_sharpness(self, sharpness_factor: float) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def adjust_hue(self, hue_factor: float) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def adjust_gamma(self, gamma: float, gain: float = 1) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def posterize(self, bits: int) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def solarize(self, threshold: float) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def autocontrast(self) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def equalize(self) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def invert(self) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def mixup(self, lam: float) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Any: - # Just output itself - # How dangerous to do this instead of raising an error ? return self From c68afd692e64b5ed1d5b9c161107e64fe96afaf2 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 30 Jun 2022 15:05:21 +0000 Subject: [PATCH 5/7] int/float support for fill in pad op --- test/test_prototype_transforms_functional.py | 2 +- torchvision/prototype/features/_bounding_box.py | 6 ++++-- torchvision/prototype/features/_feature.py | 4 +++- torchvision/prototype/features/_image.py | 6 ++++-- torchvision/prototype/features/_segmentation_mask.py | 6 ++++-- .../prototype/transforms/functional/_geometry.py | 11 +++++++---- 6 files changed, 23 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 616bee5c26d..d4fb3136ff4 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -431,7 +431,7 @@ def pad_image_tensor(): for image, padding, fill, padding_mode in itertools.product( make_images(), [[1], [1, 1], [1, 1, 2, 2]], # padding - [12], # fill + [12, 12.0], # fill ["constant", "symmetric", "edge", "reflect"], # padding mode, ): yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 0e2dc913ddc..49bd6eba865 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Tuple, Union, Optional +from typing import Any, List, Tuple, Union, Optional, Sequence import torch from torchvision._utils import StrEnum @@ -127,7 +127,9 @@ def resized_crop( image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1]) return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype) - def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> BoundingBox: + def pad( + self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant" + ) -> BoundingBox: from torchvision.prototype.transforms import functional as _F if padding_mode not in ["constant"]: diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 6cde729c9ca..e1d7d56d23d 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -119,7 +119,9 @@ def resized_crop( ) -> Any: return self - def pad(self, padding: List[int], fill: Union[float, Sequence[float]] = 0, padding_mode: str = "constant") -> Any: + def pad( + self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant" + ) -> Any: return self def rotate( diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 88bb4225d69..6acbba38d62 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, List, Optional, Union, Tuple, cast +from typing import Any, List, Optional, Union, Sequence, Tuple, cast import torch from torchvision._utils import StrEnum @@ -163,7 +163,9 @@ def resized_crop( ) return Image.new_like(self, output) - def pad(self, padding: List[int], fill: Union[float, List[float]] = 0.0, padding_mode: str = "constant") -> Image: + def pad( + self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant" + ) -> Image: from torchvision.prototype.transforms import functional as _F # PyTorch's pad supports only scalars on fill. So we need to overwrite the colour diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index 3cc5336f684..82e52208c0d 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Tuple, List, Optional +from typing import Tuple, List, Optional, Union, Sequence import torch from torchvision.transforms import InterpolationMode @@ -60,7 +60,9 @@ def resized_crop( output = _F.resized_crop_segmentation_mask(self, top, left, height, width, size=size) return SegmentationMask.new_like(self, output) - def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> SegmentationMask: + def pad( + self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant" + ) -> SegmentationMask: from torchvision.prototype.transforms import functional as _F output = _F.pad_segmentation_mask(self, padding, padding_mode=padding_mode) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ef962767ff1..fd57188aa5f 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -507,7 +507,7 @@ def rotate( def pad_image_tensor( - img: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant" + img: torch.Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant" ) -> torch.Tensor: num_masks, height, width = img.shape[-3:] extra_dims = img.shape[:-3] @@ -522,8 +522,11 @@ def pad_image_tensor( # TODO: This should be removed once pytorch pad supports non-scalar padding values def _pad_with_vector_fill( - img: torch.Tensor, padding: List[int], fill: Union[float, List[float]] = 0.0, padding_mode: str = "constant" -): + img: torch.Tensor, + padding: List[int], + fill: Sequence[float] = [0.0], + padding_mode: str = "constant", +) -> torch.Tensor: if padding_mode != "constant": raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") @@ -573,7 +576,7 @@ def pad_bounding_box( def pad( - inpt: Any, padding: List[int], fill: Union[float, Sequence[float]] = 0.0, padding_mode: str = "constant" + inpt: Any, padding: List[int], fill: Union[int, float, Sequence[float]] = 0.0, padding_mode: str = "constant" ) -> Any: if isinstance(inpt, features._Feature): return inpt.pad(padding, fill=fill, padding_mode=padding_mode) From 7b8d79ba00f63ee90902211d8cd46612dc7f90e9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 5 Jul 2022 15:50:12 +0000 Subject: [PATCH 6/7] Updated type hints and removed bypass option from mid-level methods --- .../prototype/features/_segmentation_mask.py | 6 +- torchvision/prototype/transforms/_augment.py | 2 +- .../prototype/transforms/functional/_color.py | 104 +++++--------- .../transforms/functional/_geometry.py | 136 ++++++++---------- 4 files changed, 97 insertions(+), 151 deletions(-) diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index 82e52208c0d..653f0f12ba4 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -24,7 +24,7 @@ def vertical_flip(self) -> SegmentationMask: def resize( # type: ignore[override] self, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: InterpolationMode = InterpolationMode.NEAREST, max_size: Optional[int] = None, antialias: bool = False, ) -> SegmentationMask: @@ -52,7 +52,7 @@ def resized_crop( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: InterpolationMode = InterpolationMode.NEAREST, antialias: bool = False, ) -> SegmentationMask: from torchvision.prototype.transforms import functional as _F @@ -106,7 +106,7 @@ def affine( def perspective( self, perspective_coeffs: List[float], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, ) -> SegmentationMask: from torchvision.prototype.transforms import functional as _F diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 4ad9c7302b7..f4dad53a210 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -89,7 +89,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(inpt, features._Feature): return inpt.erase(**params) elif isinstance(inpt, PIL.Image.Image): - # Shouldn't we implement a fallback to tensor ? + # TODO: We should implement a fallback to tensor, like gaussian_blur etc raise RuntimeError("Not implemented") elif isinstance(inpt, torch.Tensor): return F.erase_image_tensor(inpt, **params) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index f8016b43a36..70b7a8e1dfe 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Union import PIL.Image import torch @@ -6,166 +6,136 @@ from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +# shortcut type +DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] + adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness -def adjust_brightness(inpt: Any, brightness_factor: float) -> Any: +def adjust_brightness(inpt: DType, brightness_factor: float) -> DType: if isinstance(inpt, features._Feature): return inpt.adjust_brightness(brightness_factor=brightness_factor) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, torch.Tensor): - return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) - else: - return inpt + return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) adjust_saturation_image_tensor = _FT.adjust_saturation adjust_saturation_image_pil = _FP.adjust_saturation -def adjust_saturation(inpt: Any, saturation_factor: float) -> Any: +def adjust_saturation(inpt: DType, saturation_factor: float) -> DType: if isinstance(inpt, features._Feature): return inpt.adjust_saturation(saturation_factor=saturation_factor) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) - elif isinstance(inpt, torch.Tensor): - return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) - else: - return inpt + return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) adjust_contrast_image_tensor = _FT.adjust_contrast adjust_contrast_image_pil = _FP.adjust_contrast -def adjust_contrast(inpt: Any, contrast_factor: float) -> Any: +def adjust_contrast(inpt: DType, contrast_factor: float) -> DType: if isinstance(inpt, features._Feature): return inpt.adjust_contrast(contrast_factor=contrast_factor) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) - elif isinstance(inpt, torch.Tensor): - return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) - else: - return inpt + return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) adjust_sharpness_image_tensor = _FT.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness -def adjust_sharpness(inpt: Any, sharpness_factor: float) -> Any: +def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType: if isinstance(inpt, features._Feature): return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) - elif isinstance(inpt, torch.Tensor): - return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) - else: - return inpt + return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) adjust_hue_image_tensor = _FT.adjust_hue adjust_hue_image_pil = _FP.adjust_hue -def adjust_hue(inpt: Any, hue_factor: float) -> Any: +def adjust_hue(inpt: DType, hue_factor: float) -> DType: if isinstance(inpt, features._Feature): return inpt.adjust_hue(hue_factor=hue_factor) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return adjust_hue_image_pil(inpt, hue_factor=hue_factor) - elif isinstance(inpt, torch.Tensor): - return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) - else: - return inpt + return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) adjust_gamma_image_tensor = _FT.adjust_gamma adjust_gamma_image_pil = _FP.adjust_gamma -def adjust_gamma(inpt: Any, gamma: float, gain: float = 1) -> Any: +def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType: if isinstance(inpt, features._Feature): return inpt.adjust_gamma(gamma=gamma, gain=gain) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) - elif isinstance(inpt, torch.Tensor): - return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) - else: - return inpt + return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) posterize_image_tensor = _FT.posterize posterize_image_pil = _FP.posterize -def posterize(inpt: Any, bits: int) -> Any: +def posterize(inpt: DType, bits: int) -> DType: if isinstance(inpt, features._Feature): return inpt.posterize(bits=bits) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return posterize_image_pil(inpt, bits=bits) - elif isinstance(inpt, torch.Tensor): - return posterize_image_tensor(inpt, bits=bits) - else: - return inpt + return posterize_image_tensor(inpt, bits=bits) solarize_image_tensor = _FT.solarize solarize_image_pil = _FP.solarize -def solarize(inpt: Any, threshold: float) -> Any: +def solarize(inpt: DType, threshold: float) -> DType: if isinstance(inpt, features._Feature): return inpt.solarize(threshold=threshold) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return solarize_image_pil(inpt, threshold=threshold) - elif isinstance(inpt, torch.Tensor): - return solarize_image_tensor(inpt, threshold=threshold) - else: - return inpt + return solarize_image_tensor(inpt, threshold=threshold) autocontrast_image_tensor = _FT.autocontrast autocontrast_image_pil = _FP.autocontrast -def autocontrast(inpt: Any) -> Any: +def autocontrast(inpt: DType) -> DType: if isinstance(inpt, features._Feature): return inpt.autocontrast() - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return autocontrast_image_pil(inpt) - elif isinstance(inpt, torch.Tensor): - return autocontrast_image_tensor(inpt) - else: - return inpt + return autocontrast_image_tensor(inpt) equalize_image_tensor = _FT.equalize equalize_image_pil = _FP.equalize -def equalize(inpt: Any) -> Any: +def equalize(inpt: DType) -> DType: if isinstance(inpt, features._Feature): return inpt.equalize() - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return equalize_image_pil(inpt) - elif isinstance(inpt, torch.Tensor): - return equalize_image_tensor(inpt) - else: - return inpt + return equalize_image_tensor(inpt) invert_image_tensor = _FT.invert invert_image_pil = _FP.invert -def invert(inpt: Any) -> Any: +def invert(inpt: DType) -> DType: if isinstance(inpt, features._Feature): return inpt.invert() - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return invert_image_pil(inpt) - elif isinstance(inpt, torch.Tensor): - return invert_image_tensor(inpt) - else: - return inpt + return invert_image_tensor(inpt) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index fd57188aa5f..f7de0d95d4c 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,6 +1,6 @@ import numbers import warnings -from typing import Any, Tuple, List, Optional, Sequence, Union +from typing import Tuple, List, Optional, Sequence, Union import PIL.Image import torch @@ -16,6 +16,10 @@ from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil +# shortcut type +DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] + + horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_pil = _FP.hflip @@ -40,15 +44,12 @@ def horizontal_flip_bounding_box( ).view(shape) -def horizontal_flip(inpt: Any) -> Any: +def horizontal_flip(inpt: DType) -> DType: if isinstance(inpt, features._Feature): return inpt.horizontal_flip() - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return horizontal_flip_image_pil(inpt) - elif isinstance(inpt, torch.Tensor): - return horizontal_flip_image_tensor(inpt) - else: - return inpt + return horizontal_flip_image_tensor(inpt) vertical_flip_image_tensor = _FT.vflip @@ -75,15 +76,12 @@ def vertical_flip_bounding_box( ).view(shape) -def vertical_flip(inpt: Any) -> Any: +def vertical_flip(inpt: DType) -> DType: if isinstance(inpt, features._Feature): return inpt.vertical_flip() - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return vertical_flip_image_pil(inpt) - elif isinstance(inpt, torch.Tensor): - return vertical_flip_image_tensor(inpt) - else: - return inpt + return vertical_flip_image_tensor(inpt) def resize_image_tensor( @@ -134,24 +132,22 @@ def resize_bounding_box( def resize( - inpt: Any, + inpt: DType, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = None, -) -> Any: +) -> DType: if isinstance(inpt, features._Feature): antialias = False if antialias is None else antialias return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): if antialias is not None and not antialias: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) - elif isinstance(inpt, torch.Tensor): - antialias = False if antialias is None else antialias - return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - else: - return inpt + + antialias = False if antialias is None else antialias + return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) def _affine_parse_args( @@ -367,7 +363,7 @@ def affine_segmentation_mask( def affine( - inpt: Any, + inpt: DType, angle: float, translate: List[float], scale: float, @@ -375,12 +371,12 @@ def affine( interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, center: Optional[List[float]] = None, -) -> Any: +) -> DType: if isinstance(inpt, features._Feature): return inpt.affine( angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center ) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return affine_image_pil( inpt, angle, @@ -391,19 +387,16 @@ def affine( fill=fill, center=center, ) - elif isinstance(inpt, torch.Tensor): - return affine_image_tensor( - inpt, - angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - else: - return inpt + return affine_image_tensor( + inpt, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) def rotate_image_tensor( @@ -486,21 +479,18 @@ def rotate_segmentation_mask( def rotate( - inpt: Any, + inpt: DType, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, fill: Optional[List[float]] = None, center: Optional[List[float]] = None, -) -> Any: +) -> DType: if isinstance(inpt, features._Feature): return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, PIL.Image.Image): return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - elif isinstance(inpt, torch.Tensor): - return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - else: - return inpt + return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) pad_image_pil = _FP.pad @@ -576,20 +566,18 @@ def pad_bounding_box( def pad( - inpt: Any, padding: List[int], fill: Union[int, float, Sequence[float]] = 0.0, padding_mode: str = "constant" -) -> Any: + inpt: DType, padding: List[int], fill: Union[int, float, Sequence[float]] = 0.0, padding_mode: str = "constant" +) -> DType: if isinstance(inpt, features._Feature): return inpt.pad(padding, fill=fill, padding_mode=padding_mode) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) - elif isinstance(inpt, torch.Tensor): - # PyTorch's pad supports only scalars on fill. So we need to overwrite the colour - if isinstance(fill, (int, float)): - return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) - else: - return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode) + + # PyTorch's pad supports only scalars on fill. So we need to overwrite the colour + if isinstance(fill, (int, float)): + return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) else: - return inpt + return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode) crop_image_tensor = _FT.crop @@ -619,15 +607,12 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, return crop_image_tensor(img, top, left, height, width) -def crop(inpt: Any, top: int, left: int, height: int, width: int) -> Any: +def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType: if isinstance(inpt, features._Feature): return inpt.crop(top, left, height, width) elif isinstance(inpt, PIL.Image.Image): return crop_image_pil(inpt, top, left, height, width) - elif isinstance(inpt, torch.Tensor): - return crop_image_tensor(inpt, top, left, height, width) - else: - return inpt + return crop_image_tensor(inpt, top, left, height, width) def perspective_image_tensor( @@ -738,19 +723,16 @@ def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[fl def perspective( - inpt: Any, + inpt: DType, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[List[float]] = None, -) -> Any: +) -> DType: if isinstance(inpt, features._Feature): return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) - elif isinstance(inpt, torch.Tensor): - return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) - else: - return inpt + return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: @@ -826,15 +808,12 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: return center_crop_image_tensor(img=segmentation_mask, output_size=output_size) -def center_crop(inpt: Any, output_size: List[int]) -> Any: +def center_crop(inpt: DType, output_size: List[int]) -> DType: if isinstance(inpt, features._Feature): return inpt.center_crop(output_size) elif isinstance(inpt, PIL.Image.Image): return center_crop_image_pil(inpt, output_size) - elif isinstance(inpt, torch.Tensor): - return center_crop_image_tensor(inpt, output_size) - else: - return inpt + return center_crop_image_tensor(inpt, output_size) def resized_crop_image_tensor( @@ -890,7 +869,7 @@ def resized_crop_segmentation_mask( def resized_crop( - inpt: Any, + inpt: DType, top: int, left: int, height: int, @@ -898,19 +877,16 @@ def resized_crop( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, -) -> Any: +) -> DType: if isinstance(inpt, features._Feature): antialias = False if antialias is None else antialias return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) - elif isinstance(inpt, torch.Tensor): - antialias = False if antialias is None else antialias - return resized_crop_image_tensor( - inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation - ) - else: - return inpt + antialias = False if antialias is None else antialias + return resized_crop_image_tensor( + inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation + ) def _parse_five_crop_size(size: List[int]) -> List[int]: From 5501fd31ffc637beec702f9e7b9a2f66524eb61b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 5 Jul 2022 17:46:01 +0000 Subject: [PATCH 7/7] Minor nit fixes --- .../transforms/functional/_geometry.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index f7de0d95d4c..5044441d612 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -488,7 +488,7 @@ def rotate( ) -> DType: if isinstance(inpt, features._Feature): return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) @@ -573,11 +573,11 @@ def pad( if isinstance(inpt, PIL.Image.Image): return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) - # PyTorch's pad supports only scalars on fill. So we need to overwrite the colour + # TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour if isinstance(fill, (int, float)): return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) - else: - return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode) + + return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode) crop_image_tensor = _FT.crop @@ -610,7 +610,7 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType: if isinstance(inpt, features._Feature): return inpt.crop(top, left, height, width) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return crop_image_pil(inpt, top, left, height, width) return crop_image_tensor(inpt, top, left, height, width) @@ -738,10 +738,9 @@ def perspective( def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): return [int(output_size), int(output_size)] - elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + if isinstance(output_size, (tuple, list)) and len(output_size) == 1: return [output_size[0], output_size[0]] - else: - return list(output_size) + return list(output_size) def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]: @@ -811,7 +810,7 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: def center_crop(inpt: DType, output_size: List[int]) -> DType: if isinstance(inpt, features._Feature): return inpt.center_crop(output_size) - elif isinstance(inpt, PIL.Image.Image): + if isinstance(inpt, PIL.Image.Image): return center_crop_image_pil(inpt, output_size) return center_crop_image_tensor(inpt, output_size)