diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 413d24cac69..c16d64c903d 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1042,10 +1042,10 @@ def test__transform(self, inpt_type, mocker): inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToImageTensor() transform(inpt) - if inpt_type in (features.BoundingBox, str, int): + if inpt_type in (features.BoundingBox, features.Image, str, int): assert fn.call_count == 0 else: - fn.assert_called_once_with(inpt, copy=transform.copy) + fn.assert_called_once_with(inpt) class TestToImagePIL: @@ -1059,7 +1059,7 @@ def test__transform(self, inpt_type, mocker): inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToImagePIL() transform(inpt) - if inpt_type in (features.BoundingBox, str, int): + if inpt_type in (features.BoundingBox, PIL.Image.Image, str, int): assert fn.call_count == 0 else: fn.assert_called_once_with(inpt, mode=transform.mode) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index cd11eb2a35e..ed47fc1addf 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1867,31 +1867,23 @@ def test_midlevel_normalize_output_type(): @pytest.mark.parametrize( "inpt", [ - torch.randint(0, 256, size=(3, 32, 32)), 127 * np.ones((32, 32, 3), dtype="uint8"), PIL.Image.new("RGB", (32, 32), 122), ], ) -@pytest.mark.parametrize("copy", [True, False]) -def test_to_image_tensor(inpt, copy): - output = F.to_image_tensor(inpt, copy=copy) +def test_to_image_tensor(inpt): + output = F.to_image_tensor(inpt) assert isinstance(output, torch.Tensor) assert np.asarray(inpt).sum() == output.sum().item() - if isinstance(inpt, PIL.Image.Image) and not copy: + if isinstance(inpt, PIL.Image.Image): # we can't check this option # as PIL -> numpy is always copying return - if isinstance(inpt, PIL.Image.Image): - inpt.putpixel((0, 0), 11) - else: - inpt[0, 0, 0] = 11 - if copy: - assert output[0, 0, 0] != 11 - else: - assert output[0, 0, 0] == 11 + inpt[0, 0, 0] = 11 + assert output[0, 0, 0] == 11 @pytest.mark.parametrize( @@ -1899,7 +1891,6 @@ def test_to_image_tensor(inpt, copy): [ torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8), 127 * np.ones((32, 32, 3), dtype="uint8"), - PIL.Image.new("RGB", (32, 32), 122), ], ) @pytest.mark.parametrize("mode", [None, "RGB"]) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index fd2af16ac31..35e6635ce06 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -3,6 +3,7 @@ import numpy as np import PIL.Image +import torch import torchvision.prototype.transforms.functional as F from torchvision.prototype import features from torchvision.prototype.features import ColorSpace @@ -15,9 +16,7 @@ class ToTensor(Transform): - - # Updated transformed types for ToTensor - _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (PIL.Image.Image, np.ndarray) def __init__(self) -> None: warnings.warn( @@ -26,14 +25,13 @@ def __init__(self) -> None: ) super().__init__() - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, (PIL.Image.Image, np.ndarray)): - return _F.to_tensor(inpt) - else: - return inpt + def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: + return _F.to_tensor(inpt) class PILToTensor(Transform): + _transformed_types = (PIL.Image.Image,) + def __init__(self) -> None: warnings.warn( "The transform `PILToTensor()` is deprecated and will be removed in a future release. " @@ -41,17 +39,12 @@ def __init__(self) -> None: ) super().__init__() - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, PIL.Image.Image): - return _F.pil_to_tensor(inpt) - else: - return inpt + def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: + return _F.pil_to_tensor(inpt) class ToPILImage(Transform): - - # Updated transformed types for ToPILImage - _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features.Image, np.ndarray) def __init__(self, mode: Optional[str] = None) -> None: warnings.warn( @@ -61,11 +54,8 @@ def __init__(self, mode: Optional[str] = None) -> None: super().__init__() self.mode = mode - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)): - return _F.to_pil_image(inpt, mode=self.mode) - else: - return inpt + def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image: + return _F.to_pil_image(inpt, mode=self.mode) class Grayscale(Transform): diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index 5b98e90aee1..eaf0399464e 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -11,6 +11,8 @@ class ConvertBoundingBoxFormat(Transform): + _transformed_types = (features.BoundingBox,) + def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: super().__init__() if isinstance(format, str): @@ -18,30 +20,23 @@ def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: self.format = format def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.BoundingBox): - output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"]) - return features.BoundingBox.new_like(inpt, output, format=params["format"]) - else: - return inpt + output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"]) + return features.BoundingBox.new_like(inpt, output, format=params["format"]) class ConvertImageDtype(Transform): + _transformed_types = (is_simple_tensor, features.Image) + def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.Image): - output = convert_image_dtype(inpt, dtype=self.dtype) - return features.Image.new_like(inpt, output, dtype=self.dtype) - elif is_simple_tensor(inpt): - return convert_image_dtype(inpt, dtype=self.dtype) - else: - return inpt + output = convert_image_dtype(inpt, dtype=self.dtype) + return output if is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) class ConvertColorSpace(Transform): - # F.convert_color_space does NOT handle `_Feature`'s in general _transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image) def __init__( diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index b677ccc9d9c..dff129fe807 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -11,12 +11,11 @@ class DecodeImage(Transform): - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.EncodedImage): - output = F.decode_image_with_pil(inpt) - return features.Image(output) - else: - return inpt + _transformed_types = (features.EncodedImage,) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image: + output = F.decode_image_with_pil(inpt) + return features.Image(output) class LabelToOneHot(Transform): @@ -41,33 +40,19 @@ def extra_repr(self) -> str: class ToImageTensor(Transform): + _transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray) - # Updated transformed types for ToImageTensor - _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) - - def __init__(self, *, copy: bool = False) -> None: - super().__init__() - self.copy = copy - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt): - output = F.to_image_tensor(inpt, copy=self.copy) - return features.Image(output) - else: - return inpt + def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image: + output = F.to_image_tensor(inpt) + return features.Image(output) class ToImagePIL(Transform): - - # Updated transformed types for ToImagePIL - _transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray) + _transformed_types = (is_simple_tensor, features.Image, np.ndarray) def __init__(self, *, mode: Optional[str] = None) -> None: super().__init__() self.mode = mode - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt): - return F.to_image_pil(inpt, mode=self.mode) - else: - return inpt + def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image: + return F.to_image_pil(inpt, mode=self.mode) diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index c44cd1f50ef..c134b7b9831 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Tuple, Union import numpy as np import PIL.Image @@ -21,26 +21,11 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type] -def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor: +def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> torch.Tensor: if isinstance(image, np.ndarray): - image = torch.from_numpy(image) - - if isinstance(image, torch.Tensor): - if copy: - return image.clone() - else: - return image + return torch.from_numpy(image) return _F.pil_to_tensor(image) -def to_image_pil( - image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], mode: Optional[str] = None -) -> PIL.Image.Image: - if isinstance(image, PIL.Image.Image): - if mode != image.mode: - return image.convert(mode) - else: - return image - - return _F.to_pil_image(image, mode=mode) +to_image_pil = _F.to_pil_image