Skip to content

[proto] Added dict support for fill arg for remaining transforms #6599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 19, 2022
6 changes: 3 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test__transform_image_mask(self, fill, mocker):
if isinstance(fill, int):
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=0, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
]
else:
calls = [
Expand Down Expand Up @@ -467,7 +467,7 @@ def test__transform_image_mask(self, fill, mocker):
if isinstance(fill, int):
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=0),
mocker.call(mask, **params, fill=fill),
]
else:
calls = [
Expand Down Expand Up @@ -1555,7 +1555,7 @@ def test__get_params(self, mocker):

@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
def test__transform(self, mocker, needs):
fill_sentinel = mocker.MagicMock()
fill_sentinel = 12
padding_mode_sentinel = mocker.MagicMock()

transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
Expand Down
9 changes: 8 additions & 1 deletion test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,16 @@ def pad_image_tensor():
for image, padding, fill, padding_mode in itertools.product(
make_images(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
[None, 12, 12.0], # fill
[None, 128.0, 128, [12.0], [12.0, 13.0, 14.0]], # fill
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
if padding_mode != "constant" and fill is not None:
# ValueError: Padding mode 'reflect' is not supported if fill is not scalar
continue

if isinstance(fill, list) and len(fill) != image.shape[-3]:
continue

yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode)


Expand Down
65 changes: 32 additions & 33 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:


def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
_check_fill_arg(fill)

if isinstance(fill, dict):
return fill
else:
return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value]

return defaultdict(lambda: fill) # type: ignore[arg-type, return-value]


def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
Expand Down Expand Up @@ -242,7 +244,6 @@ def __init__(
super().__init__()

_check_padding_arg(padding)
_check_fill_arg(fill)
_check_padding_mode_arg(padding_mode)

self.padding = padding
Expand All @@ -263,7 +264,6 @@ def __init__(
) -> None:
super().__init__(p=p)

_check_fill_arg(fill)
self.fill = _setup_fill_arg(fill)

_check_sequence_input(side_range, "side_range", req_sizes=(2,))
Expand Down Expand Up @@ -299,17 +299,15 @@ def __init__(
degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = interpolation
self.expand = expand

_check_fill_arg(fill)

self.fill = fill
self.fill = _setup_fill_arg(fill)

if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
Expand All @@ -321,12 +319,13 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(angle=angle)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.rotate(
inpt,
**params,
interpolation=self.interpolation,
expand=self.expand,
fill=self.fill,
fill=fill,
center=self.center,
)

Expand All @@ -339,7 +338,7 @@ def __init__(
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
Expand All @@ -363,10 +362,7 @@ def __init__(
self.shear = shear

self.interpolation = interpolation

_check_fill_arg(fill)

self.fill = fill
self.fill = _setup_fill_arg(fill)

if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
Expand Down Expand Up @@ -404,11 +400,12 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(angle=angle, translate=translate, scale=scale, shear=shear)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.affine(
inpt,
**params,
interpolation=self.interpolation,
fill=self.fill,
fill=fill,
center=self.center,
)

Expand All @@ -419,7 +416,7 @@ def __init__(
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand All @@ -429,12 +426,11 @@ def __init__(
if pad_if_needed or padding is not None:
if padding is not None:
_check_padding_arg(padding)
_check_fill_arg(fill)
_check_padding_mode_arg(padding_mode)

self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode

def _get_params(self, sample: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -483,17 +479,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls
fill = self.fill[type(inpt)]
if self.padding is not None:
inpt = F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
inpt = F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)

if self.pad_if_needed:
input_width, input_height = params["input_width"], params["input_height"]
if input_width < self.size[1]:
padding = [self.size[1] - input_width, 0]
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
if input_height < self.size[0]:
padding = [0, self.size[0] - input_height]
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)

return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])

Expand All @@ -502,19 +499,18 @@ class RandomPerspective(_RandomApplyTransform):
def __init__(
self,
distortion_scale: float = 0.5,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
super().__init__(p=p)

_check_fill_arg(fill)
if not (0 <= distortion_scale <= 1):
raise ValueError("Argument distortion_scale value should be between 0 and 1")

self.distortion_scale = distortion_scale
self.interpolation = interpolation
self.fill = fill
self.fill = _setup_fill_arg(fill)

def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
Expand Down Expand Up @@ -546,10 +542,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(startpoints=startpoints, endpoints=endpoints)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.perspective(
inpt,
**params,
fill=self.fill,
fill=fill,
interpolation=self.interpolation,
)

Expand All @@ -576,17 +573,15 @@ def __init__(
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)

_check_fill_arg(fill)

self.interpolation = interpolation
self.fill = fill
self.fill = _setup_fill_arg(fill)

def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
Expand Down Expand Up @@ -614,10 +609,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(displacement=displacement)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.elastic(
inpt,
**params,
fill=self.fill,
fill=fill,
interpolation=self.interpolation,
)

Expand Down Expand Up @@ -789,14 +785,16 @@ class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.

self.fill = _setup_fill_arg(fill)

self.padding_mode = padding_mode

def _get_params(self, sample: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -869,7 +867,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
)

if params["needs_pad"]:
inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode)
fill = self.fill[type(inpt)]
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)

return inpt

Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def _assert_grid_transform_inputs(

# Check fill
num_channels = get_dimensions(img)[0]
if fill is not None and isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
Expand Down