Skip to content

Commit 94fc573

Browse files
authored
Implement the autocontrast transform. (#3117)
1 parent b4e9a2f commit 94fc573

7 files changed

+114
-4
lines changed

test/test_functional_tensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,16 @@ def test_adjust_sharpness(self):
912912
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
913913
)
914914

915+
def test_autocontrast(self):
916+
self._test_adjust_fn(
917+
F.autocontrast,
918+
F_pil.autocontrast,
919+
F_t.autocontrast,
920+
[{}],
921+
tol=1.0,
922+
agg_method="max"
923+
)
924+
915925

916926
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
917927
class CUDATester(Tester):

test/test_transforms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,6 +1838,14 @@ def test_random_solarize(self):
18381838
[{"threshold": 192}]
18391839
)
18401840

1841+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1842+
def test_random_autocontrast(self):
1843+
self._test_randomness(
1844+
F.autocontrast,
1845+
transforms.RandomAutocontrast,
1846+
[{}]
1847+
)
1848+
18411849

18421850
if __name__ == '__main__':
18431851
unittest.main()

test/test_transforms_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def test_random_solarize(self):
104104
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
105105
)
106106

107+
def test_random_autocontrast(self):
108+
self._test_op('autocontrast', 'RandomAutocontrast')
109+
107110
def test_color_jitter(self):
108111

109112
tol = 1.0 + 1e-10

torchvision/transforms/functional.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,7 +1190,7 @@ def invert(img: Tensor) -> Tensor:
11901190
dimensions.
11911191
11921192
Returns:
1193-
PIL Image: Color inverted image.
1193+
PIL Image or Tensor: Color inverted image.
11941194
"""
11951195
if not isinstance(img, torch.Tensor):
11961196
return F_pil.invert(img)
@@ -1208,7 +1208,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
12081208
it can have an arbitrary number of trailing dimensions.
12091209
bits (int): The number of bits to keep for each channel (0-8).
12101210
Returns:
1211-
PIL Image: Posterized image.
1211+
PIL Image or Tensor: Posterized image.
12121212
"""
12131213
if not (0 <= bits <= 8):
12141214
raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits))
@@ -1229,7 +1229,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
12291229
dimensions.
12301230
threshold (float): All pixels equal or above this value are inverted.
12311231
Returns:
1232-
PIL Image: Solarized image.
1232+
PIL Image or Tensor: Solarized image.
12331233
"""
12341234
if not isinstance(img, torch.Tensor):
12351235
return F_pil.solarize(img, threshold)
@@ -1253,3 +1253,23 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
12531253
return F_pil.adjust_sharpness(img, sharpness_factor)
12541254

12551255
return F_t.adjust_sharpness(img, sharpness_factor)
1256+
1257+
1258+
def autocontrast(img: Tensor) -> Tensor:
1259+
"""Maximize contrast of a PIL Image or torch Tensor by remapping its
1260+
pixels per channel so that the lowest becomes black and the lightest
1261+
becomes white.
1262+
1263+
Args:
1264+
img (PIL Image or Tensor): Image on which autocontrast is applied.
1265+
If img is a Tensor, it is expected to be in [..., H, W] format,
1266+
where ... means it can have an arbitrary number of trailing
1267+
dimensions.
1268+
1269+
Returns:
1270+
PIL Image or Tensor: An image that was autocontrasted.
1271+
"""
1272+
if not isinstance(img, torch.Tensor):
1273+
return F_pil.autocontrast(img)
1274+
1275+
return F_t.autocontrast(img)

torchvision/transforms/functional_pil.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,3 +637,10 @@ def adjust_sharpness(img, sharpness_factor):
637637
enhancer = ImageEnhance.Sharpness(img)
638638
img = enhancer.enhance(sharpness_factor)
639639
return img
640+
641+
642+
@torch.jit.unused
643+
def autocontrast(img):
644+
if not _is_pil_image(img):
645+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
646+
return ImageOps.autocontrast(img)

torchvision/transforms/functional_tensor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,3 +1262,25 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
12621262
return img
12631263

12641264
return _blend(img, _blur_image(img), sharpness_factor)
1265+
1266+
1267+
def autocontrast(img: Tensor) -> Tensor:
1268+
if not _is_tensor_a_torch_image(img):
1269+
raise TypeError('tensor is not a torch image.')
1270+
1271+
if img.ndim < 3:
1272+
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
1273+
1274+
_assert_channels(img, [1, 3])
1275+
1276+
bound = 1.0 if img.is_floating_point() else 255.0
1277+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
1278+
1279+
minimum = img.amin(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype)
1280+
maximum = img.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1).to(dtype)
1281+
eq_idxs = torch.where(minimum == maximum)[0]
1282+
minimum[eq_idxs] = 0
1283+
maximum[eq_idxs] = bound
1284+
scale = bound / (maximum - minimum)
1285+
1286+
return ((img.to(dtype) - minimum) * scale).clamp(0, bound).to(img.dtype)

torchvision/transforms/transforms.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
2323
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
2424
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
25-
"RandomSolarize"]
25+
"RandomSolarize", "RandomAutocontrast"]
2626

2727

2828
class Compose:
@@ -1836,3 +1836,43 @@ def forward(self, img):
18361836

18371837
def __repr__(self):
18381838
return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)
1839+
1840+
1841+
class RandomAutocontrast(torch.nn.Module):
1842+
"""Autocontrast the pixels of the given image randomly with a given probability.
1843+
The image can be a PIL Image or a torch Tensor, in which case it is expected
1844+
to have [..., H, W] shape, where ... means an arbitrary number of leading
1845+
dimensions.
1846+
1847+
Args:
1848+
p (float): probability of the image being autocontrasted. Default value is 0.5
1849+
"""
1850+
1851+
def __init__(self, p=0.5):
1852+
super().__init__()
1853+
self.p = p
1854+
1855+
@staticmethod
1856+
def get_params() -> float:
1857+
"""Choose a value for the random transformation.
1858+
1859+
Returns:
1860+
float: Random value which is used to determine whether the random transformation
1861+
should occur.
1862+
"""
1863+
return torch.rand(1).item()
1864+
1865+
def forward(self, img):
1866+
"""
1867+
Args:
1868+
img (PIL Image or Tensor): Image to be autocontrasted.
1869+
1870+
Returns:
1871+
PIL Image or Tensor: Randomly autocontrasted image.
1872+
"""
1873+
if self.get_params() < self.p:
1874+
return F.autocontrast(img)
1875+
return img
1876+
1877+
def __repr__(self):
1878+
return self.__class__.__name__ + '(p={})'.format(self.p)

0 commit comments

Comments
 (0)