Skip to content

Commit 883f1fb

Browse files
Make ColorJitter torchscriptable (#2298)
* feat: torchscriptable adjusments * fix: tensor output type * feat: ColorJitter torchscriptable * fix: too many blank lines * fix: documentation spacing and torchscript annotation * refactor: list type for _check_input * refactor: reverting to original syntax Co-authored-by: clement.joudet <[email protected]>
1 parent 2cfc360 commit 883f1fb

File tree

5 files changed

+206
-53
lines changed

5 files changed

+206
-53
lines changed

test/test_functional_tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,23 @@ def test_adjustments(self):
9797
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
9898
self.assertTrue(torch.equal(img, img_clone))
9999

100+
# test for class interface
101+
f = transforms.ColorJitter(brightness=factor.item())
102+
scripted_fn = torch.jit.script(f)
103+
scripted_fn(img)
104+
105+
f = transforms.ColorJitter(contrast=factor.item())
106+
scripted_fn = torch.jit.script(f)
107+
scripted_fn(img)
108+
109+
f = transforms.ColorJitter(saturation=factor.item())
110+
scripted_fn = torch.jit.script(f)
111+
scripted_fn(img)
112+
113+
f = transforms.ColorJitter(brightness=1)
114+
scripted_fn = torch.jit.script(f)
115+
scripted_fn(img)
116+
100117
def test_rgb_to_grayscale(self):
101118
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
102119
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)

test/test_transforms_tensor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,32 @@ def test_random_horizontal_flip(self):
3939
def test_random_vertical_flip(self):
4040
self._test_flip('vflip', 'RandomVerticalFlip')
4141

42+
def test_adjustments(self):
43+
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
44+
for _ in range(20):
45+
factor = 3 * torch.rand(1).item()
46+
tensor, _ = self._create_data()
47+
pil_img = T.ToPILImage()(tensor)
48+
49+
for func in fns:
50+
adjusted_tensor = getattr(F, func)(tensor, factor)
51+
adjusted_pil_img = getattr(F, func)(pil_img, factor)
52+
53+
adjusted_pil_tensor = T.ToTensor()(adjusted_pil_img)
54+
scripted_fn = torch.jit.script(getattr(F, func))
55+
adjusted_tensor_script = scripted_fn(tensor, factor)
56+
57+
if not tensor.dtype.is_floating_point:
58+
adjusted_tensor = adjusted_tensor.to(torch.float) / 255
59+
adjusted_tensor_script = adjusted_tensor_script.to(torch.float) / 255
60+
61+
# F uses uint8 and F_t uses float, so there is a small
62+
# difference in values caused by (at most 5) truncations.
63+
max_diff = (adjusted_tensor - adjusted_pil_tensor).abs().max()
64+
max_diff_scripted = (adjusted_tensor - adjusted_tensor_script).abs().max()
65+
self.assertLess(max_diff, 5 / 255 + 1e-5)
66+
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
67+
4268

4369
if __name__ == '__main__':
4470
unittest.main()

torchvision/transforms/functional.py

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -633,67 +633,61 @@ def ten_crop(img, size, vertical_flip=False):
633633
return first_five + second_five
634634

635635

636-
def adjust_brightness(img, brightness_factor):
636+
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
637637
"""Adjust brightness of an Image.
638638
639639
Args:
640-
img (PIL Image): PIL Image to be adjusted.
640+
img (PIL Image or Torch Tensor): Image to be adjusted.
641641
brightness_factor (float): How much to adjust the brightness. Can be
642642
any non negative number. 0 gives a black image, 1 gives the
643643
original image while 2 increases the brightness by a factor of 2.
644644
645645
Returns:
646-
PIL Image: Brightness adjusted image.
646+
PIL Image or Torch Tensor: Brightness adjusted image.
647647
"""
648-
if not _is_pil_image(img):
649-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
648+
if not isinstance(img, torch.Tensor):
649+
return F_pil.adjust_brightness(img, brightness_factor)
650650

651-
enhancer = ImageEnhance.Brightness(img)
652-
img = enhancer.enhance(brightness_factor)
653-
return img
651+
return F_t.adjust_brightness(img, brightness_factor)
654652

655653

656-
def adjust_contrast(img, contrast_factor):
654+
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
657655
"""Adjust contrast of an Image.
658656
659657
Args:
660-
img (PIL Image): PIL Image to be adjusted.
658+
img (PIL Image or Torch Tensor): Image to be adjusted.
661659
contrast_factor (float): How much to adjust the contrast. Can be any
662660
non negative number. 0 gives a solid gray image, 1 gives the
663661
original image while 2 increases the contrast by a factor of 2.
664662
665663
Returns:
666-
PIL Image: Contrast adjusted image.
664+
PIL Image or Torch Tensor: Contrast adjusted image.
667665
"""
668-
if not _is_pil_image(img):
669-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
666+
if not isinstance(img, torch.Tensor):
667+
return F_pil.adjust_contrast(img, contrast_factor)
670668

