Skip to content

Commit 98804cb

Browse files
committed
Fill arg supports float values, scripted pad op
1 parent 045805f commit 98804cb

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

test/test_functional_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,7 @@ def test_adjust_gamma(device, dtype, config, channels):
972972
[
973973
{"padding_mode": "constant", "fill": 0},
974974
{"padding_mode": "constant", "fill": 10},
975-
{"padding_mode": "constant", "fill": 20},
975+
{"padding_mode": "constant", "fill": 20.2},
976976
{"padding_mode": "edge"},
977977
{"padding_mode": "reflect"},
978978
{"padding_mode": "symmetric"},

torchvision/transforms/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numbers
33
import warnings
44
from enum import Enum
5-
from typing import List, Tuple, Any, Optional
5+
from typing import List, Tuple, Any, Optional, Union
66

77
import numpy as np
88
import torch
@@ -474,7 +474,7 @@ def resize(
474474
return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
475475

476476

477-
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
477+
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
478478
r"""Pad the given image on all sides with the given "pad" value.
479479
If the image is torch Tensor, it is expected
480480
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,

torchvision/transforms/functional_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Optional, Tuple, List
2+
from typing import Optional, Tuple, List, Union
33

44
import torch
55
from torch import Tensor
@@ -370,7 +370,7 @@ def _parse_pad_padding(padding: List[int]) -> List[int]:
370370
return [pad_left, pad_right, pad_top, pad_bottom]
371371

372372

373-
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
373+
def pad(img: Tensor, padding: List[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor:
374374
_assert_image_tensor(img)
375375

376376
if not isinstance(padding, (int, tuple, list)):

0 commit comments

Comments
 (0)