Skip to content

Commit e79cadd

Browse files
pedrofreirefmassa
authored andcommitted
Add adjustment operations for RGB Tensor Images. (#1525)
* Add adjustment operations for RGB Tensor Images. Right now, we have operations on PIL images, but we want to have a version of the opeartions that act directly on Tensor images. Here, we add such operations for adjust_brightness, adjust_contrast and adjust_saturation. In PIL, those functions are implemented by generating an degenerate image from the first, and then interpolating them together. - https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageEnhance.py - https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Blend.c A few caveats: * Since PIL operates on uint8, and the tensor operations might be on float, we can get slightly different values because of int truncation. * We assume here the images are RGB; in particular, to handle an alpha channel, we need to check whether it is present, in which case we copy it to the final image. * Keep dtype and use broadcast in adjust operations - We make our operations have input.dtype == output.dtype, at the cost of adding a few type checks and branches. - By using Tensor broadcast, we can simplify the calls to _blend. * Use is_floating_point to check dtype. * Remove unpacking in tuple It seems Python 2 does not support this type of unpacking, so it broke Python 2 builds. This should fix it. * Add from __future__ import division for Python 2
1 parent 9e27356 commit e79cadd

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

test/test_functional_tensor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import division
12
import torch
23
import torchvision.transforms as transforms
34
import torchvision.transforms.functional_tensor as F_t
@@ -36,6 +37,37 @@ def test_crop(self):
3637
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
3738
"functional_tensor crop not working")
3839

40+
def test_adjustments(self):
41+
fns = ((F.adjust_brightness, F_t.adjust_brightness),
42+
(F.adjust_contrast, F_t.adjust_contrast),
43+
(F.adjust_saturation, F_t.adjust_saturation))
44+
45+
for _ in range(20):
46+
channels = 3
47+
dims = torch.randint(1, 50, (2,))
48+
shape = (channels, dims[0], dims[1])
49+
50+
if torch.randint(0, 2, (1,)) == 0:
51+
img = torch.rand(*shape, dtype=torch.float)
52+
else:
53+
img = torch.randint(0, 256, shape, dtype=torch.uint8)
54+
55+
factor = 3 * torch.rand(1)
56+
for f, ft in fns:
57+
58+
ft_img = ft(img, factor)
59+
if not img.dtype.is_floating_point:
60+
ft_img = ft_img.to(torch.float) / 255
61+
62+
img_pil = transforms.ToPILImage()(img)
63+
f_img_pil = f(img_pil, factor)
64+
f_img = transforms.ToTensor()(f_img_pil)
65+
66+
# F uses uint8 and F_t uses float, so there is a small
67+
# difference in values caused by (at most 5) truncations.
68+
max_diff = (ft_img - f_img).abs().max()
69+
self.assertLess(max_diff, 5 / 255 + 1e-5)
70+
3971

4072
if __name__ == '__main__':
4173
unittest.main()

torchvision/transforms/functional_tensor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,69 @@ def crop(img, top, left, height, width):
4848
raise TypeError('tensor is not a torch image.')
4949

5050
return img[..., top:top + height, left:left + width]
51+
52+
53+
def adjust_brightness(img, brightness_factor):
54+
"""Adjust brightness of an RGB image.
55+
56+
Args:
57+
img (Tensor): Image to be adjusted.
58+
brightness_factor (float): How much to adjust the brightness. Can be
59+
any non negative number. 0 gives a black image, 1 gives the
60+
original image while 2 increases the brightness by a factor of 2.
61+
62+
Returns:
63+
Tensor: Brightness adjusted image.
64+
"""
65+
if not F._is_tensor_image(img):
66+
raise TypeError('tensor is not a torch image.')
67+
68+
return _blend(img, 0, brightness_factor)
69+
70+
71+
def adjust_contrast(img, contrast_factor):
72+
"""Adjust contrast of an RGB image.
73+
74+
Args:
75+
img (Tensor): Image to be adjusted.
76+
contrast_factor (float): How much to adjust the contrast. Can be any
77+
non negative number. 0 gives a solid gray image, 1 gives the
78+
original image while 2 increases the contrast by a factor of 2.
79+
80+
Returns:
81+
Tensor: Contrast adjusted image.
82+
"""
83+
if not F._is_tensor_image(img):
84+
raise TypeError('tensor is not a torch image.')
85+
86+
mean = torch.mean(_rgb_to_grayscale(img).to(torch.float))
87+
88+
return _blend(img, mean, contrast_factor)
89+
90+
91+
def adjust_saturation(img, saturation_factor):
92+
"""Adjust color saturation of an RGB image.
93+
94+
Args:
95+
img (Tensor): Image to be adjusted.
96+
saturation_factor (float): How much to adjust the saturation. 0 will
97+
give a black and white image, 1 will give the original image while
98+
2 will enhance the saturation by a factor of 2.
99+
100+
Returns:
101+
Tensor: Saturation adjusted image.
102+
"""
103+
if not F._is_tensor_image(img):
104+
raise TypeError('tensor is not a torch image.')
105+
106+
return _blend(img, _rgb_to_grayscale(img), saturation_factor)
107+
108+
109+
def _blend(img1, img2, ratio):
110+
bound = 1 if img1.dtype.is_floating_point else 255
111+
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
112+
113+
114+
def _rgb_to_grayscale(img):
115+
# ITU-R 601-2 luma transform, as used in PIL.
116+
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)

0 commit comments

Comments
 (0)