Skip to content

Implement the adjust_sharpness transform #3114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,14 @@ def test_solarize(self):
dts=(torch.float32, torch.float64)
)

def test_adjust_sharpness(self):
self._test_adjust_fn(
F.adjust_sharpness,
F_pil.adjust_sharpness,
F_t.adjust_sharpness,
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
45 changes: 44 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,48 @@ def test_adjust_hue(self):
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

def test_adjust_sharpness(self):
x_shape = [4, 4, 3]
x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105,
111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')

# test 0
y_pil = F.adjust_sharpness(x_pil, 1)
y_np = np.array(y_pil)
self.assertTrue(np.allclose(y_np, x_np))

# test 1
y_pil = F.adjust_sharpness(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30,
30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101,
107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

# test 2
y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112,
119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))

# test 3
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
x_th = torch.tensor(x_np.transpose(2, 0, 1))
y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil).transpose(2, 0, 1)
y_th = F.adjust_sharpness(x_th, 2)
self.assertTrue(np.allclose(y_np, y_th.numpy()))

def test_adjust_gamma(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
Expand Down Expand Up @@ -1268,10 +1310,11 @@ def test_adjusts_L_mode(self):
self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')

def test_color_jitter(self):
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2)

x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
Expand Down
8 changes: 7 additions & 1 deletion test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,14 @@ def test_color_jitter(self):
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max"
)

for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
meth_kwargs = {"sharpness": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

# All 4 parameters together
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max"
)
Expand Down
18 changes: 18 additions & 0 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,3 +1235,21 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
return F_pil.solarize(img, threshold)

return F_t.solarize(img, threshold)


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
"""Adjust the sharpness of an Image.

Args:
img (PIL Image or Tensor): Image to be adjusted.
sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.

Returns:
PIL Image or Tensor: Sharpness adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_sharpness(img, sharpness_factor)

return F_t.adjust_sharpness(img, sharpness_factor)
10 changes: 10 additions & 0 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,13 @@ def solarize(img, threshold):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.solarize(img, threshold)


@torch.jit.unused
def adjust_sharpness(img, sharpness_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img
36 changes: 36 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,3 +1226,39 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
result[invert_idx] = (bound - result[invert_idx].to(dtype=dtype)).clamp(0, bound).to(dtype=img.dtype)

return result.view(img.shape)


def _blur_image(img: Tensor) -> Tensor:
dtype = img.dtype if torch.is_floating_point(img) else torch.float32

kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel[1, 1] = 5.0
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

result, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ])
result = conv2d(result, kernel, groups=result.shape[-3])
result = torch_pad(result, [1, 1, 1, 1])
result = _cast_squeeze_out(result, need_cast, need_squeeze, out_dtype)

result[..., 0, :] = img[..., 0, :]
result[..., -1, :] = img[..., -1, :]
result[..., :, 0] = img[..., :, 0]
result[..., :, -1] = img[..., :, -1]

return result


def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_channels(img, [1, 3])

if img.size(-1) <= 2 or img.size(-2) <= 2:
return img

return _blend(img, _blur_image(img), sharpness_factor)
32 changes: 22 additions & 10 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def __repr__(self):


class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast and saturation of an image.
"""Randomly change the brightness, contrast, saturation, hue and sharpness of an image.

Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
Expand All @@ -1054,15 +1054,19 @@ class ColorJitter(torch.nn.Module):
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
sharpness (float or tuple of float (min, max)): How much to jitter sharpness.
sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness]
or the given [min, max]. Should be non negative numbers.
"""

def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0):
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False)
self.sharpness = self._check_input(sharpness, 'sharpness')

@torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
Expand All @@ -1078,7 +1082,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
else:
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))

# if value is 0 or (1., 1.) for brightness/contrast/saturation
# if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
Expand All @@ -1088,8 +1092,10 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
def get_params(brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
hue: Optional[List[float]],
sharpness: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float],
Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.

Args:
Expand All @@ -1101,19 +1107,22 @@ def get_params(brightness: Optional[List[float]],
uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.
sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen
uniformly. Pass None to turn off the transformation.

Returns:
tuple: The parameters used to apply the randomized transform
along with their random order.
"""
fn_idx = torch.randperm(4)
fn_idx = torch.randperm(5)

b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1]))

return fn_idx, b, c, s, h
return fn_idx, b, c, s, h, sp

def forward(self, img):
"""
Expand All @@ -1123,8 +1132,8 @@ def forward(self, img):
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness)

for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
Expand All @@ -1135,6 +1144,8 @@ def forward(self, img):
img = F.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)
elif fn_id == 4 and sharpness_factor is not None:
img = F.adjust_sharpness(img, sharpness_factor)

return img

Expand All @@ -1143,7 +1154,8 @@ def __repr__(self):
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0})'.format(self.hue)
format_string += ', hue={0}'.format(self.hue)
format_string += ', sharpness={0})'.format(self.sharpness)
return format_string


Expand Down