diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index e570e4355c5..d817e4a71fb 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -234,7 +234,6 @@ def fill_sequence_needs_broadcast(args_kwargs): condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs) and args_kwargs.kwargs.get("padding_mode", "constant") == "constant", ), - xfail_jit_python_scalar_arg("padding"), xfail_jit_tuple_instead_of_list("padding"), xfail_jit_tuple_instead_of_list("fill"), # TODO: check if this is a regression since it seems that should be supported if `int` is ok diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 008887539a5..a106aea65ba 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1146,7 +1146,6 @@ def reference_inputs_pad_bounding_box(): reference_inputs_fn=reference_inputs_pad_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, test_marks=[ - xfail_jit_python_scalar_arg("padding"), xfail_jit_tuple_instead_of_list("padding"), xfail_jit_tuple_instead_of_list("fill"), # TODO: check if this is a regression since it seems that should be supported if `int` is ok @@ -1159,7 +1158,6 @@ def reference_inputs_pad_bounding_box(): reference_fn=reference_pad_bounding_box, reference_inputs_fn=reference_inputs_pad_bounding_box, test_marks=[ - xfail_jit_python_scalar_arg("padding"), xfail_jit_tuple_instead_of_list("padding"), ], ), diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 40fa904ade2..f2a12d6f609 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -4,7 +4,8 @@ import PIL.Image import torch -from torch.nn.functional import interpolate +from torch.nn.functional import interpolate, pad as torch_pad + from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms.functional import ( @@ -15,7 +16,6 @@ pil_to_tensor, to_pil_image, ) -from torchvision.transforms.functional_tensor import _parse_pad_padding from ._meta import convert_format_bounding_box, get_spatial_size_image_pil @@ -663,7 +663,28 @@ def rotate( return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) -pad_image_pil = _FP.pad +def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + elif isinstance(padding, (tuple, list)): + if len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + elif len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + else: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) + else: + raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}") + + return [pad_left, pad_right, pad_top, pad_bottom] def pad_image_tensor( @@ -672,50 +693,86 @@ def pad_image_tensor( fill: features.FillTypeJIT = None, padding_mode: str = "constant", ) -> torch.Tensor: + # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses + # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad` + # internally. + torch_padding = _parse_pad_padding(padding) + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError( + f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, " + f"but got `'{padding_mode}'`." + ) + if fill is None: - # This is a JIT workaround - return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode) - elif isinstance(fill, (int, float)) or len(fill) == 1: - fill_number = fill[0] if isinstance(fill, list) else fill - return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode) + fill = 0 + + if isinstance(fill, (int, float)): + return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) + elif len(fill) == 1: + return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode) else: - return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode) + return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) def _pad_with_scalar_fill( image: torch.Tensor, - padding: Union[int, List[int]], - fill: Union[int, float, None], - padding_mode: str = "constant", + torch_padding: List[int], + fill: Union[int, float], + padding_mode: str, ) -> torch.Tensor: shape = image.shape num_channels, height, width = shape[-3:] if image.numel() > 0: - image = _FT.pad( - img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode - ) + image = image.reshape(-1, num_channels, height, width) + + if padding_mode == "edge": + # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map + # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad` + # name. + padding_mode = "replicate" + + if padding_mode == "constant": + image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill)) + elif padding_mode in ("reflect", "replicate"): + # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs. + # TODO: See https://github.com/pytorch/pytorch/issues/40763 + dtype = image.dtype + if not image.is_floating_point(): + needs_cast = True + image = image.to(torch.float32) + else: + needs_cast = False + + image = torch_pad(image, torch_padding, mode=padding_mode) + + if needs_cast: + image = image.to(dtype) + else: # padding_mode == "symmetric" + image = _FT._pad_symmetric(image, torch_padding) + new_height, new_width = image.shape[-2:] else: - left, right, top, bottom = _FT._parse_pad_padding(padding) + left, right, top, bottom = torch_padding new_height = height + top + bottom new_width = width + left + right return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) -# TODO: This should be removed once pytorch pad supports non-scalar padding values +# TODO: This should be removed once torch_pad supports non-scalar padding values def _pad_with_vector_fill( image: torch.Tensor, - padding: Union[int, List[int]], + torch_padding: List[int], fill: List[float], - padding_mode: str = "constant", + padding_mode: str, ) -> torch.Tensor: if padding_mode != "constant": raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") - output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant") - left, right, top, bottom = _parse_pad_padding(padding) + output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") + left, right, top, bottom = torch_padding fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1) if top > 0: @@ -729,6 +786,9 @@ def _pad_with_vector_fill( return output +pad_image_pil = _FP.pad + + def pad_mask( mask: torch.Tensor, padding: Union[int, List[int]],