671-
enhancer = ImageEnhance.Contrast(img)
672-
img = enhancer.enhance(contrast_factor)
673-
return img
669+
return F_t.adjust_contrast(img, contrast_factor)
674670

675671

676-
def adjust_saturation(img, saturation_factor):
672+
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
677673
"""Adjust color saturation of an image.
678674
679675
Args:
680-
img (PIL Image): PIL Image to be adjusted.
676+
img (PIL Image or Torch Tensor): Image to be adjusted.
681677
saturation_factor (float): How much to adjust the saturation. 0 will
682678
give a black and white image, 1 will give the original image while
683679
2 will enhance the saturation by a factor of 2.
684680
685681
Returns:
686-
PIL Image: Saturation adjusted image.
682+
PIL Image or Torch Tensor: Saturation adjusted image.
687683
"""
688-
if not _is_pil_image(img):
689-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
684+
if not isinstance(img, torch.Tensor):
685+
return F_pil.adjust_saturation(img, saturation_factor)
690686

691-
enhancer = ImageEnhance.Color(img)
692-
img = enhancer.enhance(saturation_factor)
693-
return img
687+
return F_t.adjust_saturation(img, saturation_factor)
694688

695689

696-
def adjust_hue(img, hue_factor):
690+
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
697691
"""Adjust hue of an image.
698692
699693
The image hue is adjusted by converting the image to HSV and
@@ -718,26 +712,10 @@ def adjust_hue(img, hue_factor):
718712
Returns:
719713
PIL Image: Hue adjusted image.
720714
"""
721-
if not(-0.5 <= hue_factor <= 0.5):
722-
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
723-
724-
if not _is_pil_image(img):
725-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
726-
727-
input_mode = img.mode
728-
if input_mode in {'L', '1', 'I', 'F'}:
729-
return img
730-
731-
h, s, v = img.convert('HSV').split()
732-
733-
np_h = np.array(h, dtype=np.uint8)
734-
# uint8 addition take cares of rotation across boundaries
735-
with np.errstate(over='ignore'):
736-
np_h += np.uint8(hue_factor * 255)
737-
h = Image.fromarray(np_h, 'L')
715+
if not isinstance(img, torch.Tensor):
716+
return F_pil.adjust_hue(img, hue_factor)
738717

739-
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
740-
return img
718+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
741719

742720

743721
def adjust_gamma(img, gamma, gain=1):

torchvision/transforms/functional_pil.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
except ImportError:
55
accimage = None
66
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
7+
import numpy as np
78

89

910
@torch.jit.unused
@@ -44,3 +45,110 @@ def vflip(img):
4445
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
4546

4647
return img.transpose(Image.FLIP_TOP_BOTTOM)
48+
49+
50+
@torch.jit.unused
51+
def adjust_brightness(img, brightness_factor):
52+
"""Adjust brightness of an RGB image.
53+
54+
Args:
55+
img (PIL Image): Image to be adjusted.
56+
brightness_factor (float): How much to adjust the brightness. Can be
57+
any non negative number. 0 gives a black image, 1 gives the
58+
original image while 2 increases the brightness by a factor of 2.
59+
60+
Returns:
61+
PIL Image: Brightness adjusted image.
62+
"""
63+
if not _is_pil_image(img):
64+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
65+
66+
enhancer = ImageEnhance.Brightness(img)
67+
img = enhancer.enhance(brightness_factor)
68+
return img
69+
70+
71+
@torch.jit.unused
72+
def adjust_contrast(img, contrast_factor):
73+
"""Adjust contrast of an Image.
74+
Args:
75+
img (PIL Image): PIL Image to be adjusted.
76+
contrast_factor (float): How much to adjust the contrast. Can be any
77+
non negative number. 0 gives a solid gray image, 1 gives the
78+
original image while 2 increases the contrast by a factor of 2.
79+
Returns:
80+
PIL Image: Contrast adjusted image.
81+
"""
82+
if not _is_pil_image(img):
83+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
84+
85+
enhancer = ImageEnhance.Contrast(img)
86+
img = enhancer.enhance(contrast_factor)
87+
return img
88+
89+
90+
@torch.jit.unused
91+
def adjust_saturation(img, saturation_factor):
92+
"""Adjust color saturation of an image.
93+
Args:
94+
img (PIL Image): PIL Image to be adjusted.
95+
saturation_factor (float): How much to adjust the saturation. 0 will
96+
give a black and white image, 1 will give the original image while
97+
2 will enhance the saturation by a factor of 2.
98+
Returns:
99+
PIL Image: Saturation adjusted image.
100+
"""
101+
if not _is_pil_image(img):
102+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
103+
104+
enhancer = ImageEnhance.Color(img)
105+
img = enhancer.enhance(saturation_factor)
106+
return img
107+
108+
109+
@torch.jit.unused
110+
def adjust_hue(img, hue_factor):
111+
"""Adjust hue of an image.
112+
113+
The image hue is adjusted by converting the image to HSV and
114+
cyclically shifting the intensities in the hue channel (H).
115+
The image is then converted back to original image mode.
116+
117+
`hue_factor` is the amount of shift in H channel and must be in the
118+
interval `[-0.5, 0.5]`.
119+
120+
See `Hue`_ for more details.
121+
122+
.. _Hue: https://en.wikipedia.org/wiki/Hue
123+
124+
Args:
125+
img (PIL Image): PIL Image to be adjusted.
126+
hue_factor (float): How much to shift the hue channel. Should be in
127+
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
128+
HSV space in positive and negative direction respectively.
129+
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
130+
with complementary colors while 0 gives the original image.
131+
132+
Returns:
133+
PIL Image: Hue adjusted image.
134+
"""
135+
if not(-0.5 <= hue_factor <= 0.5):
136+
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
137+
138+
if not _is_pil_image(img):
139+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
140+
141+
input_mode = img.mode
142+
if input_mode in {'L', '1', 'I', 'F'}:
143+
return img
144+
145+
h, s, v = img.convert('HSV').split()
146+
147+
np_h = np.array(h, dtype=np.uint8)
148+
# uint8 addition take cares of rotation across boundaries
149+
with np.errstate(over='ignore'):
150+
np_h += np.uint8(hue_factor * 255)
151+
h = Image.fromarray(np_h, 'L')
152+
153+
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
154+
return img

