-
Notifications
You must be signed in to change notification settings - Fork 7.1k
remove unnecessary checks from pad_image_tensor #6894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
968f37b
remove unnecessary changes from pad_image_tensor
pmeier 3f841ca
cleanup
pmeier 1622cd6
fix fill=None workaround
pmeier 75f4ba1
Merge branch 'main' into perf/pad
pmeier 7a238ab
Merge branch 'main' into perf/pad
pmeier 84146c6
address review comments
pmeier 44e475e
Merge branch 'main' into perf/pad
pmeier 015ce01
remove more xfails
pmeier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,8 @@ | |
|
||
import PIL.Image | ||
import torch | ||
from torch.nn.functional import interpolate | ||
from torch.nn.functional import interpolate, pad as torch_pad | ||
|
||
from torchvision.prototype import features | ||
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT | ||
from torchvision.transforms.functional import ( | ||
|
@@ -15,7 +16,6 @@ | |
pil_to_tensor, | ||
to_pil_image, | ||
) | ||
from torchvision.transforms.functional_tensor import _parse_pad_padding | ||
|
||
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil | ||
|
||
|
@@ -663,7 +663,28 @@ def rotate( | |
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) | ||
|
||
|
||
pad_image_pil = _FP.pad | ||
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: | ||
if isinstance(padding, int): | ||
pad_left = pad_right = pad_top = pad_bottom = padding | ||
elif isinstance(padding, (tuple, list)): | ||
if len(padding) == 1: | ||
pad_left = pad_right = pad_top = pad_bottom = padding[0] | ||
elif len(padding) == 2: | ||
pad_left = pad_right = padding[0] | ||
pad_top = pad_bottom = padding[1] | ||
elif len(padding) == 4: | ||
pad_left = padding[0] | ||
pad_top = padding[1] | ||
pad_right = padding[2] | ||
pad_bottom = padding[3] | ||
else: | ||
raise ValueError( | ||
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" | ||
) | ||
else: | ||
raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}") | ||
|
||
return [pad_left, pad_right, pad_top, pad_bottom] | ||
|
||
|
||
def pad_image_tensor( | ||
|
@@ -672,50 +693,86 @@ def pad_image_tensor( | |
fill: features.FillTypeJIT = None, | ||
padding_mode: str = "constant", | ||
) -> torch.Tensor: | ||
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses | ||
# `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad` | ||
# internally. | ||
torch_padding = _parse_pad_padding(padding) | ||
|
||
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: | ||
raise ValueError( | ||
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, " | ||
f"but got `'{padding_mode}'`." | ||
) | ||
|
||
if fill is None: | ||
# This is a JIT workaround | ||
return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode) | ||
elif isinstance(fill, (int, float)) or len(fill) == 1: | ||
fill_number = fill[0] if isinstance(fill, list) else fill | ||
return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode) | ||
fill = 0 | ||
|
||
if isinstance(fill, (int, float)): | ||
return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) | ||
elif len(fill) == 1: | ||
return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode) | ||
else: | ||
return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode) | ||
return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode) | ||
|
||
|
||
def _pad_with_scalar_fill( | ||
image: torch.Tensor, | ||
padding: Union[int, List[int]], | ||
fill: Union[int, float, None], | ||
padding_mode: str = "constant", | ||
torch_padding: List[int], | ||
fill: Union[int, float], | ||
padding_mode: str, | ||
) -> torch.Tensor: | ||
shape = image.shape | ||
num_channels, height, width = shape[-3:] | ||
|
||
if image.numel() > 0: | ||
image = _FT.pad( | ||
img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode | ||
) | ||
image = image.reshape(-1, num_channels, height, width) | ||
|
||
if padding_mode == "edge": | ||
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map | ||
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad` | ||
# name. | ||
padding_mode = "replicate" | ||
|
||
if padding_mode == "constant": | ||
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill)) | ||
elif padding_mode in ("reflect", "replicate"): | ||
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs. | ||
# TODO: See https://github.com/pytorch/pytorch/issues/40763 | ||
dtype = image.dtype | ||
if not image.is_floating_point(): | ||
needs_cast = True | ||
image = image.to(torch.float32) | ||
else: | ||
needs_cast = False | ||
|
||
image = torch_pad(image, torch_padding, mode=padding_mode) | ||
|
||
if needs_cast: | ||
image = image.to(dtype) | ||
else: # padding_mode == "symmetric" | ||
image = _FT._pad_symmetric(image, torch_padding) | ||
|
||
new_height, new_width = image.shape[-2:] | ||
else: | ||
left, right, top, bottom = _FT._parse_pad_padding(padding) | ||
left, right, top, bottom = torch_padding | ||
new_height = height + top + bottom | ||
new_width = width + left + right | ||
|
||
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) | ||
|
||
|
||
# TODO: This should be removed once pytorch pad supports non-scalar padding values | ||
# TODO: This should be removed once torch_pad supports non-scalar padding values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vfdev-5 Is there an issue for that? |
||
def _pad_with_vector_fill( | ||
image: torch.Tensor, | ||
padding: Union[int, List[int]], | ||
torch_padding: List[int], | ||
fill: List[float], | ||
padding_mode: str = "constant", | ||
padding_mode: str, | ||
) -> torch.Tensor: | ||
if padding_mode != "constant": | ||
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") | ||
|
||
output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant") | ||
left, right, top, bottom = _parse_pad_padding(padding) | ||
output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant") | ||
left, right, top, bottom = torch_padding | ||
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1) | ||
|
||
if top > 0: | ||
|
@@ -729,6 +786,9 @@ def _pad_with_vector_fill( | |
return output | ||
|
||
|
||
pad_image_pil = _FP.pad | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor cleanup, since we normally define the PIL kernel below the tensor one. |
||
|
||
|
||
def pad_mask( | ||
mask: torch.Tensor, | ||
padding: Union[int, List[int]], | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've merged the check for invalid types as well as wrong lengths here, since this function is also used by
pad_bounding_box
and that currently doesn't have these checks.