Skip to content

Commit bf1096b

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] [proto] Use the proper _transformed_types in all Transforms and eliminate unnecessary dispatching (#6494)
Summary: * Update types in deprecated transforms. * Update types in type conversion transforms. * Fixing types in meta transforms. * More changes on type conversion. * Bug fix. * Fix types * Remove unnecessary conversions. * Remove unnecessary import. * Fixing tests * Remove copy support from `to_image_tensor` * restore test param * Fix further tests Reviewed By: NicolasHug Differential Revision: D39131008 fbshipit-source-id: f44bff9066888661a764fe0a50a77894d2c31140
1 parent 2cadbda commit bf1096b

File tree

6 files changed

+43
-97
lines changed

6 files changed

+43
-97
lines changed

test/test_prototype_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,10 +1042,10 @@ def test__transform(self, inpt_type, mocker):
10421042
inpt = mocker.MagicMock(spec=inpt_type)
10431043
transform = transforms.ToImageTensor()
10441044
transform(inpt)
1045-
if inpt_type in (features.BoundingBox, str, int):
1045+
if inpt_type in (features.BoundingBox, features.Image, str, int):
10461046
assert fn.call_count == 0
10471047
else:
1048-
fn.assert_called_once_with(inpt, copy=transform.copy)
1048+
fn.assert_called_once_with(inpt)
10491049

10501050

10511051
class TestToImagePIL:
@@ -1059,7 +1059,7 @@ def test__transform(self, inpt_type, mocker):
10591059
inpt = mocker.MagicMock(spec=inpt_type)
10601060
transform = transforms.ToImagePIL()
10611061
transform(inpt)
1062-
if inpt_type in (features.BoundingBox, str, int):
1062+
if inpt_type in (features.BoundingBox, PIL.Image.Image, str, int):
10631063
assert fn.call_count == 0
10641064
else:
10651065
fn.assert_called_once_with(inpt, mode=transform.mode)

test/test_prototype_transforms_functional.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,39 +1867,30 @@ def test_midlevel_normalize_output_type():
18671867
@pytest.mark.parametrize(
18681868
"inpt",
18691869
[
1870-
torch.randint(0, 256, size=(3, 32, 32)),
18711870
127 * np.ones((32, 32, 3), dtype="uint8"),
18721871
PIL.Image.new("RGB", (32, 32), 122),
18731872
],
18741873
)
1875-
@pytest.mark.parametrize("copy", [True, False])
1876-
def test_to_image_tensor(inpt, copy):
1877-
output = F.to_image_tensor(inpt, copy=copy)
1874+
def test_to_image_tensor(inpt):
1875+
output = F.to_image_tensor(inpt)
18781876
assert isinstance(output, torch.Tensor)
18791877

18801878
assert np.asarray(inpt).sum() == output.sum().item()
18811879

1882-
if isinstance(inpt, PIL.Image.Image) and not copy:
1880+
if isinstance(inpt, PIL.Image.Image):
18831881
# we can't check this option
18841882
# as PIL -> numpy is always copying
18851883
return
18861884

1887-
if isinstance(inpt, PIL.Image.Image):
1888-
inpt.putpixel((0, 0), 11)
1889-
else:
1890-
inpt[0, 0, 0] = 11
1891-
if copy:
1892-
assert output[0, 0, 0] != 11
1893-
else:
1894-
assert output[0, 0, 0] == 11
1885+
inpt[0, 0, 0] = 11
1886+
assert output[0, 0, 0] == 11
18951887

18961888

18971889
@pytest.mark.parametrize(
18981890
"inpt",
18991891
[
19001892
torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
19011893
127 * np.ones((32, 32, 3), dtype="uint8"),
1902-
PIL.Image.new("RGB", (32, 32), 122),
19031894
],
19041895
)
19051896
@pytest.mark.parametrize("mode", [None, "RGB"])

torchvision/prototype/transforms/_deprecated.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import PIL.Image
6+
import torch
67
import torchvision.prototype.transforms.functional as F
78
from torchvision.prototype import features
89
from torchvision.prototype.features import ColorSpace
@@ -15,9 +16,7 @@
1516

1617

1718
class ToTensor(Transform):
18-
19-
# Updated transformed types for ToTensor
20-
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
19+
_transformed_types = (PIL.Image.Image, np.ndarray)
2120

2221
def __init__(self) -> None:
2322
warnings.warn(
@@ -26,32 +25,26 @@ def __init__(self) -> None:
2625
)
2726
super().__init__()
2827

29-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
30-
if isinstance(inpt, (PIL.Image.Image, np.ndarray)):
31-
return _F.to_tensor(inpt)
32-
else:
33-
return inpt
28+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
29+
return _F.to_tensor(inpt)
3430

3531

3632
class PILToTensor(Transform):
33+
_transformed_types = (PIL.Image.Image,)
34+
3735
def __init__(self) -> None:
3836
warnings.warn(
3937
"The transform `PILToTensor()` is deprecated and will be removed in a future release. "
4038
"Instead, please use `transforms.ToImageTensor()`."
4139
)
4240
super().__init__()
4341

44-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
45-
if isinstance(inpt, PIL.Image.Image):
46-
return _F.pil_to_tensor(inpt)
47-
else:
48-
return inpt
42+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
43+
return _F.pil_to_tensor(inpt)
4944

5045

5146
class ToPILImage(Transform):
52-
53-
# Updated transformed types for ToPILImage
54-
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
47+
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)
5548

