From 9a12648b11027aac075d5c35a8b46afa896ec102 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 4 Dec 2020 00:30:17 +0000 Subject: [PATCH 1/3] Adding functional operator for sharpness. --- test/test_functional_tensor.py | 8 +++++ torchvision/transforms/functional.py | 18 +++++++++++ torchvision/transforms/functional_pil.py | 10 +++++++ torchvision/transforms/functional_tensor.py | 33 +++++++++++++++++++++ 4 files changed, 69 insertions(+) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 63e8271a858..715450b5cfd 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -904,6 +904,14 @@ def test_solarize(self): dts=(torch.float32, torch.float64) ) + def test_adjust_sharpness(self): + self._test_adjust_fn( + F.adjust_sharpness, + F_pil.adjust_sharpness, + F_t.adjust_sharpness, + [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]] + ) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index e3b0a9bd98a..8383b08364b 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1235,3 +1235,21 @@ def solarize(img: Tensor, threshold: float) -> Tensor: return F_pil.solarize(img, threshold) return F_t.solarize(img, threshold) + + +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: + """Adjust the sharpness of an Image. + + Args: + img (PIL Image or Tensor): Image to be adjusted. + sharpness_factor (float): How much to adjust the sharpness. Can be + any non negative number. 0 gives a blurred image, 1 gives the + original image while 2 increases the sharpness by a factor of 2. + + Returns: + PIL Image or Tensor: Sharpness adjusted image. + """ + if not isinstance(img, torch.Tensor): + return F_pil.adjust_sharpness(img, sharpness_factor) + + return F_t.adjust_sharpness(img, sharpness_factor) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index d60588fd138..72eafe37a2c 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -627,3 +627,13 @@ def solarize(img, threshold): if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return ImageOps.solarize(img, threshold) + + +@torch.jit.unused +def adjust_sharpness(img, sharpness_factor): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Sharpness(img) + img = enhancer.enhance(sharpness_factor) + return img diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 5eb70988f90..8038e24fb71 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1226,3 +1226,36 @@ def solarize(img: Tensor, threshold: float) -> Tensor: result[invert_idx] = (bound - result[invert_idx].to(dtype=dtype)).clamp(0, bound).to(dtype=img.dtype) return result.view(img.shape) + + +def _blur_image(img: Tensor) -> Tensor: + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + + kernel = torch.ones((3, 3), dtype=dtype, device=img.device) + kernel[1, 1] = 5.0 + kernel /= kernel.sum() + kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + result, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) + result = conv2d(result, kernel, groups=result.shape[-3]) + result = torch_pad(result, [1, 1, 1, 1]) + result = _cast_squeeze_out(result, need_cast, need_squeeze, out_dtype) + + result[..., 0, :] = img[..., 0, :] + result[..., -1, :] = img[..., -1, :] + result[..., :, 0] = img[..., :, 0] + result[..., :, -1] = img[..., :, -1] + + return result + + +def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: + if sharpness_factor < 0: + raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor)) + + if not _is_tensor_a_torch_image(img): + raise TypeError('tensor is not a torch image.') + + _assert_channels(img, [1, 3]) + + return _blend(img, _blur_image(img), sharpness_factor) From 6b829be20153b166a20bf669812d220f78746585 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 4 Dec 2020 01:17:46 +0000 Subject: [PATCH 2/3] Adding transforms for sharpness. --- test/test_transforms.py | 34 +++++++++++++++++++++++++++- test/test_transforms_tensor.py | 8 ++++++- torchvision/transforms/transforms.py | 32 ++++++++++++++++++-------- 3 files changed, 62 insertions(+), 12 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index fc52fc66686..5617a786489 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1232,6 +1232,37 @@ def test_adjust_hue(self): y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) self.assertTrue(np.allclose(y_np, y_ans)) + def test_adjust_sharpness(self): + x_shape = [4, 4, 3] + x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, + 0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105, + 111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = F.adjust_sharpness(x_pil, 1) + y_np = np.array(y_pil) + self.assertTrue(np.allclose(y_np, x_np)) + + # test 1 + y_pil = F.adjust_sharpness(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30, + 30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101, + 107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 2 + y_pil = F.adjust_sharpness(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, + 0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112, + 119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + def test_adjust_gamma(self): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] @@ -1268,10 +1299,11 @@ def test_adjusts_L_mode(self): self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L') self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L') self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L') + self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L') self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L') def test_color_jitter(self): - color_jitter = transforms.ColorJitter(2, 2, 2, 0.1) + color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2) x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 331f8a2eb4f..30c5b885bb8 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -131,8 +131,14 @@ def test_color_jitter(self): "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max" ) + for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]: + meth_kwargs = {"sharpness": f} + self._test_class_op( + "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" + ) + # All 4 parameters together - meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} + meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2} self._test_class_op( "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max" ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 66ccb42e525..2a854933ebd 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1039,7 +1039,7 @@ def __repr__(self): class ColorJitter(torch.nn.Module): - """Randomly change the brightness, contrast and saturation of an image. + """Randomly change the brightness, contrast, saturation, hue and sharpness of an image. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. @@ -1054,15 +1054,19 @@ class ColorJitter(torch.nn.Module): hue (float or tuple of float (min, max)): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + sharpness (float or tuple of float (min, max)): How much to jitter sharpness. + sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness] + or the given [min, max]. Should be non negative numbers. """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0): super().__init__() self.brightness = self._check_input(brightness, 'brightness') self.contrast = self._check_input(contrast, 'contrast') self.saturation = self._check_input(saturation, 'saturation') self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + self.sharpness = self._check_input(sharpness, 'sharpness') @torch.jit.unused def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): @@ -1078,7 +1082,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs else: raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) - # if value is 0 or (1., 1.) for brightness/contrast/saturation + # if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness # or (0., 0.) for hue, do nothing if value[0] == value[1] == center: value = None @@ -1088,8 +1092,10 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs def get_params(brightness: Optional[List[float]], contrast: Optional[List[float]], saturation: Optional[List[float]], - hue: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + hue: Optional[List[float]], + sharpness: Optional[List[float]] + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float], + Optional[float]]: """Get the parameters for the randomized transform to be applied on image. Args: @@ -1101,19 +1107,22 @@ def get_params(brightness: Optional[List[float]], uniformly. Pass None to turn off the transformation. hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. Pass None to turn off the transformation. + sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen + uniformly. Pass None to turn off the transformation. Returns: tuple: The parameters used to apply the randomized transform along with their random order. """ - fn_idx = torch.randperm(4) + fn_idx = torch.randperm(5) b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1])) - return fn_idx, b, c, s, h + return fn_idx, b, c, s, h, sp def forward(self, img): """ @@ -1123,8 +1132,8 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \ + self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: @@ -1135,6 +1144,8 @@ def forward(self, img): img = F.adjust_saturation(img, saturation_factor) elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) + elif fn_id == 4 and sharpness_factor is not None: + img = F.adjust_sharpness(img, sharpness_factor) return img @@ -1143,7 +1154,8 @@ def __repr__(self): format_string += 'brightness={0}'.format(self.brightness) format_string += ', contrast={0}'.format(self.contrast) format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0})'.format(self.hue) + format_string += ', hue={0}'.format(self.hue) + format_string += ', sharpness={0})'.format(self.sharpness) return format_string From f5741dd3dcb4905e2cb9f19cd46cc3d011471c75 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 4 Dec 2020 01:42:08 +0000 Subject: [PATCH 3/3] Handling tiny images and adding a test. --- test/test_transforms.py | 11 +++++++++++ torchvision/transforms/functional_tensor.py | 3 +++ 2 files changed, 14 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 5617a786489..58ffa93f6e2 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1263,6 +1263,17 @@ def test_adjust_sharpness(self): y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) self.assertTrue(np.allclose(y_np, y_ans)) + # test 3 + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_th = torch.tensor(x_np.transpose(2, 0, 1)) + y_pil = F.adjust_sharpness(x_pil, 2) + y_np = np.array(y_pil).transpose(2, 0, 1) + y_th = F.adjust_sharpness(x_th, 2) + self.assertTrue(np.allclose(y_np, y_th.numpy())) + def test_adjust_gamma(self): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 8038e24fb71..70e1cdc7833 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1258,4 +1258,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: _assert_channels(img, [1, 3]) + if img.size(-1) <= 2 or img.size(-2) <= 2: + return img + return _blend(img, _blur_image(img), sharpness_factor)