Skip to content

Commit 04a3eab

Browse files
committed
More refactoring
1 parent f7513a4 commit 04a3eab

File tree

8 files changed

+27
-24
lines changed

8 files changed

+27
-24
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype import features
88
from torchvision.prototype.transforms import Transform, functional as F
99

10-
from ._utils import query_image, get_image_dims
10+
from ._utils import query_image, get_image_dimensions
1111

1212

1313
class RandomErasing(Transform):
@@ -41,7 +41,7 @@ def __init__(
4141

4242
def _get_params(self, sample: Any) -> Dict[str, Any]:
4343
image = query_image(sample)
44-
img_c, img_h, img_w = get_image_dims(image)
44+
img_c, img_h, img_w = get_image_dimensions(image)
4545

4646
if isinstance(self.value, (int, float)):
4747
value = [self.value]
@@ -137,7 +137,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
137137
lam = float(self._dist.sample(()))
138138

139139
image = query_image(sample)
140-
_, H, W = get_image_dims(image)
140+
_, H, W = get_image_dimensions(image)
141141

142142
r_x = torch.randint(W, ())
143143
r_y = torch.randint(H, ())

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
88
from torchvision.prototype.utils._internal import apply_recursively
99

10-
from ._utils import query_image, get_image_dims
10+
from ._utils import query_image, get_image_dimensions
1111

1212
K = TypeVar("K")
1313
V = TypeVar("V")
@@ -47,7 +47,7 @@ def dispatch(
4747
return input
4848

4949
image = query_image(sample)
50-
num_channels, *_ = get_image_dims(image)
50+
num_channels, *_ = get_image_dimensions(image)
5151

5252
fill = self.fill
5353
if isinstance(fill, (int, float)):
@@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
278278
sample = inputs if len(inputs) > 1 else inputs[0]
279279

280280
image = query_image(sample)
281-
_, height, width = get_image_dims(image)
281+
_, height, width = get_image_dimensions(image)
282282

283283
policy = self._policies[int(torch.randint(len(self._policies), ()))]
284284

@@ -334,7 +334,7 @@ def forward(self, *inputs: Any) -> Any:
334334
sample = inputs if len(inputs) > 1 else inputs[0]
335335

336336
image = query_image(sample)
337-
_, height, width = get_image_dims(image)
337+
_, height, width = get_image_dimensions(image)
338338

339339
for _ in range(self.num_ops):
340340
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
@@ -383,7 +383,7 @@ def forward(self, *inputs: Any) -> Any:
383383
sample = inputs if len(inputs) > 1 else inputs[0]
384384

385385
image = query_image(sample)
386-
_, height, width = get_image_dims(image)
386+
_, height, width = get_image_dimensions(image)
387387

388388
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
389389

torchvision/prototype/transforms/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
99
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
1010

11-
from ._utils import query_image, get_image_dims
11+
from ._utils import query_image, get_image_dimensions
1212

1313

1414
class HorizontalFlip(Transform):
@@ -109,7 +109,7 @@ def __init__(
109109

110110
def _get_params(self, sample: Any) -> Dict[str, Any]:
111111
image = query_image(sample)
112-
_, height, width = get_image_dims(image)
112+
_, height, width = get_image_dimensions(image)
113113
area = height * width
114114

115115
log_ratio = torch.log(torch.tensor(self.ratio))

torchvision/prototype/transforms/_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima
2020
raise TypeError("No image was found in the sample")
2121

2222

23-
def get_image_dims(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
23+
def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
2424
if isinstance(image, features.Image):
2525
channels = image.num_channels
2626
height, width = image.image_size
2727
elif isinstance(image, torch.Tensor):
28-
channels, height, width = _FT.get_image_dims(image)
28+
channels, height, width = _FT.get_dimensions(image)
2929
elif isinstance(image, PIL.Image.Image):
30-
channels, height, width = _FP.get_image_dims(image)
30+
channels, height, width = _FP.get_dimensions(image)
3131
else:
3232
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
3333
return channels, height, width

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def resize_image_tensor(
3939
antialias: Optional[bool] = None,
4040
) -> torch.Tensor:
4141
new_height, new_width = size
42-
num_channels, old_height, old_width = _FT.get_image_dims(image)
42+
num_channels, old_height, old_width = _FT.get_dimensions(image)
4343
batch_shape = image.shape[:-3]
4444
return _FT.resize(
4545
image.reshape((-1, num_channels, old_height, old_width)),
@@ -141,7 +141,7 @@ def affine_image_tensor(
141141

142142
center_f = [0.0, 0.0]
143143
if center is not None:
144-
_, height, width = _FT.get_image_dims(img)
144+
_, height, width = _FT.get_dimensions(img)
145145
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
146146
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
147147

@@ -167,7 +167,7 @@ def affine_image_pil(
167167
# it is visually better to estimate the center without 0.5 offset
168168
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
169169
if center is None:
170-
_, height, width = _FP.get_image_dims(img)
170+
_, height, width = _FP.get_dimensions(img)
171171
center = [width * 0.5, height * 0.5]
172172
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
173173

@@ -184,7 +184,7 @@ def rotate_image_tensor(
184184
) -> torch.Tensor:
185185
center_f = [0.0, 0.0]
186186
if center is not None:
187-
_, height, width = _FT.get_image_dims(img)
187+
_, height, width = _FT.get_dimensions(img)
188188
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
189189
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
190190

@@ -260,13 +260,13 @@ def _center_crop_compute_crop_anchor(
260260

261261
def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
262262
crop_height, crop_width = _center_crop_parse_output_size(output_size)
263-
_, image_height, image_width = _FT.get_image_dims(img)
263+
_, image_height, image_width = _FT.get_dimensions(img)
264264

265265
if crop_height > image_height or crop_width > image_width:
266266
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
267267
img = pad_image_tensor(img, padding_ltrb, fill=0)
268268

269-
_, image_height, image_width = _FT.get_image_dims(img)
269+
_, image_height, image_width = _FT.get_dimensions(img)
270270
if crop_width == image_width and crop_height == image_height:
271271
return img
272272

@@ -276,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch
276276

277277
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
278278
crop_height, crop_width = _center_crop_parse_output_size(output_size)
279-
_, image_height, image_width = _FP.get_image_dims(img)
279+
_, image_height, image_width = _FP.get_dimensions(img)
280280

281281
if crop_height > image_height or crop_width > image_width:
282282
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
283283
img = pad_image_pil(img, padding_ltrb, fill=0)
284284

285-
_, image_height, image_width = _FP.get_image_dims(img)
285+
_, image_height, image_width = _FP.get_dimensions(img)
286286
if crop_width == image_width and crop_height == image_height:
287287
return img
288288

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
import PIL.Image
44
import torch
5-
from torchvision.transforms import functional_tensor as _FT
5+
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
66
from torchvision.transforms.functional import to_tensor, to_pil_image
77

88

9+
get_dimensions_image_tensor = _FT.get_dimensions
10+
get_dimensions_image_pil = _FP.get_dimensions
11+
912
normalize_image_tensor = _FT.normalize
1013

1114

torchvision/transforms/functional_pil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _is_pil_image(img: Any) -> bool:
2121

2222

2323
@torch.jit.unused
24-
def get_image_dims(img: Any) -> List[int]:
24+
def get_dimensions(img: Any) -> List[int]:
2525
if _is_pil_image(img):
2626
channels = len(img.getbands())
2727
width, height = img.size

torchvision/transforms/functional_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _assert_threshold(img: Tensor, threshold: float) -> None:
2121
raise TypeError("Threshold should be less than bound of img.")
2222

2323

24-
def get_image_dims(img: Tensor) -> List[int]:
24+
def get_dimensions(img: Tensor) -> List[int]:
2525
_assert_image_tensor(img)
2626
channels = 1 if img.ndim == 2 else img.shape[-3]
2727
height, width = img.shape[-2:]

0 commit comments

Comments
 (0)