Skip to content

Commit c0ba3ec

Browse files
authored
Proto transform cleanup (#6408)
* fix TenCrop * use dispatchers for RandomPhotometricDistort * add convert_color_space dispatcher and use it in conversion transforms * fix convert_color_space naming scheme * add to_color_space method to Image feature * remove TODO from BoundingBox.to_format() * fix test * fix imports * fix passthrough * remove apply_recursively in favor of pytree * refactor BatchMultiCrop
1 parent 94960fe commit c0ba3ec

File tree

11 files changed

+116
-121
lines changed

11 files changed

+116
-121
lines changed

test/test_prototype_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def test_random_resized_crop(self, transform, input):
200200
@parametrize(
201201
[
202202
(
203-
transforms.ConvertImageColorSpace(color_space=new_color_space, old_color_space=old_color_space),
203+
transforms.ConvertColorSpace(color_space=new_color_space, old_color_space=old_color_space),
204204
itertools.chain.from_iterable(
205205
[
206206
fn(color_spaces=[old_color_space])
@@ -223,7 +223,7 @@ def test_random_resized_crop(self, transform, input):
223223
)
224224
]
225225
)
226-
def test_convert_image_color_space(self, transform, input):
226+
def test_convertolor_space(self, transform, input):
227227
transform(input)
228228

229229

torchvision/prototype/features/_bounding_box.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,13 @@ def new_like(
6060
)
6161

6262
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
63-
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
64-
# promote this out of the prototype state
65-
66-
# import at runtime to avoid cyclic imports
67-
from torchvision.prototype.transforms.functional import convert_bounding_box_format
63+
from torchvision.prototype.transforms import functional as _F
6864

6965
if isinstance(format, str):
7066
format = BoundingBoxFormat.from_str(format.upper())
7167

7268
return BoundingBox.new_like(
73-
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
69+
self, _F.convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
7470
)
7571

7672
def horizontal_flip(self) -> BoundingBox:

torchvision/prototype/features/_image.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,20 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
9999
else:
100100
return ColorSpace.OTHER
101101

102+
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
103+
from torchvision.prototype.transforms import functional as _F
104+
105+
if isinstance(color_space, str):
106+
color_space = ColorSpace.from_str(color_space.upper())
107+
108+
return Image.new_like(
109+
self,
110+
_F.convert_color_space_image_tensor(
111+
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
112+
),
113+
color_space=color_space,
114+
)
115+
102116
def show(self) -> None:
103117
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
104118
# promote this out of the prototype state

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
ScaleJitter,
3434
TenCrop,
3535
)
36-
from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype
36+
from ._meta import ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
3737
from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype
3838
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
3939

torchvision/prototype/transforms/_color.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -84,30 +84,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
8484
return output
8585

8686

87-
class _RandomChannelShuffle(Transform):
88-
def _get_params(self, sample: Any) -> Dict[str, Any]:
89-
image = query_image(sample)
90-
num_channels, _, _ = get_image_dimensions(image)
91-
return dict(permutation=torch.randperm(num_channels))
92-
93-
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
94-
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
95-
return inpt
96-
97-
image = inpt
98-
if isinstance(inpt, PIL.Image.Image):
99-
image = _F.pil_to_tensor(image)
100-
101-
output = image[..., params["permutation"], :, :]
102-
103-
if isinstance(inpt, features.Image):
104-
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
105-
elif isinstance(inpt, PIL.Image.Image):
106-
output = _F.to_pil_image(output)
107-
108-
return output
109-
110-
11187
class RandomPhotometricDistort(Transform):
11288
def __init__(
11389
self,
@@ -118,35 +94,62 @@ def __init__(
11894
p: float = 0.5,
11995
):
12096
super().__init__()
121-
self._brightness = ColorJitter(brightness=brightness)
122-
self._contrast = ColorJitter(contrast=contrast)
123-
self._hue = ColorJitter(hue=hue)
124-
self._saturation = ColorJitter(saturation=saturation)
125-
self._channel_shuffle = _RandomChannelShuffle()
97+
self.brightness = brightness
98+
self.contrast = contrast
99+
self.hue = hue
100+
self.saturation = saturation
126101
self.p = p
127102

128103
def _get_params(self, sample: Any) -> Dict[str, Any]:
104+
image = query_image(sample)
105+
num_channels, _, _ = get_image_dimensions(image)
129106
return dict(
130107
zip(
131-
["brightness", "contrast1", "saturation", "hue", "contrast2", "channel_shuffle"],
108+
["brightness", "contrast1", "saturation", "hue", "contrast2"],
132109
torch.rand(6) < self.p,
133110
),
134111
contrast_before=torch.rand(()) < 0.5,
112+
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
135113
)
136114

115+
def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
116+
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
117+
return inpt
118+
119+
image = inpt
120+
if isinstance(inpt, PIL.Image.Image):
121+
image = _F.pil_to_tensor(image)
122+
123+
output = image[..., permutation, :, :]
124+
125+
if isinstance(inpt, features.Image):
126+
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
127+
elif isinstance(inpt, PIL.Image.Image):
128+
output = _F.to_pil_image(output)
129+
130+
return output
131+
137132
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
138133
if params["brightness"]:
139-
inpt = self._brightness(inpt)
134+
inpt = F.adjust_brightness(
135+
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
136+
)
140137
if params["contrast1"] and params["contrast_before"]:
141-
inpt = self._contrast(inpt)
142-
if params["saturation"]:
143-
inpt = self._saturation(inpt)
138+
inpt = F.adjust_contrast(
139+
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
140+
)
144141
if params["saturation"]:
145-
inpt = self._saturation(inpt)
142+
inpt = F.adjust_saturation(
143+
inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1])
144+
)
145+
if params["hue"]:
146+
inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1]))
146147
if params["contrast2"] and not params["contrast_before"]:
147-
inpt = self._contrast(inpt)
148-
if params["channel_shuffle"]:
149-
inpt = self._channel_shuffle(inpt)
148+
inpt = F.adjust_contrast(
149+
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
150+
)
151+
if params["channel_permutation"]:
152+
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
150153
return inpt
151154

