From d9c96e65240360345e4edffe244079f653eb7591 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 30 Nov 2020 20:08:39 +0000 Subject: [PATCH 1/3] Adding invert operator. --- test/test_functional_tensor.py | 15 ++++++++ test/test_transforms.py | 32 ++++++++++++++++ test/test_transforms_tensor.py | 3 ++ torchvision/transforms/functional.py | 18 +++++++++ torchvision/transforms/functional_pil.py | 20 ++++++++++ torchvision/transforms/functional_tensor.py | 28 ++++++++++++++ torchvision/transforms/transforms.py | 42 ++++++++++++++++++++- 7 files changed, 157 insertions(+), 1 deletion(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 38a565310d0..6bea112c3b4 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -857,6 +857,21 @@ def test_gaussian_blur(self): msg="{}, {}".format(ksize, sigma) ) + def test_invert(self): + script_invert = torch.jit.script(F.invert) + + img_tensor, pil_img = self._create_data(16, 18, device=self.device) + inverted_img = F.invert(img_tensor) + inverted_pil_img = F.invert(pil_img) + self.compareTensorToPIL(inverted_img, inverted_pil_img) + + # scriptable function test + inverted_img_script = script_invert(img_tensor) + self.assertTrue(inverted_img.equal(inverted_img_script)) + + batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) + self._test_fn_on_batch(batch_tensors, F.invert) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/test/test_transforms.py b/test/test_transforms.py index 30749772d6a..d6b8f48959c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1749,6 +1749,38 @@ def test_gaussian_blur_asserts(self): with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"): transforms.GaussianBlur(3, "sigma_string") + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_invert(self): + random_state = random.getstate() + random.seed(42) + img = transforms.ToPILImage()(torch.rand(3, 10, 10)) + inv_img = F.invert(img) + + num_samples = 250 + num_inverts = 0 + for _ in range(num_samples): + out = transforms.RandomInvert()(img) + if out == inv_img: + num_inverts += 1 + + p_value = stats.binom_test(num_inverts, num_samples, p=0.5) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + num_samples = 250 + num_inverts = 0 + for _ in range(num_samples): + out = transforms.RandomInvert(p=0.7)(img) + if out == inv_img: + num_inverts += 1 + + p_value = stats.binom_test(num_inverts, num_samples, p=0.7) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + # Checking if RandomInvert can be printed as string + transforms.RandomInvert().__repr__() + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index c5c4a7f09e0..30606b1548b 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -89,6 +89,9 @@ def test_random_horizontal_flip(self): def test_random_vertical_flip(self): self._test_op('vflip', 'RandomVerticalFlip') + def test_random_invert(self): + self._test_op('invert', 'RandomInvert') + def test_color_jitter(self): tol = 1.0 + 1e-10 diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 3f6548357d5..0ebdf967d13 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1174,3 +1174,21 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa if not isinstance(img, torch.Tensor): output = to_pil_image(output) return output + + +def invert(img: Tensor) -> Tensor: + """Invert the colors of a PIL Image or torch Tensor. + + Args: + img (PIL Image or Tensor): Image to have its colors inverted. + If img is a Tensor, it is expected to be in [..., H, W] format, + where ... means it can have an arbitrary number of trailing + dimensions. + + Returns: + PIL Image: Color inverted image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.invert(img) + + return F_t.invert(img) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 7e3989f0288..f04b10be201 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -603,3 +603,23 @@ def to_grayscale(img, num_output_channels): raise ValueError('num_output_channels should be either 1 or 3') return img + + +@torch.jit.unused +def invert(img): + """PRIVATE METHOD. Invert the colors of an image. + + .. warning:: + + Module ``transforms.functional_pil`` is private and should not be used in user application. + Please, consider instead using methods from `transforms.functional` module. + + Args: + img (PIL Image): Image to have its colors inverted. + + Returns: + PIL Image: Color inverted image Tensor. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + return ImageOps.invert(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4f3e72a62ce..3a61b44173c 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1138,3 +1138,31 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) return img + + +def invert(img: Tensor) -> Tensor: + """PRIVATE METHOD. Invert the colors of a grayscale or RGB image. + + .. warning::`` + + Module ``transforms.functional_tensor`` is private and should not be used in user application. + Please, consider instead using methods from `transforms.functional` module. + + Args: + img (Tensor): Image to have its colors inverted in the form [C, H, W]. + + Returns: + Tensor: Color inverted image Tensor. + """ + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + if img.ndim < 3: + raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) + c = img.shape[-3] + if c != 1 and c != 3: + raise TypeError("Input image tensor should 1 or 3 channels, but found {}".format(c)) + + max_val = _max_value(img.dtype) + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + return (max_val - img.to(dtype)).to(img.dtype) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 25886d59a60..15ee09623e9 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -21,7 +21,7 @@ "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode"] + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert"] class Compose: @@ -1677,3 +1677,43 @@ def _setup_angle(x, name, req_sizes=(2, )): _check_sequence_input(x, name, req_sizes) return [float(d) for d in x] + + +class RandomInvert(torch.nn.Module): + """Inverts the colors of the given image randomly with a given probability. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + + Args: + p (float): probability of the image being color inverted. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__() + self.p = p + + @staticmethod + def get_params() -> float: + """Choose value for random color inversion. + + Returns: + float: Random value which is used to determine whether the random color inversion + should occur. + """ + return torch.rand(1).item() + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be inverted. + + Returns: + PIL Image or Tensor: Randomly color inverted image. + """ + if self.get_params() < self.p: + return F.invert(img) + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) From 72daf4d8ce08c6d9933070fa829dd28fc1178f4d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Dec 2020 18:28:07 +0000 Subject: [PATCH 2/3] Make use of the _assert_channels(). --- torchvision/transforms/functional_tensor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 81354c0ea18..a4c9935ab2a 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1200,9 +1200,8 @@ def invert(img: Tensor) -> Tensor: if img.ndim < 3: raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) - c = img.shape[-3] - if c != 1 and c != 3: - raise TypeError("Input image tensor should 1 or 3 channels, but found {}".format(c)) + + _assert_channels(img, [1, 3]) max_val = _max_value(img.dtype) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 From 60fe2cd836be611043ed3210bf431cac5f59eab0 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Dec 2020 11:08:39 +0000 Subject: [PATCH 3/3] Update upper bound value. --- torchvision/transforms/functional_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a4c9935ab2a..ce899efbabf 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1203,6 +1203,6 @@ def invert(img: Tensor) -> Tensor: _assert_channels(img, [1, 3]) - max_val = _max_value(img.dtype) + bound = 1.0 if img.is_floating_point() else 255.0 dtype = img.dtype if torch.is_floating_point(img) else torch.float32 - return (max_val - img.to(dtype)).to(img.dtype) + return (bound - img.to(dtype)).to(img.dtype)