Skip to content

Commit 7b8d79b

Browse files
committed
Updated type hints and removed bypass option from mid-level methods
1 parent c68afd6 commit 7b8d79b

File tree

4 files changed

+97
-151
lines changed

4 files changed

+97
-151
lines changed

torchvision/prototype/features/_segmentation_mask.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def vertical_flip(self) -> SegmentationMask:
2424
def resize( # type: ignore[override]
2525
self,
2626
size: List[int],
27-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
27+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
2828
max_size: Optional[int] = None,
2929
antialias: bool = False,
3030
) -> SegmentationMask:
@@ -52,7 +52,7 @@ def resized_crop(
5252
height: int,
5353
width: int,
5454
size: List[int],
55-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
55+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
5656
antialias: bool = False,
5757
) -> SegmentationMask:
5858
from torchvision.prototype.transforms import functional as _F
@@ -106,7 +106,7 @@ def affine(
106106
def perspective(
107107
self,
108108
perspective_coeffs: List[float],
109-
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
109+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
110110
fill: Optional[List[float]] = None,
111111
) -> SegmentationMask:
112112
from torchvision.prototype.transforms import functional as _F

torchvision/prototype/transforms/_augment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
8989
if isinstance(inpt, features._Feature):
9090
return inpt.erase(**params)
9191
elif isinstance(inpt, PIL.Image.Image):
92-
# Shouldn't we implement a fallback to tensor ?
92+
# TODO: We should implement a fallback to tensor, like gaussian_blur etc
9393
raise RuntimeError("Not implemented")
9494
elif isinstance(inpt, torch.Tensor):
9595
return F.erase_image_tensor(inpt, **params)
Lines changed: 37 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,171 +1,141 @@
1-
from typing import Any
1+
from typing import Union
22

33
import PIL.Image
44
import torch
55
from torchvision.prototype import features
66
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
77

88

9+
# shortcut type
10+
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
11+
912
adjust_brightness_image_tensor = _FT.adjust_brightness
1013
adjust_brightness_image_pil = _FP.adjust_brightness
1114

1215

13-
def adjust_brightness(inpt: Any, brightness_factor: float) -> Any:
16+
def adjust_brightness(inpt: DType, brightness_factor: float) -> DType:
1417
if isinstance(inpt, features._Feature):
1518
return inpt.adjust_brightness(brightness_factor=brightness_factor)
16-
elif isinstance(inpt, PIL.Image.Image):
19+
if isinstance(inpt, PIL.Image.Image):
1720
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
18-
elif isinstance(inpt, torch.Tensor):
19-
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
20-
else:
21-
return inpt
21+
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
2222

2323

2424
adjust_saturation_image_tensor = _FT.adjust_saturation
2525
adjust_saturation_image_pil = _FP.adjust_saturation
2626

2727

28-
def adjust_saturation(inpt: Any, saturation_factor: float) -> Any:
28+
def adjust_saturation(inpt: DType, saturation_factor: float) -> DType:
2929
if isinstance(inpt, features._Feature):
3030
return inpt.adjust_saturation(saturation_factor=saturation_factor)
31-
elif isinstance(inpt, PIL.Image.Image):
31+
if isinstance(inpt, PIL.Image.Image):
3232
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
33-
elif isinstance(inpt, torch.Tensor):
34-
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
35-
else:
36-
return inpt
33+
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
3734

3835

3936
adjust_contrast_image_tensor = _FT.adjust_contrast
4037
adjust_contrast_image_pil = _FP.adjust_contrast
4138

4239

43-
def adjust_contrast(inpt: Any, contrast_factor: float) -> Any:
40+
def adjust_contrast(inpt: DType, contrast_factor: float) -> DType:
4441
if isinstance(inpt, features._Feature):
4542
return inpt.adjust_contrast(contrast_factor=contrast_factor)
46-
elif isinstance(inpt, PIL.Image.Image):
43+
if isinstance(inpt, PIL.Image.Image):
4744
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
48-
elif isinstance(inpt, torch.Tensor):
49-
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
50-
else:
51-
return inpt
45+
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
5246

5347

5448
adjust_sharpness_image_tensor = _FT.adjust_sharpness
5549
adjust_sharpness_image_pil = _FP.adjust_sharpness
5650

5751

58-
def adjust_sharpness(inpt: Any, sharpness_factor: float) -> Any:
52+
def adjust_sharpness(inpt: DType, sharpness_factor: float) -> DType:
5953
if isinstance(inpt, features._Feature):
6054
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
61-
elif isinstance(inpt, PIL.Image.Image):
55+
if isinstance(inpt, PIL.Image.Image):
6256
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
63-
elif isinstance(inpt, torch.Tensor):
64-
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
65-
else:
66-
return inpt
57+
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
6758

6859

6960
adjust_hue_image_tensor = _FT.adjust_hue
7061
adjust_hue_image_pil = _FP.adjust_hue
7162

7263

73-
def adjust_hue(inpt: Any, hue_factor: float) -> Any:
64+
def adjust_hue(inpt: DType, hue_factor: float) -> DType:
7465
if isinstance(inpt, features._Feature):
7566
return inpt.adjust_hue(hue_factor=hue_factor)
76-
elif isinstance(inpt, PIL.Image.Image):
67+
if isinstance(inpt, PIL.Image.Image):
7768
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
78-
elif isinstance(inpt, torch.Tensor):
79-
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
80-
else:
81-
return inpt
69+
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
8270

8371

8472
adjust_gamma_image_tensor = _FT.adjust_gamma
8573
adjust_gamma_image_pil = _FP.adjust_gamma
8674

8775

88-
def adjust_gamma(inpt: Any, gamma: float, gain: float = 1) -> Any:
76+
def adjust_gamma(inpt: DType, gamma: float, gain: float = 1) -> DType:
8977
if isinstance(inpt, features._Feature):
9078
return inpt.adjust_gamma(gamma=gamma, gain=gain)
91-
elif isinstance(inpt, PIL.Image.Image):
79+
if isinstance(inpt, PIL.Image.Image):
9280
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
93-
elif isinstance(inpt, torch.Tensor):
94-
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
95-
else:
96-
return inpt
81+
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
9782

9883

9984
posterize_image_tensor = _FT.posterize
10085
posterize_image_pil = _FP.posterize
10186

10287

103-
def posterize(inpt: Any, bits: int) -> Any:
88+
def posterize(inpt: DType, bits: int) -> DType:
10489
if isinstance(inpt, features._Feature):
10590
return inpt.posterize(bits=bits)
106-
elif isinstance(inpt, PIL.Image.Image):
91+
if isinstance(inpt, PIL.Image.Image):
10792
return posterize_image_pil(inpt, bits=bits)
108-
elif isinstance(inpt, torch.Tensor):
109-
return posterize_image_tensor(inpt, bits=bits)
110-
else:
111-
return inpt
93+
return posterize_image_tensor(inpt, bits=bits)
11294

11395

11496
solarize_image_tensor = _FT.solarize
11597
solarize_image_pil = _FP.solarize
11698

11799

118-
def solarize(inpt: Any, threshold: float) -> Any:
100+
def solarize(inpt: DType, threshold: float) -> DType:
119101
if isinstance(inpt, features._Feature):
120102
return inpt.solarize(threshold=threshold)
121-
elif isinstance(inpt, PIL.Image.Image):
103+
if isinstance(inpt, PIL.Image.Image):
122104
return solarize_image_pil(inpt, threshold=threshold)
123-
elif isinstance(inpt, torch.Tensor):
124-
return solarize_image_tensor(inpt, threshold=threshold)
125-
else:
126-
return inpt
105+
return solarize_image_tensor(inpt, threshold=threshold)
127106

128107

129108
autocontrast_image_tensor = _FT.autocontrast
130109
autocontrast_image_pil = _FP.autocontrast
131110

132111

133-
def autocontrast(inpt: Any) -> Any:
112+
def autocontrast(inpt: DType) -> DType:
134113
if isinstance(inpt, features._Feature):
135114
return inpt.autocontrast()
136-
elif isinstance(inpt, PIL.Image.Image):
115+
if isinstance(inpt, PIL.Image.Image):
137116
return autocontrast_image_pil(inpt)
138-
elif isinstance(inpt, torch.Tensor):
139-
return autocontrast_image_tensor(inpt)
140-
else:
141-
return inpt
117+
return autocontrast_image_tensor(inpt)
142118

143119

144120
equalize_image_tensor = _FT.equalize
145121
equalize_image_pil = _FP.equalize
146122

147123

148-
def equalize(inpt: Any) -> Any:
124+
def equalize(inpt: DType) -> DType:
149125
if isinstance(inpt, features._Feature):
150126
return inpt.equalize()
151-
elif isinstance(inpt, PIL.Image.Image):
127+
if isinstance(inpt, PIL.Image.Image):
152128
return equalize_image_pil(inpt)
153-
elif isinstance(inpt, torch.Tensor):
154-
return equalize_image_tensor(inpt)
155-
else:
156-
return inpt
129+
return equalize_image_tensor(inpt)
157130

158131

159132
invert_image_tensor = _FT.invert
160133
invert_image_pil = _FP.invert
161134

162135

163-
def invert(inpt: Any) -> Any:
136+
def invert(inpt: DType) -> DType:
164137
if isinstance(inpt, features._Feature):
165138
return inpt.invert()
166-
elif isinstance(inpt, PIL.Image.Image):
139+
if isinstance(inpt, PIL.Image.Image):
167140
return invert_image_pil(inpt)
168-
elif isinstance(inpt, torch.Tensor):
169-
return invert_image_tensor(inpt)
170-
else:
171-
return inpt
141+
return invert_image_tensor(inpt)

0 commit comments

Comments
 (0)