Skip to content

Commit ad60b50

Browse files
committed
Refactoring typing definitions.
1 parent 4eb82f4 commit ad60b50

File tree

20 files changed

+136
-147
lines changed

20 files changed

+136
-147
lines changed

torchvision/prototype/features/_bounding_box.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from torchvision._utils import StrEnum
77
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
88

9-
from ._feature import _Feature
10-
from ._utils import FillType
9+
from ._feature import _Feature, FillTypeJIT
1110

1211

1312
class BoundingBoxFormat(StrEnum):
@@ -116,7 +115,7 @@ def resized_crop(
116115
def pad(
117116
self,
118117
padding: Union[int, Sequence[int]],
119-
fill: FillType = None,
118+
fill: FillTypeJIT = None,
120119
padding_mode: str = "constant",
121120
) -> BoundingBox:
122121
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
@@ -138,7 +137,7 @@ def rotate(
138137
angle: float,
139138
interpolation: InterpolationMode = InterpolationMode.NEAREST,
140139
expand: bool = False,
141-
fill: FillType = None,
140+
fill: FillTypeJIT = None,
142141
center: Optional[List[float]] = None,
143142
) -> BoundingBox:
144143
output = self._F.rotate_bounding_box(
@@ -166,7 +165,7 @@ def affine(
166165
scale: float,
167166
shear: List[float],
168167
interpolation: InterpolationMode = InterpolationMode.NEAREST,
169-
fill: FillType = None,
168+
fill: FillTypeJIT = None,
170169
center: Optional[List[float]] = None,
171170
) -> BoundingBox:
172171
output = self._F.affine_bounding_box(
@@ -185,7 +184,7 @@ def perspective(
185184
self,
186185
perspective_coeffs: List[float],
187186
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
188-
fill: FillType = None,
187+
fill: FillTypeJIT = None,
189188
) -> BoundingBox:
190189
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
191190
return BoundingBox.new_like(self, output, dtype=output.dtype)
@@ -194,7 +193,7 @@ def elastic(
194193
self,
195194
displacement: torch.Tensor,
196195
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
197-
fill: FillType = None,
196+
fill: FillTypeJIT = None,
198197
) -> BoundingBox:
199198
output = self._F.elastic_bounding_box(self, self.format, displacement)
200199
return BoundingBox.new_like(self, output, dtype=output.dtype)

torchvision/prototype/features/_feature.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from types import ModuleType
44
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
55

6+
import PIL.Image
67
import torch
78
from torch._C import _TensorBase, DisableTorchFunction
89
from torchvision.transforms import InterpolationMode
910

10-
from ._utils import FillType
11-
12-
1311
F = TypeVar("F", bound="_Feature")
12+
FillType = Union[int, float, Sequence[int], Sequence[float], None]
13+
FillTypeJIT = Union[int, float, List[float], None]
1414

1515

1616
def is_simple_tensor(inpt: Any) -> bool:
@@ -152,7 +152,7 @@ def resized_crop(
152152
def pad(
153153
self,
154154
padding: Union[int, List[int]],
155-
fill: FillType = None,
155+
fill: FillTypeJIT = None,
156156
padding_mode: str = "constant",
157157
) -> _Feature:
158158
return self
@@ -162,7 +162,7 @@ def rotate(
162162
angle: float,
163163
interpolation: InterpolationMode = InterpolationMode.NEAREST,
164164
expand: bool = False,
165-
fill: FillType = None,
165+
fill: FillTypeJIT = None,
166166
center: Optional[List[float]] = None,
167167
) -> _Feature:
168168
return self
@@ -174,7 +174,7 @@ def affine(
174174
scale: float,
175175
shear: List[float],
176176
interpolation: InterpolationMode = InterpolationMode.NEAREST,
177-
fill: FillType = None,
177+
fill: FillTypeJIT = None,
178178
center: Optional[List[float]] = None,
179179
) -> _Feature:
180180
return self
@@ -183,15 +183,15 @@ def perspective(
183183
self,
184184
perspective_coeffs: List[float],
185185
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
186-
fill: FillType = None,
186+
fill: FillTypeJIT = None,
187187
) -> _Feature:
188188
return self
189189

190190
def elastic(
191191
self,
192192
displacement: torch.Tensor,
193193
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
194-
fill: FillType = None,
194+
fill: FillTypeJIT = None,
195195
) -> _Feature:
196196
return self
197197

@@ -230,3 +230,7 @@ def invert(self) -> _Feature:
230230

231231
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature:
232232
return self
233+
234+
235+
InputType = Union[torch.Tensor, PIL.Image.Image, _Feature]
236+
InputTypeJIT = torch.Tensor

torchvision/prototype/features/_image.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import warnings
44
from typing import Any, cast, List, Optional, Tuple, Union
55

6+
import PIL.Image
67
import torch
78
from torchvision._utils import StrEnum
89
from torchvision.transforms.functional import InterpolationMode, to_pil_image
910
from torchvision.utils import draw_bounding_boxes, make_grid
1011

1112
from ._bounding_box import BoundingBox
12-
from ._feature import _Feature
13-
from ._utils import FillType
13+
from ._feature import _Feature, FillTypeJIT
1414

1515

1616
class ColorSpace(StrEnum):
@@ -177,7 +177,7 @@ def resized_crop(
177177
def pad(
178178
self,
179179
padding: Union[int, List[int]],
180-
fill: FillType = None,
180+
fill: FillTypeJIT = None,
181181
padding_mode: str = "constant",
182182
) -> Image:
183183
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
@@ -188,7 +188,7 @@ def rotate(
188188
angle: float,
189189
interpolation: InterpolationMode = InterpolationMode.NEAREST,
190190
expand: bool = False,
191-
fill: FillType = None,
191+
fill: FillTypeJIT = None,
192192
center: Optional[List[float]] = None,
193193
) -> Image:
194194
output = self._F._geometry.rotate_image_tensor(
@@ -203,7 +203,7 @@ def affine(
203203
scale: float,
204204
shear: List[float],
205205
interpolation: InterpolationMode = InterpolationMode.NEAREST,
206-
fill: FillType = None,
206+
fill: FillTypeJIT = None,
207207
center: Optional[List[float]] = None,
208208
) -> Image:
209209
output = self._F._geometry.affine_image_tensor(
@@ -222,7 +222,7 @@ def perspective(
222222
self,
223223
perspective_coeffs: List[float],
224224
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
225-
fill: FillType = None,
225+
fill: FillTypeJIT = None,
226226
) -> Image:
227227
output = self._F._geometry.perspective_image_tensor(
228228
self, perspective_coeffs, interpolation=interpolation, fill=fill
@@ -233,7 +233,7 @@ def elastic(
233233
self,
234234
displacement: torch.Tensor,
235235
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
236-
fill: FillType = None,
236+
fill: FillTypeJIT = None,
237237
) -> Image:
238238
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
239239
return Image.new_like(self, output)
@@ -285,3 +285,11 @@ def invert(self) -> Image:
285285
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
286286
output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
287287
return Image.new_like(self, output)
288+
289+
290+
ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
291+
ImageTypeJIT = torch.Tensor
292+
LegacyImageType = Union[torch.Tensor, PIL.Image.Image]
293+
LegacyImageTypeJIT = torch.Tensor
294+
TensorImageType = Union[torch.Tensor, Image]
295+
TensorImageTypeJIT = torch.Tensor

torchvision/prototype/features/_mask.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import torch
66
from torchvision.transforms import InterpolationMode
77

8-
from ._feature import _Feature
9-
from ._utils import FillType
8+
from ._feature import _Feature, FillTypeJIT
109

1110

1211
class Mask(_Feature):
@@ -52,7 +51,7 @@ def resized_crop(
5251
def pad(
5352
self,
5453
padding: Union[int, List[int]],
55-
fill: FillType = None,
54+
fill: FillTypeJIT = None,
5655
padding_mode: str = "constant",
5756
) -> Mask:
5857
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
@@ -63,7 +62,7 @@ def rotate(
6362
angle: float,
6463
interpolation: InterpolationMode = InterpolationMode.NEAREST,
6564
expand: bool = False,
66-
fill: FillType = None,
65+
fill: FillTypeJIT = None,
6766
center: Optional[List[float]] = None,
6867
) -> Mask:
6968
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
@@ -76,7 +75,7 @@ def affine(
7675
scale: float,
7776
shear: List[float],
7877
interpolation: InterpolationMode = InterpolationMode.NEAREST,
79-
fill: FillType = None,
78+
fill: FillTypeJIT = None,
8079
center: Optional[List[float]] = None,
8180
) -> Mask:
8281
output = self._F.affine_mask(
@@ -94,7 +93,7 @@ def perspective(
9493
self,
9594
perspective_coeffs: List[float],
9695
interpolation: InterpolationMode = InterpolationMode.NEAREST,
97-
fill: FillType = None,
96+
fill: FillTypeJIT = None,
9897
) -> Mask:
9998
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
10099
return Mask.new_like(self, output)
@@ -103,7 +102,7 @@ def elastic(
103102
self,
104103
displacement: torch.Tensor,
105104
interpolation: InterpolationMode = InterpolationMode.NEAREST,
106-
fill: FillType = None,
105+
fill: FillTypeJIT = None,
107106
) -> Mask:
108107
output = self._F.elastic_mask(self, displacement, fill=fill)
109108
return Mask.new_like(self, output, dtype=output.dtype)

torchvision/prototype/features/_utils.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

torchvision/prototype/transforms/_augment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from torchvision.prototype import features
1111
from torchvision.prototype.transforms import functional as F, InterpolationMode
1212

13+
from ..features._image import ImageType, TensorImageType
14+
1315
from ._transform import _RandomApplyTransform
14-
from ._utils import has_any, ImageType, query_chw, TensorImageType
16+
from ._utils import has_any, query_chw
1517

1618

1719
class RandomErasing(_RandomApplyTransform):

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
1010
from torchvision.prototype.transforms.functional._meta import get_chw
1111

12-
from ._utils import _isinstance, _setup_fill_arg, FillType, ImageType
12+
from ..features._feature import FillType
13+
from ..features._image import ImageType
14+
15+
from ._utils import _isinstance, _setup_fill_arg
1316

1417
K = TypeVar("K")
1518
V = TypeVar("V")

torchvision/prototype/transforms/_color.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from torchvision.prototype import features
77
from torchvision.prototype.transforms import functional as F, Transform
88

9+
from ..features._image import ImageType
10+
911
from ._transform import _RandomApplyTransform
10-
from ._utils import ImageType, query_chw
12+
from ._utils import query_chw
1113

1214

1315
class ColorJitter(Transform):

torchvision/prototype/transforms/_deprecated.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from torchvision.transforms import functional as _F
1111
from typing_extensions import Literal
1212

13+
from ..features._image import ImageType
14+
1315
from ._transform import _RandomApplyTransform
14-
from ._utils import ImageType, query_chw
16+
from ._utils import query_chw
1517

1618

1719
class ToTensor(Transform):

torchvision/prototype/transforms/_geometry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
from typing_extensions import Literal
1414

15+
from ..features._feature import FillType
16+
from ..features._image import ImageType
17+
1518
from ._transform import _RandomApplyTransform
1619
from ._utils import (
1720
_check_padding_arg,
@@ -20,10 +23,8 @@
2023
_setup_angle,
2124
_setup_fill_arg,
2225
_setup_size,
23-
FillType,
2426
has_all,
2527
has_any,
26-
ImageType,
2728
query_bounding_box,
2829
query_chw,
2930
)

0 commit comments

Comments
 (0)