Skip to content

Commit 4beffea

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] assume that integer images are [0, 255] in equalize (#6859)
Reviewed By: datumbox Differential Revision: D40851027 fbshipit-source-id: c3e55e63bd176e3e88501139a1a0400304a477f8
1 parent 7bf6314 commit 4beffea

File tree

1 file changed

+7
-11
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+7
-11
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -385,18 +385,14 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
385385
if image.numel() == 0:
386386
return image
387387

388-
# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
389-
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
390-
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
391-
# unfeasible for `torch.int64`.
392-
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
393-
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
394-
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
395-
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
396-
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
397-
# by far the most common, we choose it as base.
398388
output_dtype = image.dtype
399-
image = convert_dtype_image_tensor(image, torch.uint8)
389+
if image.is_floating_point():
390+
# Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
391+
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
392+
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it
393+
# slower and more complicated to implement than a simple conversion and a fast histogram implementation for
394+
# integers.
395+
image = convert_dtype_image_tensor(image, torch.uint8)
400396

401397
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
402398
# corresponds to adding 1 to index 127 in the histogram.

0 commit comments

Comments
 (0)