diff --git a/test/test_transforms.py b/test/test_transforms.py index a801360424c..3a76f3a1adb 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,5 +1,6 @@ from __future__ import division import os +import mock import torch import torchvision.transforms as transforms import torchvision.transforms.functional as F @@ -1074,6 +1075,26 @@ def test_rotate(self): self.assertTrue(np.all(np.array(result_a) == np.array(result_b))) + def test_rotate_fill(self): + img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB") + + modes = ("L", "RGB") + nums_bands = [len(mode) for mode in modes] + fill = 127 + + for mode, num_bands in zip(modes, nums_bands): + img_conv = img.convert(mode) + img_rot = F.rotate(img_conv, 45.0, fill=fill) + pixel = img_rot.getpixel((0, 0)) + + if not isinstance(pixel, tuple): + pixel = (pixel,) + self.assertTupleEqual(pixel, tuple([fill] * num_bands)) + + for wrong_num_bands in set(nums_bands) - {num_bands}: + with self.assertRaises(ValueError): + F.rotate(img_conv, 45.0, fill=tuple([fill] * wrong_num_bands)) + def test_affine(self): input_img = np.zeros((40, 40, 3), dtype=np.uint8) pts = [] diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8ae75f84c5b..299b0203944 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -696,7 +696,7 @@ def adjust_gamma(img, gamma, gain=1): return img -def rotate(img, angle, resample=False, expand=False, center=None, fill=0): +def rotate(img, angle, resample=False, expand=False, center=None, fill=None): """Rotate the image by angle. @@ -713,20 +713,39 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0): center (2-tuple, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. - fill (3-tuple or int): RGB pixel fill value for area outside the rotated image. - If int, it is used for all channels respectively. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ + def parse_fill(fill, num_bands): + if PILLOW_VERSION < "5.2.0": + if fill is None: + return {} + else: + msg = ("The option to fill background area of the rotated image, " + "requires pillow>=5.2.0") + raise RuntimeError(msg) + + if fill is None: + fill = 0 + if isinstance(fill, (int, float)): + fill = tuple([fill] * num_bands) + if len(fill) != num_bands: + msg = ("The number of elements in 'fill' does not match the number of " + "bands of the image ({} != {})") + raise ValueError(msg.format(len(fill), num_bands)) + + return {"fillcolor": fill} if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - if isinstance(fill, int): - fill = tuple([fill] * 3) + opts = parse_fill(fill, len(img.getbands())) - return img.rotate(angle, resample, expand, center, fillcolor=fill) + return img.rotate(angle, resample, expand, center, **opts) def _get_inverse_affine_matrix(center, angle, translate, scale, shear): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 3ec84aae84c..393e3c2db33 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -956,14 +956,15 @@ class RandomRotation(object): center (2-tuple, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. - fill (3-tuple or int): RGB pixel fill value for area outside the rotated image. - If int, it is used for all channels respectively. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - def __init__(self, degrees, resample=False, expand=False, center=None, fill=0): + def __init__(self, degrees, resample=False, expand=False, center=None, fill=None): if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError("If degrees is a single number, it must be positive.")