diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index fa6297fa5ef..0e039be6041 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -922,6 +922,16 @@ def test_autocontrast(self): agg_method="max" ) + def test_equalize(self): + torch.set_deterministic(False) + self._test_adjust_fn( + F.equalize, + F_pil.equalize, + F_t.equalize, + [{}], + dts=(None,) + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 81104c10f21..5defce28588 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1795,7 +1795,7 @@ def test_gaussian_blur_asserts(self): def _test_randomness(self, fn, trans, configs): random_state = random.getstate() random.seed(42) - img = transforms.ToPILImage()(torch.rand(3, 10, 10)) + img = transforms.ToPILImage()(torch.rand(3, 16, 18)) for p in [0.5, 0.7]: for config in configs: @@ -1846,6 +1846,14 @@ def test_random_autocontrast(self): [{}] ) + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_equalize(self): + self._test_randomness( + F.equalize, + transforms.RandomEqualize, + [{}] + ) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 2c36664a517..7af1f1d4c46 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -107,6 +107,10 @@ def test_random_solarize(self): def test_random_autocontrast(self): self._test_op('autocontrast', 'RandomAutocontrast') + def test_random_equalize(self): + torch.set_deterministic(False) + self._test_op('equalize', 'RandomEqualize') + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d401aa4cc90..948638f3dd8 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1273,3 +1273,23 @@ def autocontrast(img: Tensor) -> Tensor: return F_pil.autocontrast(img) return F_t.autocontrast(img) + + +def equalize(img: Tensor) -> Tensor: + """Equalize the histogram of a PIL Image or torch Tensor by applying + a non-linear mapping to the input in order to create a uniform + distribution of grayscale values in the output. + + Args: + img (PIL Image or Tensor): Image on which equalize is applied. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image or Tensor: An image that was equalized. + """ + if not isinstance(img, torch.Tensor): + return F_pil.equalize(img) + + return F_t.equalize(img) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 14f91713aca..26f3b504d99 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -644,3 +644,10 @@ def autocontrast(img): if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return ImageOps.autocontrast(img) + + +@torch.jit.unused +def equalize(img): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.equalize(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 8a0432fa456..47437385828 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1284,3 +1284,41 @@ def autocontrast(img: Tensor) -> Tensor: scale = bound / (maximum - minimum) return ((img.to(dtype) - minimum) * scale).clamp(0, bound).to(img.dtype) + + +def _scale_channel(img_chan): + hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) + + nonzero_hist = hist[hist != 0] + if nonzero_hist.numel() > 0: + step = (nonzero_hist.sum() - nonzero_hist[-1]) // 255 + else: + step = torch.tensor(0, device=img_chan.device) + if step == 0: + return img_chan + + lut = (torch.cumsum(hist, 0) + (step // 2)) // step + lut = torch.cat([torch.zeros(1, device=img_chan.device), lut[:-1]]).clamp(0, 255) + + return lut[img_chan.to(torch.int64)].to(torch.uint8) + + +def _equalize_single_image(img: Tensor) -> Tensor: + return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))]) + + +def equalize(img: Tensor) -> Tensor: + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if not (3 <= img.ndim <= 4): + raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim)) + if img.dtype != torch.uint8: + raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype)) + + _assert_channels(img, [1, 3]) + + if img.ndim == 3: + return _equalize_single_image(img) + + return torch.stack([_equalize_single_image(x) for x in img]) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 963806b9962..f4416b36acd 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -22,7 +22,7 @@ "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAutocontrast"] + "RandomSolarize", "RandomAutocontrast", "RandomEqualize"] class Compose: @@ -1876,3 +1876,43 @@ def forward(self, img): def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomEqualize(torch.nn.Module): + """Equalize the histogram of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions. + + Args: + p (float): probability of the image being equalized. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + @staticmethod + def get_params() -> float: + """Choose a value for the random transformation. + + Returns: + float: Random value which is used to determine whether the random transformation + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be equalized. + + Returns: + PIL Image or Tensor: Randomly equalized image. + """ + if self.get_params() < self.p: + return F.equalize(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p)