diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 28b21ebbaf6..d61004c61a1 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -200,7 +200,7 @@ def test_random_resized_crop(self, transform, input): @parametrize( [ ( - transforms.ConvertImageColorSpace(color_space=new_color_space, old_color_space=old_color_space), + transforms.ConvertColorSpace(color_space=new_color_space, old_color_space=old_color_space), itertools.chain.from_iterable( [ fn(color_spaces=[old_color_space]) @@ -223,7 +223,7 @@ def test_random_resized_crop(self, transform, input): ) ] ) - def test_convert_image_color_space(self, transform, input): + def test_convertolor_space(self, transform, input): transform(input) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 54e1315c9ab..78acdcaa9c6 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -60,17 +60,13 @@ def new_like( ) def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: - # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we - # promote this out of the prototype state - - # import at runtime to avoid cyclic imports - from torchvision.prototype.transforms.functional import convert_bounding_box_format + from torchvision.prototype.transforms import functional as _F if isinstance(format, str): format = BoundingBoxFormat.from_str(format.upper()) return BoundingBox.new_like( - self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format + self, _F.convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format ) def horizontal_flip(self) -> BoundingBox: diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 1726f672eb7..ad0899b289f 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -99,6 +99,20 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: else: return ColorSpace.OTHER + def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image: + from torchvision.prototype.transforms import functional as _F + + if isinstance(color_space, str): + color_space = ColorSpace.from_str(color_space.upper()) + + return Image.new_like( + self, + _F.convert_color_space_image_tensor( + self, old_color_space=self.color_space, new_color_space=color_space, copy=copy + ), + color_space=color_space, + ) + def show(self) -> None: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # promote this out of the prototype state diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index e92ab2f154c..fb7aa7015fd 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -33,7 +33,7 @@ ScaleJitter, TenCrop, ) -from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype +from ._meta import ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index bc29fe5b677..58cc72cda08 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -84,30 +84,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return output -class _RandomChannelShuffle(Transform): - def _get_params(self, sample: Any) -> Dict[str, Any]: - image = query_image(sample) - num_channels, _, _ = get_image_dimensions(image) - return dict(permutation=torch.randperm(num_channels)) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): - return inpt - - image = inpt - if isinstance(inpt, PIL.Image.Image): - image = _F.pil_to_tensor(image) - - output = image[..., params["permutation"], :, :] - - if isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER) - elif isinstance(inpt, PIL.Image.Image): - output = _F.to_pil_image(output) - - return output - - class RandomPhotometricDistort(Transform): def __init__( self, @@ -118,35 +94,62 @@ def __init__( p: float = 0.5, ): super().__init__() - self._brightness = ColorJitter(brightness=brightness) - self._contrast = ColorJitter(contrast=contrast) - self._hue = ColorJitter(hue=hue) - self._saturation = ColorJitter(saturation=saturation) - self._channel_shuffle = _RandomChannelShuffle() + self.brightness = brightness + self.contrast = contrast + self.hue = hue + self.saturation = saturation self.p = p def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + num_channels, _, _ = get_image_dimensions(image) return dict( zip( - ["brightness", "contrast1", "saturation", "hue", "contrast2", "channel_shuffle"], + ["brightness", "contrast1", "saturation", "hue", "contrast2"], torch.rand(6) < self.p, ), contrast_before=torch.rand(()) < 0.5, + channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, ) + def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any: + if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): + return inpt + + image = inpt + if isinstance(inpt, PIL.Image.Image): + image = _F.pil_to_tensor(image) + + output = image[..., permutation, :, :] + + if isinstance(inpt, features.Image): + output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER) + elif isinstance(inpt, PIL.Image.Image): + output = _F.to_pil_image(output) + + return output + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["brightness"]: - inpt = self._brightness(inpt) + inpt = F.adjust_brightness( + inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) + ) if params["contrast1"] and params["contrast_before"]: - inpt = self._contrast(inpt) - if params["saturation"]: - inpt = self._saturation(inpt) + inpt = F.adjust_contrast( + inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) + ) if params["saturation"]: - inpt = self._saturation(inpt) + inpt = F.adjust_saturation( + inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1]) + ) + if params["hue"]: + inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1])) if params["contrast2"] and not params["contrast_before"]: - inpt = self._contrast(inpt) - if params["channel_shuffle"]: - inpt = self._channel_shuffle(inpt) + inpt = F.adjust_contrast( + inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) + ) + if params["channel_permutation"]: + inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) return inpt diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 9edf5b404e7..02e827916ce 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -4,13 +4,13 @@ import numpy as np import PIL.Image import torch +import torchvision.prototype.transforms.functional as F from torchvision.prototype import features from torchvision.prototype.features import ColorSpace from torchvision.prototype.transforms import Transform from torchvision.transforms import functional as _F from typing_extensions import Literal -from ._meta import ConvertImageColorSpace from ._transform import _RandomApplyTransform from ._utils import is_simple_tensor @@ -90,13 +90,11 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: super().__init__() self.num_output_channels = num_output_channels - self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY) - self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - output = self._rgb_to_gray(inpt) + output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB) if self.num_output_channels == 3: - output = self._gray_to_rgb(output) + output = F.convert_color_space(inpt, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY) return output @@ -115,8 +113,7 @@ def __init__(self, p: float = 0.1) -> None: ) super().__init__(p=p) - self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY) - self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._gray_to_rgb(self._rgb_to_gray(inpt)) + output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB) + return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index c88d05cd58f..23fb311a73b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,4 +1,3 @@ -import collections.abc import math import numbers import warnings @@ -180,9 +179,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip) return MultiCropResult(features.Image.new_like(inpt, o) for o in output) elif is_simple_tensor(inpt): - return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size)) + return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)) elif isinstance(inpt, PIL.Image.Image): - return MultiCropResult(F.ten_crop_image_pil(inpt, self.size)) + return MultiCropResult(F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip)) else: return inpt @@ -194,31 +193,19 @@ def forward(self, *inputs: Any) -> Any: class BatchMultiCrop(Transform): - def forward(self, *inputs: Any) -> Any: - # This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one - # significant difference: - # Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from - # the sequence case. - def apply_recursively(obj: Any) -> Any: - if isinstance(obj, MultiCropResult): - crops = obj - if isinstance(obj[0], PIL.Image.Image): - crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment] - - batch = torch.stack(crops) - - if isinstance(obj[0], features.Image): - batch = features.Image.new_like(obj[0], batch) - - return batch - elif isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): - return [apply_recursively(item) for item in obj] - elif isinstance(obj, collections.abc.Mapping): - return {key: apply_recursively(item) for key, item in obj.items()} - else: - return obj - - return apply_recursively(inputs if len(inputs) > 1 else inputs[0]) + _transformed_types = (MultiCropResult,) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + crops = inpt + if isinstance(inpt[0], PIL.Image.Image): + crops = [pil_to_tensor(crop) for crop in crops] + + batch = torch.stack(crops) + + if isinstance(inpt[0], features.Image): + batch = features.Image.new_like(inpt[0], batch) + + return batch def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None: diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 74b7e473cad..b3b87b7cb09 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional, Union import PIL.Image + import torch from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform @@ -39,11 +40,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt -class ConvertImageColorSpace(Transform): +class ConvertColorSpace(Transform): + # F.convert_color_space does NOT handle `_Feature`'s in general + _transformed_types = (torch.Tensor, features.Image, PIL.Image.Image) + def __init__( self, color_space: Union[str, features.ColorSpace], old_color_space: Optional[Union[str, features.ColorSpace]] = None, + copy: bool = True, ) -> None: super().__init__() @@ -55,23 +60,9 @@ def __init__( old_color_space = features.ColorSpace.from_str(old_color_space) self.old_color_space = old_color_space + self.copy = copy + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.Image): - output = F.convert_image_color_space_tensor( - inpt, old_color_space=inpt.color_space, new_color_space=self.color_space - ) - return features.Image.new_like(inpt, output, color_space=self.color_space) - elif is_simple_tensor(inpt): - if self.old_color_space is None: - raise RuntimeError( - f"In order to convert simple tensor images, `{type(self).__name__}(...)` " - f"needs to be constructed with the `old_color_space=...` parameter." - ) - - return F.convert_image_color_space_tensor( - inpt, old_color_space=self.old_color_space, new_color_space=self.color_space - ) - elif isinstance(inpt, PIL.Image.Image): - return F.convert_image_color_space_pil(inpt, color_space=self.color_space) - else: - return inpt + return F.convert_color_space( + inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy + ) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 958f9103e06..fee0c4dd1e3 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,8 +1,9 @@ from torchvision.transforms import InterpolationMode # usort: skip from ._meta import ( convert_bounding_box_format, - convert_image_color_space_tensor, - convert_image_color_space_pil, + convert_color_space_image_tensor, + convert_color_space_image_pil, + convert_color_space, ) # usort: skip from ._augment import erase_image_pil, erase_image_tensor diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index db7918558bc..f1aea2018bc 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -1,8 +1,8 @@ -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import PIL.Image import torch -from torchvision.prototype.features import BoundingBoxFormat, ColorSpace +from torchvision.prototype.features import BoundingBoxFormat, ColorSpace, Image from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT get_dimensions_image_tensor = _FT.get_dimensions @@ -91,7 +91,7 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: _rgb_to_gray = _FT.rgb_to_grayscale -def convert_image_color_space_tensor( +def convert_color_space_image_tensor( image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True ) -> torch.Tensor: if new_color_space == old_color_space: @@ -141,7 +141,7 @@ def convert_image_color_space_tensor( } -def convert_image_color_space_pil( +def convert_color_space_image_pil( image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True ) -> PIL.Image.Image: old_mode = image.mode @@ -154,3 +154,21 @@ def convert_image_color_space_pil( return image return image.convert(new_mode) + + +def convert_color_space( + inpt: Any, *, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True +) -> Any: + if isinstance(inpt, Image): + return inpt.to_color_space(color_space, copy=copy) + elif isinstance(inpt, PIL.Image.Image): + return convert_color_space_image_pil(inpt, color_space, copy=copy) + else: + if old_color_space is None: + raise RuntimeError( + "In order to convert the color space of simple tensor images, " + "the `old_color_space=...` parameter needs to be passed." + ) + return convert_color_space_image_tensor( + inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy + ) diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index fb5c3b83de6..8ddf2aa6178 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -14,7 +14,6 @@ "add_suggestion", "fromfile", "ReadOnlyTensorBuffer", - "apply_recursively", "query_recursively", ] @@ -128,17 +127,6 @@ def read(self, size: int = -1) -> bytes: return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() -def apply_recursively(fn: Callable, obj: Any) -> Any: - # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: - # "a" == "a"[0][0]... - if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): - return [apply_recursively(fn, item) for item in obj] - elif isinstance(obj, collections.abc.Mapping): - return {key: apply_recursively(fn, item) for key, item in obj.items()} - else: - return fn(obj) - - def query_recursively( fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = () ) -> Iterator[D]: