-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Updated fill arg typehint for affine, perspective and elastic ops #6595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e24d71a
8bf31ba
a36da7d
6d62612
fc42d08
430911e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -232,7 +232,7 @@ def affine_image_tensor( | |||||||||||||||||||||||||||
scale: float, | ||||||||||||||||||||||||||||
shear: List[float], | ||||||||||||||||||||||||||||
interpolation: InterpolationMode = InterpolationMode.NEAREST, | ||||||||||||||||||||||||||||
fill: Optional[List[float]] = None, | ||||||||||||||||||||||||||||
fill: Optional[Union[int, float, List[float]]] = None, | ||||||||||||||||||||||||||||
center: Optional[List[float]] = None, | ||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||
if img.numel() == 0: | ||||||||||||||||||||||||||||
|
@@ -405,7 +405,9 @@ def affine_mask( | |||||||||||||||||||||||||||
return output | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]: | ||||||||||||||||||||||||||||
def _convert_fill_arg( | ||||||||||||||||||||||||||||
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] | ||||||||||||||||||||||||||||
) -> Optional[Union[int, float, List[float]]]: | ||||||||||||||||||||||||||||
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 | ||||||||||||||||||||||||||||
# So, we can't reassign fill to 0 | ||||||||||||||||||||||||||||
# if fill is None: | ||||||||||||||||||||||||||||
|
@@ -416,9 +418,6 @@ def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[f | |||||||||||||||||||||||||||
# This cast does Sequence -> List[float] to please mypy and torch.jit.script | ||||||||||||||||||||||||||||
if not isinstance(fill, (int, float)): | ||||||||||||||||||||||||||||
fill = [float(v) for v in list(fill)] | ||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
# It is OK to cast int to float as later we use inpt.dtype | ||||||||||||||||||||||||||||
fill = [float(fill)] | ||||||||||||||||||||||||||||
return fill | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
@@ -591,7 +590,23 @@ def rotate( | |||||||||||||||||||||||||||
def pad_image_tensor( | ||||||||||||||||||||||||||||
img: torch.Tensor, | ||||||||||||||||||||||||||||
padding: Union[int, List[int]], | ||||||||||||||||||||||||||||
fill: Optional[Union[int, float]] = 0, | ||||||||||||||||||||||||||||
fill: Optional[Union[int, float, List[float]]] = None, | ||||||||||||||||||||||||||||
padding_mode: str = "constant", | ||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||
if fill is None: | ||||||||||||||||||||||||||||
# This is a JIT workaround | ||||||||||||||||||||||||||||
return _pad_with_scalar_fill(img, padding, fill=None, padding_mode=padding_mode) | ||||||||||||||||||||||||||||
elif isinstance(fill, (int, float)) or len(fill) == 1: | ||||||||||||||||||||||||||||
fill_number = fill[0] if isinstance(fill, list) else fill | ||||||||||||||||||||||||||||
return _pad_with_scalar_fill(img, padding, fill=fill_number, padding_mode=padding_mode) | ||||||||||||||||||||||||||||
Comment on lines
+596
to
+601
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just wonder if we could do something like:
Suggested change
To avoid the duplicate dispatch to scalar. I assume that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Functional tensor is already doing vision/torchvision/transforms/functional_tensor.py Lines 381 to 382 in c0911e3
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking that below code is valid, but JIT does not agree if fill is None or isinstance(fill, (int, float)) or len(fill) == 1:
fill_number = fill[0] if isinstance(fill, list) else fill |
||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
return _pad_with_vector_fill(img, padding, fill=fill, padding_mode=padding_mode) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def _pad_with_scalar_fill( | ||||||||||||||||||||||||||||
img: torch.Tensor, | ||||||||||||||||||||||||||||
padding: Union[int, List[int]], | ||||||||||||||||||||||||||||
fill: Union[int, float, None], | ||||||||||||||||||||||||||||
padding_mode: str = "constant", | ||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||
num_channels, height, width = img.shape[-3:] | ||||||||||||||||||||||||||||
|
@@ -614,13 +629,13 @@ def pad_image_tensor( | |||||||||||||||||||||||||||
def _pad_with_vector_fill( | ||||||||||||||||||||||||||||
img: torch.Tensor, | ||||||||||||||||||||||||||||
padding: Union[int, List[int]], | ||||||||||||||||||||||||||||
fill: Sequence[float] = [0.0], | ||||||||||||||||||||||||||||
fill: List[float], | ||||||||||||||||||||||||||||
padding_mode: str = "constant", | ||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||
if padding_mode != "constant": | ||||||||||||||||||||||||||||
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant") | ||||||||||||||||||||||||||||
output = _pad_with_scalar_fill(img, padding, fill=0, padding_mode="constant") | ||||||||||||||||||||||||||||
left, right, top, bottom = _parse_pad_padding(padding) | ||||||||||||||||||||||||||||
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
@@ -639,8 +654,14 @@ def pad_mask( | |||||||||||||||||||||||||||
mask: torch.Tensor, | ||||||||||||||||||||||||||||
padding: Union[int, List[int]], | ||||||||||||||||||||||||||||
padding_mode: str = "constant", | ||||||||||||||||||||||||||||
fill: Optional[Union[int, float]] = 0, | ||||||||||||||||||||||||||||
fill: Optional[Union[int, float, List[float]]] = None, | ||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||
if fill is None: | ||||||||||||||||||||||||||||
fill = 0 | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if isinstance(fill, list): | ||||||||||||||||||||||||||||
raise ValueError("Non-scalar fill value is not supported") | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if mask.ndim < 3: | ||||||||||||||||||||||||||||
mask = mask.unsqueeze(0) | ||||||||||||||||||||||||||||
needs_squeeze = True | ||||||||||||||||||||||||||||
|
@@ -693,10 +714,9 @@ def pad( | |||||||||||||||||||||||||||
if not isinstance(padding, int): | ||||||||||||||||||||||||||||
padding = list(padding) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour | ||||||||||||||||||||||||||||
if isinstance(fill, (int, float)) or fill is None: | ||||||||||||||||||||||||||||
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) | ||||||||||||||||||||||||||||
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode) | ||||||||||||||||||||||||||||
fill = _convert_fill_arg(fill) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
crop_image_tensor = _FT.crop | ||||||||||||||||||||||||||||
|
@@ -739,7 +759,7 @@ def perspective_image_tensor( | |||||||||||||||||||||||||||
img: torch.Tensor, | ||||||||||||||||||||||||||||
perspective_coeffs: List[float], | ||||||||||||||||||||||||||||
interpolation: InterpolationMode = InterpolationMode.BILINEAR, | ||||||||||||||||||||||||||||
fill: Optional[List[float]] = None, | ||||||||||||||||||||||||||||
fill: Optional[Union[int, float, List[float]]] = None, | ||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
@@ -878,7 +898,7 @@ def elastic_image_tensor( | |||||||||||||||||||||||||||
img: torch.Tensor, | ||||||||||||||||||||||||||||
displacement: torch.Tensor, | ||||||||||||||||||||||||||||
interpolation: InterpolationMode = InterpolationMode.BILINEAR, | ||||||||||||||||||||||||||||
fill: Optional[List[float]] = None, | ||||||||||||||||||||||||||||
fill: Optional[Union[int, float, List[float]]] = None, | ||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||
return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this kernel is private on stable and I agree with the change to ensure consistency. Any BC breakages that can creep upwards due to the change on the default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pad_image_tensor
is low-level new API so I think there is no BC to keep. In any case fill=None will be transformed to fill=0 insidevision/torchvision/transforms/functional_tensor.py
Lines 381 to 382 in c0911e3