Skip to content

Commit 264af37

Browse files
NicolasHugvfdev-5datumbox
authored andcommitted
RandomRotation and fill (#3303)
Summary: * initial fix * fill=0 * docstrings * fill type check * fill type check * set None to zero * unit tests * set instead of NotImplemented * fix W293 Reviewed By: datumbox Differential Revision: D26226611 fbshipit-source-id: acee01300be03ad94eab3beb1c711f5e6050f632 Co-authored-by: vfdev <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent b0482f5 commit 264af37

File tree

2 files changed

+47
-7
lines changed

2 files changed

+47
-7
lines changed

test/test_transforms.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,14 @@ def test_randomperspective(self):
180180
torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)))
181181

182182
def test_randomperspective_fill(self):
183+
184+
# assert fill being either a Sequence or a Number
185+
with self.assertRaises(TypeError):
186+
transforms.RandomPerspective(fill={})
187+
188+
t = transforms.RandomPerspective(fill=None)
189+
self.assertTrue(t.fill == 0)
190+
183191
height = 100
184192
width = 100
185193
img = torch.ones(3, height, width)
@@ -1531,6 +1539,13 @@ def test_random_rotation(self):
15311539
transforms.RandomRotation([-0.7])
15321540
transforms.RandomRotation([-0.7, 0, 0.7])
15331541

1542+
# assert fill being either a Sequence or a Number
1543+
with self.assertRaises(TypeError):
1544+
transforms.RandomRotation(0, fill={})
1545+
1546+
t = transforms.RandomRotation(0, fill=None)
1547+
self.assertTrue(t.fill == 0)
1548+
15341549
t = transforms.RandomRotation(10)
15351550
angle = t.get_params(t.degrees)
15361551
self.assertTrue(angle > -10 and angle < 10)
@@ -1573,6 +1588,13 @@ def test_random_affine(self):
15731588
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10])
15741589
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10])
15751590

1591+
# assert fill being either a Sequence or a Number
1592+
with self.assertRaises(TypeError):
1593+
transforms.RandomAffine(0, fill={})
1594+
1595+
t = transforms.RandomAffine(0, fill=None)
1596+
self.assertTrue(t.fill == 0)
1597+
15761598
x = np.zeros((100, 100, 3), dtype=np.uint8)
15771599
img = F.to_pil_image(x)
15781600

torchvision/transforms/transforms.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -673,8 +673,8 @@ class RandomPerspective(torch.nn.Module):
673673
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
674674
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
675675
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
676-
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
677-
image. If given a number, the value is used for all bands respectively.
676+
fill (sequence or number): Pixel fill value for the area outside the transformed
677+
image. Default is ``0``. If given a number, the value is used for all bands respectively.
678678
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
679679
"""
680680

@@ -692,6 +692,12 @@ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.
692692

693693
self.interpolation = interpolation
694694
self.distortion_scale = distortion_scale
695+
696+
if fill is None:
697+
fill = 0
698+
elif not isinstance(fill, (Sequence, numbers.Number)):
699+
raise TypeError("Fill should be either a sequence or a number.")
700+
695701
self.fill = fill
696702

697703
def forward(self, img):
@@ -1175,8 +1181,8 @@ class RandomRotation(torch.nn.Module):
11751181
Note that the expand flag assumes rotation around the center and no translation.
11761182
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
11771183
Default is the center of the image.
1178-
fill (sequence or number, optional): Pixel fill value for the area outside the rotated
1179-
image. If given a number, the value is used for all bands respectively.
1184+
fill (sequence or number): Pixel fill value for the area outside the rotated
1185+
image. Default is ``0``. If given a number, the value is used for all bands respectively.
11801186
If input is PIL Image, the options is only available for ``Pillow>=5.2.0``.
11811187
resample (int, optional): deprecated argument and will be removed since v0.10.0.
11821188
Please use the ``interpolation`` parameter instead.
@@ -1186,7 +1192,7 @@ class RandomRotation(torch.nn.Module):
11861192
"""
11871193

11881194
def __init__(
1189-
self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=None, resample=None
1195+
self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None
11901196
):
11911197
super().__init__()
11921198
if resample is not None:
@@ -1212,6 +1218,12 @@ def __init__(
12121218

12131219
self.resample = self.interpolation = interpolation
12141220
self.expand = expand
1221+
1222+
if fill is None:
1223+
fill = 0
1224+
elif not isinstance(fill, (Sequence, numbers.Number)):
1225+
raise TypeError("Fill should be either a sequence or a number.")
1226+
12151227
self.fill = fill
12161228

12171229
@staticmethod
@@ -1280,8 +1292,8 @@ class RandomAffine(torch.nn.Module):
12801292
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
12811293
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
12821294
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1283-
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
1284-
image. If given a number, the value is used for all bands respectively.
1295+
fill (sequence or number): Pixel fill value for the area outside the transformed
1296+
image. Default is ``0``. If given a number, the value is used for all bands respectively.
12851297
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
12861298
fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0.
12871299
Please use the ``fill`` parameter instead.
@@ -1339,6 +1351,12 @@ def __init__(
13391351
self.shear = shear
13401352

13411353
self.resample = self.interpolation = interpolation
1354+
1355+
if fill is None:
1356+
fill = 0
1357+
elif not isinstance(fill, (Sequence, numbers.Number)):
1358+
raise TypeError("Fill should be either a sequence or a number.")
1359+
13421360
self.fillcolor = self.fill = fill
13431361

13441362
@staticmethod

0 commit comments

Comments
 (0)