From 515b823351f7c99dce29fb2fa0aefb3c93593cb8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 28 Nov 2022 12:07:54 +0000 Subject: [PATCH 1/2] Remove `_FT` usages from functional --- .../prototype/transforms/functional/_augment.py | 11 +++++++++-- .../prototype/transforms/functional/_color.py | 17 ++++++++++------- .../transforms/functional/_geometry.py | 16 ++++++++++++---- .../prototype/transforms/functional/_meta.py | 11 ++++++----- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index baa3e157385..3d121eb33fc 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -4,10 +4,17 @@ import torch from torchvision.prototype import features -from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image -erase_image_tensor = _FT.erase + +def erase_image_tensor( + image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> torch.Tensor: + if not inplace: + image = image.clone() + + image[..., i : i + h, j : j + w] = v + return image @torch.jit.unused diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 66805339cc8..723f6397f5d 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,7 +1,8 @@ import torch from torch.nn.functional import conv2d from torchvision.prototype import features -from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT +from torchvision.transforms import functional_pil as _FP +from torchvision.transforms.functional_tensor import _max_value from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor @@ -9,7 +10,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: ratio = float(ratio) fp = image1.is_floating_point() - bound = _FT._max_value(image1.dtype) + bound = _max_value(image1.dtype) output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound) return output if fp else output.to(image1.dtype) @@ -18,10 +19,12 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") - _FT._assert_channels(image, [1, 3]) + c = image.shape[-3] + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") fp = image.is_floating_point() - bound = _FT._max_value(image.dtype) + bound = _max_value(image.dtype) output = image.mul(brightness_factor).clamp_(0, bound) return output if fp else output.to(image.dtype) @@ -121,7 +124,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) if image.numel() == 0 or height <= 2 or width <= 2: return image - bound = _FT._max_value(image.dtype) + bound = _max_value(image.dtype) fp = image.is_floating_point() shape = image.shape @@ -350,7 +353,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: - if threshold > _FT._max_value(image.dtype): + if threshold > _max_value(image.dtype): raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") return torch.where(image >= threshold, invert_image_tensor(image), image) @@ -381,7 +384,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: # exit earlier on empty images return image - bound = _FT._max_value(image.dtype) + bound = _max_value(image.dtype) fp = image.is_floating_point() float_image = image if fp else image.to(torch.float32) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 41262185b5d..56fcdf52319 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -8,7 +8,7 @@ from torch.nn.functional import grid_sample, interpolate, pad as torch_pad from torchvision.prototype import features -from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT +from torchvision.transforms import functional_pil as _FP from torchvision.transforms.functional import ( _compute_resized_output_size as __compute_resized_output_size, _get_perspective_coeffs, @@ -17,10 +17,15 @@ pil_to_tensor, to_pil_image, ) +from torchvision.transforms.functional_tensor import _pad_symmetric from ._meta import convert_format_bounding_box, get_spatial_size_image_pil -horizontal_flip_image_tensor = _FT.hflip + +def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: + return image.flip(-1) + + horizontal_flip_image_pil = _FP.hflip @@ -58,7 +63,10 @@ def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return horizontal_flip_image_pil(inpt) -vertical_flip_image_tensor = _FT.vflip +def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: + return image.flip(-2) + + vertical_flip_image_pil = _FP.vflip @@ -975,7 +983,7 @@ def _pad_with_scalar_fill( if needs_cast: image = image.to(dtype) else: # padding_mode == "symmetric" - image = _FT._pad_symmetric(image, torch_padding) + image = _pad_symmetric(image, torch_padding) new_height, new_width = image.shape[-2:] diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index a2da77b1267..4f190d7b89e 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -4,7 +4,8 @@ import torch from torchvision.prototype import features from torchvision.prototype.features import BoundingBoxFormat, ColorSpace -from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT +from torchvision.transforms import functional_pil as _FP +from torchvision.transforms.functional_tensor import _max_value def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: @@ -193,7 +194,7 @@ def clamp_bounding_box( def _strip_alpha(image: torch.Tensor) -> torch.Tensor: image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3) - if not torch.all(alpha == _FT._max_value(alpha.dtype)): + if not torch.all(alpha == _max_value(alpha.dtype)): raise RuntimeError( "Stripping the alpha channel if it contains values other than the max value is not supported." ) @@ -204,7 +205,7 @@ def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> tor if alpha is None: shape = list(image.shape) shape[-3] = 1 - alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device) + alpha = torch.full(shape, _max_value(image.dtype), dtype=image.dtype, device=image.device) return torch.cat((image, alpha), dim=-3) @@ -363,14 +364,14 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f # Instead, we can also multiply by the maximum value plus something close to `1`. See # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. eps = 1e-3 - max_value = float(_FT._max_value(dtype)) + max_value = float(_max_value(dtype)) # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the # discrete set `{0, 1}`. return image.mul(max_value + 1.0 - eps).to(dtype) else: # int to float if float_output: - return image.to(dtype).mul_(1.0 / _FT._max_value(image.dtype)) + return image.to(dtype).mul_(1.0 / _max_value(image.dtype)) # int to int num_value_bits_input = _num_value_bits(image.dtype) From 270043dc25976948cf9f39d55b48c1e43b867e83 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 28 Nov 2022 13:28:57 +0000 Subject: [PATCH 2/2] Update error messages --- torchvision/prototype/transforms/functional/_color.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 723f6397f5d..fe09d3ba7bc 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -21,7 +21,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float c = image.shape[-3] if c not in [1, 3]: - raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") fp = image.is_floating_point() bound = _max_value(image.dtype) @@ -51,7 +51,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float c = image.shape[-3] if c not in [1, 3]: - raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") if c == 1: # Match PIL behaviour return image @@ -85,7 +85,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> c = image.shape[-3] if c not in [1, 3]: - raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") fp = image.is_floating_point() if c == 3: grayscale_image = _rgb_to_gray(image, cast=False) @@ -251,7 +251,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten c = image.shape[-3] if c not in [1, 3]: - raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") if c == 1: # Match PIL behaviour return image @@ -378,7 +378,7 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: c = image.shape[-3] if c not in [1, 3]: - raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") if image.numel() == 0: # exit earlier on empty images