Skip to content

Commit 21deb4d

Browse files
voldemortXvfdev-5
andauthored
Fill color support for tensor affine transforms (#2904)
* Fill color support for tensor affine transforms * PEP fix * Docstring changes and float support * Docstring update for transforms and float type cast * Cast only for Tensor * Temporary patch for lack of Union type support, plus an extra unit test * More plausible bilinear filling for tensors * Keep things simple & New docstrings * Fix lint and other issues after merge * make it in one line * Docstring and some code modifications * More tests and corresponding changes for transoforms and docstring changes * Simplify test configs * Update test_functional_tensor.py * Update test_functional_tensor.py * Move assertions Co-authored-by: vfdev <[email protected]>
1 parent df4003f commit 21deb4d

File tree

6 files changed

+197
-132
lines changed

6 files changed

+197
-132
lines changed

test/test_functional_tensor.py

+65-60
Original file line numberDiff line numberDiff line change
@@ -552,24 +552,25 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine):
552552
def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
553553
# 4) Test rotation + translation + scale + share
554554
test_configs = [
555-
(45, [5, 6], 1.0, [0.0, 0.0]),
556-
(33, (5, -4), 1.0, [0.0, 0.0]),
557-
(45, [-5, 4], 1.2, [0.0, 0.0]),
558-
(33, (-4, -8), 2.0, [0.0, 0.0]),
559-
(85, (10, -10), 0.7, [0.0, 0.0]),
560-
(0, [0, 0], 1.0, [35.0, ]),
561-
(-25, [0, 0], 1.2, [0.0, 15.0]),
562-
(-45, [-10, 0], 0.7, [2.0, 5.0]),
563-
(-45, [-10, -10], 1.2, [4.0, 5.0]),
564-
(-90, [0, 0], 1.0, [0.0, 0.0]),
555+
(45.5, [5, 6], 1.0, [0.0, 0.0], None),
556+
(33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
557+
(45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
558+
(33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
559+
(85, (10, -10), 0.7, [0.0, 0.0], [1, ]),
560+
(0, [0, 0], 1.0, [35.0, ], (2.0, )),
561+
(-25, [0, 0], 1.2, [0.0, 15.0], None),
562+
(-45, [-10, 0], 0.7, [2.0, 5.0], None),
563+
(-45, [-10, -10], 1.2, [4.0, 5.0], None),
564+
(-90, [0, 0], 1.0, [0.0, 0.0], None),
565565
]
566566
for r in [NEAREST, ]:
567-
for a, t, s, sh in test_configs:
568-
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r)
567+
for a, t, s, sh, f in test_configs:
568+
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
569+
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f_pil)
569570
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
570571

571572
for fn in [F.affine, scripted_affine]:
572-
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu()
573+
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f).cpu()
573574

574575
if out_tensor.dtype != torch.uint8:
575576
out_tensor = out_tensor.to(torch.uint8)
@@ -582,7 +583,7 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
582583
ratio_diff_pixels,
583584
tol,
584585
msg="{}: {}\n{} vs \n{}".format(
585-
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
586+
(r, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
586587
)
587588
)
588589

@@ -643,35 +644,36 @@ def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
643644
for a in range(-180, 180, 17):
644645
for e in [True, False]:
645646
for c in centers:
646-
647-
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c)
648-
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
649-
for fn in [F.rotate, scripted_rotate]:
650-
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu()
651-
652-
if out_tensor.dtype != torch.uint8:
653-
out_tensor = out_tensor.to(torch.uint8)
654-
655-
self.assertEqual(
656-
out_tensor.shape,
657-
out_pil_tensor.shape,
658-
msg="{}: {} vs {}".format(
659-
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
660-
)
661-
)
662-
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
663-
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
664-
# Tolerance : less than 3% of different pixels
665-
self.assertLess(
666-
ratio_diff_pixels,
667-
0.03,
668-
msg="{}: {}\n{} vs \n{}".format(
669-
(img_size, r, dt, a, e, c),
647+
for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
648+
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
649+
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
650+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
651+
for fn in [F.rotate, scripted_rotate]:
652+
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()
653+
654+
if out_tensor.dtype != torch.uint8:
655+
out_tensor = out_tensor.to(torch.uint8)
656+
657+
self.assertEqual(
658+
out_tensor.shape,
659+
out_pil_tensor.shape,
660+
msg="{}: {} vs {}".format(
661+
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
662+
))
663+
664+
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
665+
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
666+
# Tolerance : less than 3% of different pixels
667+
self.assertLess(
670668
ratio_diff_pixels,
671-
out_tensor[0, :7, :7],
672-
out_pil_tensor[0, :7, :7]
669+
0.03,
670+
msg="{}: {}\n{} vs \n{}".format(
671+
(img_size, r, dt, a, e, c, f),
672+
ratio_diff_pixels,
673+
out_tensor[0, :7, :7],
674+
out_pil_tensor[0, :7, :7]
675+
)
673676
)
674-
)
675677

676678
def test_rotate(self):
677679
# Tests on square image
@@ -721,30 +723,33 @@ def test_rotate(self):
721723

722724
def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
723725
dt = tensor.dtype
724-
for r in [NEAREST, ]:
725-
for spoints, epoints in test_configs:
726-
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
727-
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
726+
for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )]:
727+
for r in [NEAREST, ]:
728+
for spoints, epoints in test_configs:
729+
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
730+
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r,
731+
fill=f_pil)
732+
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
728733

