From 4b327baee99e1df89cd894e7b3978a3d85c8a3e7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Jan 2022 15:32:11 +0100 Subject: [PATCH 01/11] readd functional transforms --- .../prototype/features/_bounding_box.py | 11 +++ torchvision/prototype/features/_encoded.py | 9 ++- torchvision/prototype/features/_feature.py | 44 ++++++++++++ torchvision/prototype/features/_image.py | 21 +++--- .../transforms/functional/__init__.py | 29 ++++++++ .../transforms/functional/_augment.py | 41 +++++++++++ .../prototype/transforms/functional/_color.py | 22 ++++++ .../transforms/functional/_geometry.py | 70 +++++++++++++++++++ .../transforms/functional/_meta_conversion.py | 69 ++++++++++++++++++ .../prototype/transforms/functional/_misc.py | 5 ++ .../transforms/functional/_type_conversion.py | 25 +++++++ .../prototype/transforms/functional/utils.py | 51 ++++++++++++++ 12 files changed, 388 insertions(+), 9 deletions(-) create mode 100644 torchvision/prototype/transforms/functional/__init__.py create mode 100644 torchvision/prototype/transforms/functional/_augment.py create mode 100644 torchvision/prototype/transforms/functional/_color.py create mode 100644 torchvision/prototype/transforms/functional/_geometry.py create mode 100644 torchvision/prototype/transforms/functional/_meta_conversion.py create mode 100644 torchvision/prototype/transforms/functional/_misc.py create mode 100644 torchvision/prototype/transforms/functional/_type_conversion.py create mode 100644 torchvision/prototype/transforms/functional/utils.py diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 2d0685c2088..6c5dac72d53 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -36,3 +36,14 @@ def __new__( bounding_box._metadata.update(dict(format=format, image_size=image_size)) return bounding_box + + def to_format(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": + # import at runtime to avoid cyclic imports + from torchvision.prototype.transforms.functional import convert_bounding_box_format + + if isinstance(format, str): + format = BoundingBoxFormat[format] + + return BoundingBox.new_like( + self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format + ) diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 9160b5e36e1..76ea907067a 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -1,12 +1,13 @@ import os import sys -from typing import BinaryIO, Tuple, Type, TypeVar, Union +from typing import BinaryIO, Tuple, Type, TypeVar, cast, Union import PIL.Image import torch from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer from ._feature import Feature +from ._image import Image D = TypeVar("D", bound="EncodedData") @@ -37,6 +38,12 @@ def image_size(self) -> Tuple[int, int]: return self._image_size + def decode(self) -> Image: + # import at runtime to avoid cyclic imports + from torchvision.prototype.transforms.functional import decode_image_with_pil + + return cast(Image, decode_image_with_pil(self)) + class EncodedVideo(EncodedData): pass diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 38fff2da04a..57e118b5755 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -10,6 +10,7 @@ class Feature(torch.Tensor): _META_ATTRS: Set[str] = set() _metadata: Dict[str, Any] + _KERNELS: Dict[Callable, Callable] def __init_subclass__(cls): # In order to help static type checkers, we require subclasses of `Feature` to add the metadata attributes @@ -37,6 +38,8 @@ def __init_subclass__(cls): for name in meta_attrs: setattr(cls, name, property(lambda self, name=name: self._metadata[name])) + cls._KERNELS = {} + def __new__(cls, data, *, dtype=None, device=None): feature = torch.Tensor._make_subclass( cast(_TensorBase, cls), @@ -57,5 +60,46 @@ def new_like(cls, other, data, *, dtype=None, device=None, **metadata): metadata.setdefault(name, getattr(other, name)) return cls(data, dtype=dtype or other.dtype, device=device or other.device, **metadata) + _TORCH_FUNCTION_ALLOW_MAP = { + torch.Tensor.clone: (0,), + torch.stack: (0, 0), + torch.Tensor.to: (0,), + } + + _DTYPE_CONVERTERS = { + torch.Tensor.to, + } + + _DEVICE_CONVERTERS = { + torch.Tensor.to, + } + + @classmethod + def __torch_function__( + cls, + func: Callable[..., torch.Tensor], + types: Tuple[Type[torch.Tensor], ...], + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, + ) -> torch.Tensor: + kwargs = kwargs or dict() + if cls is not Feature and func in cls._KERNELS: + return cls._KERNELS[func](*args, **kwargs) + + with DisableTorchFunction(): + output = func(*args, **kwargs) + + if func not in cls._TORCH_FUNCTION_ALLOW_MAP: + return output + + other = args + for item in cls._TORCH_FUNCTION_ALLOW_MAP[func]: + other = other[item] + + dtype = output.dtype if func in cls._DTYPE_CONVERTERS else None + device = output.device if func in cls._DTYPE_CONVERTERS else None + + return cls.new_like(other, output, dtype=dtype, device=device) + def __repr__(self): return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 93a9b517235..aee3a5f946f 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -4,8 +4,10 @@ import torch from torchvision.prototype.utils._internal import StrEnum from torchvision.transforms.functional import to_pil_image +from torchvision.utils import draw_bounding_boxes from torchvision.utils import make_grid +from ._bounding_box import BoundingBox from ._feature import Feature @@ -51,14 +53,6 @@ def _to_tensor(cls, data, *, dtype, device): tensor = tensor.unsqueeze(0) return tensor - @property - def image_size(self) -> Tuple[int, int]: - return cast(Tuple[int, int], self.shape[-2:]) - - @property - def num_channels(self) -> int: - return self.shape[-3] - @staticmethod def guess_color_space(data: torch.Tensor) -> ColorSpace: if data.ndim < 2: @@ -76,3 +70,14 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: def show(self) -> None: to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show() + + @property + def image_size(self) -> Tuple[int, int]: + return cast(Tuple[int, int], self.shape[-2:]) + + @property + def num_channels(self) -> int: + return self.shape[-3] + + def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> "Image": + return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py new file mode 100644 index 00000000000..be8813e3fe0 --- /dev/null +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -0,0 +1,29 @@ +from . import utils # usort: skip + +from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label +from ._color import ( + adjust_brightness_image, + adjust_contrast_image, + adjust_saturation_image, + adjust_sharpness_image, + posterize_image, + solarize_image, + autocontrast_image, + equalize_image, + invert_image, +) +from ._geometry import ( + horizontal_flip_bounding_box, + horizontal_flip_image, + resize_bounding_box, + resize_image, + resize_segmentation_mask, + center_crop_image, + resized_crop_image, + InterpolationMode, + affine_image, + rotate_image, +) +from ._meta_conversion import convert_color_space, convert_bounding_box_format +from ._misc import normalize_image +from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py new file mode 100644 index 00000000000..5f4fea5b446 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -0,0 +1,41 @@ +from typing import Tuple + +import torch +from torchvision.transforms import functional as _F + +from .utils import _from_legacy_kernel + +erase_image = _from_legacy_kernel(_F.erase) + + +def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: + if not inplace: + input = input.clone() + + input_rolled = input.roll(1, batch_dim) + return input.mul_(lam).add_(input_rolled.mul_(1 - lam)) + + +def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: + return _mixup(image_batch, -4, lam, inplace) + + +def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: + return _mixup(one_hot_label_batch, -2, lam, inplace) + + +def cutmix_image(image: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: + if not inplace: + image = image.clone() + + x1, y1, x2, y2 = box + image_rolled = image.roll(1, -4) + + image[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return image + + +def cutmix_one_hot_label( + one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False +) -> torch.Tensor: + return mixup_one_hot_label(one_hot_label_batch, lam=lam_adjusted, inplace=inplace) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py new file mode 100644 index 00000000000..c9eeb0675af --- /dev/null +++ b/torchvision/prototype/transforms/functional/_color.py @@ -0,0 +1,22 @@ +from torchvision.transforms import functional_tensor as _FT + +from .utils import _from_legacy_kernel + + +adjust_brightness_image = _from_legacy_kernel(_FT.adjust_brightness) + +adjust_saturation_image = _from_legacy_kernel(_FT.adjust_saturation) + +adjust_contrast_image = _from_legacy_kernel(_FT.adjust_contrast) + +adjust_sharpness_image = _from_legacy_kernel(_FT.adjust_sharpness) + +posterize_image = _from_legacy_kernel(_FT.posterize) + +solarize_image = _from_legacy_kernel(_FT.solarize) + +autocontrast_image = _from_legacy_kernel(_FT.autocontrast) + +equalize_image = _from_legacy_kernel(_FT.equalize) + +invert_image = _from_legacy_kernel(_FT.invert) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py new file mode 100644 index 00000000000..68f44c31dfc --- /dev/null +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -0,0 +1,70 @@ +from typing import Tuple, List, Optional + +import torch +from torchvision.prototype.features import BoundingBoxFormat +from torchvision.transforms import ( # noqa: F401 + functional as _F, + functional_tensor as _FT, + InterpolationMode, +) + +from ._meta_conversion import convert_bounding_box_format +from .utils import _from_legacy_kernel + + +def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: + return image.flip((-1,)) + + +def horizontal_flip_bounding_box(bounding_box: torch.Tensor, *, image_size: Tuple[int, int]) -> torch.Tensor: + x, y, w, h = convert_bounding_box_format( + bounding_box, + old_format=BoundingBoxFormat.XYXY, + new_format=BoundingBoxFormat.XYWH, + ).unbind(-1) + x = image_size[1] - (x + w) + return convert_bounding_box_format( + torch.stack((x, y, w, h), dim=-1), + old_format=BoundingBoxFormat.XYWH, + new_format=BoundingBoxFormat.XYXY, + ) + + +resize_image = _from_legacy_kernel(_FT.resize) + + +def resize_segmentation_mask( + segmentation_mask: torch.Tensor, + size: List[int], + interpolation: str = "nearest", + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + return resize_image( + segmentation_mask, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias + ) + + +# TODO: handle max_size +def resize_bounding_box( + bounding_box: torch.Tensor, + *, + old_image_size: List[int], + new_image_size: List[int], +) -> torch.Tensor: + old_height, old_width = old_image_size + new_height, new_width = new_image_size + return ( + bounding_box.view(-1, 2, 2) + .mul(torch.tensor([new_width / old_width, new_height / old_height])) + .view(bounding_box.shape) + ) + + +center_crop_image = _from_legacy_kernel(_F.center_crop) + +resized_crop_image = _from_legacy_kernel(_F.resized_crop) + +affine_image = _from_legacy_kernel(_F.affine) + +rotate_image = _from_legacy_kernel(_F.rotate) diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/functional/_meta_conversion.py new file mode 100644 index 00000000000..484066a39ee --- /dev/null +++ b/torchvision/prototype/transforms/functional/_meta_conversion.py @@ -0,0 +1,69 @@ +import torch +from torchvision.prototype.features import BoundingBoxFormat, ColorSpace +from torchvision.transforms.functional_tensor import rgb_to_grayscale as _rgb_to_grayscale + + +def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: + xyxy = xywh.clone() + xyxy[..., 2:] += xyxy[..., :2] + return xyxy + + +def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: + xywh = xyxy.clone() + xywh[..., 2:] -= xywh[..., :2] + return xywh + + +def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor: + cx, cy, w, h = torch.unbind(cxcywh, dim=-1) + x1 = cx - 0.5 * w + y1 = cy - 0.5 * h + x2 = cx + 0.5 * w + y2 = cy + 0.5 * h + return torch.stack((x1, y1, x2, y2), dim=-1) + + +def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: + x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1) + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + return torch.stack((cx, cy, w, h), dim=-1) + + +def convert_bounding_box_format( + bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat +) -> torch.Tensor: + if new_format == old_format: + return bounding_box + + if old_format == BoundingBoxFormat.XYWH: + bounding_box = _xywh_to_xyxy(bounding_box) + elif old_format == BoundingBoxFormat.CXCYWH: + bounding_box = _cxcywh_to_xyxy(bounding_box) + + if new_format == BoundingBoxFormat.XYWH: + bounding_box = _xyxy_to_xywh(bounding_box) + elif new_format == BoundingBoxFormat.CXCYWH: + bounding_box = _xyxy_to_cxcywh(bounding_box) + + return bounding_box + + +def _grayscale_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: + return grayscale.expand(3, 1, 1) + + +def convert_color_space(image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace) -> torch.Tensor: + if new_color_space == old_color_space: + return image + + if old_color_space == ColorSpace.GRAYSCALE: + image = _grayscale_to_rgb(image) + + if new_color_space == ColorSpace.GRAYSCALE: + image = _rgb_to_grayscale(image) + + return image diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py new file mode 100644 index 00000000000..0fa3d81e190 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -0,0 +1,5 @@ +from torchvision.transforms import functional as _F + +from .utils import _from_legacy_kernel + +normalize_image = _from_legacy_kernel(_F.normalize) diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py new file mode 100644 index 00000000000..ed355ab5eae --- /dev/null +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -0,0 +1,25 @@ +import unittest.mock +from typing import Dict, Any, Tuple + +import numpy as np +import PIL.Image +import torch +from torch.nn.functional import one_hot +from torchvision.io.video import read_video +from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer + + +def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor: + image = torch.as_tensor(np.array(PIL.Image.open(ReadOnlyTensorBuffer(encoded_image)), copy=True)) + if image.ndim == 2: + image = image.unsqueeze(2) + return image.permute(2, 0, 1) + + +def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True): + return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type] + + +def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor: + return one_hot(label, num_classes=num_categories) diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py new file mode 100644 index 00000000000..3af6e841be2 --- /dev/null +++ b/torchvision/prototype/transforms/functional/utils.py @@ -0,0 +1,51 @@ +import types +from typing import Any, Type + +import torch +import torch.overrides +from torchvision.prototype import features + + +def is_supported(obj: Any, *types: Type) -> bool: + return (obj if isinstance(obj, type) else type(obj)) in types + + +class Dispatcher: + def __init__(self, dispatch_fn): + self._dispatch_fn = dispatch_fn + self._support = set() + + def supports(self, obj: Any) -> bool: + return is_supported(obj, *self._support) + + def implements(self, feature_type): + def wrapper(implement_fn): + feature_type._KERNELS[self._dispatch_fn] = implement_fn + self._support.add(feature_type) + return implement_fn + + return wrapper + + def __call__(self, input, *args, **kwargs): + if not isinstance(input, torch.Tensor): + raise ValueError("No tensor") + + if not (issubclass(type(input), features.Feature)): + input = features.Image(input) + + if not self.supports(input): + raise ValueError(f"No support for {type(input).__name__}") + + return torch.overrides.handle_torch_function(self._dispatch_fn, (input,), input, *args, **kwargs) + + +def _from_legacy_kernel(legacy_kernel, new_name=None): + kernel = types.FunctionType( + code=legacy_kernel.__code__, + globals=legacy_kernel.__globals__, + name=new_name or f"{legacy_kernel.__name__}_image", + argdefs=legacy_kernel.__defaults__, + closure=legacy_kernel.__closure__, + ) + kernel.__annotations__ = legacy_kernel.__annotations__ + return kernel From eaac4f229f9cc9da1ca319edc3911bca4c1e7025 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 09:04:51 +0100 Subject: [PATCH 02/11] cleanup --- torchvision/prototype/features/_encoded.py | 4 +- torchvision/prototype/features/_image.py | 16 ++++---- .../prototype/transforms/functional/utils.py | 38 ------------------- 3 files changed, 10 insertions(+), 48 deletions(-) diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 76ea907067a..338b2d2230d 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -1,6 +1,6 @@ import os import sys -from typing import BinaryIO, Tuple, Type, TypeVar, cast, Union +from typing import BinaryIO, Tuple, Type, TypeVar, Union import PIL.Image import torch @@ -42,7 +42,7 @@ def decode(self) -> Image: # import at runtime to avoid cyclic imports from torchvision.prototype.transforms.functional import decode_image_with_pil - return cast(Image, decode_image_with_pil(self)) + return Image(decode_image_with_pil(self)) class EncodedVideo(EncodedData): diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index aee3a5f946f..a07da277314 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -53,6 +53,14 @@ def _to_tensor(cls, data, *, dtype, device): tensor = tensor.unsqueeze(0) return tensor + @property + def image_size(self) -> Tuple[int, int]: + return cast(Tuple[int, int], self.shape[-2:]) + + @property + def num_channels(self) -> int: + return self.shape[-3] + @staticmethod def guess_color_space(data: torch.Tensor) -> ColorSpace: if data.ndim < 2: @@ -71,13 +79,5 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: def show(self) -> None: to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show() - @property - def image_size(self) -> Tuple[int, int]: - return cast(Tuple[int, int], self.shape[-2:]) - - @property - def num_channels(self) -> int: - return self.shape[-3] - def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> "Image": return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index 3af6e841be2..495e9efa502 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -1,42 +1,4 @@ import types -from typing import Any, Type - -import torch -import torch.overrides -from torchvision.prototype import features - - -def is_supported(obj: Any, *types: Type) -> bool: - return (obj if isinstance(obj, type) else type(obj)) in types - - -class Dispatcher: - def __init__(self, dispatch_fn): - self._dispatch_fn = dispatch_fn - self._support = set() - - def supports(self, obj: Any) -> bool: - return is_supported(obj, *self._support) - - def implements(self, feature_type): - def wrapper(implement_fn): - feature_type._KERNELS[self._dispatch_fn] = implement_fn - self._support.add(feature_type) - return implement_fn - - return wrapper - - def __call__(self, input, *args, **kwargs): - if not isinstance(input, torch.Tensor): - raise ValueError("No tensor") - - if not (issubclass(type(input), features.Feature)): - input = features.Image(input) - - if not self.supports(input): - raise ValueError(f"No support for {type(input).__name__}") - - return torch.overrides.handle_torch_function(self._dispatch_fn, (input,), input, *args, **kwargs) def _from_legacy_kernel(legacy_kernel, new_name=None): From 079a00548e658f7cbd4d59556d5a06a851eff2dd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 09:07:16 +0100 Subject: [PATCH 03/11] add missing imports --- torchvision/prototype/features/_feature.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 57e118b5755..59387cdb16b 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,7 +1,7 @@ -from typing import Any, cast, Dict, Set, TypeVar +from typing import Any, cast, Dict, Set, TypeVar, Callable, Tuple, Type, Sequence, Optional, Mapping import torch -from torch._C import _TensorBase +from torch._C import _TensorBase, DisableTorchFunction F = TypeVar("F", bound="Feature") From c78f6f7f842ee1cc7fca9466aa28b498670e0624 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 09:09:43 +0100 Subject: [PATCH 04/11] remove __torch_function__ dispatch --- torchvision/prototype/features/_feature.py | 52 +--------------------- 1 file changed, 2 insertions(+), 50 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 59387cdb16b..77648679250 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,8 +1,7 @@ -from typing import Any, cast, Dict, Set, TypeVar, Callable, Tuple, Type, Sequence, Optional, Mapping +from typing import Any, cast, Dict, Set, TypeVar import torch -from torch._C import _TensorBase, DisableTorchFunction - +from torch._C import _TensorBase F = TypeVar("F", bound="Feature") @@ -10,7 +9,6 @@ class Feature(torch.Tensor): _META_ATTRS: Set[str] = set() _metadata: Dict[str, Any] - _KERNELS: Dict[Callable, Callable] def __init_subclass__(cls): # In order to help static type checkers, we require subclasses of `Feature` to add the metadata attributes @@ -38,8 +36,6 @@ def __init_subclass__(cls): for name in meta_attrs: setattr(cls, name, property(lambda self, name=name: self._metadata[name])) - cls._KERNELS = {} - def __new__(cls, data, *, dtype=None, device=None): feature = torch.Tensor._make_subclass( cast(_TensorBase, cls), @@ -59,47 +55,3 @@ def new_like(cls, other, data, *, dtype=None, device=None, **metadata): for name in cls._META_ATTRS: metadata.setdefault(name, getattr(other, name)) return cls(data, dtype=dtype or other.dtype, device=device or other.device, **metadata) - - _TORCH_FUNCTION_ALLOW_MAP = { - torch.Tensor.clone: (0,), - torch.stack: (0, 0), - torch.Tensor.to: (0,), - } - - _DTYPE_CONVERTERS = { - torch.Tensor.to, - } - - _DEVICE_CONVERTERS = { - torch.Tensor.to, - } - - @classmethod - def __torch_function__( - cls, - func: Callable[..., torch.Tensor], - types: Tuple[Type[torch.Tensor], ...], - args: Sequence[Any] = (), - kwargs: Optional[Mapping[str, Any]] = None, - ) -> torch.Tensor: - kwargs = kwargs or dict() - if cls is not Feature and func in cls._KERNELS: - return cls._KERNELS[func](*args, **kwargs) - - with DisableTorchFunction(): - output = func(*args, **kwargs) - - if func not in cls._TORCH_FUNCTION_ALLOW_MAP: - return output - - other = args - for item in cls._TORCH_FUNCTION_ALLOW_MAP[func]: - other = other[item] - - dtype = output.dtype if func in cls._DTYPE_CONVERTERS else None - device = output.device if func in cls._DTYPE_CONVERTERS else None - - return cls.new_like(other, output, dtype=dtype, device=device) - - def __repr__(self): - return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) From 45ce1f90924c036de22cabf97860d9697838e27a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 09:10:51 +0100 Subject: [PATCH 05/11] readd repr --- torchvision/prototype/features/_feature.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 77648679250..76b6a8f2d87 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -55,3 +55,6 @@ def new_like(cls, other, data, *, dtype=None, device=None, **metadata): for name in cls._META_ATTRS: metadata.setdefault(name, getattr(other, name)) return cls(data, dtype=dtype or other.dtype, device=device or other.device, **metadata) + + def __repr__(self): + return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) From 0b81a23f7359b846cf7b8cdff04029267e6c10c9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 09:11:19 +0100 Subject: [PATCH 06/11] readd empty line --- torchvision/prototype/features/_feature.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 76b6a8f2d87..38fff2da04a 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -3,6 +3,7 @@ import torch from torch._C import _TensorBase + F = TypeVar("F", bound="Feature") From 265b3d45626cff3322db6b9c7e75f25b87a0a77b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 12:02:03 +0100 Subject: [PATCH 07/11] add test for scriptability --- test/test_prototype_transforms_functional.py | 197 ++++++++++++++++++ .../transforms/functional/_geometry.py | 20 +- 2 files changed, 216 insertions(+), 1 deletion(-) create mode 100644 test/test_prototype_transforms_functional.py diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py new file mode 100644 index 00000000000..0627ed14d6e --- /dev/null +++ b/test/test_prototype_transforms_functional.py @@ -0,0 +1,197 @@ +import functools +import itertools + +import pytest +import torch.testing +import torchvision.prototype.transforms.functional as F +from torch import jit +from torchvision.prototype import features + +make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") + + +def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32): + size = size or torch.randint(16, 33, (2,)).tolist() + + if isinstance(color_space, str): + color_space = features.ColorSpace[color_space] + num_channels = { + features.ColorSpace.GRAYSCALE: 1, + features.ColorSpace.RGB: 3, + }[color_space] + + shape = (*extra_dims, num_channels, *size) + if dtype.is_floating_point: + data = torch.rand(shape, dtype=dtype) + else: + data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype) + return features.Image(data, color_space=color_space) + + +make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE) +make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB) + + +def make_images( + sizes=((16, 16), (7, 33), (31, 9)), + color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB), + dtypes=(torch.float32, torch.uint8), + extra_dims=((4,), (2, 3)), +): + for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes): + yield make_image(size, color_space=color_space) + + for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims): + yield make_image(color_space=color_space, extra_dims=extra_dims_) + + +def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): + low, high = torch.broadcast_tensors( + *[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] + ) + try: + return torch.stack( + [ + torch.randint(low_scalar, high_scalar, (), **kwargs) + for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist()) + ] + ).reshape(low.shape) + except RuntimeError as error: + raise error + + +def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64): + if isinstance(format, str): + format = features.BoundingBoxFormat[format] + + height, width = image_size + + if format == features.BoundingBoxFormat.XYXY: + x1 = torch.randint(0, width // 2, extra_dims) + y1 = torch.randint(0, height // 2, extra_dims) + x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 + y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 + parts = (x1, y1, x2, y2) + elif format == features.BoundingBoxFormat.XYWH: + x = torch.randint(0, width // 2, extra_dims) + y = torch.randint(0, height // 2, extra_dims) + w = randint_with_tensor_bounds(1, width - x) + h = randint_with_tensor_bounds(1, height - y) + parts = (x, y, w, h) + elif format == features.BoundingBoxFormat.CXCYWH: + cx = torch.randint(1, width - 1, ()) + cy = torch.randint(1, height - 1, ()) + w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) + h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1) + parts = (cx, cy, w, h) + else: # format == features.BoundingBoxFormat._SENTINEL: + raise ValueError() + + return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size) + + +make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY) + + +def make_bounding_boxes( + formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH), + image_sizes=((32, 32),), + dtypes=(torch.int64, torch.float32), + extra_dims=((4,), (2, 3)), +): + for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes): + yield make_bounding_box(format=format, image_size=image_size, dtype=dtype) + + for format, extra_dims_ in itertools.product(formats, extra_dims): + yield make_bounding_box(format=format, extra_dims=extra_dims_) + + +class SampleInput: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +class KernelInfo: + def __init__(self, name, *, sample_inputs_fn): + self.name = name + self.kernel = getattr(F, name) + self._sample_inputs_fn = sample_inputs_fn + + def sample_inputs(self): + yield from self._sample_inputs_fn() + + def __call__(self, *args, **kwargs): + if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput): + sample_input = args[0] + return self.kernel(*sample_input.args, **sample_input.kwargs) + + return self.kernel(*args, **kwargs) + + +KERNEL_INFOS = [] + + +def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): + KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) + return sample_inputs_fn + + +@register_kernel_info_from_sample_inputs_fn +def horizontal_flip_image(): + for image in make_images(): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def horizontal_flip_bounding_box(): + for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): + yield SampleInput(bounding_box, image_size=bounding_box.image_size) + + +@register_kernel_info_from_sample_inputs_fn +def resize_image(): + for image, interpolation in itertools.product( + make_images(), + [ + "bilinear", + "nearest", + ], + ): + height, width = image.shape[-2:] + for size in [ + (height, width), + (int(height * 0.75), int(width * 1.25)), + ]: + yield SampleInput(image, size=size, interpolation=interpolation) + + +@register_kernel_info_from_sample_inputs_fn +def resize_bounding_box(): + for bounding_box in make_bounding_boxes(): + height, width = bounding_box.image_size + for new_image_size in [ + (height, width), + (int(height * 0.75), int(width * 1.25)), + ]: + yield SampleInput(bounding_box, old_image_size=bounding_box.image_size, new_image_size=new_image_size) + + +class TestKernelsCommon: + @pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name) + def test_scriptable(self, kernel_info): + jit.script(kernel_info.kernel) + + @pytest.mark.parametrize( + ("kernel_info", "sample_input"), + [ + pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}") + for kernel_info in KERNEL_INFOS + for idx, sample_input in enumerate(kernel_info.sample_inputs()) + ], + ) + def test_eager_vs_scripted(self, kernel_info, sample_input): + eager = kernel_info(sample_input) + scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs) + + torch.testing.assert_close(eager, scripted) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 68f44c31dfc..44d170fbbd0 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -30,7 +30,25 @@ def horizontal_flip_bounding_box(bounding_box: torch.Tensor, *, image_size: Tupl ) -resize_image = _from_legacy_kernel(_FT.resize) +_resize_image = _from_legacy_kernel(_FT.resize) + + +def resize_image( + image: torch.Tensor, + size: List[int], + interpolation: str = "bilinear", + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + num_channels, old_height, old_width = image.shape[-3:] + batch_shape = image.shape[:-3] + return _resize_image( + image.reshape(-1, num_channels, old_height, old_width), + size=size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ).reshape(batch_shape + (num_channels,) + size) def resize_segmentation_mask( From 2a7983ee40681a46713bc711669bcbdd79f5931d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 15:07:14 +0100 Subject: [PATCH 08/11] remove function copy --- .../transforms/functional/_augment.py | 3 +-- .../prototype/transforms/functional/_color.py | 20 +++++++++---------- .../transforms/functional/_geometry.py | 11 +++++----- .../prototype/transforms/functional/_misc.py | 3 +-- .../prototype/transforms/functional/utils.py | 13 ------------ 5 files changed, 16 insertions(+), 34 deletions(-) delete mode 100644 torchvision/prototype/transforms/functional/utils.py diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 5f4fea5b446..814c34e5b00 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -3,9 +3,8 @@ import torch from torchvision.transforms import functional as _F -from .utils import _from_legacy_kernel -erase_image = _from_legacy_kernel(_F.erase) +erase_image = _F.erase def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index c9eeb0675af..6dccbc6442e 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,22 +1,20 @@ from torchvision.transforms import functional_tensor as _FT -from .utils import _from_legacy_kernel +adjust_brightness_image = _FT.adjust_brightness -adjust_brightness_image = _from_legacy_kernel(_FT.adjust_brightness) +adjust_saturation_image = _FT.adjust_saturation -adjust_saturation_image = _from_legacy_kernel(_FT.adjust_saturation) +adjust_contrast_image = _FT.adjust_contrast -adjust_contrast_image = _from_legacy_kernel(_FT.adjust_contrast) +adjust_sharpness_image = _FT.adjust_sharpness -adjust_sharpness_image = _from_legacy_kernel(_FT.adjust_sharpness) +posterize_image = _FT.posterize -posterize_image = _from_legacy_kernel(_FT.posterize) +solarize_image = _FT.solarize -solarize_image = _from_legacy_kernel(_FT.solarize) +autocontrast_image = _FT.autocontrast -autocontrast_image = _from_legacy_kernel(_FT.autocontrast) +equalize_image = _FT.equalize -equalize_image = _from_legacy_kernel(_FT.equalize) - -invert_image = _from_legacy_kernel(_FT.invert) +invert_image = _FT.invert diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 44d170fbbd0..93811e888a7 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -9,7 +9,6 @@ ) from ._meta_conversion import convert_bounding_box_format -from .utils import _from_legacy_kernel def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: @@ -30,7 +29,7 @@ def horizontal_flip_bounding_box(bounding_box: torch.Tensor, *, image_size: Tupl ) -_resize_image = _from_legacy_kernel(_FT.resize) +_resize_image = _FT.resize def resize_image( @@ -79,10 +78,10 @@ def resize_bounding_box( ) -center_crop_image = _from_legacy_kernel(_F.center_crop) +center_crop_image = _F.center_crop -resized_crop_image = _from_legacy_kernel(_F.resized_crop) +resized_crop_image = _F.resized_crop -affine_image = _from_legacy_kernel(_F.affine) +affine_image = _F.affine -rotate_image = _from_legacy_kernel(_F.rotate) +rotate_image = _F.rotate diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 0fa3d81e190..de148ab194a 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,5 +1,4 @@ from torchvision.transforms import functional as _F -from .utils import _from_legacy_kernel -normalize_image = _from_legacy_kernel(_F.normalize) +normalize_image = _F.normalize diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py deleted file mode 100644 index 495e9efa502..00000000000 --- a/torchvision/prototype/transforms/functional/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -import types - - -def _from_legacy_kernel(legacy_kernel, new_name=None): - kernel = types.FunctionType( - code=legacy_kernel.__code__, - globals=legacy_kernel.__globals__, - name=new_name or f"{legacy_kernel.__name__}_image", - argdefs=legacy_kernel.__defaults__, - closure=legacy_kernel.__closure__, - ) - kernel.__annotations__ = legacy_kernel.__annotations__ - return kernel From 1392e786a074b1ac7ac5f2628212251a44807e72 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 15:08:22 +0100 Subject: [PATCH 09/11] change import from functional tensor transforms to just functional --- .../prototype/transforms/functional/_color.py | 20 +++++++++---------- .../transforms/functional/_geometry.py | 1 - 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 6dccbc6442e..f2529166d9a 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,20 +1,20 @@ -from torchvision.transforms import functional_tensor as _FT +from torchvision.transforms import functional as _F -adjust_brightness_image = _FT.adjust_brightness +adjust_brightness_image = _F.adjust_brightness -adjust_saturation_image = _FT.adjust_saturation +adjust_saturation_image = _F.adjust_saturation -adjust_contrast_image = _FT.adjust_contrast +adjust_contrast_image = _F.adjust_contrast -adjust_sharpness_image = _FT.adjust_sharpness +adjust_sharpness_image = _F.adjust_sharpness -posterize_image = _FT.posterize +posterize_image = _F.posterize -solarize_image = _FT.solarize +solarize_image = _F.solarize -autocontrast_image = _FT.autocontrast +autocontrast_image = _F.autocontrast -equalize_image = _FT.equalize +equalize_image = _F.equalize -invert_image = _FT.invert +invert_image = _F.invert diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 93811e888a7..27764cf9fb6 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -4,7 +4,6 @@ from torchvision.prototype.features import BoundingBoxFormat from torchvision.transforms import ( # noqa: F401 functional as _F, - functional_tensor as _FT, InterpolationMode, ) From 249cfb3b023edc695c01fc6ebf89a8b384d8dc6f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 15:59:32 +0100 Subject: [PATCH 10/11] fix import --- torchvision/prototype/transforms/functional/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index be8813e3fe0..087f2fb2ac0 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,5 +1,3 @@ -from . import utils # usort: skip - from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label from ._color import ( adjust_brightness_image, From a77bf1f0d095a8cb67e9fb222b1dc641f76c0c3e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jan 2022 17:46:29 +0100 Subject: [PATCH 11/11] fix test --- test/test_prototype_transforms_functional.py | 4 ++-- torchvision/prototype/transforms/__init__.py | 4 +++- .../prototype/transforms/functional/_geometry.py | 11 ++++++----- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 0627ed14d6e..53776e1a8a4 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -154,8 +154,8 @@ def resize_image(): for image, interpolation in itertools.product( make_images(), [ - "bilinear", - "nearest", + F.InterpolationMode.BILINEAR, + F.InterpolationMode.NEAREST, ], ): height, width = image.shape[-2:] diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 56cca7b0402..963bdebc7ed 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,6 +1,8 @@ +from . import functional +from .functional import InterpolationMode # usort: skip + from ._transform import Transform from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip - from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop from ._misc import Identity, Normalize from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 27764cf9fb6..c8142742fa8 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -28,31 +28,32 @@ def horizontal_flip_bounding_box(bounding_box: torch.Tensor, *, image_size: Tupl ) -_resize_image = _FT.resize +_resize_image = _F.resize def resize_image( image: torch.Tensor, size: List[int], - interpolation: str = "bilinear", + interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: + new_height, new_width = size num_channels, old_height, old_width = image.shape[-3:] batch_shape = image.shape[:-3] return _resize_image( - image.reshape(-1, num_channels, old_height, old_width), + image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation, max_size=max_size, antialias=antialias, - ).reshape(batch_shape + (num_channels,) + size) + ).reshape(batch_shape + (num_channels, new_height, new_width)) def resize_segmentation_mask( segmentation_mask: torch.Tensor, size: List[int], - interpolation: str = "nearest", + interpolation: InterpolationMode = InterpolationMode.NEAREST, max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: