Skip to content

remove unnecessary checks from pad_image_tensor #6894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
),
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
Expand Down
2 changes: 0 additions & 2 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,6 @@ def reference_inputs_pad_bounding_box():
reference_inputs_fn=reference_inputs_pad_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
Expand All @@ -1159,7 +1158,6 @@ def reference_inputs_pad_bounding_box():
reference_fn=reference_pad_bounding_box,
reference_inputs_fn=reference_inputs_pad_bounding_box,
test_marks=[
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"),
],
),
Expand Down
102 changes: 81 additions & 21 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import PIL.Image
import torch
from torch.nn.functional import interpolate
from torch.nn.functional import interpolate, pad as torch_pad

from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import (
Expand All @@ -15,7 +16,6 @@
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import _parse_pad_padding

from ._meta import convert_format_bounding_box, get_spatial_size_image_pil

Expand Down Expand Up @@ -663,7 +663,28 @@ def rotate(
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


pad_image_pil = _FP.pad
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've merged the check for invalid types as well as wrong lengths here, since this function is also used by pad_bounding_box and that currently doesn't have these checks.

if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif isinstance(padding, (tuple, list)):
if len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
elif len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
else:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
else:
raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")

return [pad_left, pad_right, pad_top, pad_bottom]


def pad_image_tensor(
Expand All @@ -672,50 +693,86 @@ def pad_image_tensor(
fill: features.FillTypeJIT = None,
padding_mode: str = "constant",
) -> torch.Tensor:
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
# `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
# internally.
torch_padding = _parse_pad_padding(padding)

if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
f"but got `'{padding_mode}'`."
)

if fill is None:
# This is a JIT workaround
return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode)
elif isinstance(fill, (int, float)) or len(fill) == 1:
fill_number = fill[0] if isinstance(fill, list) else fill
return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode)
fill = 0

if isinstance(fill, (int, float)):
return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
elif len(fill) == 1:
return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
else:
return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode)
return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)


def _pad_with_scalar_fill(
image: torch.Tensor,
padding: Union[int, List[int]],
fill: Union[int, float, None],
padding_mode: str = "constant",
torch_padding: List[int],
fill: Union[int, float],
padding_mode: str,
) -> torch.Tensor:
shape = image.shape
num_channels, height, width = shape[-3:]

if image.numel() > 0:
image = _FT.pad(
img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
)
image = image.reshape(-1, num_channels, height, width)

if padding_mode == "edge":
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
padding_mode = "replicate"

if padding_mode == "constant":
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
elif padding_mode in ("reflect", "replicate"):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
# TODO: See https://github.com/pytorch/pytorch/issues/40763
dtype = image.dtype
if not image.is_floating_point():
needs_cast = True
image = image.to(torch.float32)
else:
needs_cast = False

image = torch_pad(image, torch_padding, mode=padding_mode)

if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)

new_height, new_width = image.shape[-2:]
else:
left, right, top, bottom = _FT._parse_pad_padding(padding)
left, right, top, bottom = torch_padding
new_height = height + top + bottom
new_width = width + left + right

return image.reshape(shape[:-3] + (num_channels, new_height, new_width))


# TODO: This should be removed once pytorch pad supports non-scalar padding values
# TODO: This should be removed once torch_pad supports non-scalar padding values
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 Is there an issue for that?

def _pad_with_vector_fill(
image: torch.Tensor,
padding: Union[int, List[int]],
torch_padding: List[int],
fill: List[float],
padding_mode: str = "constant",
padding_mode: str,
) -> torch.Tensor:
if padding_mode != "constant":
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")

output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding)
output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
left, right, top, bottom = torch_padding
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)

if top > 0:
Expand All @@ -729,6 +786,9 @@ def _pad_with_vector_fill(
return output


pad_image_pil = _FP.pad
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor cleanup, since we normally define the PIL kernel below the tensor one.



def pad_mask(
mask: torch.Tensor,
padding: Union[int, List[int]],
Expand Down