@@ -228,39 +228,6 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
228
228
return autocontrast_image_pil (inpt )
229
229
230
230
231
- def _equalize_image_tensor_vec (image : torch .Tensor ) -> torch .Tensor :
232
- # input image shape should be [N, H, W]
233
- shape = image .shape
234
- # Compute image histogram:
235
- flat_image = image .flatten (start_dim = 1 ).to (torch .long ) # -> [N, H * W]
236
- hist = flat_image .new_zeros (shape [0 ], 256 )
237
- hist .scatter_add_ (dim = 1 , index = flat_image , src = flat_image .new_ones (1 ).expand_as (flat_image ))
238
-
239
- # Compute image cdf
240
- chist = hist .cumsum_ (dim = 1 )
241
- # Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
242
- # Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
243
- idx = chist .argmax (dim = 1 ).sub_ (1 )
244
- # If histogram is degenerate (hist of zero image), index is -1
245
- neg_idx_mask = idx < 0
246
- idx .clamp_ (min = 0 )
247
- step = chist .gather (dim = 1 , index = idx .unsqueeze (1 ))
248
- step [neg_idx_mask ] = 0
249
- step .div_ (255 , rounding_mode = "floor" )
250
-
251
- # Compute batched Look-up-table:
252
- # Necessary to avoid an integer division by zero, which raises
253
- clamped_step = step .clamp (min = 1 )
254
- chist .add_ (torch .div (step , 2 , rounding_mode = "floor" )).div_ (clamped_step , rounding_mode = "floor" ).clamp_ (0 , 255 )
255
- lut = chist .to (torch .uint8 ) # [N, 256]
256
-
257
- # Pad lut with zeros
258
- zeros = lut .new_zeros ((1 , 1 )).expand (shape [0 ], 1 )
259
- lut = torch .cat ([zeros , lut [:, :- 1 ]], dim = 1 )
260
-
261
- return torch .where ((step == 0 ).unsqueeze (- 1 ), image , lut .gather (dim = 1 , index = flat_image ).reshape_as (image ))
262
-
263
-
264
231
def equalize_image_tensor (image : torch .Tensor ) -> torch .Tensor :
265
232
if image .dtype != torch .uint8 :
266
233
raise TypeError (f"Only torch.uint8 image tensors are supported, but found { image .dtype } " )
@@ -272,7 +239,60 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
272
239
if image .numel () == 0 :
273
240
return image
274
241
275
- return _equalize_image_tensor_vec (image .reshape (- 1 , height , width )).reshape (image .shape )
242
+ batch_shape = image .shape [:- 2 ]
243
+ flat_image = image .flatten (start_dim = - 2 ).to (torch .long )
244
+
245
+ # The algorithm for histogram equalization is mirrored from PIL:
246
+ # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385
247
+
248
+ # Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8
249
+ # images here and thus the values are already binned, the computation is trivial. The histogram is computed by using
250
+ # the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127
251
+ # in the histogram.
252
+ hist = flat_image .new_zeros (batch_shape + (256 ,), dtype = torch .int32 )
253
+ hist .scatter_add_ (dim = - 1 , index = flat_image , src = hist .new_ones (1 ).expand_as (flat_image ))
254
+ cum_hist = hist .cumsum (dim = - 1 )
255
+
256
+ # The simplest form of lookup-table (LUT) that also achieves histogram equalization is
257
+ # `lut = cum_hist / flat_image.shape[-1] * 255`
258
+ # However, PIL uses a more elaborate scheme:
259
+ # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255`
260
+
261
+ # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum
262
+ # value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but
263
+ # rather the maximum value in the image, which might be or not be 255.
264
+ index = cum_hist .argmax (dim = - 1 )
265
+ num_non_max_pixels = flat_image .shape [- 1 ] - hist .gather (dim = - 1 , index = index .unsqueeze_ (- 1 ))
266
+
267
+ # This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies
268
+ # to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the
269
+ # division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison.
270
+ step = num_non_max_pixels .div_ (255 , rounding_mode = "floor" )
271
+
272
+ # Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as
273
+ # easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't,
274
+ # we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to
275
+ # pay the runtime cost for checking it every time.
276
+ no_equalization = step .eq (0 ).unsqueeze_ (- 1 )
277
+
278
+ # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the
279
+ # computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards.
280
+ cum_hist = cum_hist [..., :- 1 ]
281
+ (
282
+ cum_hist .add_ (step // 2 )
283
+ # We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no
284
+ # effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is
285
+ # instead of equalized version.
286
+ .div_ (step .clamp_ (min = 1 ), rounding_mode = "floor" )
287
+ # We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value
288
+ # range of uint8 images
289
+ .clamp_ (0 , 255 )
290
+ )
291
+ lut = cum_hist .to (torch .uint8 )
292
+ lut = torch .cat ([lut .new_zeros (1 ).expand (batch_shape + (1 ,)), lut ], dim = - 1 )
293
+ equalized_image = lut .gather (dim = - 1 , index = flat_image ).view_as (image )
294
+
295
+ return torch .where (no_equalization , image , equalized_image )
276
296
277
297
278
298
equalize_image_pil = _FP .equalize
0 commit comments