Skip to content

Commit 2718f73

Browse files
authored
Updated fill arg typehint for affine, perspective and elastic ops (#6595)
* Updated fill arg typehint for affine, perspective and elastic ops * Updated pad op on prototype side * Code updates * Few other minor updates
1 parent 9c660c6 commit 2718f73

File tree

6 files changed

+59
-33
lines changed

6 files changed

+59
-33
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def sample_inputs_affine_image_tensor():
226226
],
227227
[None, (0, 0)],
228228
):
229-
for fill in [None, [0.5] * image_loader.num_channels]:
229+
for fill in [None, 128.0, 128, [12.0], [0.5] * image_loader.num_channels]:
230230
yield ArgsKwargs(
231231
image_loader,
232232
interpolation=interpolation_mode,

test/test_prototype_transforms_functional.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,12 @@ def perspective_image_tensor():
228228
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
229229
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
230230
],
231-
[None, [128], [12.0]], # fill
231+
[None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]], # fill
232232
):
233+
if isinstance(fill, list) and len(fill) == 3 and image.shape[1] != 3:
234+
# skip the test with non-broadcastable fill value
235+
continue
236+
233237
yield ArgsKwargs(image, perspective_coeffs=perspective_coeffs, fill=fill)
234238

235239

@@ -268,8 +272,12 @@ def perspective_mask():
268272
def elastic_image_tensor():
269273
for image, fill in itertools.product(
270274
make_images(extra_dims=((), (4,))),
271-
[None, [128], [12.0]], # fill
275+
[None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]], # fill
272276
):
277+
if isinstance(fill, list) and len(fill) == 3 and image.shape[1] != 3:
278+
# skip the test with non-broadcastable fill value
279+
continue
280+
273281
h, w = image.shape[-2:]
274282
displacement = torch.rand(1, h, w, 2)
275283
yield ArgsKwargs(image, displacement=displacement, fill=fill)

