diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index d3e1d94c96c..95f620f68b7 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -4,6 +4,7 @@ import PIL.Image import torch +from torch.nn.functional import interpolate from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms.functional import ( @@ -115,6 +116,12 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: bool = False, ) -> torch.Tensor: + align_corners: Optional[bool] = None + if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC: + align_corners = False + elif antialias: + raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") + shape = image.shape num_channels, old_height, old_width = shape[-3:] new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) @@ -122,13 +129,24 @@ def resize_image_tensor( if image.numel() > 0: image = image.reshape(-1, num_channels, old_height, old_width) - image = _FT.resize( + dtype = image.dtype + need_cast = dtype not in (torch.float32, torch.float64) + if need_cast: + image = image.to(dtype=torch.float32) + + image = interpolate( image, size=[new_height, new_width], - interpolation=interpolation.value, + mode=interpolation.value, + align_corners=align_corners, antialias=antialias, ) + if need_cast: + if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8: + image = image.clamp_(min=0, max=255) + image = image.round_().to(dtype=dtype) + return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) @@ -1312,9 +1330,11 @@ def resized_crop( def _parse_five_crop_size(size: List[int]) -> List[int]: if isinstance(size, numbers.Number): - size = [int(size), int(size)] + s = int(size) + size = [s, s] elif isinstance(size, (tuple, list)) and len(size) == 1: - size = [size[0], size[0]] + s = size[0] + size = [s, s] if len(size) != 2: raise ValueError("Please provide only two dimensions (h, w) for size.")