torchvision/transforms/transforms.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ def __repr__(self):
865865
return format_string
866866

867867

868-
class ColorJitter(object):
868+
class ColorJitter(torch.nn.Module):
869869
"""Randomly change the brightness, contrast and saturation of an image.
870870
871871
Args:
@@ -882,20 +882,23 @@ class ColorJitter(object):
882882
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
883883
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
884884
"""
885+
885886
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
887+
super().__init__()
886888
self.brightness = self._check_input(brightness, 'brightness')
887889
self.contrast = self._check_input(contrast, 'contrast')
888890
self.saturation = self._check_input(saturation, 'saturation')
889891
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
890892
clip_first_on_zero=False)
891893

894+
@torch.jit.unused
892895
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
893896
if isinstance(value, numbers.Number):
894897
if value < 0:
895898
raise ValueError("If {} is a single number, it must be non negative.".format(name))
896-
value = [center - value, center + value]
899+
value = [center - float(value), center + float(value)]
897900
if clip_first_on_zero:
898-
value[0] = max(value[0], 0)
901+
value[0] = max(value[0], 0.0)
899902
elif isinstance(value, (tuple, list)) and len(value) == 2:
900903
if not bound[0] <= value[0] <= value[1] <= bound[1]:
901904
raise ValueError("{} values should be between {}".format(name, bound))
@@ -909,6 +912,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
909912
return value
910913

911914
@staticmethod
915+
@torch.jit.unused
912916
def get_params(brightness, contrast, saturation, hue):
913917
"""Get a randomized transform to be applied on image.
914918
@@ -941,17 +945,37 @@ def get_params(brightness, contrast, saturation, hue):
941945

942946
return transform
943947

944-
def __call__(self, img):
948+
def forward(self, img):
945949
"""
946950
Args:
947-
img (PIL Image): Input image.
951+
img (PIL Image or Tensor): Input image.
948952
949953
Returns:
950-
PIL Image: Color jittered image.
954+
PIL Image or Tensor: Color jittered image.
951955
"""
952-
transform = self.get_params(self.brightness, self.contrast,
953-
self.saturation, self.hue)
954-
return transform(img)
956+
fn_idx = torch.randperm(4)
957+
for fn_id in fn_idx:
958+
if fn_id == 0 and self.brightness is not None:
959+
brightness = self.brightness
960+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
961+
img = F.adjust_brightness(img, brightness_factor)
962+
963+
if fn_id == 1 and self.contrast is not None:
964+
contrast = self.contrast
965+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
966+
img = F.adjust_contrast(img, contrast_factor)
967+
968+
if fn_id == 2 and self.saturation is not None:
969+
saturation = self.saturation
970+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
971+
img = F.adjust_saturation(img, saturation_factor)
972+
973+
if fn_id == 3 and self.hue is not None:
974+
hue = self.hue
975+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
976+
img = F.adjust_hue(img, hue_factor)
977+
978+
return img
955979

956980
def __repr__(self):
957981
format_string = self.__class__.__name__ + '('

0 commit comments

Comments
 (0)