Skip to content

Commit f7513a4

Browse files
committed
Refactoring
1 parent c8f7b14 commit f7513a4

File tree

9 files changed

+50
-42
lines changed

9 files changed

+50
-42
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype import features
88
from torchvision.prototype.transforms import Transform, functional as F
99

10-
from ._utils import query_image
10+
from ._utils import query_image, get_image_dims
1111

1212

1313
class RandomErasing(Transform):
@@ -41,7 +41,7 @@ def __init__(
4141

4242
def _get_params(self, sample: Any) -> Dict[str, Any]:
4343
image = query_image(sample)
44-
img_c, img_h, img_w = F.get_image_dims(image)
44+
img_c, img_h, img_w = get_image_dims(image)
4545

4646
if isinstance(self.value, (int, float)):
4747
value = [self.value]
@@ -137,7 +137,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
137137
lam = float(self._dist.sample(()))
138138

139139
image = query_image(sample)
140-
_, H, W = F.get_image_dims(image)
140+
_, H, W = get_image_dims(image)
141141

142142
r_x = torch.randint(W, ())
143143
r_y = torch.randint(H, ())

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
88
from torchvision.prototype.utils._internal import apply_recursively
99

10-
from ._utils import query_image
10+
from ._utils import query_image, get_image_dims
1111

1212
K = TypeVar("K")
1313
V = TypeVar("V")
@@ -47,7 +47,7 @@ def dispatch(
4747
return input
4848

4949
image = query_image(sample)
50-
num_channels, _, _ = F.get_image_dims(image)
50+
num_channels, *_ = get_image_dims(image)
5151

5252
fill = self.fill
5353
if isinstance(fill, (int, float)):
@@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
278278
sample = inputs if len(inputs) > 1 else inputs[0]
279279

280280
image = query_image(sample)
281-
_, height, width = F.get_image_dims(image)
281+
_, height, width = get_image_dims(image)
282282

283283
policy = self._policies[int(torch.randint(len(self._policies), ()))]
284284

@@ -334,7 +334,7 @@ def forward(self, *inputs: Any) -> Any:
334334
sample = inputs if len(inputs) > 1 else inputs[0]
335335

336336
image = query_image(sample)
337-
_, height, width = F.get_image_dims(image)
337+
_, height, width = get_image_dims(image)
338338

339339
for _ in range(self.num_ops):
340340
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
@@ -383,7 +383,7 @@ def forward(self, *inputs: Any) -> Any:
383383
sample = inputs if len(inputs) > 1 else inputs[0]
384384

385385
image = query_image(sample)
386-
_, height, width = F.get_image_dims(image)
386+
_, height, width = get_image_dims(image)
387387

388388
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
389389

torchvision/prototype/transforms/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
99
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
1010

11-
from ._utils import query_image
11+
from ._utils import query_image, get_image_dims
1212

1313

1414
class HorizontalFlip(Transform):
@@ -109,7 +109,7 @@ def __init__(
109109

110110
def _get_params(self, sample: Any) -> Dict[str, Any]:
111111
image = query_image(sample)
112-
_, height, width = F.get_image_dims(image)
112+
_, height, width = get_image_dims(image)
113113
area = height * width
114114

115115
log_ratio = torch.log(torch.tensor(self.ratio))

torchvision/prototype/transforms/_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Any, Optional, Union
1+
from typing import Any, Optional, Tuple, Union
22

33
import PIL.Image
44
import torch
55
from torchvision.prototype import features
66
from torchvision.prototype.utils._internal import query_recursively
7+
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
78

89

910
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
@@ -17,3 +18,16 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima
1718
return next(query_recursively(fn, sample))
1819
except StopIteration:
1920
raise TypeError("No image was found in the sample")
21+
22+
23+
def get_image_dims(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
24+
if isinstance(image, features.Image):
25+
channels = image.num_channels
26+
height, width = image.image_size
27+
elif isinstance(image, torch.Tensor):
28+
channels, height, width = _FT.get_image_dims(image)
29+
elif isinstance(image, PIL.Image.Image):
30+
channels, height, width = _FP.get_image_dims(image)
31+
else:
32+
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
33+
return channels, height, width

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
2-
from ._utils import get_image_dims # usort: skip
32
from ._meta_conversion import (
43
convert_bounding_box_format,
54
convert_image_color_space_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from torchvision.prototype import features
77
from torchvision.prototype.transforms import InterpolationMode
8-
from torchvision.prototype.transforms.functional import get_image_dims
98
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
109
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix
1110

@@ -40,7 +39,7 @@ def resize_image_tensor(
4039
antialias: Optional[bool] = None,
4140
) -> torch.Tensor:
4241
new_height, new_width = size
43-
num_channels, old_height, old_width = image.shape[-3:]
42+
num_channels, old_height, old_width = _FT.get_image_dims(image)
4443
batch_shape = image.shape[:-3]
4544
return _FT.resize(
4645
image.reshape((-1, num_channels, old_height, old_width)),
@@ -142,7 +141,7 @@ def affine_image_tensor(
142141

143142
center_f = [0.0, 0.0]
144143
if center is not None:
145-
_, height, width = get_image_dims(img)
144+
_, height, width = _FT.get_image_dims(img)
146145
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
147146
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
148147

@@ -168,7 +167,7 @@ def affine_image_pil(
168167
# it is visually better to estimate the center without 0.5 offset
169168
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
170169
if center is None:
171-
_, height, width = get_image_dims(img)
170+
_, height, width = _FP.get_image_dims(img)
172171
center = [width * 0.5, height * 0.5]
173172
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
174173

@@ -185,7 +184,7 @@ def rotate_image_tensor(
185184
) -> torch.Tensor:
186185
center_f = [0.0, 0.0]
187186
if center is not None:
188-
_, height, width = get_image_dims(img)
187+
_, height, width = _FT.get_image_dims(img)
189188
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
190189
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
191190

@@ -261,13 +260,13 @@ def _center_crop_compute_crop_anchor(
261260

262261
def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
263262
crop_height, crop_width = _center_crop_parse_output_size(output_size)
264-
_, image_height, image_width = get_image_dims(img)
263+
_, image_height, image_width = _FT.get_image_dims(img)
265264

266265
if crop_height > image_height or crop_width > image_width:
267266
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
268267
img = pad_image_tensor(img, padding_ltrb, fill=0)
269268

270-
_, image_height, image_width = get_image_dims(img)
269+
_, image_height, image_width = _FT.get_image_dims(img)
271270
if crop_width == image_width and crop_height == image_height:
272271
return img
273272

@@ -277,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch
277276

278277
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
279278
crop_height, crop_width = _center_crop_parse_output_size(output_size)
280-
_, image_height, image_width = get_image_dims(img)
279+
_, image_height, image_width = _FP.get_image_dims(img)
281280

282281
if crop_height > image_height or crop_width > image_width:
283282
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
284283
img = pad_image_pil(img, padding_ltrb, fill=0)
285284

286-
_, image_height, image_width = get_image_dims(img)
285+
_, image_height, image_width = _FP.get_image_dims(img)
287286
if crop_width == image_width and crop_height == image_height:
288287
return img
289288

torchvision/prototype/transforms/functional/_utils.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

torchvision/transforms/functional_pil.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ def _is_pil_image(img: Any) -> bool:
2020
return isinstance(img, Image.Image)
2121

2222

23+
@torch.jit.unused
24+
def get_image_dims(img: Any) -> List[int]:
25+
if _is_pil_image(img):
26+
channels = len(img.getbands())
27+
width, height = img.size
28+
return [channels, height, width]
29+
raise TypeError(f"Unexpected type {type(img)}")
30+
31+
2332
@torch.jit.unused
2433
def get_image_size(img: Any) -> List[int]:
2534
if _is_pil_image(img):
@@ -30,7 +39,7 @@ def get_image_size(img: Any) -> List[int]:
3039
@torch.jit.unused
3140
def get_image_num_channels(img: Any) -> int:
3241
if _is_pil_image(img):
33-
return 1 if img.mode == "L" else 3
42+
return len(img.getbands())
3443
raise TypeError(f"Unexpected type {type(img)}")
3544

3645

torchvision/transforms/functional_tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ def _assert_threshold(img: Tensor, threshold: float) -> None:
2121
raise TypeError("Threshold should be less than bound of img.")
2222

2323

24+
def get_image_dims(img: Tensor) -> List[int]:
25+
_assert_image_tensor(img)
26+
channels = 1 if img.ndim == 2 else img.shape[-3]
27+
height, width = img.shape[-2:]
28+
return [channels, height, width]
29+
30+
2431
def get_image_size(img: Tensor) -> List[int]:
2532
# Returns (w, h) of tensor image
2633
_assert_image_tensor(img)

0 commit comments

Comments
 (0)