5649
def __init__(self, mode: Optional[str] = None) -> None:
5750
warnings.warn(
@@ -61,11 +54,8 @@ def __init__(self, mode: Optional[str] = None) -> None:
6154
super().__init__()
6255
self.mode = mode
6356

64-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
65-
if is_simple_tensor(inpt) or isinstance(inpt, (features.Image, np.ndarray)):
66-
return _F.to_pil_image(inpt, mode=self.mode)
67-
else:
68-
return inpt
57+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image:
58+
return _F.to_pil_image(inpt, mode=self.mode)
6959

7060

7161
class Grayscale(Transform):

torchvision/prototype/transforms/_meta.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,32 @@
1111

1212

1313
class ConvertBoundingBoxFormat(Transform):
14+
_transformed_types = (features.BoundingBox,)
15+
1416
def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
1517
super().__init__()
1618
if isinstance(format, str):
1719
format = features.BoundingBoxFormat[format]
1820
self.format = format
1921

2022
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
21-
if isinstance(inpt, features.BoundingBox):
22-
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
23-
return features.BoundingBox.new_like(inpt, output, format=params["format"])
24-
else:
25-
return inpt
23+
output = F.convert_bounding_box_format(inpt, old_format=inpt.format, new_format=params["format"])
24+
return features.BoundingBox.new_like(inpt, output, format=params["format"])
2625

2726

2827
class ConvertImageDtype(Transform):
28+
_transformed_types = (is_simple_tensor, features.Image)
29+
2930
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
3031
super().__init__()
3132
self.dtype = dtype
3233

3334
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
34-
if isinstance(inpt, features.Image):
35-
output = convert_image_dtype(inpt, dtype=self.dtype)
36-
return features.Image.new_like(inpt, output, dtype=self.dtype)
37-
elif is_simple_tensor(inpt):
38-
return convert_image_dtype(inpt, dtype=self.dtype)
39-
else:
40-
return inpt
35+
output = convert_image_dtype(inpt, dtype=self.dtype)
36+
return output if is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype)
4137

4238

4339
class ConvertColorSpace(Transform):
44-
# F.convert_color_space does NOT handle `_Feature`'s in general
4540
_transformed_types = (is_simple_tensor, features.Image, PIL.Image.Image)
4641

4742
def __init__(

torchvision/prototype/transforms/_type_conversion.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111

1212

1313
class DecodeImage(Transform):
14-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
15-
if isinstance(inpt, features.EncodedImage):
16-
output = F.decode_image_with_pil(inpt)
17-
return features.Image(output)
18-
else:
19-
return inpt
14+
_transformed_types = (features.EncodedImage,)
15+
16+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
17+
output = F.decode_image_with_pil(inpt)
18+
return features.Image(output)
2019

2120

2221
class LabelToOneHot(Transform):
@@ -41,33 +40,19 @@ def extra_repr(self) -> str:
4140

4241

4342
class ToImageTensor(Transform):
43+
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
4444

45-
# Updated transformed types for ToImageTensor
46-
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
47-
48-
def __init__(self, *, copy: bool = False) -> None:
49-
super().__init__()
50-
self.copy = copy
51-
52-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
53-
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
54-
output = F.to_image_tensor(inpt, copy=self.copy)
55-
return features.Image(output)
56-
else:
57-
return inpt
45+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> features.Image:
46+
output = F.to_image_tensor(inpt)
47+
return features.Image(output)
5848

5949

6050
class ToImagePIL(Transform):
61-
62-
# Updated transformed types for ToImagePIL
63-
_transformed_types = (is_simple_tensor, features._Feature, PIL.Image.Image, np.ndarray)
51+
_transformed_types = (is_simple_tensor, features.Image, np.ndarray)
6452

6553
def __init__(self, *, mode: Optional[str] = None) -> None:
6654
super().__init__()
6755
self.mode = mode
6856

69-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
70-
if isinstance(inpt, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(inpt):
71-
return F.to_image_pil(inpt, mode=self.mode)
72-
else:
73-
return inpt
57+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image.Image:
58+
return F.to_image_pil(inpt, mode=self.mode)
Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest.mock
2-
from typing import Any, Dict, Optional, Tuple, Union
2+
from typing import Any, Dict, Tuple, Union
33

44
import numpy as np
55
import PIL.Image
@@ -21,26 +21,11 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor
2121
return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type]
2222

2323

24-
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
24+
def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> torch.Tensor:
2525
if isinstance(image, np.ndarray):
26-
image = torch.from_numpy(image)
27-
28-
if isinstance(image, torch.Tensor):
29-
if copy:
30-
return image.clone()
31-
else:
32-
return image
26+
return torch.from_numpy(image)
3327

3428
return _F.pil_to_tensor(image)
3529

3630

37-
def to_image_pil(
38-
image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], mode: Optional[str] = None
39-
) -> PIL.Image.Image:
40-
if isinstance(image, PIL.Image.Image):
41-
if mode != image.mode:
42-
return image.convert(mode)
43-
else:
44-
return image
45-
46-
return _F.to_pil_image(image, mode=mode)
31+
to_image_pil = _F.to_pil_image

0 commit comments

Comments
 (0)