diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 903798d19e3..e5e547c20d7 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,3 +1,4 @@ +from __future__ import division import torch import torchvision.transforms as transforms import torchvision.transforms.functional_tensor as F_t @@ -36,6 +37,37 @@ def test_crop(self): self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)), "functional_tensor crop not working") + def test_adjustments(self): + fns = ((F.adjust_brightness, F_t.adjust_brightness), + (F.adjust_contrast, F_t.adjust_contrast), + (F.adjust_saturation, F_t.adjust_saturation)) + + for _ in range(20): + channels = 3 + dims = torch.randint(1, 50, (2,)) + shape = (channels, dims[0], dims[1]) + + if torch.randint(0, 2, (1,)) == 0: + img = torch.rand(*shape, dtype=torch.float) + else: + img = torch.randint(0, 256, shape, dtype=torch.uint8) + + factor = 3 * torch.rand(1) + for f, ft in fns: + + ft_img = ft(img, factor) + if not img.dtype.is_floating_point: + ft_img = ft_img.to(torch.float) / 255 + + img_pil = transforms.ToPILImage()(img) + f_img_pil = f(img_pil, factor) + f_img = transforms.ToTensor()(f_img_pil) + + # F uses uint8 and F_t uses float, so there is a small + # difference in values caused by (at most 5) truncations. + max_diff = (ft_img - f_img).abs().max() + self.assertLess(max_diff, 5 / 255 + 1e-5) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 5027958164c..7ef83c1086b 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -48,3 +48,69 @@ def crop(img, top, left, height, width): raise TypeError('tensor is not a torch image.') return img[..., top:top + height, left:left + width] + + +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an RGB image. + + Args: + img (Tensor): Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + Tensor: Brightness adjusted image. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + return _blend(img, 0, brightness_factor) + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an RGB image. + + Args: + img (Tensor): Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + Tensor: Contrast adjusted image. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + mean = torch.mean(_rgb_to_grayscale(img).to(torch.float)) + + return _blend(img, mean, contrast_factor) + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an RGB image. + + Args: + img (Tensor): Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + Tensor: Saturation adjusted image. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + return _blend(img, _rgb_to_grayscale(img), saturation_factor) + + +def _blend(img1, img2, ratio): + bound = 1 if img1.dtype.is_floating_point else 255 + return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) + + +def _rgb_to_grayscale(img): + # ITU-R 601-2 luma transform, as used in PIL. + return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)