Skip to content

Commit 8f4f8d8

Browse files
pmeierfmassa
authored andcommitted
Fix fill in rotate (#1760)
* initial fix * outsourced num bands lookup * fix doc * added pillow version requirement * simplify number of bands extraction * remove unrelated change * remove indirect dependency on pillow>=5.2.0 * extend docstring to transform * bug fix * added test
1 parent d9a3018 commit 8f4f8d8

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

test/test_transforms.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import division
22
import os
3+
import mock
34
import torch
45
import torchvision.transforms as transforms
56
import torchvision.transforms.functional as F
@@ -1074,6 +1075,26 @@ def test_rotate(self):
10741075

10751076
self.assertTrue(np.all(np.array(result_a) == np.array(result_b)))
10761077

1078+
def test_rotate_fill(self):
1079+
img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB")
1080+
1081+
modes = ("L", "RGB")
1082+
nums_bands = [len(mode) for mode in modes]
1083+
fill = 127
1084+
1085+
for mode, num_bands in zip(modes, nums_bands):
1086+
img_conv = img.convert(mode)
1087+
img_rot = F.rotate(img_conv, 45.0, fill=fill)
1088+
pixel = img_rot.getpixel((0, 0))
1089+
1090+
if not isinstance(pixel, tuple):
1091+
pixel = (pixel,)
1092+
self.assertTupleEqual(pixel, tuple([fill] * num_bands))
1093+
1094+
for wrong_num_bands in set(nums_bands) - {num_bands}:
1095+
with self.assertRaises(ValueError):
1096+
F.rotate(img_conv, 45.0, fill=tuple([fill] * wrong_num_bands))
1097+
10771098
def test_affine(self):
10781099
input_img = np.zeros((40, 40, 3), dtype=np.uint8)
10791100
pts = []

torchvision/transforms/functional.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def adjust_gamma(img, gamma, gain=1):
696696
return img
697697

698698

699-
def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
699+
def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
700700
"""Rotate the image by angle.
701701
702702
@@ -713,20 +713,39 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=0):
713713
center (2-tuple, optional): Optional center of rotation.
714714
Origin is the upper left corner.
715715
Default is the center of the image.
716-
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
717-
If int, it is used for all channels respectively.
716+
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
717+
image. If int or float, the value is used for all bands respectively.
718+
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
718719
719720
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
720721
721722
"""
723+
def parse_fill(fill, num_bands):
724+
if PILLOW_VERSION < "5.2.0":
725+
if fill is None:
726+
return {}
727+
else:
728+
msg = ("The option to fill background area of the rotated image, "
729+
"requires pillow>=5.2.0")
730+
raise RuntimeError(msg)
731+
732+
if fill is None:
733+
fill = 0
734+
if isinstance(fill, (int, float)):
735+
fill = tuple([fill] * num_bands)
736+
if len(fill) != num_bands:
737+
msg = ("The number of elements in 'fill' does not match the number of "
738+
"bands of the image ({} != {})")
739+
raise ValueError(msg.format(len(fill), num_bands))
740+
741+
return {"fillcolor": fill}
722742

723743
if not _is_pil_image(img):
724744
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
725745

726-
if isinstance(fill, int):
727-
fill = tuple([fill] * 3)
746+
opts = parse_fill(fill, len(img.getbands()))
728747

729-
return img.rotate(angle, resample, expand, center, fillcolor=fill)
748+
return img.rotate(angle, resample, expand, center, **opts)
730749

731750

732751
def _get_inverse_affine_matrix(center, angle, translate, scale, shear):

torchvision/transforms/transforms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -956,14 +956,15 @@ class RandomRotation(object):
956956
center (2-tuple, optional): Optional center of rotation.
957957
Origin is the upper left corner.
958958
Default is the center of the image.
959-
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
960-
If int, it is used for all channels respectively.
959+
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
960+
image. If int or float, the value is used for all bands respectively.
961+
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
961962
962963
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
963964
964965
"""
965966

966-
def __init__(self, degrees, resample=False, expand=False, center=None, fill=0):
967+
def __init__(self, degrees, resample=False, expand=False, center=None, fill=None):
967968
if isinstance(degrees, numbers.Number):
968969
if degrees < 0:
969970
raise ValueError("If degrees is a single number, it must be positive.")

0 commit comments

Comments
 (0)