Skip to content

Commit 17f3c25

Browse files
committed
Implement the solarize transform.
1 parent 4b800b9 commit 17f3c25

File tree

7 files changed

+128
-7
lines changed

7 files changed

+128
-7
lines changed

test/test_functional_tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,26 @@ def test_posterize(self):
884884
dts=(None,)
885885
)
886886

887+
def test_solarize(self):
888+
self._test_adjust_fn(
889+
F.solarize,
890+
F_pil.solarize,
891+
F_t.solarize,
892+
[{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
893+
tol=1.0,
894+
agg_method="max",
895+
dts=(None,)
896+
)
897+
self._test_adjust_fn(
898+
F.solarize,
899+
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
900+
F_t.solarize,
901+
[{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
902+
tol=1.0,
903+
agg_method="max",
904+
dts=(torch.float32, torch.float64)
905+
)
906+
887907

888908
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
889909
class CUDATester(Tester):

test/test_transforms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,6 +1787,14 @@ def test_random_posterize(self):
17871787
[{"bits": 4}]
17881788
)
17891789

1790+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1791+
def test_random_solarize(self):
1792+
self._test_randomness(
1793+
F.solarize,
1794+
transforms.RandomSolarize,
1795+
[{"threshold": 192}]
1796+
)
1797+
17901798

17911799
if __name__ == '__main__':
17921800
unittest.main()

test/test_transforms_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def test_random_posterize(self):
9898
'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
9999
)
100100

101+
def test_random_solarize(self):
102+
fn_kwargs = meth_kwargs = {"threshold": 192.0}
103+
self._test_op(
104+
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
105+
)
106+
101107
def test_color_jitter(self):
102108

103109
tol = 1.0 + 1e-10

torchvision/transforms/functional.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,9 +1203,9 @@ def posterize(img: Tensor, bits: int) -> Tensor:
12031203
12041204
Args:
12051205
img (PIL Image or Tensor): Image to have its colors inverted.
1206-
If img is a Tensor, it is expected to be in [..., H, W] format,
1207-
where ... means it can have an arbitrary number of trailing
1208-
dimensions.
1206+
If img is a Tensor, it should be of type torch.uint8 and
1207+
it is expected to be in [..., H, W] format, where ... means
1208+
it can have an arbitrary number of trailing dimensions.
12091209
bits (int): The number of bits to keep for each channel (0-8).
12101210
Returns:
12111211
PIL Image: Posterized image.
@@ -1217,3 +1217,21 @@ def posterize(img: Tensor, bits: int) -> Tensor:
12171217
return F_pil.posterize(img, bits)
12181218

12191219
return F_t.posterize(img, bits)
1220+
1221+
1222+
def solarize(img: Tensor, threshold: float) -> Tensor:
1223+
"""Solarize a PIL Image or torch Tensor by inverting all pixel values above a threshold.
1224+
1225+
Args:
1226+
img (PIL Image or Tensor): Image to have its colors inverted.
1227+
If img is a Tensor, it is expected to be in [..., H, W] format,
1228+
where ... means it can have an arbitrary number of trailing
1229+
dimensions.
1230+
threshold (float): All pixels equal or above this value are inverted.
1231+
Returns:
1232+
PIL Image: Solarized image.
1233+
"""
1234+
if not isinstance(img, torch.Tensor):
1235+
return F_pil.solarize(img, threshold)
1236+
1237+
return F_t.solarize(img, threshold)

torchvision/transforms/functional_pil.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,3 +620,10 @@ def posterize(img, bits):
620620
if not _is_pil_image(img):
621621
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
622622
return ImageOps.posterize(img, bits)
623+
624+
625+
@torch.jit.unused
626+
def solarize(img, threshold):
627+
if not _is_pil_image(img):
628+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
629+
return ImageOps.solarize(img, threshold)

torchvision/transforms/functional_tensor.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ def invert(img: Tensor) -> Tensor:
11921192

11931193
bound = 1.0 if img.is_floating_point() else 255.0
11941194
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
1195-
return (bound - img.to(dtype)).to(img.dtype)
1195+
return (bound - img.to(dtype)).clamp(0, bound).to(img.dtype)
11961196

11971197

11981198
def posterize(img: Tensor, bits: int) -> Tensor:
@@ -1207,3 +1207,22 @@ def posterize(img: Tensor, bits: int) -> Tensor:
12071207
_assert_channels(img, [1, 3])
12081208
mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
12091209
return img & mask
1210+
1211+
1212+
def solarize(img: Tensor, threshold: float) -> Tensor:
1213+
if not _is_tensor_a_torch_image(img):
1214+
raise TypeError('tensor is not a torch image.')
1215+
1216+
if img.ndim < 3:
1217+
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
1218+
1219+
_assert_channels(img, [1, 3])
1220+
1221+
bound = 1.0 if img.is_floating_point() else 255.0
1222+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
1223+
1224+
result = img.clone().view(-1)
1225+
invert_idx = torch.where(result >= threshold)[0]
1226+
result[invert_idx] = (bound - result[invert_idx].to(dtype=dtype)).clamp(0, bound).to(dtype=img.dtype)
1227+
1228+
return result.view(img.shape)

torchvision/transforms/transforms.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
2222
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
2323
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
24-
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize"]
24+
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
25+
"RandomSolarize"]
2526

2627

2728
class Compose:
@@ -1705,7 +1706,7 @@ class RandomInvert(torch.nn.Module):
17051706
"""Inverts the colors of the given image randomly with a given probability.
17061707
The image can be a PIL Image or a torch Tensor, in which case it is expected
17071708
to have [..., H, W] shape, where ... means an arbitrary number of leading
1708-
dimensions
1709+
dimensions.
17091710
17101711
Args:
17111712
p (float): probability of the image being color inverted. Default value is 0.5
@@ -1745,7 +1746,7 @@ class RandomPosterize(torch.nn.Module):
17451746
"""Posterize the image randomly with a given probability by reducing the
17461747
number of bits for each color channel. The image can be a PIL Image or a torch
17471748
Tensor, in which case it is expected to have [..., H, W] shape, where ... means
1748-
an arbitrary number of leading dimensions
1749+
an arbitrary number of leading dimensions.
17491750
17501751
Args:
17511752
bits (int): number of bits to keep for each channel (0-8)
@@ -1781,3 +1782,45 @@ def forward(self, img):
17811782

17821783
def __repr__(self):
17831784
return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p)
1785+
1786+
1787+
class RandomSolarize(torch.nn.Module):
1788+
"""Solarize the image randomly with a given probability by inverting all pixel
1789+
values above a threshold. The image can be a PIL Image or a torch Tensor, in
1790+
which case it is expected to have [..., H, W] shape, where ... means an arbitrary
1791+
number of leading dimensions.
1792+
1793+
Args:
1794+
threshold (float): all pixels equal or above this value are inverted.
1795+
p (float): probability of the image being color inverted. Default value is 0.5
1796+
"""
1797+
1798+
def __init__(self, threshold, p=0.5):
1799+
super().__init__()
1800+
self.threshold = threshold
1801+
self.p = p
1802+
1803+
@staticmethod
1804+
def get_params() -> float:
1805+
"""Choose a value for the random transformation.
1806+
1807+
Returns:
1808+
float: Random value which is used to determine whether the random transformation
1809+
should occur.
1810+
"""
1811+
return torch.rand(1).item()
1812+
1813+
def forward(self, img):
1814+
"""
1815+
Args:
1816+
img (PIL Image or Tensor): Image to be solarized.
1817+
1818+
Returns:
1819+
PIL Image or Tensor: Randomly solarized image.
1820+
"""
1821+
if self.get_params() < self.p:
1822+
return F.solarize(img, self.threshold)
1823+
return img
1824+
1825+
def __repr__(self):
1826+
return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)

0 commit comments

Comments
 (0)