From 9e159349bf5091c23faadee6c8b2e741baa9f04f Mon Sep 17 00:00:00 2001 From: Brian Date: Fri, 10 Jul 2020 16:47:06 -0400 Subject: [PATCH 1/3] add torchscriptable adjust_gamma transform https://github.com/pytorch/vision/issues/1375 --- test/test_functional_tensor.py | 29 ++++++++++++++++ torchvision/transforms/functional.py | 24 ++++--------- torchvision/transforms/functional_pil.py | 37 ++++++++++++++++++++ torchvision/transforms/functional_tensor.py | 38 +++++++++++++++++++++ 4 files changed, 111 insertions(+), 17 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index d23930e7313..48604ec287b 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -23,6 +23,8 @@ def _create_data(self, height=3, width=3, channels=3): def compareTensorToPIL(self, tensor, pil_image, msg=None): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) + if msg is None: + msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor) self.assertTrue(tensor.equal(pil_tensor), msg) def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None): @@ -293,6 +295,33 @@ def test_pad(self): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): F_t.pad(tensor, (-2, -3), padding_mode="symmetric") + def test_adjust_gamma(self): + script_fn = torch.jit.script(F_t.adjust_gamma) + tensor, pil_img = self._create_data(26, 36) + + for dt in [torch.float64, torch.float32, None]: + + if dt is not None: + tensor = F.convert_image_dtype(tensor, dt) + + gammas = [0.8, 1.0, 1.2] + gains = [0.7, 1.0, 1.3] + for gamma, gain in zip(gammas, gains): + + adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain) + adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain) + scripted_result = script_fn(tensor, gamma, gain) + self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype) + self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1]) + + rbg_tensor = adjusted_tensor + if adjusted_tensor.dtype != torch.uint8: + rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8) + + self.compareTensorToPIL(rbg_tensor, adjusted_pil) + + self.assertTrue(adjusted_tensor.equal(scripted_result)) + def test_resize(self): script_fn = torch.jit.script(F_t.resize) tensor, pil_img = self._create_data(26, 36) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 801df42a187..59c5129af16 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -161,8 +161,8 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." raise RuntimeError(msg) - eps = 1e-3 - return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype) + max = torch.iinfo(dtype).max + return image.mul(torch.iinfo(dtype).max).clamp(0, max).to(dtype) else: # int to float if dtype.is_floating_point: @@ -760,7 +760,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: raise TypeError('img should be PIL Image. Got {}'.format(type(img))) -def adjust_gamma(img, gamma, gain=1): +def adjust_gamma(img, gamma: float, gain: float = 1): r"""Perform gamma correction on an image. Also known as Power Law Transform. Intensities in RGB mode are adjusted @@ -774,26 +774,16 @@ def adjust_gamma(img, gamma, gain=1): .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction Args: - img (PIL Image): PIL Image to be adjusted. + img (PIL Image or Tensor): PIL Image to be adjusted. gamma (float): Non negative real number, same as :math:`\gamma` in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. gain (float): The constant multiplier. """ - if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - if gamma < 0: - raise ValueError('Gamma should be a non-negative real number') + if F_pil._is_pil_image(img): + return F_pil.adjust_gamma(img, gamma, gain) - input_mode = img.mode - img = img.convert('RGB') - - gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 - img = img.point(gamma_map) # use PIL's point-function to accelerate this part - - img = img.convert(input_mode) - return img + return F_t.adjust_gamma(img, gamma, gain) def rotate(img, angle, resample=False, expand=False, center=None, fill=None): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 994988ce1f6..28c8584dc57 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -164,6 +164,43 @@ def adjust_hue(img, hue_factor): return img +@torch.jit.unused +def adjust_gamma(img, gamma, gain=1): + r"""Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + .. math:: + I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} + + See `Gamma Correction`_ for more details. + + .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction + + Args: + img (PIL Image): PIL Image to be adjusted. + gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma larger than 1 make the shadows darker, + while gamma smaller than 1 make dark regions lighter. + gain (float): The constant multiplier. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + input_mode = img.mode + img = img.convert('RGB') + + gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 + img = img.point(gamma_map) # use PIL's point-function to accelerate this part + + img = img.convert(input_mode) + return img + + @torch.jit.unused def pad(img, padding, fill=0, padding_mode="constant"): r"""Pad the given PIL.Image on all sides with the given "pad" value. diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 59cf6bc2764..ad56bd891bd 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -194,6 +194,44 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: return _blend(img, rgb_to_grayscale(img), saturation_factor) +def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: + r"""Adjust gamma of an RGB image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + .. math:: + `I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}` + + See `Gamma Correction`_ for more details. + + .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction + + Args: + img (Tensor): PIL Image to be adjusted. + gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma larger than 1 make the shadows darker, + while gamma smaller than 1 make dark regions lighter. + gain (float): The constant multiplier. + """ + + if not isinstance(img, torch.Tensor): + raise TypeError('img should be a Tensor. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + result = img + dtype = img.dtype + if torch.is_floating_point(img): + return gain * result ** gamma + + result = 255.0 * gain * (result / 255.0) ** gamma + # PIL clamps, to(torch.uint8) would wrap + result = result.clamp(0, 255).to(dtype) + return result + + def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: """Crop the Image Tensor and resize it to desired size. From 67e70d8b88e855b5a4d1383243ce58fa7829ec2c Mon Sep 17 00:00:00 2001 From: Brian Date: Wed, 15 Jul 2020 14:47:58 -0400 Subject: [PATCH 2/3] changes based on code-review --- test/test_transforms.py | 4 ++-- torchvision/transforms/functional.py | 14 +++++++++++--- torchvision/transforms/functional_pil.py | 3 +-- torchvision/transforms/functional_tensor.py | 15 +++++++++------ 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index b0eb844fcf8..19caefcd788 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1179,14 +1179,14 @@ def test_adjust_gamma(self): # test 1 y_pil = F.adjust_gamma(x_pil, 0.5) y_np = np.array(y_pil) - y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15] + y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) self.assertTrue(np.allclose(y_np, y_ans)) # test 2 y_pil = F.adjust_gamma(x_pil, 2) y_np = np.array(y_pil) - y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0] + y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) self.assertTrue(np.allclose(y_np, y_ans)) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 59c5129af16..9ff35e26041 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -161,8 +161,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." raise RuntimeError(msg) - max = torch.iinfo(dtype).max - return image.mul(torch.iinfo(dtype).max).clamp(0, max).to(dtype) + # https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 + # For data in the range 0-1, (float * 255).to(uint) is only 255 + # when float is exactly 1.0. + # `max + 1 - epsilon` provides more evenly distributed mapping of + # ranges of floats to ints. + eps = 1e-3 + result = image.mul(torch.iinfo(dtype).max + 1 - eps) + return result.to(dtype) else: # int to float if dtype.is_floating_point: @@ -779,8 +785,10 @@ def adjust_gamma(img, gamma: float, gain: float = 1): gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. gain (float): The constant multiplier. + Returns: + PIL Image or Tensor: Gamma correction adjusted image. """ - if F_pil._is_pil_image(img): + if not isinstance(img, torch.Tensor): return F_pil.adjust_gamma(img, gamma, gain) return F_t.adjust_gamma(img, gamma, gain) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 28c8584dc57..753497f9b2d 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -193,8 +193,7 @@ def adjust_gamma(img, gamma, gain=1): input_mode = img.mode img = img.convert('RGB') - - gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 + gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 img = img.point(gamma_map) # use PIL's point-function to accelerate this part img = img.convert(input_mode) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ad56bd891bd..b446fe37567 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -208,7 +208,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction Args: - img (Tensor): PIL Image to be adjusted. + img (Tensor): Tensor of RBG values to be adjusted. gamma (float): Non negative real number, same as :math:`\gamma` in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. @@ -223,12 +223,15 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: result = img dtype = img.dtype - if torch.is_floating_point(img): - return gain * result ** gamma + if not torch.is_floating_point(img): + result = result / 255.0 - result = 255.0 * gain * (result / 255.0) ** gamma - # PIL clamps, to(torch.uint8) would wrap - result = result.clamp(0, 255).to(dtype) + result = (gain * result ** gamma).clamp(0, 1) + + if result.dtype != dtype: + eps = 1e-3 + result = (255 + 1.0 - eps) * result + result = result.to(dtype) return result From 187eebce0cd43cd21c3f02cd30e748fe7399041a Mon Sep 17 00:00:00 2001 From: Brian Vaughan Date: Thu, 16 Jul 2020 11:20:15 -0400 Subject: [PATCH 3/3] Apply suggested change to add type hint Required by mypy, even thought technically incorrect due to possible Image parameter. torchscript doesn't support a union based type hint. Co-authored-by: vfdev --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 9ff35e26041..659ea88b84a 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -766,7 +766,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: raise TypeError('img should be PIL Image. Got {}'.format(type(img))) -def adjust_gamma(img, gamma: float, gain: float = 1): +def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: r"""Perform gamma correction on an image. Also known as Power Law Transform. Intensities in RGB mode are adjusted