diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index cae53728b96..5909b68966b 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -270,6 +270,7 @@ you can use a functional transform to build transform classes with custom behavi erase five_crop gaussian_blur + get_dimensions get_image_num_channels get_image_size hflip diff --git a/references/classification/transforms.py b/references/classification/transforms.py index 892b4e7e6c0..e72cd67fbfd 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -141,7 +141,7 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: # Implemented as on cutmix paper, page 12 (with minor corrections on typos). lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) - W, H = F.get_image_size(batch) + _, H, W = F.get_dimensions(batch) r_x = torch.randint(W, (1,)) r_y = torch.randint(H, (1,)) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 342d491ceb1..3a9bad78c25 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -34,7 +34,7 @@ def forward( if torch.rand(1) < self.p: image = F.hflip(image) if target is not None: - width, _ = F.get_image_size(image) + _, _, width = F.get_dimensions(image) target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]] if "masks" in target: target["masks"] = target["masks"].flip(-1) @@ -107,7 +107,7 @@ def forward( elif image.ndimension() == 2: image = image.unsqueeze(0) - orig_w, orig_h = F.get_image_size(image) + _, orig_h, orig_w = F.get_dimensions(image) while True: # sample an option @@ -192,7 +192,7 @@ def forward( if torch.rand(1) >= self.p: return image, target - orig_w, orig_h = F.get_image_size(image) + _, orig_h, orig_w = F.get_dimensions(image) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) @@ -270,7 +270,7 @@ def forward( image = self._contrast(image) if r[6] < self.p: - channels = F.get_image_num_channels(image) + channels, _, _ = F.get_dimensions(image) permutation = torch.randperm(channels) is_pil = F._is_pil_image(image) @@ -317,7 +317,7 @@ def forward( elif image.ndimension() == 2: image = image.unsqueeze(0) - orig_width, orig_height = F.get_image_size(image) + _, orig_height, orig_width = F.get_dimensions(image) r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) new_width = int(self.target_size[1] * r) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 0ac559565b7..3bdf0cfe34e 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -29,7 +29,7 @@ @pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels]) +@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions]) def test_image_sizes(device, fn): script_F = torch.jit.script(fn) @@ -1020,7 +1020,9 @@ def test_resized_crop(device, mode): @pytest.mark.parametrize( "func, args", [ + (F_t.get_dimensions, ()), (F_t.get_image_size, ()), + (F_t.get_image_num_channels, ()), (F_t.vflip, ()), (F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)), diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 73235720d58..73d45097a93 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -8,7 +8,7 @@ from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop -from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace +from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index ce198d39b33..5862c6a06dc 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F -from ._utils import query_image +from ._utils import query_image, get_image_dimensions class RandomErasing(Transform): @@ -41,8 +41,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - img_c = F.get_image_num_channels(image) - img_w, img_h = F.get_image_size(image) + img_c, img_h, img_w = get_image_dimensions(image) if isinstance(self.value, (int, float)): value = [self.value] @@ -138,7 +137,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) image = query_image(sample) - W, H = F.get_image_size(image) + _, H, W = get_image_dimensions(image) r_x = torch.randint(W, ()) r_y = torch.randint(H, ()) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 7eae25a681e..78cbf958ccd 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F from torchvision.prototype.utils._internal import apply_recursively -from ._utils import query_image +from ._utils import query_image, get_image_dimensions K = TypeVar("K") V = TypeVar("V") @@ -47,7 +47,7 @@ def dispatch( return input image = query_image(sample) - num_channels = F.get_image_num_channels(image) + num_channels, *_ = get_image_dimensions(image) fill = self.fill if isinstance(fill, (int, float)): @@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), @@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - image_size = F.get_image_size(image) + _, height, width = get_image_dimensions(image) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -288,7 +288,7 @@ def forward(self, *inputs: Any) -> Any: magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] - magnitudes = magnitudes_fn(10, image_size) + magnitudes = magnitudes_fn(10, (height, width)) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) if signed and torch.rand(()) <= 0.5: @@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase): "Identity": (lambda num_bins, image_size: None, False), "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), @@ -334,12 +334,12 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - image_size = F.get_image_size(image) + _, height, width = get_image_dimensions(image) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) + magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: @@ -383,11 +383,11 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) - image_size = F.get_image_size(image) + _, height, width = get_image_dimensions(image) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) + magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width)) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) if signed and torch.rand(()) <= 0.5: diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 4c9d9192ac8..c58f26a0e06 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int -from ._utils import query_image +from ._utils import query_image, get_image_dimensions class HorizontalFlip(Transform): @@ -109,7 +109,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - width, height = F.get_image_size(image) + _, height, width = get_image_dimensions(image) area = height * width log_ratio = torch.log(torch.tensor(self.ratio)) diff --git a/torchvision/prototype/transforms/_meta_conversion.py b/torchvision/prototype/transforms/_meta.py similarity index 100% rename from torchvision/prototype/transforms/_meta_conversion.py rename to torchvision/prototype/transforms/_meta.py diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 24d794a2cb4..d8677d451c8 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,10 +1,12 @@ -from typing import Any, Optional, Union +from typing import Any, Optional, Tuple, Union import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.utils._internal import query_recursively +from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil + def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: @@ -17,3 +19,16 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima return next(query_recursively(fn, sample)) except StopIteration: raise TypeError("No image was found in the sample") + + +def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: + if isinstance(image, features.Image): + channels = image.num_channels + height, width = image.image_size + elif isinstance(image, torch.Tensor): + channels, height, width = get_dimensions_image_tensor(image) + elif isinstance(image, PIL.Image.Image): + channels, height, width = get_dimensions_image_pil(image) + else: + raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}") + return channels, height, width diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index c487aba7fa2..e3fe60a7919 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,6 +1,5 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import get_image_size, get_image_num_channels # usort: skip -from ._meta_conversion import ( +from ._meta import ( convert_bounding_box_format, convert_image_color_space_tensor, convert_image_color_space_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 76564fdd54d..080fe5da891 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -5,11 +5,10 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import InterpolationMode -from torchvision.prototype.transforms.functional import get_image_size from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix -from ._meta_conversion import convert_bounding_box_format +from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil horizontal_flip_image_tensor = _FT.hflip @@ -40,8 +39,7 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: new_height, new_width = size - old_width, old_height = _FT.get_image_size(image) - num_channels = _FT.get_image_num_channels(image) + num_channels, old_height, old_width = get_dimensions_image_tensor(image) batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), @@ -143,9 +141,9 @@ def affine_image_tensor( center_f = [0.0, 0.0] if center is not None: - width, height = get_image_size(img) + _, height, width = get_dimensions_image_tensor(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) @@ -169,7 +167,7 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - width, height = get_image_size(img) + _, height, width = get_dimensions_image_pil(img) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) @@ -186,9 +184,9 @@ def rotate_image_tensor( ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: - width, height = get_image_size(img) + _, height, width = get_dimensions_image_tensor(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. @@ -262,13 +260,13 @@ def _center_crop_compute_crop_anchor( def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions_image_tensor(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_tensor(img, padding_ltrb, fill=0) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions_image_tensor(img) if crop_width == image_width and crop_height == image_height: return img @@ -278,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions_image_pil(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_pil(img, padding_ltrb, fill=0) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions_image_pil(img) if crop_width == image_width and crop_height == image_height: return img diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/functional/_meta.py similarity index 96% rename from torchvision/prototype/transforms/functional/_meta_conversion.py rename to torchvision/prototype/transforms/functional/_meta.py index b260beaa361..6ecb5aff257 100644 --- a/torchvision/prototype/transforms/functional/_meta_conversion.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -4,6 +4,10 @@ from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP +get_dimensions_image_tensor = _FT.get_dimensions +get_dimensions_image_pil = _FP.get_dimensions + + def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: xyxy = xywh.clone() xyxy[..., 2:] += xyxy[..., :2] diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py deleted file mode 100644 index 07235d63716..00000000000 --- a/torchvision/prototype/transforms/functional/_utils.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Tuple, Union, cast - -import PIL.Image -import torch -from torchvision.prototype import features -from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP - - -def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: - if isinstance(image, features.Image): - height, width = image.image_size - return width, height - elif isinstance(image, torch.Tensor): - return cast(Tuple[int, int], tuple(_FT.get_image_size(image))) - if isinstance(image, PIL.Image.Image): - return cast(Tuple[int, int], tuple(_FP.get_image_size(image))) - else: - raise TypeError(f"unable to get image size from object of type {type(image).__name__}") - - -def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: - if isinstance(image, features.Image): - return image.num_channels - elif isinstance(image, torch.Tensor): - return _FT.get_image_num_channels(image) - if isinstance(image, PIL.Image.Image): - return cast(int, _FP.get_image_num_channels(image)) - else: - raise TypeError(f"unable to get num channels from object of type {type(image).__name__}") diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index d820e5126a1..357e5bf250e 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -220,13 +220,13 @@ def _get_policies( else: raise ValueError(f"The provided policy {policy} is not recognized.") - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: return { # op_name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), "Color": (torch.linspace(0.0, 0.9, num_bins), True), @@ -260,15 +260,16 @@ def forward(self, img: Tensor) -> Tensor: PIL Image or Tensor: AutoAugmented image. """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels elif fill is not None: fill = [float(f) for f in fill] transform_id, probs, signs = self.get_params(len(self.policies)) - op_meta = self._augmentation_space(10, F.get_image_size(img)) + op_meta = self._augmentation_space(10, (height, width)) for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): if probs[i] <= p: magnitudes, signed = op_meta[op_name] @@ -317,14 +318,14 @@ def __init__( self.interpolation = interpolation self.fill = fill - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: return { # op_name: (magnitudes, signed) "Identity": (torch.tensor(0.0), False), "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), "Color": (torch.linspace(0.0, 0.9, num_bins), True), @@ -344,13 +345,14 @@ def forward(self, img: Tensor) -> Tensor: PIL Image or Tensor: Transformed image. """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels elif fill is not None: fill = [float(f) for f in fill] - op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) + op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width)) for _ in range(self.num_ops): op_index = int(torch.randint(len(op_meta), (1,)).item()) op_name = list(op_meta.keys())[op_index] @@ -429,9 +431,10 @@ def forward(self, img: Tensor) -> Tensor: PIL Image or Tensor: Transformed image. """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels elif fill is not None: fill = [float(f) for f in fill] @@ -503,13 +506,13 @@ def __init__( self.interpolation = interpolation self.fill = fill - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]: s = { # op_name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), - "TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), + "TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), @@ -547,16 +550,17 @@ def forward(self, orig_img: Tensor) -> Tensor: PIL Image or Tensor: Transformed image. """ fill = self.fill + channels, height, width = F.get_dimensions(orig_img) if isinstance(orig_img, Tensor): img = orig_img if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels elif fill is not None: fill = [float(f) for f in fill] else: img = self._pil_to_tensor(orig_img) - op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) + op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width)) orig_dims = list(img.shape) batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 41c6ceada03..b2fc3f44f55 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -59,6 +59,23 @@ def _interpolation_modes_from_int(i: int) -> InterpolationMode: _is_pil_image = F_pil._is_pil_image +def get_dimensions(img: Tensor) -> List[int]: + """Returns the dimensions of an image as [channels, height, width]. + + Args: + img (PIL Image or Tensor): The image to be checked. + + Returns: + List[int]: The image dimensions. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(get_dimensions) + if isinstance(img, torch.Tensor): + return F_t.get_dimensions(img) + + return F_pil.get_dimensions(img) + + def get_image_size(img: Tensor) -> List[int]: """Returns the size of an image as [width, height]. @@ -512,7 +529,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: output_size = (output_size[0], output_size[0]) - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions(img) crop_height, crop_width = output_size if crop_width > image_width or crop_height > image_height: @@ -523,7 +540,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, ] img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0 - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions(img) if crop_width == image_width and crop_height == image_height: return img @@ -721,7 +738,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten if len(size) != 2: raise ValueError("Please provide only two dimensions (h, w) for size.") - image_width, image_height = get_image_size(img) + _, image_height, image_width = get_dimensions(img) crop_height, crop_width = size if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}" @@ -1047,9 +1064,9 @@ def rotate( center_f = [0.0, 0.0] if center is not None: - img_size = get_image_size(img) + _, height, width = get_dimensions(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. @@ -1167,22 +1184,22 @@ def affine( if center is not None and not isinstance(center, (list, tuple)): raise TypeError("Argument center should be a sequence") - img_size = get_image_size(img) + _, height, width = get_dimensions(img) if not isinstance(img, torch.Tensor): - # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) + # center = (width * 0.5 + 0.5, height * 0.5 + 0.5) # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - center = [img_size[0] * 0.5, img_size[1] * 0.5] + center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) pil_interpolation = pil_modes_mapping[interpolation] return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) center_f = [0.0, 0.0] if center is not None: - img_size = get_image_size(img) + _, height, width = get_dimensions(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 01c321dabfa..5e383ff3286 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -20,6 +20,15 @@ def _is_pil_image(img: Any) -> bool: return isinstance(img, Image.Image) +@torch.jit.unused +def get_dimensions(img: Any) -> List[int]: + if _is_pil_image(img): + channels = len(img.getbands()) + width, height = img.size + return [channels, height, width] + raise TypeError(f"Unexpected type {type(img)}") + + @torch.jit.unused def get_image_size(img: Any) -> List[int]: if _is_pil_image(img): @@ -30,7 +39,7 @@ def get_image_size(img: Any) -> List[int]: @torch.jit.unused def get_image_num_channels(img: Any) -> int: if _is_pil_image(img): - return 1 if img.mode == "L" else 3 + return len(img.getbands()) raise TypeError(f"Unexpected type {type(img)}") diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index fae681b3aa9..18b2c721f4e 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -21,6 +21,13 @@ def _assert_threshold(img: Tensor, threshold: float) -> None: raise TypeError("Threshold should be less than bound of img.") +def get_dimensions(img: Tensor) -> List[int]: + _assert_image_tensor(img) + channels = 1 if img.ndim == 2 else img.shape[-3] + height, width = img.shape[-2:] + return [channels, height, width] + + def get_image_size(img: Tensor) -> List[int]: # Returns (w, h) of tensor image _assert_image_tensor(img) @@ -28,6 +35,7 @@ def get_image_size(img: Tensor) -> List[int]: def get_image_num_channels(img: Tensor) -> int: + _assert_image_tensor(img) if img.ndim == 2: return 1 elif img.ndim > 2: @@ -55,7 +63,7 @@ def _max_value(dtype: torch.dtype) -> float: def _assert_channels(img: Tensor, permitted: List[int]) -> None: - c = get_image_num_channels(img) + c = get_dimensions(img)[0] if c not in permitted: raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}") @@ -127,7 +135,7 @@ def hflip(img: Tensor) -> Tensor: def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: _assert_image_tensor(img) - w, h = get_image_size(img) + _, h, w = get_dimensions(img) right = left + width bottom = top + height @@ -175,7 +183,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: _assert_image_tensor(img) _assert_channels(img, [3, 1]) - c = get_image_num_channels(img) + c = get_dimensions(img)[0] dtype = img.dtype if torch.is_floating_point(img) else torch.float32 if c == 3: mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) @@ -195,7 +203,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: _assert_image_tensor(img) _assert_channels(img, [1, 3]) - if get_image_num_channels(img) == 1: # Match PIL behaviour + if get_dimensions(img)[0] == 1: # Match PIL behaviour return img orig_dtype = img.dtype @@ -222,7 +230,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: _assert_channels(img, [1, 3]) - if get_image_num_channels(img) == 1: # Match PIL behaviour + if get_dimensions(img)[0] == 1: # Match PIL behaviour return img return _blend(img, rgb_to_grayscale(img), saturation_factor) @@ -451,7 +459,7 @@ def resize( if antialias and interpolation not in ["bilinear", "bicubic"]: raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") - w, h = get_image_size(img) + _, h, w = get_dimensions(img) if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge short, long = (w, h) if w <= h else (h, w) @@ -518,7 +526,7 @@ def _assert_grid_transform_inputs( warnings.warn("Argument fill should be either int, float, tuple or list") # Check fill - num_channels = get_image_num_channels(img) + num_channels = get_dimensions(img)[0] if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): msg = ( "The number of elements in 'fill' cannot broadcast to match the number of " diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 9fc79c1d8cc..37556cd4984 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -628,7 +628,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - w, h = F.get_image_size(img) + _, h, w = F.get_dimensions(img) th, tw = output_size if h + 1 < th or w + 1 < tw: @@ -663,7 +663,7 @@ def forward(self, img): if self.padding is not None: img = F.pad(img, self.padding, self.fill, self.padding_mode) - width, height = F.get_image_size(img) + _, height, width = F.get_dimensions(img) # pad the width if needed if self.pad_if_needed and width < self.size[1]: padding = [self.size[1] - width, 0] @@ -793,14 +793,14 @@ def forward(self, img): """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels else: fill = [float(f) for f in fill] if torch.rand(1) < self.p: - width, height = F.get_image_size(img) startpoints, endpoints = self.get_params(width, height, self.distortion_scale) return F.perspective(img, startpoints, endpoints, self.interpolation, fill) return img @@ -910,7 +910,7 @@ def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ - width, height = F.get_image_size(img) + _, height, width = F.get_dimensions(img) area = height * width log_ratio = torch.log(torch.tensor(ratio)) @@ -1339,9 +1339,10 @@ def forward(self, img): PIL Image or Tensor: Rotated image. """ fill = self.fill + channels, _, _ = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels else: fill = [float(f) for f in fill] angle = self.get_params(self.degrees) @@ -1519,13 +1520,14 @@ def forward(self, img): PIL Image or Tensor: Affine transformed image. """ fill = self.fill + channels, height, width = F.get_dimensions(img) if isinstance(img, Tensor): if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) + fill = [float(fill)] * channels else: fill = [float(f) for f in fill] - img_size = F.get_image_size(img) + img_size = [width, height] # flip for keeping BC on get_params call ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) @@ -1608,7 +1610,7 @@ def forward(self, img): Returns: PIL Image or Tensor: Randomly grayscaled image. """ - num_output_channels = F.get_image_num_channels(img) + num_output_channels, _, _ = F.get_dimensions(img) if torch.rand(1) < self.p: return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) return img