-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[prototype] Speed up adjust_sharpness_image_tensor
#6930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import torch | ||
from torch.nn.functional import conv2d | ||
from torchvision.prototype import features | ||
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT | ||
|
||
|
@@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) | |
if image.numel() == 0 or height <= 2 or width <= 2: | ||
return image | ||
|
||
bound = _FT._max_value(image.dtype) | ||
fp = image.is_floating_point() | ||
shape = image.shape | ||
|
||
if image.ndim > 4: | ||
|
@@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) | |
else: | ||
needs_unsquash = False | ||
|
||
output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) | ||
# The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle. | ||
kernel_dtype = image.dtype if fp else torch.float32 | ||
a, b = 1.0 / 13.0, 5.0 / 13.0 | ||
kernel = torch.tensor([[a, a, a], [a, b, a], [a, a, a]], dtype=kernel_dtype, device=image.device) | ||
kernel = kernel.expand(num_channels, 1, 3, 3) | ||
|
||
# We copy and cast at the same time to avoid modifications on the original data | ||
output = image.to(dtype=kernel_dtype, copy=True) | ||
blurred_degenerate = conv2d(output, kernel, groups=num_channels) | ||
if not fp: | ||
# it is better to round before cast | ||
blurred_degenerate = blurred_degenerate.round_() | ||
|
||
# Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice. | ||
view = output[..., 1:-1, 1:-1] | ||
|
||
# We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent: | ||
# x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r) | ||
view.add_(blurred_degenerate.sub_(view), alpha=(1.0 - sharpness_factor)) | ||
Comment on lines
+141
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can push a change like this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a special case :( We can do this only because we are allowed to subtract |
||
|
||
# The actual data of ouput have been modified by the above. We only need to clamp and cast now. | ||
output = output.clamp_(0, bound) | ||
if not fp: | ||
output = output.to(image.dtype) | ||
|
||
if needs_unsquash: | ||
output = output.reshape(shape) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To ease review, here is the old implementation:
vision/torchvision/transforms/functional_tensor.py
Line 803 in bf58902