From efbf42ea26c2c9c0347378d42609c301ae4e359a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 18 Mar 2022 10:26:21 +0100 Subject: [PATCH 1/4] port image type conversion transforms to prototype API --- torchvision/prototype/transforms/__init__.py | 2 +- .../prototype/transforms/_type_conversion.py | 34 ++++++++++++++++++- .../transforms/functional/__init__.py | 9 ++++- .../transforms/functional/_type_conversion.py | 6 ++++ 4 files changed, 48 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 81e914e8383..e0226e79b9d 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -28,4 +28,4 @@ VideoClassificationEval, OpticalFlowEval, ) -from ._type_conversion import DecodeImage, LabelToOneHot +from ._type_conversion import DecodeImage, LabelToOneHot, ToTensor, ImageToPIL, PILToTensor diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index f2dc426897b..dade689b07e 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -1,8 +1,12 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional +import numpy as np +import PIL.Image from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F +from ._utils import is_simple_tensor + class DecodeImage(Transform): def _transform(self, input: Any, params: Dict[str, Any]) -> Any: @@ -33,3 +37,31 @@ def extra_repr(self) -> str: return "" return f"num_categories={self.num_categories}" + + +class ToTensor(Transform): + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, (PIL.Image.Image, np.ndarray)): + return F.to_tensor(input) + else: + return input + + +class PILToTensor(Transform): + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, PIL.Image.Image): + return F.pil_to_tensor(input) + else: + return input + + +class ImageToPIL(Transform): + def __init__(self, mode: Optional[str] = None) -> None: + super().__init__() + self.mode = mode + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if is_simple_tensor(input) or isinstance(input, (features.Image, np.ndarray)): + return F.image_to_pil(input, mode=self.mode) + else: + return input diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 469768ba9c2..5a71ca3aba8 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -71,4 +71,11 @@ ten_crop_image_pil, ) from ._misc import normalize_image_tensor, gaussian_blur_image_tensor -from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot +from ._type_conversion import ( + decode_image_with_pil, + decode_video_with_av, + label_to_one_hot, + to_tensor, + pil_to_tensor, + image_to_pil, +) diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index 06b2daaf6f1..acdafc0a224 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -7,6 +7,7 @@ from torch.nn.functional import one_hot from torchvision.io.video import read_video from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer +from torchvision.transforms import functional as F def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor: @@ -23,3 +24,8 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor: return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return] + + +to_tensor = F.to_tensor +pil_to_tensor = F.pil_to_tensor +image_to_pil = F.to_pil_image From ae88555464c82fad955fbeae1de7292ceabee200 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 29 Mar 2022 18:05:32 +0200 Subject: [PATCH 2/4] implement proposal for image type conversion --- test/test_prototype_transforms_functional.py | 4 ++ torchvision/prototype/transforms/__init__.py | 4 +- .../prototype/transforms/_deprecated.py | 40 +++++++++++++++++++ .../prototype/transforms/_type_conversion.py | 29 ++++++-------- .../transforms/functional/__init__.py | 5 +-- .../transforms/functional/_type_conversion.py | 25 +++++++++--- torchvision/transforms/functional.py | 2 +- 7 files changed, 83 insertions(+), 26 deletions(-) create mode 100644 torchvision/prototype/transforms/_deprecated.py diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6f10945feaf..b63919d947d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -306,6 +306,10 @@ def rotate_bounding_box(): and callable(kernel) and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"}) and "pil" not in name + and name + not in { + "to_image_tensor", + } ], ) def test_scriptable(kernel): diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 8b1777411fb..b685ea0b35e 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -20,4 +20,6 @@ ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda -from ._type_conversion import DecodeImage, LabelToOneHot, ToTensor, ImageToPIL, PILToTensor +from ._type_conversion import DecodeImage, LabelToOneHot + +from ._deprecated import ToTensor, ToPILImage, PILToTensor # usort: skip diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py new file mode 100644 index 00000000000..3c08e4195f6 --- /dev/null +++ b/torchvision/prototype/transforms/_deprecated.py @@ -0,0 +1,40 @@ +from typing import Any, Dict, Optional + +import numpy as np +import PIL.Image +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform +from torchvision.transforms import functional as _F + +from ._utils import is_simple_tensor + + +# TODO: add deprecation warning +class ToTensor(Transform): + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, (PIL.Image.Image, np.ndarray)): + return _F.to_tensor(input) + else: + return input + + +# TODO: add deprecation warning +class PILToTensor(Transform): + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, PIL.Image.Image): + return _F.pil_to_tensor(input) + else: + return input + + +# TODO: add deprecation warning +class ToPILImage(Transform): + def __init__(self, mode: Optional[str] = None) -> None: + super().__init__() + self.mode = mode + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if is_simple_tensor(input) or isinstance(input, (features.Image, np.ndarray)): + return _F.to_pil_image(input, mode=self.mode) + else: + return input diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index dade689b07e..09c071a27e0 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict import numpy as np import PIL.Image @@ -39,29 +39,26 @@ def extra_repr(self) -> str: return f"num_categories={self.num_categories}" -class ToTensor(Transform): - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, (PIL.Image.Image, np.ndarray)): - return F.to_tensor(input) - else: - return input - +class ToImageTensor(Transform): + def __init__(self, *, copy: bool = False) -> None: + super().__init__() + self.copy = copy -class PILToTensor(Transform): def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, PIL.Image.Image): - return F.pil_to_tensor(input) + if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input): + output = F.to_image_tensor(input, copy=self.copy) + return features.Image(output) else: return input -class ImageToPIL(Transform): - def __init__(self, mode: Optional[str] = None) -> None: +class ToImagePIL(Transform): + def __init__(self, *, copy: bool = False) -> None: super().__init__() - self.mode = mode + self.copy = copy def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if is_simple_tensor(input) or isinstance(input, (features.Image, np.ndarray)): - return F.image_to_pil(input, mode=self.mode) + if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input): + return F.to_image_pil(input, copy=self.copy) else: return input diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 9aa51269502..9b9a87fd17c 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -77,7 +77,6 @@ decode_image_with_pil, decode_video_with_av, label_to_one_hot, - to_tensor, - pil_to_tensor, - image_to_pil, + to_image_tensor, + to_image_pil, ) diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py index acdafc0a224..37f8f9b70a3 100644 --- a/torchvision/prototype/transforms/functional/_type_conversion.py +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Dict, Any, Tuple +from typing import Dict, Any, Tuple, Union import numpy as np import PIL.Image @@ -7,7 +7,7 @@ from torch.nn.functional import one_hot from torchvision.io.video import read_video from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer -from torchvision.transforms import functional as F +from torchvision.transforms import functional as _F def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor: @@ -26,6 +26,21 @@ def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tenso return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return] -to_tensor = F.to_tensor -pil_to_tensor = F.pil_to_tensor -image_to_pil = F.to_pil_image +def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor: + if isinstance(image, torch.Tensor): + if copy: + return image.clone() + else: + return image + + return _F.to_tensor(image) + + +def to_image_pil(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> PIL.Image.Image: + if isinstance(image, PIL.Image.Image): + if copy: + return image.copy() + else: + return image + + return _F.to_pil_image(to_image_tensor(image, copy=False)) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 230ad67f683..e964b10e18e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -120,7 +120,7 @@ def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} -def to_tensor(pic): +def to_tensor(pic) -> Tensor: """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This function does not support torchscript. From 5742968fbb7e02782a8c19010a0a8f2470c1e483 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Apr 2022 13:57:47 +0200 Subject: [PATCH 3/4] add deprecation warnings --- .../prototype/transforms/_deprecated.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 46009cf1101..820dfd9060d 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -14,8 +14,14 @@ from ._utils import is_simple_tensor -# TODO: add deprecation warning class ToTensor(Transform): + def __init__(self): + warnings.warn( + "The transform `ToTensor()` is deprecated and will be removed in a future release. " + "Instead, please use `transforms.ToImageTensor()`." + ) + super().__init__() + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, (PIL.Image.Image, np.ndarray)): return _F.to_tensor(input) @@ -23,8 +29,14 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return input -# TODO: add deprecation warning class PILToTensor(Transform): + def __init__(self): + warnings.warn( + "The transform `PILToTensor()` is deprecated and will be removed in a future release. " + "Instead, please use `transforms.ToImageTensor()`." + ) + super().__init__() + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, PIL.Image.Image): return _F.pil_to_tensor(input) @@ -32,9 +44,12 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return input -# TODO: add deprecation warning class ToPILImage(Transform): def __init__(self, mode: Optional[str] = None) -> None: + warnings.warn( + "The transform `ToPILImage()` is deprecated and will be removed in a future release. " + "Instead, please use `transforms.ToImagePIL()`." + ) super().__init__() self.mode = mode From ab86ce0d350e16c233428c8f7133f9d5dac7c7a7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 4 Apr 2022 14:34:07 +0200 Subject: [PATCH 4/4] appease mypy --- torchvision/prototype/transforms/_deprecated.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 820dfd9060d..b9b712ebcae 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -15,7 +15,7 @@ class ToTensor(Transform): - def __init__(self): + def __init__(self) -> None: warnings.warn( "The transform `ToTensor()` is deprecated and will be removed in a future release. " "Instead, please use `transforms.ToImageTensor()`." @@ -30,7 +30,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: class PILToTensor(Transform): - def __init__(self): + def __init__(self) -> None: warnings.warn( "The transform `PILToTensor()` is deprecated and will be removed in a future release. " "Instead, please use `transforms.ToImageTensor()`."