Skip to content

Commit 51f149d

Browse files
committed
Implement the equalize transform.
1 parent 94fc573 commit 51f149d

7 files changed

+126
-2
lines changed

test/test_functional_tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,15 @@ def test_autocontrast(self):
922922
agg_method="max"
923923
)
924924

925+
def test_equalize(self):
926+
self._test_adjust_fn(
927+
F.equalize,
928+
F_pil.equalize,
929+
F_t.equalize,
930+
[{}],
931+
dts=(None,)
932+
)
933+
925934

926935
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
927936
class CUDATester(Tester):

test/test_transforms.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1795,7 +1795,7 @@ def test_gaussian_blur_asserts(self):
17951795
def _test_randomness(self, fn, trans, configs):
17961796
random_state = random.getstate()
17971797
random.seed(42)
1798-
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
1798+
img = transforms.ToPILImage()(torch.rand(3, 16, 18))
17991799

18001800
for p in [0.5, 0.7]:
18011801
for config in configs:
@@ -1846,6 +1846,14 @@ def test_random_autocontrast(self):
18461846
[{}]
18471847
)
18481848

1849+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1850+
def test_random_equalize(self):
1851+
self._test_randomness(
1852+
F.equalize,
1853+
transforms.RandomEqualize,
1854+
[{}]
1855+
)
1856+
18491857

18501858
if __name__ == '__main__':
18511859
unittest.main()

test/test_transforms_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def test_random_solarize(self):
107107
def test_random_autocontrast(self):
108108
self._test_op('autocontrast', 'RandomAutocontrast')
109109

110+
def test_random_equalize(self):
111+
self._test_op('equalize', 'RandomEqualize')
112+
110113
def test_color_jitter(self):
111114

112115
tol = 1.0 + 1e-10

torchvision/transforms/functional.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,3 +1273,23 @@ def autocontrast(img: Tensor) -> Tensor:
12731273
return F_pil.autocontrast(img)
12741274

12751275
return F_t.autocontrast(img)
1276+
1277+
1278+
def equalize(img: Tensor) -> Tensor:
1279+
"""Equalize the histogram of a PIL Image or torch Tensor by applying
1280+
a non-linear mapping to the input in order to create a uniform
1281+
distribution of grayscale values in the output.
1282+
1283+
Args:
1284+
img (PIL Image or Tensor): Image on which equalize is applied.
1285+
If img is a Tensor, it is expected to be in [..., H, W] format,
1286+
where ... means it can have an arbitrary number of trailing
1287+
dimensions.
1288+
1289+
Returns:
1290+
PIL Image or Tensor: An image that was equalized.
1291+
"""
1292+
if not isinstance(img, torch.Tensor):
1293+
return F_pil.equalize(img)
1294+
1295+
return F_t.equalize(img)

torchvision/transforms/functional_pil.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,3 +644,10 @@ def autocontrast(img):
644644
if not _is_pil_image(img):
645645
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
646646
return ImageOps.autocontrast(img)
647+
648+
649+
@torch.jit.unused
650+
def equalize(img):
651+
if not _is_pil_image(img):
652+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
653+
return ImageOps.equalize(img)

torchvision/transforms/functional_tensor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,3 +1284,40 @@ def autocontrast(img: Tensor) -> Tensor:
12841284
scale = bound / (maximum - minimum)
12851285

12861286
return ((img.to(dtype) - minimum) * scale).clamp(0, bound).to(img.dtype)
1287+
1288+
1289+
def _scale_channel(img_chan):
1290+
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
1291+
1292+
nonzero_hist = hist[hist != 0]
1293+
if nonzero_hist.numel() > 0:
1294+
step = (nonzero_hist.sum() - nonzero_hist[-1]) // 255
1295+
else:
1296+
step = torch.tensor(0)
1297+
if step == 0:
1298+
return img_chan
1299+
1300+
lut = (torch.cumsum(hist, 0) + (step // 2)) // step
1301+
lut = torch.cat([torch.zeros(1), lut[:-1]]).clamp(0, 255)
1302+
1303+
return lut[img_chan.to(torch.int64)].to(torch.uint8)
1304+
1305+
def _equalize_single_image(img: Tensor) -> Tensor:
1306+
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
1307+
1308+
1309+
def equalize(img: Tensor) -> Tensor:
1310+
if not _is_tensor_a_torch_image(img):
1311+
raise TypeError('tensor is not a torch image.')
1312+
1313+
if not (3 <= img.ndim <= 4):
1314+
raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim))
1315+
if img.dtype != torch.uint8:
1316+
raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype))
1317+
1318+
_assert_channels(img, [1, 3])
1319+
1320+
if img.ndim == 3:
1321+
return _equalize_single_image(img)
1322+
1323+
return torch.stack([_equalize_single_image(x) for x in img])

torchvision/transforms/transforms.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
2323
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
2424
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
25-
"RandomSolarize", "RandomAutocontrast"]
25+
"RandomSolarize", "RandomAutocontrast", "RandomEqualize"]
2626

2727

2828
class Compose:
@@ -1876,3 +1876,43 @@ def forward(self, img):
18761876

18771877
def __repr__(self):
18781878
return self.__class__.__name__ + '(p={})'.format(self.p)
1879+
1880+
1881+
class RandomEqualize(torch.nn.Module):
1882+
"""Equalize the histogram of the given image randomly with a given probability.
1883+
The image can be a PIL Image or a torch Tensor, in which case it is expected
1884+
to have [..., H, W] shape, where ... means an arbitrary number of leading
1885+
dimensions.
1886+
1887+
Args:
1888+
p (float): probability of the image being equalized. Default value is 0.5
1889+
"""
1890+
1891+
def __init__(self, p=0.5):
1892+
super().__init__()
1893+
self.p = p
1894+
1895+
@staticmethod
1896+
def get_params() -> float:
1897+
"""Choose a value for the random transformation.
1898+
1899+
Returns:
1900+
float: Random value which is used to determine whether the random transformation
1901+
should occur.
1902+
"""
1903+
return torch.rand(1).item()
1904+
1905+
def forward(self, img):
1906+
"""
1907+
Args:
1908+
img (PIL Image or Tensor): Image to be equalized.
1909+
1910+
Returns:
1911+
PIL Image or Tensor: Randomly equalized image.
1912+
"""
1913+
if self.get_params() < self.p:
1914+
return F.equalize(img)
1915+
return img
1916+
1917+
def __repr__(self):
1918+
return self.__class__.__name__ + '(p={})'.format(self.p)

0 commit comments

Comments
 (0)