diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 4334d157e40..fab4cc0ddd6 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -389,7 +389,7 @@ def test__transform(self, padding, fill, padding_mode, mocker): inpt = mocker.MagicMock(spec=features.Image) _ = transform(inpt) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) if isinstance(padding, tuple): padding = list(padding) fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) @@ -405,14 +405,14 @@ def test__transform_image_mask(self, fill, mocker): _ = transform(inpt) if isinstance(fill, int): - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) calls = [ mocker.call(image, padding=1, fill=fill, padding_mode="constant"), mocker.call(mask, padding=1, fill=fill, padding_mode="constant"), ] else: - fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)]) - fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)]) + fill_img = transforms._utils._convert_fill_arg(fill[type(image)]) + fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)]) calls = [ mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"), mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"), @@ -466,7 +466,7 @@ def test__transform(self, fill, side_range, mocker): torch.rand(1) # random apply changes random state params = transform._get_params([inpt]) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, fill=fill) @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) @@ -485,14 +485,14 @@ def test__transform_image_mask(self, fill, mocker): params = transform._get_params(inpt) if isinstance(fill, int): - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) calls = [ mocker.call(image, **params, fill=fill), mocker.call(mask, **params, fill=fill), ] else: - fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)]) - fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)]) + fill_img = transforms._utils._convert_fill_arg(fill[type(image)]) + fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)]) calls = [ mocker.call(image, **params, fill=fill_img), mocker.call(mask, **params, fill=fill_mask), @@ -556,7 +556,7 @@ def test__transform(self, degrees, expand, fill, center, mocker): torch.manual_seed(12) params = transform._get_params(inpt) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center) @pytest.mark.parametrize("angle", [34, -87]) @@ -694,7 +694,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker torch.manual_seed(12) params = transform._get_params([inpt]) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) @@ -939,7 +939,7 @@ def test__transform(self, distortion_scale, mocker): torch.rand(1) # random apply changes random state params = transform._get_params([inpt]) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) @@ -1009,7 +1009,7 @@ def test__transform(self, alpha, sigma, mocker): transform._get_params = mocker.MagicMock() _ = transform(inpt) params = transform._get_params([inpt]) - fill = transforms.functional._geometry._convert_fill_arg(fill) + fill = transforms._utils._convert_fill_arg(fill) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) @@ -1632,7 +1632,7 @@ def test__transform(self, mocker, needs): if not needs_crop: assert args[0] is inpt_sentinel assert args[1] is padding_sentinel - fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel) + fill_sentinel = transforms._utils._convert_fill_arg(fill_sentinel) assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) else: mock_pad.assert_not_called() diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index b0022baaa37..a23783b0037 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -983,8 +983,6 @@ def _transform(self, inpt, params): return inpt fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) - return F.pad(inpt, padding=params["padding"], fill=fill) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 56d581eff9e..3714fc13682 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,5 +1,5 @@ import math -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import PIL.Image import torch @@ -11,9 +11,6 @@ from ._utils import _isinstance, _setup_fill_arg -K = TypeVar("K") -V = TypeVar("V") - class _AutoAugmentBase(Transform): def __init__( @@ -26,7 +23,7 @@ def __init__( self.interpolation = interpolation self.fill = _setup_fill_arg(fill) - def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: + def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: keys = tuple(dct.keys()) key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] @@ -71,10 +68,9 @@ def _apply_image_or_video_transform( transform_id: str, magnitude: float, interpolation: InterpolationMode, - fill: Dict[Type, features.FillType], + fill: Dict[Type, features.FillTypeJIT], ) -> Union[features.ImageType, features.VideoType]: fill_ = fill[type(image)] - fill_ = F._geometry._convert_fill_arg(fill_) if transform_id == "Identity": return image @@ -170,9 +166,7 @@ class AutoAugment(_AutoAugmentBase): "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) - .round() - .int(), + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), @@ -327,9 +321,7 @@ class RandAugment(_AutoAugmentBase): "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) - .round() - .int(), + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), @@ -383,9 +375,7 @@ class TrivialAugmentWide(_AutoAugmentBase): "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) - .round() - .int(), + lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), @@ -430,9 +420,7 @@ class AugMix(_AutoAugmentBase): "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True), "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True), "Posterize": ( - lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) - .round() - .int(), + lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(), False, ), "Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False), @@ -517,7 +505,13 @@ def forward(self, *inputs: Any) -> Any: aug = self._apply_image_or_video_transform( aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) - mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) + mix.add_( + # The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()` + # Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`. + # TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840 + combined_weights[:, i].reshape(batch_dims) + * aug + ) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) if isinstance(orig_image_or_video, (features.Image, features.Video)): diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 3647365c3fb..0dcf636c3db 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -51,7 +51,7 @@ def _check_input( @staticmethod def _generate_value(left: float, right: float) -> float: - return float(torch.distributions.Uniform(left, right).sample()) + return torch.empty(1).uniform_(left, right).item() def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: fn_idx = torch.randperm(4) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 440e23ab631..c5ab38d8418 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -223,20 +223,16 @@ def __init__( _check_padding_arg(padding) _check_padding_mode_arg(padding_mode) + # This cast does Sequence[int] -> List[int] and is required to make mypy happy + if not isinstance(padding, int): + padding = list(padding) self.padding = padding self.fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - - # This cast does Sequence[int] -> List[int] and is required to make mypy happy - padding = self.padding - if not isinstance(padding, int): - padding = list(padding) - - fill = F._geometry._convert_fill_arg(fill) - return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) + return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) class RandomZoomOut(_RandomApplyTransform): @@ -274,7 +270,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.pad(inpt, **params, fill=fill) @@ -300,12 +295,11 @@ def __init__( self.center = center def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.rotate( inpt, **params, @@ -358,7 +352,7 @@ def __init__( def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_spatial_size(flat_inputs) - angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) + angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() if self.translate is not None: max_dx = float(self.translate[0] * width) max_dy = float(self.translate[1] * height) @@ -369,22 +363,21 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: translate = (0, 0) if self.scale is not None: - scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()) + scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() else: scale = 1.0 shear_x = shear_y = 0.0 if self.shear is not None: - shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()) + shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() if len(self.shear) == 4: - shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()) + shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.affine( inpt, **params, @@ -478,8 +471,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) - inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: @@ -512,21 +503,23 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: half_height = height // 2 half_width = width // 2 + bound_height = int(distortion_scale * half_height) + 1 + bound_width = int(distortion_scale * half_width) + 1 topleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), ] topright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(0, bound_height, size=(1,))), ] botright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + int(torch.randint(width - bound_width, width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), ] botleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + int(torch.randint(0, bound_width, size=(1,))), + int(torch.randint(height - bound_height, height, size=(1,))), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] @@ -535,7 +528,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.perspective( inpt, **params, @@ -584,7 +576,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) return F.elastic( inpt, **params, @@ -855,7 +846,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = self.fill[type(inpt)] - fill = F._geometry._convert_fill_arg(fill) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index d4ee7387126..d0b11d53a8f 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -1,4 +1,4 @@ -from typing import Any, cast, Dict, Optional, Union +from typing import Any, Dict, Optional, Union import numpy as np import PIL.Image @@ -13,7 +13,7 @@ class DecodeImage(Transform): _transformed_types = (features.EncodedImage,) def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image: - return cast(features.Image, F.decode_image_with_pil(inpt)) + return F.decode_image_with_pil(inpt) # type: ignore[no-any-return] class LabelToOneHot(Transform): @@ -27,7 +27,7 @@ def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.O num_categories = self.num_categories if num_categories == -1 and inpt.categories is not None: num_categories = len(inpt.categories) - output = one_hot(inpt, num_classes=num_categories) + output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories) return features.OneHotLabel(output, categories=inpt.categories) def extra_repr(self) -> str: @@ -50,7 +50,7 @@ class ToImageTensor(Transform): def _transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] ) -> features.Image: - return cast(features.Image, F.to_image_tensor(inpt)) + return F.to_image_tensor(inpt) # type: ignore[no-any-return] class ToImagePIL(Transform): diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index cff439b8872..2272396f766 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -7,7 +7,7 @@ from torchvision._utils import sequence_to_str from torchvision.prototype import features -from torchvision.prototype.features._feature import FillType +from torchvision.prototype.features._feature import FillType, FillTypeJIT from torchvision.prototype.transforms.functional._meta import get_dimensions, get_spatial_size from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 @@ -37,9 +37,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: for key, value in fill.items(): # Check key for type _check_fill_arg(value) + if isinstance(fill, defaultdict) and callable(fill.default_factory): + default_value = fill.default_factory() + _check_fill_arg(default_value) else: if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate fill arg") + raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") T = TypeVar("T") @@ -55,13 +58,33 @@ def _get_defaultdict(default: T) -> Dict[Any, T]: return defaultdict(functools.partial(_default_arg, default)) -def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: +def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: + # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 + # So, we can't reassign fill to 0 + # if fill is None: + # fill = 0 + if fill is None: + return fill + + # This cast does Sequence -> List[float] to please mypy and torch.jit.script + if not isinstance(fill, (int, float)): + fill = [float(v) for v in list(fill)] + return fill + + +def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]: _check_fill_arg(fill) if isinstance(fill, dict): - return fill + for k, v in fill.items(): + fill[k] = _convert_fill_arg(v) + if isinstance(fill, defaultdict) and callable(fill.default_factory): + default_value = fill.default_factory() + sanitized_default = _convert_fill_arg(default_value) + fill.default_factory = functools.partial(_default_arg, sanitized_default) + return fill # type: ignore[return-value] - return _get_defaultdict(fill) + return _get_defaultdict(_convert_fill_arg(fill)) def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: @@ -80,7 +103,7 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox: - bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)} + bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)] if not bounding_boxes: raise TypeError("No bounding box was found in the sample") elif len(bounding_boxes) > 1: diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index a112db7e127..7f709b73b4b 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -470,20 +470,6 @@ def affine_video( ) -def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT: - # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 - # So, we can't reassign fill to 0 - # if fill is None: - # fill = 0 - if fill is None: - return fill - - # This cast does Sequence -> List[float] to please mypy and torch.jit.script - if not isinstance(fill, (int, float)): - fill = [float(v) for v in list(fill)] - return fill - - def affine( inpt: features.InputTypeJIT, angle: Union[int, float],