152155

torchvision/prototype/transforms/_deprecated.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import numpy as np
55
import PIL.Image
66
import torch
7+
import torchvision.prototype.transforms.functional as F
78
from torchvision.prototype import features
89
from torchvision.prototype.features import ColorSpace
910
from torchvision.prototype.transforms import Transform
1011
from torchvision.transforms import functional as _F
1112
from typing_extensions import Literal
1213

13-
from ._meta import ConvertImageColorSpace
1414
from ._transform import _RandomApplyTransform
1515
from ._utils import is_simple_tensor
1616

@@ -90,13 +90,11 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
9090

9191
super().__init__()
9292
self.num_output_channels = num_output_channels
93-
self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)
94-
self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB)
9593

9694
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
97-
output = self._rgb_to_gray(inpt)
95+
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
9896
if self.num_output_channels == 3:
99-
output = self._gray_to_rgb(output)
97+
output = F.convert_color_space(inpt, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
10098
return output
10199

102100

@@ -115,8 +113,7 @@ def __init__(self, p: float = 0.1) -> None:
115113
)
116114

117115
super().__init__(p=p)
118-
self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)
119-
self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB)
120116

121117
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
122-
return self._gray_to_rgb(self._rgb_to_gray(inpt))
118+
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
119+
return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)

torchvision/prototype/transforms/_geometry.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import collections.abc
21
import math
32
import numbers
43
import warnings
@@ -180,9 +179,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
180179
output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
181180
return MultiCropResult(features.Image.new_like(inpt, o) for o in output)
182181
elif is_simple_tensor(inpt):
183-
return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size))
182+
return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip))
184183
elif isinstance(inpt, PIL.Image.Image):
185-
return MultiCropResult(F.ten_crop_image_pil(inpt, self.size))
184+
return MultiCropResult(F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip))
186185
else:
187186
return inpt
188187

@@ -194,31 +193,19 @@ def forward(self, *inputs: Any) -> Any:
194193

195194

196195
class BatchMultiCrop(Transform):
197-
def forward(self, *inputs: Any) -> Any:
198-
# This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one
199-
# significant difference:
200-
# Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from
201-
# the sequence case.
202-
def apply_recursively(obj: Any) -> Any:
203-
if isinstance(obj, MultiCropResult):
204-
crops = obj
205-
if isinstance(obj[0], PIL.Image.Image):
206-
crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment]
207-
208-
batch = torch.stack(crops)
209-
210-
if isinstance(obj[0], features.Image):
211-
batch = features.Image.new_like(obj[0], batch)
212-
213-
return batch
214-
elif isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
215-
return [apply_recursively(item) for item in obj]
216-
elif isinstance(obj, collections.abc.Mapping):
217-
return {key: apply_recursively(item) for key, item in obj.items()}
218-
else:
219-
return obj
220-
221-
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
196+
_transformed_types = (MultiCropResult,)
197+
198+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
199+
crops = inpt
200+
if isinstance(inpt[0], PIL.Image.Image):
201+
crops = [pil_to_tensor(crop) for crop in crops]
202+
203+
batch = torch.stack(crops)
204+
205+
if isinstance(inpt[0], features.Image):
206+
batch = features.Image.new_like(inpt[0], batch)
207+
208+
return batch
222209

