Skip to content

[prototype] Speed up adjust_sharpness_image_tensor #6930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 8, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch.nn.functional import conv2d
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT

Expand Down Expand Up @@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
if image.numel() == 0 or height <= 2 or width <= 2:
return image

bound = _FT._max_value(image.dtype)
fp = image.is_floating_point()
shape = image.shape

if image.ndim > 4:
Expand All @@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
else:
needs_unsquash = False

output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor)
# The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle.
kernel_dtype = image.dtype if fp else torch.float32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ease review, here is the old implementation:

def _blurred_degenerate_image(img: Tensor) -> Tensor:

a, b = 1.0 / 13.0, 5.0 / 13.0
kernel = torch.tensor([[a, a, a], [a, b, a], [a, a, a]], dtype=kernel_dtype, device=image.device)
kernel = kernel.expand(num_channels, 1, 3, 3)

# We copy and cast at the same time to avoid modifications on the original data
output = image.to(dtype=kernel_dtype, copy=True)
blurred_degenerate = conv2d(output, kernel, groups=num_channels)
if not fp:
# it is better to round before cast
blurred_degenerate = blurred_degenerate.round_()

# Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice.
view = output[..., 1:-1, 1:-1]

# We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent:
# x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r)
view.add_(blurred_degenerate.sub_(view), alpha=(1.0 - sharpness_factor))
Comment on lines +141 to +143
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can push a change like this to _blend or is this a special case here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a special case :( We can do this only because we are allowed to subtract image1 from image2 in place. In all other cases where _blend() is used, we rely on broadcasting do that's not possible.


# The actual data of ouput have been modified by the above. We only need to clamp and cast now.
output = output.clamp_(0, bound)
if not fp:
output = output.to(image.dtype)

if needs_unsquash:
output = output.reshape(shape)
Expand Down