Skip to content

add grid transform zero fill shortcut #8099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,13 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
# Apply same grid to a batch of images
grid = grid.expand(squashed_batch_size, -1, -1, -1)

if fill is not None and not isinstance(fill, (tuple, list)):
fill = [float(fill)]

# filling with zeros is the default behavior and thus we can skip the extra fill handling
if fill is not None and all(f == 0 for f in fill):
fill = None
Comment on lines +576 to +578
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's because I should be sleeping already but I don't really understand the meaning of that comment, or what that code is really doing...

Is fill = None supposed to be equivalent to filling with zeros?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is fill = None supposed to be equivalent to filling with zeros?

Yes. IMO the parameter name "fill" is a misnomer. Filling will happen for most affine ops whether the user likes it or not and regardless of what value they pass for fill. This is just the nature of affine transformations since they can leave the original image into a state that doesn't fill the full rectangular image. Hence we need to fill the parts that are not covered by the transformed image.

IMO, a better name would be "fill color". We had that initially, but it was decided for fill in #2479. IIUC, the only reason for it was that we had more ops that used fill than ops that used fillcolor 🤷

Coming back to your question: yes fill=None, fill=0, and fill=[0,0,0] in theory should all do the same thing, namely filling with zeros. I don't know why we chose None as default rather than fill=0. IMO that further enhances the confusion that there is actually a case that doesn't fill.

Anyway, fill=None is a performance optimization. The PyTorch op that we use does not allow filling with anything but zeros (other filling modes are supported, but we need to change the fill value for constant filling). If we get a any value besides fill=None, we perform some extra steps to allow the fill color the user wants. However in case of fill=0 or fill=[0,0,0] this unnecessary, since it is the default behavior from the PyTorch op anyway.

Right now there is a bug (?) in our affine ops that actually gives us different behavior for the case of fill is not None. #8098 deals with that. If that is resolved, this PR will just be a performance shortcut. If we decide to not move forward with #8098, this likely means, we should also close this PR.


# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones(
Expand All @@ -583,8 +590,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
if fill is not None:
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
mask = mask.expand_as(float_img)
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
fill_img = torch.tensor(fill, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
if mode == "nearest":
bool_mask = mask < 0.5
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
Expand Down