729-
for fn in [F.perspective, scripted_transform]:
730-
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
734+
for fn in [F.perspective, scripted_transform]:
735+
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu()
731736

732-
if out_tensor.dtype != torch.uint8:
733-
out_tensor = out_tensor.to(torch.uint8)
737+
if out_tensor.dtype != torch.uint8:
738+
out_tensor = out_tensor.to(torch.uint8)
734739

735-
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
736-
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
737-
# Tolerance : less than 5% of different pixels
738-
self.assertLess(
739-
ratio_diff_pixels,
740-
0.05,
741-
msg="{}: {}\n{} vs \n{}".format(
742-
(r, dt, spoints, epoints),
740+
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
741+
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
742+
# Tolerance : less than 5% of different pixels
743+
self.assertLess(
743744
ratio_diff_pixels,
744-
out_tensor[0, :7, :7],
745-
out_pil_tensor[0, :7, :7]
745+
0.05,
746+
msg="{}: {}\n{} vs \n{}".format(
747+
(f, r, dt, spoints, epoints),
748+
ratio_diff_pixels,
749+
out_tensor[0, :7, :7],
750+
out_pil_tensor[0, :7, :7]
751+
)
746752
)
747-
)
748753

749754
def test_perspective(self):
750755

test/test_transforms_tensor.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -349,14 +349,15 @@ def test_random_affine(self):
349349
for translate in [(0.1, 0.2), [0.2, 0.1]]:
350350
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
351351
for interpolation in [NEAREST, BILINEAR]:
352-
transform = T.RandomAffine(
353-
degrees=degrees, translate=translate,
354-
scale=scale, shear=shear, interpolation=interpolation
355-
)
356-
s_transform = torch.jit.script(transform)
352+
for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
353+
transform = T.RandomAffine(
354+
degrees=degrees, translate=translate,
355+
scale=scale, shear=shear, interpolation=interpolation, fill=fill
356+
)
357+
s_transform = torch.jit.script(transform)
357358

358-
self._test_transform_vs_scripted(transform, s_transform, tensor)
359-
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
359+
self._test_transform_vs_scripted(transform, s_transform, tensor)
360+
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
360361

361362
with get_tmp_dir() as tmp_dir:
362363
s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))
@@ -369,13 +370,14 @@ def test_random_rotate(self):
369370
for expand in [True, False]:
370371
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
371372
for interpolation in [NEAREST, BILINEAR]:
372-
transform = T.RandomRotation(
373-
degrees=degrees, interpolation=interpolation, expand=expand, center=center
374-
)
375-
s_transform = torch.jit.script(transform)
373+
for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
374+
transform = T.RandomRotation(
375+
degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill
376+
)
377+
s_transform = torch.jit.script(transform)
376378

377-
self._test_transform_vs_scripted(transform, s_transform, tensor)
378-
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
379+
self._test_transform_vs_scripted(transform, s_transform, tensor)
380+
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
379381

380382
with get_tmp_dir() as tmp_dir:
381383
s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))
@@ -386,14 +388,16 @@ def test_random_perspective(self):
386388

387389
for distortion_scale in np.linspace(0.1, 1.0, num=20):
388390
for interpolation in [NEAREST, BILINEAR]:
389-
transform = T.RandomPerspective(
390-
distortion_scale=distortion_scale,
391-
interpolation=interpolation
392-
)
393-
s_transform = torch.jit.script(transform)
391+
for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
392+
transform = T.RandomPerspective(
393+
distortion_scale=distortion_scale,
394+
interpolation=interpolation,
395+
fill=fill
396+
)
397+
s_transform = torch.jit.script(transform)
394398

