Skip to content

another round of perf improvements for equalize #6776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
get_num_channels,
ImageLoader,
InfoBase,
make_bounding_box_loaders,
make_image_loader,
Expand Down Expand Up @@ -1359,9 +1361,43 @@ def sample_inputs_equalize_image_tensor():


def reference_inputs_equalize_image_tensor():
for image_loader in make_image_loaders(
extra_dims=[()], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
# the information gain is low if we already provide something really close to the expected value.
spatial_size = (256, 256)
for fn, color_space in itertools.product(
[
*[
lambda shape, dtype, device, low=low, high=high: torch.randint(
low, high, shape, dtype=dtype, device=device
)
for low, high in [
(0, 1),
(255, 256),
(0, 64),
(64, 192),
(192, 256),
]
],
*[
lambda shape, dtype, device, alpha=alpha, beta=beta: torch.distributions.Beta(alpha, beta)
.sample(shape)
.mul_(255)
.round_()
.to(dtype=dtype, device=device)
for alpha, beta in [
(0.5, 0.5),
(2, 2),
(2, 5),
(5, 2),
]
],
],
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
):
image_loader = ImageLoader(
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=torch.uint8, color_space=color_space
)
yield ArgsKwargs(image_loader)


Expand Down
88 changes: 54 additions & 34 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,39 +228,6 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return autocontrast_image_pil(inpt)


def _equalize_image_tensor_vec(image: torch.Tensor) -> torch.Tensor:
# input image shape should be [N, H, W]
shape = image.shape
# Compute image histogram:
flat_image = image.flatten(start_dim=1).to(torch.long) # -> [N, H * W]
hist = flat_image.new_zeros(shape[0], 256)
hist.scatter_add_(dim=1, index=flat_image, src=flat_image.new_ones(1).expand_as(flat_image))

# Compute image cdf
chist = hist.cumsum_(dim=1)
# Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
# Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
idx = chist.argmax(dim=1).sub_(1)
# If histogram is degenerate (hist of zero image), index is -1
neg_idx_mask = idx < 0
idx.clamp_(min=0)
step = chist.gather(dim=1, index=idx.unsqueeze(1))
step[neg_idx_mask] = 0
step.div_(255, rounding_mode="floor")

# Compute batched Look-up-table:
# Necessary to avoid an integer division by zero, which raises
clamped_step = step.clamp(min=1)
chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255)
lut = chist.to(torch.uint8) # [N, 256]

# Pad lut with zeros
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
lut = torch.cat([zeros, lut[:, :-1]], dim=1)

return torch.where((step == 0).unsqueeze(-1), image, lut.gather(dim=1, index=flat_image).reshape_as(image))


def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype != torch.uint8:
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}")
Expand All @@ -272,7 +239,60 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image

return _equalize_image_tensor_vec(image.reshape(-1, height, width)).reshape(image.shape)
batch_shape = image.shape[:-2]
flat_image = image.flatten(start_dim=-2).to(torch.long)

# The algorithm for histogram equalization is mirrored from PIL:
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385

# Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8
# images here and thus the values are already binned, the computation is trivial. The histogram is computed by using
# the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127
# in the histogram.
hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32)
hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image))
cum_hist = hist.cumsum(dim=-1)

# The simplest form of lookup-table (LUT) that also achieves histogram equalization is
# `lut = cum_hist / flat_image.shape[-1] * 255`
# However, PIL uses a more elaborate scheme:
# `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255`

# The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum
# value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but
# rather the maximum value in the image, which might be or not be 255.
index = cum_hist.argmax(dim=-1)
num_non_max_pixels = flat_image.shape[-1] - hist.gather(dim=-1, index=index.unsqueeze_(-1))

# This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies
# to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the
# division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison.
step = num_non_max_pixels.div_(255, rounding_mode="floor")

# Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as
# easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't,
# we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to
# pay the runtime cost for checking it every time.
no_equalization = step.eq(0).unsqueeze_(-1)

# `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the
# computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards.
cum_hist = cum_hist[..., :-1]
(
cum_hist.add_(step // 2)
# We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no
# effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is
# instead of equalized version.
.div_(step.clamp_(min=1), rounding_mode="floor")
# We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value
# range of uint8 images
.clamp_(0, 255)
)
lut = cum_hist.to(torch.uint8)
lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1)
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)

return torch.where(no_equalization, image, equalized_image)


equalize_image_pil = _FP.equalize
Expand Down