Skip to content

Implement the equalize transform #3119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,16 @@ def test_autocontrast(self):
agg_method="max"
)

def test_equalize(self):
torch.set_deterministic(False)
self._test_adjust_fn(
F.equalize,
F_pil.equalize,
F_t.equalize,
[{}],
dts=(None,)
)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
10 changes: 9 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ def test_gaussian_blur_asserts(self):
def _test_randomness(self, fn, trans, configs):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
img = transforms.ToPILImage()(torch.rand(3, 16, 18))

for p in [0.5, 0.7]:
for config in configs:
Expand Down Expand Up @@ -1846,6 +1846,14 @@ def test_random_autocontrast(self):
[{}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_equalize(self):
self._test_randomness(
F.equalize,
transforms.RandomEqualize,
[{}]
)


if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def test_random_solarize(self):
def test_random_autocontrast(self):
self._test_op('autocontrast', 'RandomAutocontrast')

def test_random_equalize(self):
torch.set_deterministic(False)
self._test_op('equalize', 'RandomEqualize')

def test_color_jitter(self):

tol = 1.0 + 1e-10
Expand Down
20 changes: 20 additions & 0 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,3 +1273,23 @@ def autocontrast(img: Tensor) -> Tensor:
return F_pil.autocontrast(img)

return F_t.autocontrast(img)


def equalize(img: Tensor) -> Tensor:
"""Equalize the histogram of a PIL Image or torch Tensor by applying
a non-linear mapping to the input in order to create a uniform
distribution of grayscale values in the output.

Args:
img (PIL Image or Tensor): Image on which equalize is applied.
If img is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.

Returns:
PIL Image or Tensor: An image that was equalized.
"""
if not isinstance(img, torch.Tensor):
return F_pil.equalize(img)

return F_t.equalize(img)
7 changes: 7 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,10 @@ def autocontrast(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img)


@torch.jit.unused
def equalize(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.equalize(img)
38 changes: 38 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,3 +1284,41 @@ def autocontrast(img: Tensor) -> Tensor:
scale = bound / (maximum - minimum)

return ((img.to(dtype) - minimum) * scale).clamp(0, bound).to(img.dtype)


def _scale_channel(img_chan):
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)

nonzero_hist = hist[hist != 0]
if nonzero_hist.numel() > 0:
step = (nonzero_hist.sum() - nonzero_hist[-1]) // 255
else:
step = torch.tensor(0, device=img_chan.device)
if step == 0:
return img_chan

lut = (torch.cumsum(hist, 0) + (step // 2)) // step
lut = torch.cat([torch.zeros(1, device=img_chan.device), lut[:-1]]).clamp(0, 255)

return lut[img_chan.to(torch.int64)].to(torch.uint8)


def _equalize_single_image(img: Tensor) -> Tensor:
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])


def equalize(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

if not (3 <= img.ndim <= 4):
raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim))
if img.dtype != torch.uint8:
raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype))

_assert_channels(img, [1, 3])

if img.ndim == 3:
return _equalize_single_image(img)

return torch.stack([_equalize_single_image(x) for x in img])
42 changes: 41 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize", "RandomAutocontrast"]
"RandomSolarize", "RandomAutocontrast", "RandomEqualize"]


class Compose:
Expand Down Expand Up @@ -1876,3 +1876,43 @@ def forward(self, img):

def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomEqualize(torch.nn.Module):
"""Equalize the histogram of the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions.

Args:
p (float): probability of the image being equalized. Default value is 0.5
"""

def __init__(self, p=0.5):
super().__init__()
self.p = p

@staticmethod
def get_params() -> float:
"""Choose a value for the random transformation.

Returns:
float: Random value which is used to determine whether the random transformation
should occur.
"""
return torch.rand(1).item()

def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be equalized.

Returns:
PIL Image or Tensor: Randomly equalized image.
"""
if self.get_params() < self.p:
return F.equalize(img)
return img

def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)