Skip to content

Replace get_image_size/num_channels with get_dimensions #5487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ you can use a functional transform to build transform classes with custom behavi
erase
five_crop
gaussian_blur
get_dimensions
get_image_num_channels
get_image_size
hflip
Expand Down
2 changes: 1 addition & 1 deletion references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:

# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
W, H = F.get_image_size(batch)
_, H, W = F.get_dimensions(batch)

r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))
Expand Down
10 changes: 5 additions & 5 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
width, _ = F.get_image_size(image)
_, _, width = F.get_dimensions(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
Expand Down Expand Up @@ -107,7 +107,7 @@ def forward(
elif image.ndimension() == 2:
image = image.unsqueeze(0)

orig_w, orig_h = F.get_image_size(image)
_, orig_h, orig_w = F.get_dimensions(image)

while True:
# sample an option
Expand Down Expand Up @@ -192,7 +192,7 @@ def forward(
if torch.rand(1) >= self.p:
return image, target

orig_w, orig_h = F.get_image_size(image)
_, orig_h, orig_w = F.get_dimensions(image)

r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
Expand Down Expand Up @@ -270,7 +270,7 @@ def forward(
image = self._contrast(image)

if r[6] < self.p:
channels = F.get_image_num_channels(image)
channels, _, _ = F.get_dimensions(image)
permutation = torch.randperm(channels)

is_pil = F._is_pil_image(image)
Expand Down Expand Up @@ -317,7 +317,7 @@ def forward(
elif image.ndimension() == 2:
image = image.unsqueeze(0)

orig_width, orig_height = F.get_image_size(image)
_, orig_height, orig_width = F.get_dimensions(image)

r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
new_width = int(self.target_size[1] * r)
Expand Down
4 changes: 3 additions & 1 deletion test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels])
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions])
def test_image_sizes(device, fn):
script_F = torch.jit.script(fn)

Expand Down Expand Up @@ -1020,7 +1020,9 @@ def test_resized_crop(device, mode):
@pytest.mark.parametrize(
"func, args",
[
(F_t.get_dimensions, ()),
(F_t.get_image_size, ()),
(F_t.get_image_num_channels, ()),
(F_t.vflip, ()),
(F_t.hflip, ()),
(F_t.crop, (1, 2, 4, 5)),
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot
7 changes: 3 additions & 4 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F

from ._utils import query_image
from ._utils import query_image, get_image_dimensions


class RandomErasing(Transform):
Expand Down Expand Up @@ -41,8 +41,7 @@ def __init__(

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
img_c = F.get_image_num_channels(image)
img_w, img_h = F.get_image_size(image)
img_c, img_h, img_w = get_image_dimensions(image)

if isinstance(self.value, (int, float)):
value = [self.value]
Expand Down Expand Up @@ -138,7 +137,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(()))

image = query_image(sample)
W, H = F.get_image_size(image)
_, H, W = get_image_dimensions(image)

r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
Expand Down
24 changes: 12 additions & 12 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
from torchvision.prototype.utils._internal import apply_recursively

from ._utils import query_image
from ._utils import query_image, get_image_dimensions

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -47,7 +47,7 @@ def dispatch(
return input

image = query_image(sample)
num_channels = F.get_image_num_channels(image)
num_channels, *_ = get_image_dimensions(image)

fill = self.fill
if isinstance(fill, (int, float)):
Expand Down Expand Up @@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
Expand Down Expand Up @@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

image = query_image(sample)
image_size = F.get_image_size(image)
_, height, width = get_image_dimensions(image)

policy = self._policies[int(torch.randint(len(self._policies), ()))]

Expand All @@ -288,7 +288,7 @@ def forward(self, *inputs: Any) -> Any:

magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]

magnitudes = magnitudes_fn(10, image_size)
magnitudes = magnitudes_fn(10, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
Expand All @@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase):
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
Expand All @@ -334,12 +334,12 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

image = query_image(sample)
image_size = F.get_image_size(image)
_, height, width = get_image_dimensions(image)

for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
Expand Down Expand Up @@ -383,11 +383,11 @@ def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

image = query_image(sample)
image_size = F.get_image_size(image)
_, height, width = get_image_dimensions(image)

transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int

from ._utils import query_image
from ._utils import query_image, get_image_dimensions


class HorizontalFlip(Transform):
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
width, height = F.get_image_size(image)
_, height, width = get_image_dimensions(image)
area = height * width

log_ratio = torch.log(torch.tensor(self.ratio))
Expand Down
17 changes: 16 additions & 1 deletion torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Optional, Union
from typing import Any, Optional, Tuple, Union

import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.utils._internal import query_recursively

from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil


def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]:
Expand All @@ -17,3 +19,16 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima
return next(query_recursively(fn, sample))
except StopIteration:
raise TypeError("No image was found in the sample")


def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif isinstance(image, torch.Tensor):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image)
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width
3 changes: 1 addition & 2 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import get_image_size, get_image_num_channels # usort: skip
from ._meta_conversion import (
from ._meta import (
convert_bounding_box_format,
convert_image_color_space_tensor,
convert_image_color_space_pil,
Expand Down
24 changes: 11 additions & 13 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import InterpolationMode
from torchvision.prototype.transforms.functional import get_image_size
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix

from ._meta_conversion import convert_bounding_box_format
from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil


horizontal_flip_image_tensor = _FT.hflip
Expand Down Expand Up @@ -40,8 +39,7 @@ def resize_image_tensor(
antialias: Optional[bool] = None,
) -> torch.Tensor:
new_height, new_width = size
old_width, old_height = _FT.get_image_size(image)
num_channels = _FT.get_image_num_channels(image)
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
batch_shape = image.shape[:-3]
return _FT.resize(
image.reshape((-1, num_channels, old_height, old_width)),
Expand Down Expand Up @@ -143,9 +141,9 @@ def affine_image_tensor(

center_f = [0.0, 0.0]
if center is not None:
width, height = get_image_size(img)
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]

translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
Expand All @@ -169,7 +167,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
width, height = get_image_size(img)
_, height, width = get_dimensions_image_pil(img)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)

Expand All @@ -186,9 +184,9 @@ def rotate_image_tensor(
) -> torch.Tensor:
center_f = [0.0, 0.0]
if center is not None:
width, height = get_image_size(img)
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]

# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
Expand Down Expand Up @@ -262,13 +260,13 @@ def _center_crop_compute_crop_anchor(

def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions_image_tensor(img)

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_tensor(img, padding_ltrb, fill=0)

image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions_image_tensor(img)
if crop_width == image_width and crop_height == image_height:
return img

Expand All @@ -278,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch

def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions_image_pil(img)

if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_pil(img, padding_ltrb, fill=0)

image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions_image_pil(img)
if crop_width == image_width and crop_height == image_height:
return img

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP


get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions


def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
xyxy = xywh.clone()
xyxy[..., 2:] += xyxy[..., :2]
Expand Down
29 changes: 0 additions & 29 deletions torchvision/prototype/transforms/functional/_utils.py

This file was deleted.

Loading