torchvision/prototype/features/_image.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,9 @@ def pad(
177177
if not isinstance(padding, int):
178178
padding = list(padding)
179179

180-
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
181-
if isinstance(fill, (int, float)) or fill is None:
182-
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
183-
else:
184-
output = self._F._geometry._pad_with_vector_fill(self, padding, fill=fill, padding_mode=padding_mode)
180+
fill = self._F._geometry._convert_fill_arg(fill)
185181

182+
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
186183
return Image.new_like(self, output)
187184

188185
def rotate(

torchvision/prototype/features/_mask.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,9 @@ def pad(
5858
if not isinstance(padding, int):
5959
padding = list(padding)
6060

61-
if isinstance(fill, (int, float)) or fill is None:
62-
if fill is None:
63-
fill = 0
64-
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
65-
else:
66-
# Let's raise an error for vector fill on masks
67-
raise ValueError("Non-scalar fill value is not supported")
61+
fill = self._F._geometry._convert_fill_arg(fill)
6862

63+
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
6964
return Mask.new_like(self, output)
7065

7166
def rotate(

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def affine_image_tensor(
232232
scale: float,
233233
shear: List[float],
234234
interpolation: InterpolationMode = InterpolationMode.NEAREST,
235-
fill: Optional[List[float]] = None,
235+
fill: Optional[Union[int, float, List[float]]] = None,
236236
center: Optional[List[float]] = None,
237237
) -> torch.Tensor:
238238
if img.numel() == 0:
@@ -405,7 +405,9 @@ def affine_mask(
405405
return output
406406

407407

408-
def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]:
408+
def _convert_fill_arg(
409+
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]
410+
) -> Optional[Union[int, float, List[float]]]:
409411
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
410412
# So, we can't reassign fill to 0
411413
# if fill is None:
@@ -416,9 +418,6 @@ def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[f
416418
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
417419
if not isinstance(fill, (int, float)):
418420
fill = [float(v) for v in list(fill)]
419-
else:
420-
# It is OK to cast int to float as later we use inpt.dtype
421-
fill = [float(fill)]
422421
return fill
423422

424423

@@ -591,7 +590,23 @@ def rotate(
591590
def pad_image_tensor(
592591
img: torch.Tensor,
593592
padding: Union[int, List[int]],
594-
fill: Optional[Union[int, float]] = 0,
593+
fill: Optional[Union[int, float, List[float]]] = None,
594+
padding_mode: str = "constant",
595+
) -> torch.Tensor:
596+
if fill is None:
597+
# This is a JIT workaround
598+
return _pad_with_scalar_fill(img, padding, fill=None, padding_mode=padding_mode)
599+
elif isinstance(fill, (int, float)) or len(fill) == 1:
600+
fill_number = fill[0] if isinstance(fill, list) else fill
601+
return _pad_with_scalar_fill(img, padding, fill=fill_number, padding_mode=padding_mode)
602+
else:
603+
return _pad_with_vector_fill(img, padding, fill=fill, padding_mode=padding_mode)
604+
605+
606+
def _pad_with_scalar_fill(
607+
img: torch.Tensor,
608+
padding: Union[int, List[int]],
609+
fill: Union[int, float, None],
595610
padding_mode: str = "constant",
596611
) -> torch.Tensor:
597612
num_channels, height, width = img.shape[-3:]
@@ -614,13 +629,13 @@ def pad_image_tensor(
614629
def _pad_with_vector_fill(
615630
img: torch.Tensor,
616631
padding: Union[int, List[int]],
617-
fill: Sequence[float] = [0.0],
632+
fill: List[float],
618633
padding_mode: str = "constant",
619634
) -> torch.Tensor:
620635
if padding_mode != "constant":
621636
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
622637

623-
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant")
638+
output = _pad_with_scalar_fill(img, padding, fill=0, padding_mode="constant")
624639
left, right, top, bottom = _parse_pad_padding(padding)
625640
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1)
626641

@@ -639,8 +654,14 @@ def pad_mask(
639654
mask: torch.Tensor,
640655
padding: Union[int, List[int]],
641656
padding_mode: str = "constant",
642-
fill: Optional[Union[int, float]] = 0,
657+
fill: Optional[Union[int, float, List[float]]] = None,
643658
) -> torch.Tensor:
659+
if fill is None:
660+
fill = 0
661+
662+
if isinstance(fill, list):
663+
raise ValueError("Non-scalar fill value is not supported")
664+
644665
if mask.ndim < 3:
645666
mask = mask.unsqueeze(0)
646667
needs_squeeze = True
@@ -693,10 +714,9 @@ def pad(
693714
if not isinstance(padding, int):
694715
padding = list(padding)
695716

696-
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
697-
if isinstance(fill, (int, float)) or fill is None:
698-
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
699-
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
717+
fill = _convert_fill_arg(fill)
718+
719+
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
700720

701721

702722
crop_image_tensor = _FT.crop
@@ -739,7 +759,7 @@ def perspective_image_tensor(
739759
img: torch.Tensor,
740760
perspective_coeffs: List[float],
741761
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
742-
fill: Optional[List[float]] = None,
762+
fill: Optional[Union[int, float, List[float]]] = None,
743763
) -> torch.Tensor:
744764
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill)
745765

@@ -878,7 +898,7 @@ def elastic_image_tensor(
878898
img: torch.Tensor,
879899
displacement: torch.Tensor,
880900
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
881-
fill: Optional[List[float]] = None,
901+
fill: Optional[Union[int, float, List[float]]] = None,
882902
) -> torch.Tensor:
883903
return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill)
884904

torchvision/transforms/functional_tensor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,10 @@ def _gen_affine_grid(
600600

601601

602602
def affine(
603-
img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
603+
img: Tensor,
604+
matrix: List[float],
605+
interpolation: str = "nearest",
606+
fill: Optional[Union[int, float, List[float]]] = None,
604607
) -> Tensor:
605608
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
606609

@@ -693,7 +696,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
693696

694697

695698
def perspective(
696-
img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None
699+
img: Tensor,
700+
perspective_coeffs: List[float],
701+
interpolation: str = "bilinear",
702+
fill: Optional[Union[int, float, List[float]]] = None,
697703
) -> Tensor:
698704
if not (isinstance(img, torch.Tensor)):
699705
raise TypeError("Input img should be Tensor.")
@@ -950,7 +956,7 @@ def elastic_transform(
950956
img: Tensor,
951957
displacement: Tensor,
952958
interpolation: str = "bilinear",
953-
fill: Optional[List[float]] = None,
959+
fill: Optional[Union[int, float, List[float]]] = None,
954960
) -> Tensor:
955961

956962
if not (isinstance(img, torch.Tensor)):

0 commit comments

Comments
 (0)