395-
self._test_transform_vs_scripted(transform, s_transform, tensor)
396-
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
399+
self._test_transform_vs_scripted(transform, s_transform, tensor)
400+
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
397401

398402
with get_tmp_dir() as tmp_dir:
399403
s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))

torchvision/transforms/functional.py

+21-17
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def perspective(
557557
startpoints: List[List[int]],
558558
endpoints: List[List[int]],
559559
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
560-
fill: Optional[int] = None
560+
fill: Optional[List[float]] = None
561561
) -> Tensor:
562562
"""Perform perspective transform of the given image.
563563
The image can be a PIL Image or a Tensor, in which case it is expected
@@ -573,10 +573,12 @@ def perspective(
573573
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
574574
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
575575
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
576-
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
576+
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
577577
image. If int or float, the value is used for all bands respectively.
578-
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
579-
input. Fill value for the area outside the transform in the output image is always 0.
578+
This option is supported for PIL image and Tensor inputs.
579+
In torchscript mode single int/float value is not supported, please use a tuple
580+
or list of length 1: ``[value, ]``.
581+
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
580582
581583
Returns:
582584
PIL Image or Tensor: transformed Image.
@@ -871,7 +873,7 @@ def _get_inverse_affine_matrix(
871873
def rotate(
872874
img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST,
873875
expand: bool = False, center: Optional[List[int]] = None,
874-
fill: Optional[int] = None, resample: Optional[int] = None
876+
fill: Optional[List[float]] = None, resample: Optional[int] = None
875877
) -> Tensor:
876878
"""Rotate the image by angle.
877879
The image can be a PIL Image or a Tensor, in which case it is expected
@@ -890,13 +892,12 @@ def rotate(
890892
Note that the expand flag assumes rotation around the center and no translation.
891893
center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner.
892894
Default is the center of the image.
893-
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
895+
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
894896
image. If int or float, the value is used for all bands respectively.
895-
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
896-
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
897-
image is always 0.
898-
resample (int, optional): deprecated argument and will be removed since v0.10.0.
899-
Please use `arg`:interpolation: instead.
897+
This option is supported for PIL image and Tensor inputs.
898+
In torchscript mode single int/float value is not supported, please use a tuple
899+
or list of length 1: ``[value, ]``.
900+
If input is PIL Image, the options is only available for ``Pillow>=5.2.0``.
900901
901902
Returns:
902903
PIL Image or Tensor: Rotated image.
@@ -945,8 +946,8 @@ def rotate(
945946

946947
def affine(
947948
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
948-
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[int] = None,
949-
resample: Optional[int] = None, fillcolor: Optional[int] = None
949+
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None,
950+
resample: Optional[int] = None, fillcolor: Optional[List[float]] = None
950951
) -> Tensor:
951952
"""Apply affine transformation on the image keeping image center invariant.
952953
The image can be a PIL Image or a Tensor, in which case it is expected
@@ -964,10 +965,13 @@ def affine(
964965
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
965966
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
966967
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
967-
fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
968-
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
969-
image is always 0.
970-
fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0.
968+
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
969+
image. If int or float, the value is used for all bands respectively.
970+
This option is supported for PIL image and Tensor inputs.
971+
In torchscript mode single int/float value is not supported, please use a tuple
972+
or list of length 1: ``[value, ]``.
973+
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
974+
fillcolor (sequence, int, float): deprecated argument and will be removed since v0.10.0.
971975
Please use `arg`:fill: instead.
972976
resample (int, optional): deprecated argument and will be removed since v0.10.0.
973977
Please use `arg`:interpolation: instead.

torchvision/transforms/functional_pil.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,13 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"):
465465
fill = 0
466466
if isinstance(fill, (int, float)) and num_bands > 1:
467467
fill = tuple([fill] * num_bands)
468-
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
469-
msg = ("The number of elements in 'fill' does not match the number of "
470-
"bands of the image ({} != {})")
471-
raise ValueError(msg.format(len(fill), num_bands))
468+
if isinstance(fill, (list, tuple)):
469+
if len(fill) != num_bands:
470+
msg = ("The number of elements in 'fill' does not match the number of "
471+
"bands of the image ({} != {})")
472+
raise ValueError(msg.format(len(fill), num_bands))
473+
474+
fill = tuple(fill)
472475

473476
return {name: fill}
474477

0 commit comments

Comments
 (0)