Skip to content

[prototype] Speed up adjust_contrast_image_tensor #6933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,14 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
grayscale_image = _rgb_to_gray(image) if c == 3 else image
mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True)
fp = image.is_floating_point()
if c == 3:
grayscale_image = _rgb_to_gray(image, cast=False)
if not fp:
grayscale_image = grayscale_image.floor_()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier In an early iteration of the PR, I have missed this floor_() call which is necessary to reproduce identical results with stable. Unfortunately the tests passed without catching the issue. Might be worth checking that our reference tests check both on floats and ints.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is properly tested with

def reference_inputs_adjust_contrast_image_tensor():
for image_loader, contrast_factor in itertools.product(
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
_ADJUST_CONTRAST_FACTORS,
):
yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)

but it seems our tolerances are to high to pick up on this

closeness_kwargs=DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS,

DEFAULT_PIL_REFERENCE_CLOSENESS_KWARGS = {
(("TestKernels", "test_against_reference"), torch.float32, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
(("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=1e-5, rtol=0, agg_method="mean"),
}

However, we can't remove the tolerances completely here, since the output differs even with the current implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... 99% of the performance optimizations were tested for training and we know they are correct. So there might be rounding error issues. I think we should adjust the tolerances to the degree that will catch BC issues but not go ballistic for the 8th decimal. That might be a hard exercise though.

else:
grayscale_image = image if fp else image.to(torch.float32)
mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True)
return _blend(image, mean, contrast_factor)


Expand Down
8 changes: 5 additions & 3 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,12 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.repeat(repeats)


def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor:
def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
r, g, b = image.unbind(dim=-3)
l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114)
l_img = l_img.to(image.dtype).unsqueeze(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
if cast:
l_img = l_img.to(image.dtype)
l_img = l_img.unsqueeze(dim=-3)
return l_img


Expand Down