Skip to content

Commit 97acb26

Browse files
committed
int/float support for fill in pad op
1 parent 98804cb commit 97acb26

File tree

6 files changed

+23
-12
lines changed

6 files changed

+23
-12
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def pad_image_tensor():
431431
for image, padding, fill, padding_mode in itertools.product(
432432
make_images(),
433433
[[1], [1, 1], [1, 1, 2, 2]], # padding
434-
[12], # fill
434+
[12, 12.0], # fill
435435
["constant", "symmetric", "edge", "reflect"], # padding mode,
436436
):
437437
yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode)

torchvision/prototype/features/_bounding_box.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, List, Tuple, Union, Optional
3+
from typing import Any, List, Tuple, Union, Optional, Sequence
44

55
import torch
66
from torchvision._utils import StrEnum
@@ -127,7 +127,9 @@ def resized_crop(
127127
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
128128
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
129129

130-
def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> BoundingBox:
130+
def pad(
131+
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
132+
) -> BoundingBox:
131133
from torchvision.prototype.transforms import functional as _F
132134

133135
if padding_mode not in ["constant"]:

torchvision/prototype/features/_feature.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def resized_crop(
119119
) -> Any:
120120
return self
121121

122-
def pad(self, padding: List[int], fill: Union[float, Sequence[float]] = 0, padding_mode: str = "constant") -> Any:
122+
def pad(
123+
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
124+
) -> Any:
123125
return self
124126

125127
def rotate(

torchvision/prototype/features/_image.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from typing import Any, List, Optional, Union, Tuple, cast
4+
from typing import Any, List, Optional, Union, Sequence, Tuple, cast
55

66
import torch
77
from torchvision._utils import StrEnum
@@ -163,7 +163,9 @@ def resized_crop(
163163
)
164164
return Image.new_like(self, output)
165165

166-
def pad(self, padding: List[int], fill: Union[float, List[float]] = 0.0, padding_mode: str = "constant") -> Image:
166+
def pad(
167+
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
168+
) -> Image:
167169
from torchvision.prototype.transforms import functional as _F
168170

169171
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour

torchvision/prototype/features/_segmentation_mask.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Tuple, List, Optional
3+
from typing import Tuple, List, Optional, Union, Sequence
44

55
import torch
66
from torchvision.transforms import InterpolationMode
@@ -60,7 +60,9 @@ def resized_crop(
6060
output = _F.resized_crop_segmentation_mask(self, top, left, height, width, size=size)
6161
return SegmentationMask.new_like(self, output)
6262

63-
def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> SegmentationMask:
63+
def pad(
64+
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
65+
) -> SegmentationMask:
6466
from torchvision.prototype.transforms import functional as _F
6567

6668
output = _F.pad_segmentation_mask(self, padding, padding_mode=padding_mode)

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def rotate(
507507

508508

509509
def pad_image_tensor(
510-
img: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant"
510+
img: torch.Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant"
511511
) -> torch.Tensor:
512512
num_masks, height, width = img.shape[-3:]
513513
extra_dims = img.shape[:-3]
@@ -522,8 +522,11 @@ def pad_image_tensor(
522522

523523
# TODO: This should be removed once pytorch pad supports non-scalar padding values
524524
def _pad_with_vector_fill(
525-
img: torch.Tensor, padding: List[int], fill: Union[float, List[float]] = 0.0, padding_mode: str = "constant"
526-
):
525+
img: torch.Tensor,
526+
padding: List[int],
527+
fill: Sequence[float] = [0.0],
528+
padding_mode: str = "constant",
529+
) -> torch.Tensor:
527530
if padding_mode != "constant":
528531
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
529532

@@ -573,7 +576,7 @@ def pad_bounding_box(
573576

574577

575578
def pad(
576-
inpt: Any, padding: List[int], fill: Union[float, Sequence[float]] = 0.0, padding_mode: str = "constant"
579+
inpt: Any, padding: List[int], fill: Union[int, float, Sequence[float]] = 0.0, padding_mode: str = "constant"
577580
) -> Any:
578581
if isinstance(inpt, features._Feature):
579582
return inpt.pad(padding, fill=fill, padding_mode=padding_mode)

0 commit comments

Comments
 (0)