Skip to content

port image type conversion transforms to prototype API #5640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def rotate_segmentation_mask():
and callable(kernel)
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
and "pil" not in name
and name
not in {
"to_image_tensor",
}
],
)
def test_scriptable(kernel):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._type_conversion import DecodeImage, LabelToOneHot

from ._legacy import Grayscale, RandomGrayscale # usort: skip
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
Original file line number Diff line number Diff line change
@@ -1,14 +1,63 @@
from __future__ import annotations

import warnings
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np
import PIL.Image
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal

from ._meta import ConvertImageColorSpace
from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor


class ToTensor(Transform):
def __init__(self) -> None:
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
)
super().__init__()

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(input)
else:
return input


class PILToTensor(Transform):
def __init__(self) -> None:
warnings.warn(
"The transform `PILToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImageTensor()`."
)
super().__init__()

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, PIL.Image.Image):
return _F.pil_to_tensor(input)
else:
return input


class ToPILImage(Transform):
def __init__(self, mode: Optional[str] = None) -> None:
warnings.warn(
"The transform `ToPILImage()` is deprecated and will be removed in a future release. "
"Instead, please use `transforms.ToImagePIL()`."
)
super().__init__()
self.mode = mode

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if is_simple_tensor(input) or isinstance(input, (features.Image, np.ndarray)):
return _F.to_pil_image(input, mode=self.mode)
else:
return input


class Grayscale(Transform):
Expand Down
29 changes: 29 additions & 0 deletions torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Any, Dict

import numpy as np
import PIL.Image
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F

from ._utils import is_simple_tensor


class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
Expand Down Expand Up @@ -33,3 +37,28 @@ def extra_repr(self) -> str:
return ""

return f"num_categories={self.num_categories}"


class ToImageTensor(Transform):
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input):
output = F.to_image_tensor(input, copy=self.copy)
return features.Image(output)
else:
return input


class ToImagePIL(Transform):
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input):
return F.to_image_pil(input, copy=self.copy)
else:
return input
8 changes: 7 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,10 @@
ten_crop_image_pil,
)
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from ._type_conversion import (
decode_image_with_pil,
decode_video_with_av,
label_to_one_hot,
to_image_tensor,
to_image_pil,
)
23 changes: 22 additions & 1 deletion torchvision/prototype/transforms/functional/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import unittest.mock
from typing import Dict, Any, Tuple
from typing import Dict, Any, Tuple, Union

import numpy as np
import PIL.Image
import torch
from torch.nn.functional import one_hot
from torchvision.io.video import read_video
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as _F


def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor:
Expand All @@ -23,3 +24,23 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor

def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]


def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
if isinstance(image, torch.Tensor):
if copy:
return image.clone()
else:
return image

return _F.to_tensor(image)


def to_image_pil(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> PIL.Image.Image:
if isinstance(image, PIL.Image.Image):
if copy:
return image.copy()
else:
return image

return _F.to_pil_image(to_image_tensor(image, copy=False))
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _is_numpy_image(img: Any) -> bool:
return img.ndim in {2, 3}


def to_tensor(pic):
def to_tensor(pic) -> Tensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
This function does not support torchscript.

Expand Down