Skip to content

Commit c041798

Browse files
pmeierlezcano
andauthored
another round of perf improvements for equalize (#6776)
* perf improvements for equalize Co-authored-by: lezcano <[email protected]> * improve reference tests * add extensive comments and minor fixes to the kernel * improve comments Co-authored-by: lezcano <[email protected]>
1 parent 9f024a6 commit c041798

File tree

2 files changed

+92
-36
lines changed

2 files changed

+92
-36
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from datasets_utils import combinations_grid
1414
from prototype_common_utils import (
1515
ArgsKwargs,
16+
get_num_channels,
17+
ImageLoader,
1618
InfoBase,
1719
make_bounding_box_loaders,
1820
make_image_loader,
@@ -1359,9 +1361,43 @@ def sample_inputs_equalize_image_tensor():
13591361

13601362

13611363
def reference_inputs_equalize_image_tensor():
1362-
for image_loader in make_image_loaders(
1363-
extra_dims=[()], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
1364+
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
1365+
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
1366+
# the information gain is low if we already provide something really close to the expected value.
1367+
spatial_size = (256, 256)
1368+
for fn, color_space in itertools.product(
1369+
[
1370+
*[
1371+
lambda shape, dtype, device, low=low, high=high: torch.randint(
1372+
low, high, shape, dtype=dtype, device=device
1373+
)
1374+
for low, high in [
1375+
(0, 1),
1376+
(255, 256),
1377+
(0, 64),
1378+
(64, 192),
1379+
(192, 256),
1380+
]
1381+
],
1382+
*[
1383+
lambda shape, dtype, device, alpha=alpha, beta=beta: torch.distributions.Beta(alpha, beta)
1384+
.sample(shape)
1385+
.mul_(255)
1386+
.round_()
1387+
.to(dtype=dtype, device=device)
1388+
for alpha, beta in [
1389+
(0.5, 0.5),
1390+
(2, 2),
1391+
(2, 5),
1392+
(5, 2),
1393+
]
1394+
],
1395+
],
1396+
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
13641397
):
1398+
image_loader = ImageLoader(
1399+
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=torch.uint8, color_space=color_space
1400+
)
13651401
yield ArgsKwargs(image_loader)
13661402

13671403

torchvision/prototype/transforms/functional/_color.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -228,39 +228,6 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
228228
return autocontrast_image_pil(inpt)
229229

230230

231-
def _equalize_image_tensor_vec(image: torch.Tensor) -> torch.Tensor:
232-
# input image shape should be [N, H, W]
233-
shape = image.shape
234-
# Compute image histogram:
235-
flat_image = image.flatten(start_dim=1).to(torch.long) # -> [N, H * W]
236-
hist = flat_image.new_zeros(shape[0], 256)
237-
hist.scatter_add_(dim=1, index=flat_image, src=flat_image.new_ones(1).expand_as(flat_image))
238-
239-
# Compute image cdf
240-
chist = hist.cumsum_(dim=1)
241-
# Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
242-
# Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
243-
idx = chist.argmax(dim=1).sub_(1)
244-
# If histogram is degenerate (hist of zero image), index is -1
245-
neg_idx_mask = idx < 0
246-
idx.clamp_(min=0)
247-
step = chist.gather(dim=1, index=idx.unsqueeze(1))
248-
step[neg_idx_mask] = 0
249-
step.div_(255, rounding_mode="floor")
250-
251-
# Compute batched Look-up-table:
252-
# Necessary to avoid an integer division by zero, which raises
253-
clamped_step = step.clamp(min=1)
254-
chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255)
255-
lut = chist.to(torch.uint8) # [N, 256]
256-
257-
# Pad lut with zeros
258-
zeros = lut.new_zeros((1, 1)).expand(shape[0], 1)
259-
lut = torch.cat([zeros, lut[:, :-1]], dim=1)
260-
261-
return torch.where((step == 0).unsqueeze(-1), image, lut.gather(dim=1, index=flat_image).reshape_as(image))
262-
263-
264231
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
265232
if image.dtype != torch.uint8:
266233
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}")
@@ -272,7 +239,60 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
272239
if image.numel() == 0:
273240
return image
274241

275-
return _equalize_image_tensor_vec(image.reshape(-1, height, width)).reshape(image.shape)
242+
batch_shape = image.shape[:-2]
243+
flat_image = image.flatten(start_dim=-2).to(torch.long)
244+
245+
# The algorithm for histogram equalization is mirrored from PIL:
246+
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385
247+
248+
# Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8
249+
# images here and thus the values are already binned, the computation is trivial. The histogram is computed by using
250+
# the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127
251+
# in the histogram.
252+
hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32)
253+
hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image))
254+
cum_hist = hist.cumsum(dim=-1)
255+
256+
# The simplest form of lookup-table (LUT) that also achieves histogram equalization is
257+
# `lut = cum_hist / flat_image.shape[-1] * 255`
258+
# However, PIL uses a more elaborate scheme:
259+
# `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255`
260+
261+
# The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum
262+
# value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but
263+
# rather the maximum value in the image, which might be or not be 255.
264+
index = cum_hist.argmax(dim=-1)
265+
num_non_max_pixels = flat_image.shape[-1] - hist.gather(dim=-1, index=index.unsqueeze_(-1))
266+
267+
# This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies
268+
# to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the
269+
# division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison.
270+
step = num_non_max_pixels.div_(255, rounding_mode="floor")
271+
272+
# Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as
273+
# easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't,
274+
# we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to
275+
# pay the runtime cost for checking it every time.
276+
no_equalization = step.eq(0).unsqueeze_(-1)
277+
278+
# `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the
279+
# computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards.
280+
cum_hist = cum_hist[..., :-1]
281+
(
282+
cum_hist.add_(step // 2)
283+
# We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no
284+
# effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is
285+
# instead of equalized version.
286+
.div_(step.clamp_(min=1), rounding_mode="floor")
287+
# We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value
288+
# range of uint8 images
289+
.clamp_(0, 255)
290+
)
291+
lut = cum_hist.to(torch.uint8)
292+
lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1)
293+
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)
294+
295+
return torch.where(no_equalization, image, equalized_image)
276296

277297

278298
equalize_image_pil = _FP.equalize

0 commit comments

Comments
 (0)