From e7c5fcb721b0c29ff71a362283ac1a884831772b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 14 Oct 2022 17:50:51 +0200 Subject: [PATCH 1/4] perf improvements for equalize Co-authored-by: lezcano --- .../prototype/transforms/functional/_color.py | 61 +++++++++---------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 8460f9c6422..66add24622c 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -183,51 +183,46 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return autocontrast_image_pil(inpt) -def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor: - # input img shape should be [N, H, W] - shape = img.shape +def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: + if image.dtype != torch.uint8: + raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") + + num_channels, height, width = get_dimensions_image_tensor(image) + if num_channels not in (1, 3): + raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") + + if image.numel() == 0: + return image + + # input image shape should be (*, H, W) + shape = image.shape # Compute image histogram: - flat_img = img.flatten(start_dim=1).to(torch.long) # -> [N, H * W] - hist = flat_img.new_zeros(shape[0], 256) - hist.scatter_add_(dim=1, index=flat_img, src=flat_img.new_ones(1).expand_as(flat_img)) + flat_image = image.flatten(start_dim=-2).to(torch.long) # -> [*, H * W] + hist = flat_image.new_zeros(shape[:-2] + (256,), dtype=torch.int32) + hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image)) # Compute image cdf - chist = hist.cumsum_(dim=1) + pixels = flat_image.size(-1) + chist = hist.cumsum(dim=-1) # Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255 - # Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax() - idx = chist.argmax(dim=1).sub_(1) - # If histogram is degenerate (hist of zero image), index is -1 - neg_idx_mask = idx < 0 - idx.clamp_(min=0) - step = chist.gather(dim=1, index=idx.unsqueeze(1)) - step[neg_idx_mask] = 0 + # Trick: nonzero_hist[-1] == hist[chist.argmax()] + idx = chist.argmax(dim=-1).unsqueeze_(-1) + step = pixels - hist.gather(dim=-1, index=idx) step.div_(255, rounding_mode="floor") # Compute batched Look-up-table: + non_zero = step == 0 # Necessary to avoid an integer division by zero, which raises - clamped_step = step.clamp(min=1) + clamped_step = step.clamp_(min=1) + chist = chist[..., :-1] chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255) - lut = chist.to(torch.uint8) # [N, 256] + lut = chist.to(torch.uint8) # Pad lut with zeros - zeros = lut.new_zeros((1, 1)).expand(shape[0], 1) - lut = torch.cat([zeros, lut[:, :-1]], dim=1) - - return torch.where((step == 0).unsqueeze(-1), img, lut.gather(dim=1, index=flat_img).view_as(img)) - - -def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: - if image.dtype != torch.uint8: - raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") - - num_channels, height, width = get_dimensions_image_tensor(image) - if num_channels not in (1, 3): - raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") - - if image.numel() == 0: - return image + zeros = lut.new_zeros(1).expand(shape[:-2] + (1,)) + lut = torch.cat([zeros, lut], dim=-1) - return _equalize_image_tensor_vec(image.view(-1, height, width)).reshape(image.shape) + return torch.where(non_zero.unsqueeze_(-1), image, lut.gather(dim=-1, index=flat_image).view_as(image)) equalize_image_pil = _FP.equalize From 5291e349c0c4fed758c008584fa9bf1ca5e75582 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 21 Oct 2022 08:21:03 +0200 Subject: [PATCH 2/4] improve reference tests --- test/prototype_transforms_kernel_infos.py | 40 +++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index f8b237f2e96..a9b38930a17 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -12,6 +12,8 @@ from datasets_utils import combinations_grid from prototype_common_utils import ( ArgsKwargs, + get_num_channels, + ImageLoader, InfoBase, make_bounding_box_loaders, make_image_loader, @@ -1357,9 +1359,43 @@ def sample_inputs_equalize_image_tensor(): def reference_inputs_equalize_image_tensor(): - for image_loader in make_image_loaders( - extra_dims=[()], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8] + # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range. + # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one, + # the information gain is low if we already provide something really close to the expected value. + spatial_size = (256, 256) + for fn, color_space in itertools.product( + [ + *[ + lambda shape, dtype, device, low=low, high=high: torch.randint( + low, high, shape, dtype=dtype, device=device + ) + for low, high in [ + (0, 1), + (255, 256), + (0, 64), + (64, 192), + (192, 256), + ] + ], + *[ + lambda shape, dtype, device, alpha=alpha, beta=beta: torch.distributions.Beta(alpha, beta) + .sample(shape) + .mul_(255) + .round_() + .to(dtype=dtype, device=device) + for alpha, beta in [ + (0.5, 0.5), + (2, 2), + (2, 5), + (5, 2), + ] + ], + ], + [features.ColorSpace.GRAY, features.ColorSpace.RGB], ): + image_loader = ImageLoader( + fn, shape=(get_num_channels(color_space), *spatial_size), dtype=torch.uint8, color_space=color_space + ) yield ArgsKwargs(image_loader) From 5aa4f8b2ebd09900ef7741c1282fe4f4d412d66c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 21 Oct 2022 09:03:30 +0200 Subject: [PATCH 3/4] add extensive comments and minor fixes to the kernel --- .../prototype/transforms/functional/_color.py | 78 ++++++++++++------- 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 2765002e0f0..a1440b1d0af 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -194,35 +194,57 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: return image - # input image shape should be (*, H, W) - shape = image.shape - # Compute image histogram: - flat_image = image.flatten(start_dim=-2).to(torch.long) # -> [*, H * W] - hist = flat_image.new_zeros(shape[:-2] + (256,), dtype=torch.int32) + # The algorithm for histogram equalization is mirrored from PIL: + # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 + # Although PyTorch has builtin functionality for histograms, we are not using them here since they can only be + # applied to the complete tensor. We need to apply them to the spatial size, i.e. the last two dimensions only, and + # thus we would need to call them in a for loop, which is inefficient. + batch_shape = image.shape[:-2] + flat_image = image.flatten(start_dim=-2).to(torch.long) + + # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image + # corresponds to adding 1 to index 127 in the histogram. + hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32) hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image)) - - # Compute image cdf - pixels = flat_image.size(-1) - chist = hist.cumsum(dim=-1) - # Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255 - # Trick: nonzero_hist[-1] == hist[chist.argmax()] - idx = chist.argmax(dim=-1).unsqueeze_(-1) - step = pixels - hist.gather(dim=-1, index=idx) - step.div_(255, rounding_mode="floor") - - # Compute batched Look-up-table: - non_zero = step == 0 - # Necessary to avoid an integer division by zero, which raises - clamped_step = step.clamp_(min=1) - chist = chist[..., :-1] - chist.add_(torch.div(step, 2, rounding_mode="floor")).div_(clamped_step, rounding_mode="floor").clamp_(0, 255) - lut = chist.to(torch.uint8) - - # Pad lut with zeros - zeros = lut.new_zeros(1).expand(shape[:-2] + (1,)) - lut = torch.cat([zeros, lut], dim=-1) - - return torch.where(non_zero.unsqueeze_(-1), image, lut.gather(dim=-1, index=flat_image).view_as(image)) + cum_hist = hist.cumsum(dim=-1) + + # The simplest form of lookup-table (LUT) that also achieves histogram equalization is + # `lut = cum_hist / flat_image.shape[-1] * 255` + # However, PIL uses a more elaborate scheme: + # 1. Instead of normalizing the cumulative histogram by the number of pixels (` / flat_image.shape[-1]` above), + # it is normalized by the number of pixels that are don't have the maximum value. Note that maximum value here + # does not mean 255 per se but rather the maximum value in the image, which might be or not be 255. + # 2. Instead of normalizing just the cumulative histogram, a constant based on the number of non-maximum is added + # to it. + # This brings the computation of the LUT to + # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255` + + # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum value + index = cum_hist.argmax(dim=-1) + num_non_max_pixels = flat_image.shape[-1] - hist.gather(dim=-1, index=index.unsqueeze_(-1)) + + # This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies + # to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the + # division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison. + step = num_non_max_pixels.div_(255, rounding_mode="floor") + + # Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as + # easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't, + # we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to + # pay the runtime cost for checking it every time. + no_equalization = step.eq(0).unsqueeze_(-1) + + # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the + # computation only for `lut[1:]` and add `lut[0] == 0` afterwards. + cum_hist = cum_hist[..., :-1] + # We need the `step.clamp_`(min=1) call here to avoid zero division. This has no effect on the returned result of + # this kernel since images inside the batch with `step == 0` are returned as is rather than their equalized version. + cum_hist.add_(step // 2).div_(step.clamp_(min=1), rounding_mode="floor").clamp_(0, 255) + lut = cum_hist.to(torch.uint8) + lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1) + equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) + + return torch.where(no_equalization, image, equalized_image) equalize_image_pil = _FP.equalize From 6d448cb17a6be40348915f18d0a9a75720449021 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 21 Oct 2022 10:05:39 +0200 Subject: [PATCH 4/4] improve comments --- .../prototype/transforms/functional/_color.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index a1440b1d0af..3947d87d111 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -194,16 +194,16 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: return image - # The algorithm for histogram equalization is mirrored from PIL: - # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 - # Although PyTorch has builtin functionality for histograms, we are not using them here since they can only be - # applied to the complete tensor. We need to apply them to the spatial size, i.e. the last two dimensions only, and - # thus we would need to call them in a for loop, which is inefficient. batch_shape = image.shape[:-2] flat_image = image.flatten(start_dim=-2).to(torch.long) - # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image - # corresponds to adding 1 to index 127 in the histogram. + # The algorithm for histogram equalization is mirrored from PIL: + # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 + + # Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8 + # images here and thus the values are already binned, the computation is trivial. The histogram is computed by using + # the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127 + # in the histogram. hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32) hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image)) cum_hist = hist.cumsum(dim=-1) @@ -211,15 +211,11 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: # The simplest form of lookup-table (LUT) that also achieves histogram equalization is # `lut = cum_hist / flat_image.shape[-1] * 255` # However, PIL uses a more elaborate scheme: - # 1. Instead of normalizing the cumulative histogram by the number of pixels (` / flat_image.shape[-1]` above), - # it is normalized by the number of pixels that are don't have the maximum value. Note that maximum value here - # does not mean 255 per se but rather the maximum value in the image, which might be or not be 255. - # 2. Instead of normalizing just the cumulative histogram, a constant based on the number of non-maximum is added - # to it. - # This brings the computation of the LUT to # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255` - # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum value + # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum + # value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but + # rather the maximum value in the image, which might be or not be 255. index = cum_hist.argmax(dim=-1) num_non_max_pixels = flat_image.shape[-1] - hist.gather(dim=-1, index=index.unsqueeze_(-1)) @@ -235,11 +231,18 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: no_equalization = step.eq(0).unsqueeze_(-1) # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the - # computation only for `lut[1:]` and add `lut[0] == 0` afterwards. + # computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards. cum_hist = cum_hist[..., :-1] - # We need the `step.clamp_`(min=1) call here to avoid zero division. This has no effect on the returned result of - # this kernel since images inside the batch with `step == 0` are returned as is rather than their equalized version. - cum_hist.add_(step // 2).div_(step.clamp_(min=1), rounding_mode="floor").clamp_(0, 255) + ( + cum_hist.add_(step // 2) + # We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no + # effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is + # instead of equalized version. + .div_(step.clamp_(min=1), rounding_mode="floor") + # We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value + # range of uint8 images + .clamp_(0, 255) + ) lut = cum_hist.to(torch.uint8) lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1) equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)