Skip to content

Commit 095cabb

Browse files
pmeierdatumbox
andauthored
port image type conversion transforms to prototype API (#5640)
* port image type conversion transforms to prototype API * implement proposal for image type conversion * add deprecation warnings * appease mypy Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent a7746ef commit 095cabb

File tree

7 files changed

+116
-7
lines changed

7 files changed

+116
-7
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,10 @@ def rotate_segmentation_mask():
330330
and callable(kernel)
331331
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
332332
and "pil" not in name
333+
and name
334+
not in {
335+
"to_image_tensor",
336+
}
333337
],
334338
)
335339
def test_scriptable(kernel):

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@
2222
from ._misc import Identity, Normalize, ToDtype, Lambda
2323
from ._type_conversion import DecodeImage, LabelToOneHot
2424

25-
from ._legacy import Grayscale, RandomGrayscale # usort: skip
25+
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip

torchvision/prototype/transforms/_legacy.py renamed to torchvision/prototype/transforms/_deprecated.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,63 @@
1-
from __future__ import annotations
2-
31
import warnings
4-
from typing import Any, Dict
2+
from typing import Any, Dict, Optional
53

4+
import numpy as np
5+
import PIL.Image
6+
from torchvision.prototype import features
67
from torchvision.prototype.features import ColorSpace
78
from torchvision.prototype.transforms import Transform
9+
from torchvision.transforms import functional as _F
810
from typing_extensions import Literal
911

1012
from ._meta import ConvertImageColorSpace
1113
from ._transform import _RandomApplyTransform
14+
from ._utils import is_simple_tensor
15+
16+
17+
class ToTensor(Transform):
18+
def __init__(self) -> None:
19+
warnings.warn(
20+
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
21+
"Instead, please use `transforms.ToImageTensor()`."
22+
)
23+
super().__init__()
24+
25+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
26+
if isinstance(input, (PIL.Image.Image, np.ndarray)):
27+
return _F.to_tensor(input)
28+
else:
29+
return input
30+
31+
32+
class PILToTensor(Transform):
33+
def __init__(self) -> None:
34+
warnings.warn(
35+
"The transform `PILToTensor()` is deprecated and will be removed in a future release. "
36+
"Instead, please use `transforms.ToImageTensor()`."
37+
)
38+
super().__init__()
39+
40+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
41+
if isinstance(input, PIL.Image.Image):
42+
return _F.pil_to_tensor(input)
43+
else:
44+
return input
45+
46+
47+
class ToPILImage(Transform):
48+
def __init__(self, mode: Optional[str] = None) -> None:
49+
warnings.warn(
50+
"The transform `ToPILImage()` is deprecated and will be removed in a future release. "
51+
"Instead, please use `transforms.ToImagePIL()`."
52+
)
53+
super().__init__()
54+
self.mode = mode
55+
56+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
57+
if is_simple_tensor(input) or isinstance(input, (features.Image, np.ndarray)):
58+
return _F.to_pil_image(input, mode=self.mode)
59+
else:
60+
return input
1261

1362

1463
class Grayscale(Transform):

torchvision/prototype/transforms/_type_conversion.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from typing import Any, Dict
22

3+
import numpy as np
4+
import PIL.Image
35
from torchvision.prototype import features
46
from torchvision.prototype.transforms import Transform, functional as F
57

8+
from ._utils import is_simple_tensor
9+
610

711
class DecodeImage(Transform):
812
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
@@ -33,3 +37,28 @@ def extra_repr(self) -> str:
3337
return ""
3438

3539
return f"num_categories={self.num_categories}"
40+
41+
42+
class ToImageTensor(Transform):
43+
def __init__(self, *, copy: bool = False) -> None:
44+
super().__init__()
45+
self.copy = copy
46+
47+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
48+
if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input):
49+
output = F.to_image_tensor(input, copy=self.copy)
50+
return features.Image(output)
51+
else:
52+
return input
53+
54+
55+
class ToImagePIL(Transform):
56+
def __init__(self, *, copy: bool = False) -> None:
57+
super().__init__()
58+
self.copy = copy
59+
60+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
61+
if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input):
62+
return F.to_image_pil(input, copy=self.copy)
63+
else:
64+
return input

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,10 @@
7474
ten_crop_image_pil,
7575
)
7676
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
77-
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
77+
from ._type_conversion import (
78+
decode_image_with_pil,
79+
decode_video_with_av,
80+
label_to_one_hot,
81+
to_image_tensor,
82+
to_image_pil,
83+
)

torchvision/prototype/transforms/functional/_type_conversion.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import unittest.mock
2-
from typing import Dict, Any, Tuple
2+
from typing import Dict, Any, Tuple, Union
33

44
import numpy as np
55
import PIL.Image
66
import torch
77
from torch.nn.functional import one_hot
88
from torchvision.io.video import read_video
99
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
10+
from torchvision.transforms import functional as _F
1011

1112

1213
def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor:
@@ -23,3 +24,23 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
2324

2425
def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
2526
return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]
27+
28+
29+
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
30+
if isinstance(image, torch.Tensor):
31+
if copy:
32+
return image.clone()
33+
else:
34+
return image
35+
36+
return _F.to_tensor(image)
37+
38+
39+
def to_image_pil(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> PIL.Image.Image:
40+
if isinstance(image, PIL.Image.Image):
41+
if copy:
42+
return image.copy()
43+
else:
44+
return image
45+
46+
return _F.to_pil_image(to_image_tensor(image, copy=False))

torchvision/transforms/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _is_numpy_image(img: Any) -> bool:
120120
return img.ndim in {2, 3}
121121

122122

123-
def to_tensor(pic):
123+
def to_tensor(pic) -> Tensor:
124124
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
125125
This function does not support torchscript.
126126

0 commit comments

Comments
 (0)