223210

224211
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:

torchvision/prototype/transforms/_meta.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Dict, Optional, Union
22

33
import PIL.Image
4+
45
import torch
56
from torchvision.prototype import features
67
from torchvision.prototype.transforms import functional as F, Transform
@@ -39,11 +40,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
3940
return inpt
4041

4142

42-
class ConvertImageColorSpace(Transform):
43+
class ConvertColorSpace(Transform):
44+
# F.convert_color_space does NOT handle `_Feature`'s in general
45+
_transformed_types = (torch.Tensor, features.Image, PIL.Image.Image)
46+
4347
def __init__(
4448
self,
4549
color_space: Union[str, features.ColorSpace],
4650
old_color_space: Optional[Union[str, features.ColorSpace]] = None,
51+
copy: bool = True,
4752
) -> None:
4853
super().__init__()
4954

@@ -55,23 +60,9 @@ def __init__(
5560
old_color_space = features.ColorSpace.from_str(old_color_space)
5661
self.old_color_space = old_color_space
5762

63+
self.copy = copy
64+
5865
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
59-
if isinstance(inpt, features.Image):
60-
output = F.convert_image_color_space_tensor(
61-
inpt, old_color_space=inpt.color_space, new_color_space=self.color_space
62-
)
63-
return features.Image.new_like(inpt, output, color_space=self.color_space)
64-
elif is_simple_tensor(inpt):
65-
if self.old_color_space is None:
66-
raise RuntimeError(
67-
f"In order to convert simple tensor images, `{type(self).__name__}(...)` "
68-
f"needs to be constructed with the `old_color_space=...` parameter."
69-
)
70-
71-
return F.convert_image_color_space_tensor(
72-
inpt, old_color_space=self.old_color_space, new_color_space=self.color_space
73-
)
74-
elif isinstance(inpt, PIL.Image.Image):
75-
return F.convert_image_color_space_pil(inpt, color_space=self.color_space)
76-
else:
77-
return inpt
66+
return F.convert_color_space(
67+
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
68+
)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
22
from ._meta import (
33
convert_bounding_box_format,
4-
convert_image_color_space_tensor,
5-
convert_image_color_space_pil,
4+
convert_color_space_image_tensor,
5+
convert_color_space_image_pil,
6+
convert_color_space,
67
) # usort: skip
78

89
from ._augment import erase_image_pil, erase_image_tensor

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Optional, Tuple
1+
from typing import Any, Optional, Tuple
22

33
import PIL.Image
44
import torch
5-
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
5+
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace, Image
66
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
77

88
get_dimensions_image_tensor = _FT.get_dimensions
@@ -91,7 +91,7 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
9191
_rgb_to_gray = _FT.rgb_to_grayscale
9292

9393

94-
def convert_image_color_space_tensor(
94+
def convert_color_space_image_tensor(
9595
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True
9696
) -> torch.Tensor:
9797
if new_color_space == old_color_space:
@@ -141,7 +141,7 @@ def convert_image_color_space_tensor(
141141
}
142142

143143

144-
def convert_image_color_space_pil(
144+
def convert_color_space_image_pil(
145145
image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True
146146
) -> PIL.Image.Image:
147147
old_mode = image.mode
@@ -154,3 +154,21 @@ def convert_image_color_space_pil(
154154
return image
155155

156156
return image.convert(new_mode)
157+
158+
159+
def convert_color_space(
160+
inpt: Any, *, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True
161+
) -> Any:
162+
if isinstance(inpt, Image):
163+
return inpt.to_color_space(color_space, copy=copy)
164+
elif isinstance(inpt, PIL.Image.Image):
165+
return convert_color_space_image_pil(inpt, color_space, copy=copy)
166+
else:
167+
if old_color_space is None:
168+
raise RuntimeError(
169+
"In order to convert the color space of simple tensor images, "
170+
"the `old_color_space=...` parameter needs to be passed."
171+
)
172+
return convert_color_space_image_tensor(
173+
inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy
174+
)

0 commit comments

Comments
 (0)