Skip to content

Commit b4e9a2f

Browse files
authored
Implement the adjust_sharpness transform (#3114)
* Adding functional operator for sharpness. * Adding transforms for sharpness. * Handling tiny images and adding a test.
1 parent 63b8a27 commit b4e9a2f

File tree

7 files changed

+145
-12
lines changed

7 files changed

+145
-12
lines changed

test/test_functional_tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,14 @@ def test_solarize(self):
904904
dts=(torch.float32, torch.float64)
905905
)
906906

907+
def test_adjust_sharpness(self):
908+
self._test_adjust_fn(
909+
F.adjust_sharpness,
910+
F_pil.adjust_sharpness,
911+
F_t.adjust_sharpness,
912+
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
913+
)
914+
907915

908916
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
909917
class CUDATester(Tester):

test/test_transforms.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,48 @@ def test_adjust_hue(self):
12321232
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
12331233
self.assertTrue(np.allclose(y_np, y_ans))
12341234

1235+
def test_adjust_sharpness(self):
1236+
x_shape = [4, 4, 3]
1237+
x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
1238+
0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105,
1239+
111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
1240+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1241+
x_pil = Image.fromarray(x_np, mode='RGB')
1242+
1243+
# test 0
1244+
y_pil = F.adjust_sharpness(x_pil, 1)
1245+
y_np = np.array(y_pil)
1246+
self.assertTrue(np.allclose(y_np, x_np))
1247+
1248+
# test 1
1249+
y_pil = F.adjust_sharpness(x_pil, 0.5)
1250+
y_np = np.array(y_pil)
1251+
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30,
1252+
30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101,
1253+
107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
1254+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1255+
self.assertTrue(np.allclose(y_np, y_ans))
1256+
1257+
# test 2
1258+
y_pil = F.adjust_sharpness(x_pil, 2)
1259+
y_np = np.array(y_pil)
1260+
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
1261+
0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112,
1262+
119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
1263+
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1264+
self.assertTrue(np.allclose(y_np, y_ans))
1265+
1266+
# test 3
1267+
x_shape = [2, 2, 3]
1268+
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
1269+
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1270+
x_pil = Image.fromarray(x_np, mode='RGB')
1271+
x_th = torch.tensor(x_np.transpose(2, 0, 1))
1272+
y_pil = F.adjust_sharpness(x_pil, 2)
1273+
y_np = np.array(y_pil).transpose(2, 0, 1)
1274+
y_th = F.adjust_sharpness(x_th, 2)
1275+
self.assertTrue(np.allclose(y_np, y_th.numpy()))
1276+
12351277
def test_adjust_gamma(self):
12361278
x_shape = [2, 2, 3]
12371279
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
@@ -1268,10 +1310,11 @@ def test_adjusts_L_mode(self):
12681310
self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
12691311
self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
12701312
self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
1313+
self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L')
12711314
self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')
12721315

12731316
def test_color_jitter(self):
1274-
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)
1317+
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1, 2)
12751318

12761319
x_shape = [2, 2, 3]
12771320
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]

test/test_transforms_tensor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,14 @@ def test_color_jitter(self):
131131
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max"
132132
)
133133

134+
for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
135+
meth_kwargs = {"sharpness": f}
136+
self._test_class_op(
137+
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
138+
)
139+
134140
# All 4 parameters together
135-
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
141+
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2, "sharpness": 0.2}
136142
self._test_class_op(
137143
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max"
138144
)

torchvision/transforms/functional.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,3 +1235,21 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
12351235
return F_pil.solarize(img, threshold)
12361236

12371237
return F_t.solarize(img, threshold)
1238+
1239+
1240+
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
1241+
"""Adjust the sharpness of an Image.
1242+
1243+
Args:
1244+
img (PIL Image or Tensor): Image to be adjusted.
1245+
sharpness_factor (float): How much to adjust the sharpness. Can be
1246+
any non negative number. 0 gives a blurred image, 1 gives the
1247+
original image while 2 increases the sharpness by a factor of 2.
1248+
1249+
Returns:
1250+
PIL Image or Tensor: Sharpness adjusted image.
1251+
"""
1252+
if not isinstance(img, torch.Tensor):
1253+
return F_pil.adjust_sharpness(img, sharpness_factor)
1254+
1255+
return F_t.adjust_sharpness(img, sharpness_factor)

torchvision/transforms/functional_pil.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,3 +627,13 @@ def solarize(img, threshold):
627627
if not _is_pil_image(img):
628628
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
629629
return ImageOps.solarize(img, threshold)
630+
631+
632+
@torch.jit.unused
633+
def adjust_sharpness(img, sharpness_factor):
634+
if not _is_pil_image(img):
635+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
636+
637+
enhancer = ImageEnhance.Sharpness(img)
638+
img = enhancer.enhance(sharpness_factor)
639+
return img

torchvision/transforms/functional_tensor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,3 +1226,39 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
12261226
result[invert_idx] = (bound - result[invert_idx].to(dtype=dtype)).clamp(0, bound).to(dtype=img.dtype)
12271227

