-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Vectorize equalize transformation #3334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
if img.ndim == 3: | ||
return _remap_single_image(img, luts) | ||
else: # more than one image | ||
imgs = img | ||
return torch.stack([ | ||
_remap_single_image(img, luts) for (img, luts) in zip(imgs, luts) | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we would get rid of the stack()
calls: at that point all we want to do is to remap img (...xCxHxW)
according to the values in luts (...xCx256)
.
There's probably a vetorized one-liner to do that, maybe with gather()
, but I couldn't get it to work. Any help welcome :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can get both the single image and the batch of image to work with the approach I mentioned just above. The implementation would look somewhat like
idx_batch = torch.arange(0, batch_size, device=imgs.device)[:, None, None, None]
idx_channels = torch.arange(0, num_channels, device=imgs.device)[None, :, None, None]
return luts[idx_batch, idx_channels, imgs]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I've a few suggestions which I think would allow to vectorize the rest of the operations.
Additionally, it would be good to see if we have any runtime speedups (either on CPU or GPU), as the vectorized implementation uses a bit more memory, so if we could show speedups there it would be nice
return torch.stack([ | ||
channel_values[channel_indices] for (channel_values, channel_indices) in zip(luts, img) | ||
]).to(torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can replace the for loop and the stack
with something like
idxs = torch.arange(0, img.shape[-3], device=img.device)[:, None, None]
luts[idxs, img]
if img.ndim == 3: | ||
return _remap_single_image(img, luts) | ||
else: # more than one image | ||
imgs = img | ||
return torch.stack([ | ||
_remap_single_image(img, luts) for (img, luts) in zip(imgs, luts) | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can get both the single image and the batch of image to work with the approach I mentioned just above. The implementation would look somewhat like
idx_batch = torch.arange(0, batch_size, device=imgs.device)[:, None, None, None]
idx_channels = torch.arange(0, num_channels, device=imgs.device)[None, :, None, None]
return luts[idx_batch, idx_channels, imgs]
Thanks for the review @fmassa . Interestingly, this PR does not seem to lead to any improvement (quite the opposite actually). Locally on my laptop I get the following for a
Do you think it's worth investigating further where the slowdown might come from, or should we just close this PR and the original issue? |
Codecov Report
@@ Coverage Diff @@
## master #3334 +/- ##
==========================================
+ Coverage 73.90% 73.91% +0.01%
==========================================
Files 104 104
Lines 9618 9622 +4
Branches 1544 1542 -2
==========================================
+ Hits 7108 7112 +4
Misses 2028 2028
Partials 482 482
Continue to review full report at Codecov.
|
|
||
return torch.stack([_equalize_single_image(x) for x in img]) | ||
luts = (torch.cumsum(hist, dim=-1) + (step // 2)) // step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is a potential bug here. In the original implementation the step
was a scalar so the if statement was enough to protect against division by zero. Here it's a C dimensional vector and if some of its values are zeroes it can lead to infs/nans.
I think the 20x slowdown on a CPU is a bit surprising. I would have expected that it could be a bit slower on CPU and faster on the GPU. I would say to investigate this a bit further to see where the slowdown comes from -- is it from the histogram creation? If yes, then we could probably leave the histogram creation in a for loop and then perform the batch indexing as we currently have in this PR. |
The slowdown does seem to come from the histogram creation, which amounts for the vast majority of the computation time. As a result I'm not sure we can expect a significant improvement even if we just use batch indexing (we'd still need to stack the histograms anyway, and the execution time is dominated by the histogram computation). Note that on GPU the difference is smaller, but master is still faster.
On top of that the solution in this PR is much more memory-consuming, so we can probably just close this PR? |
Thanks for investigating this @NicolasHug ! |
Closes #3173
This PR vectorizes the histograms computation (over batches and over channels) in the
equalize
transformation.