Skip to content

Vectorize equalize transformation #3334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

NicolasHug
Copy link
Member

Closes #3173

This PR vectorizes the histograms computation (over batches and over channels) in the equalize transformation.

Comment on lines 921 to 927
if img.ndim == 3:
return _remap_single_image(img, luts)
else: # more than one image
imgs = img
return torch.stack([
_remap_single_image(img, luts) for (img, luts) in zip(imgs, luts)
])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would get rid of the stack() calls: at that point all we want to do is to remap img (...xCxHxW) according to the values in luts (...xCx256).

There's probably a vetorized one-liner to do that, maybe with gather(), but I couldn't get it to work. Any help welcome :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get both the single image and the batch of image to work with the approach I mentioned just above. The implementation would look somewhat like

idx_batch = torch.arange(0, batch_size, device=imgs.device)[:, None, None, None]
idx_channels = torch.arange(0, num_channels, device=imgs.device)[None, :, None, None]
return luts[idx_batch, idx_channels, imgs]

@datumbox datumbox requested a review from fmassa February 1, 2021 13:57
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I've a few suggestions which I think would allow to vectorize the rest of the operations.

Additionally, it would be good to see if we have any runtime speedups (either on CPU or GPU), as the vectorized implementation uses a bit more memory, so if we could show speedups there it would be nice

Comment on lines 889 to 891
return torch.stack([
channel_values[channel_indices] for (channel_values, channel_indices) in zip(luts, img)
]).to(torch.uint8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can replace the for loop and the stack with something like

idxs = torch.arange(0, img.shape[-3], device=img.device)[:, None, None]
luts[idxs, img]

Comment on lines 921 to 927
if img.ndim == 3:
return _remap_single_image(img, luts)
else: # more than one image
imgs = img
return torch.stack([
_remap_single_image(img, luts) for (img, luts) in zip(imgs, luts)
])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get both the single image and the batch of image to work with the approach I mentioned just above. The implementation would look somewhat like

idx_batch = torch.arange(0, batch_size, device=imgs.device)[:, None, None, None]
idx_channels = torch.arange(0, num_channels, device=imgs.device)[None, :, None, None]
return luts[idx_batch, idx_channels, imgs]

@NicolasHug
Copy link
Member Author

Thanks for the review @fmassa .

Interestingly, this PR does not seem to lead to any improvement (quite the opposite actually). Locally on my laptop I get the following for a 64x3x128x128 tensor:

This PR: 1.51 s ± 88.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Master:  42.1 ms ± 436 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Do you think it's worth investigating further where the slowdown might come from, or should we just close this PR and the original issue?

@codecov
Copy link

codecov bot commented Feb 1, 2021

Codecov Report

Merging #3334 (9f2fb98) into master (859a535) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3334      +/-   ##
==========================================
+ Coverage   73.90%   73.91%   +0.01%     
==========================================
  Files         104      104              
  Lines        9618     9622       +4     
  Branches     1544     1542       -2     
==========================================
+ Hits         7108     7112       +4     
  Misses       2028     2028              
  Partials      482      482              
Impacted Files Coverage Δ
torchvision/transforms/functional_tensor.py 79.39% <100.00%> (+0.15%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 859a535...9f2fb98. Read the comment docs.


return torch.stack([_equalize_single_image(x) for x in img])
luts = (torch.cumsum(hist, dim=-1) + (step // 2)) // step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a potential bug here. In the original implementation the step was a scalar so the if statement was enough to protect against division by zero. Here it's a C dimensional vector and if some of its values are zeroes it can lead to infs/nans.

@fmassa
Copy link
Member

fmassa commented Feb 2, 2021

Do you think it's worth investigating further where the slowdown might come from, or should we just close this PR and the original issue?

I think the 20x slowdown on a CPU is a bit surprising. I would have expected that it could be a bit slower on CPU and faster on the GPU.

I would say to investigate this a bit further to see where the slowdown comes from -- is it from the histogram creation? If yes, then we could probably leave the histogram creation in a for loop and then perform the batch indexing as we currently have in this PR.

@NicolasHug
Copy link
Member Author

The slowdown does seem to come from the histogram creation, which amounts for the vast majority of the computation time. As a result I'm not sure we can expect a significant improvement even if we just use batch indexing (we'd still need to stack the histograms anyway, and the execution time is dominated by the histogram computation).

Note that on GPU the difference is smaller, but master is still faster.
Histogram creation for a (32, 3, 64, 64) image:

This PR: 14.6 ms ± 616 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Master:  8.81 ms ± 87 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

On top of that the solution in this PR is much more memory-consuming, so we can probably just close this PR?

@fmassa
Copy link
Member

fmassa commented Feb 16, 2021

Thanks for investigating this @NicolasHug !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Vectorize the equalize transformation
4 participants