12281228
return result.view(img.shape)
1229+
1230+
1231+
def _blur_image(img: Tensor) -> Tensor:
1232+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
1233+
1234+
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
1235+
kernel[1, 1] = 5.0
1236+
kernel /= kernel.sum()
1237+
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
1238+
1239+
result, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ])
1240+
result = conv2d(result, kernel, groups=result.shape[-3])
1241+
result = torch_pad(result, [1, 1, 1, 1])
1242+
result = _cast_squeeze_out(result, need_cast, need_squeeze, out_dtype)
1243+
1244+
result[..., 0, :] = img[..., 0, :]
1245+
result[..., -1, :] = img[..., -1, :]
1246+
result[..., :, 0] = img[..., :, 0]
1247+
result[..., :, -1] = img[..., :, -1]
1248+
1249+
return result
1250+
1251+
1252+
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
1253+
if sharpness_factor < 0:
1254+
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor))
1255+
1256+
if not _is_tensor_a_torch_image(img):
1257+
raise TypeError('tensor is not a torch image.')
1258+
1259+
_assert_channels(img, [1, 3])
1260+
1261+
if img.size(-1) <= 2 or img.size(-2) <= 2:
1262+
return img
1263+
1264+
return _blend(img, _blur_image(img), sharpness_factor)

torchvision/transforms/transforms.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ def __repr__(self):
10391039

10401040

10411041
class ColorJitter(torch.nn.Module):
1042-
"""Randomly change the brightness, contrast and saturation of an image.
1042+
"""Randomly change the brightness, contrast, saturation, hue and sharpness of an image.
10431043
10441044
Args:
10451045
brightness (float or tuple of float (min, max)): How much to jitter brightness.
@@ -1054,15 +1054,19 @@ class ColorJitter(torch.nn.Module):
10541054
hue (float or tuple of float (min, max)): How much to jitter hue.
10551055
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
10561056
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1057+
sharpness (float or tuple of float (min, max)): How much to jitter sharpness.
1058+
sharpness_factor is chosen uniformly from [max(0, 1 - sharpness), 1 + sharpness]
1059+
or the given [min, max]. Should be non negative numbers.
10571060
"""
10581061

1059-
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
1062+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharpness=0):
10601063
super().__init__()
10611064
self.brightness = self._check_input(brightness, 'brightness')
10621065
self.contrast = self._check_input(contrast, 'contrast')
10631066
self.saturation = self._check_input(saturation, 'saturation')
10641067
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
10651068
clip_first_on_zero=False)
1069+
self.sharpness = self._check_input(sharpness, 'sharpness')
10661070

10671071
@torch.jit.unused
10681072
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
@@ -1078,7 +1082,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
10781082
else:
10791083
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
10801084

1081-
# if value is 0 or (1., 1.) for brightness/contrast/saturation
1085+
# if value is 0 or (1., 1.) for brightness/contrast/saturation/sharpness
10821086
# or (0., 0.) for hue, do nothing
10831087
if value[0] == value[1] == center:
10841088
value = None
@@ -1088,8 +1092,10 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
10881092
def get_params(brightness: Optional[List[float]],
10891093
contrast: Optional[List[float]],
10901094
saturation: Optional[List[float]],
1091-
hue: Optional[List[float]]
1092-
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
1095+
hue: Optional[List[float]],
1096+
sharpness: Optional[List[float]]
1097+
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float],
1098+
Optional[float]]:
10931099
"""Get the parameters for the randomized transform to be applied on image.
10941100
10951101
Args:
@@ -1101,19 +1107,22 @@ def get_params(brightness: Optional[List[float]],
11011107
uniformly. Pass None to turn off the transformation.
11021108
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
11031109
Pass None to turn off the transformation.
1110+
sharpness (tuple of float (min, max), optional): The range from which the sharpness is chosen
1111+
uniformly. Pass None to turn off the transformation.
11041112
11051113
Returns:
11061114
tuple: The parameters used to apply the randomized transform
11071115
along with their random order.
11081116
"""
1109-
fn_idx = torch.randperm(4)
1117+
fn_idx = torch.randperm(5)
11101118

11111119
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
11121120
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
11131121
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
11141122
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
1123+
sp = None if sharpness is None else float(torch.empty(1).uniform_(sharpness[0], sharpness[1]))
11151124

1116-
return fn_idx, b, c, s, h
1125+
return fn_idx, b, c, s, h, sp
11171126

11181127
def forward(self, img):
11191128
"""
@@ -1123,8 +1132,8 @@ def forward(self, img):
11231132
Returns:
11241133
PIL Image or Tensor: Color jittered image.
11251134
"""
1126-
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
1127-
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
1135+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, sharpness_factor = \
1136+
self.get_params(self.brightness, self.contrast, self.saturation, self.hue, self.sharpness)
11281137

11291138
for fn_id in fn_idx:
11301139
if fn_id == 0 and brightness_factor is not None:
@@ -1135,6 +1144,8 @@ def forward(self, img):
11351144
img = F.adjust_saturation(img, saturation_factor)
11361145
elif fn_id == 3 and hue_factor is not None:
11371146
img = F.adjust_hue(img, hue_factor)
1147+
elif fn_id == 4 and sharpness_factor is not None:
1148+
img = F.adjust_sharpness(img, sharpness_factor)
11381149

11391150
return img
11401151

@@ -1143,7 +1154,8 @@ def __repr__(self):
11431154
format_string += 'brightness={0}'.format(self.brightness)
11441155
format_string += ', contrast={0}'.format(self.contrast)
11451156
format_string += ', saturation={0}'.format(self.saturation)
1146-
format_string += ', hue={0})'.format(self.hue)
1157+
format_string += ', hue={0}'.format(self.hue)
1158+
format_string += ', sharpness={0})'.format(self.sharpness)
11471159
return format_string
11481160

11491161

0 commit comments

Comments
 (0)