Skip to content

[prototype] Remove _FT aliases from functional #6983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions torchvision/prototype/transforms/functional/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@

import torch
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image

erase_image_tensor = _FT.erase

def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor:
if not inplace:
image = image.clone()

image[..., i : i + h, j : j + w] = v
return image


@torch.jit.unused
Expand Down
25 changes: 14 additions & 11 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import torch
from torch.nn.functional import conv2d
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value

from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
bound = _FT._max_value(image1.dtype)
bound = _max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype)

Expand All @@ -18,10 +19,12 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")

_FT._assert_channels(image, [1, 3])
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")

fp = image.is_floating_point()
bound = _FT._max_value(image.dtype)
bound = _max_value(image.dtype)
output = image.mul(brightness_factor).clamp_(0, bound)
return output if fp else output.to(image.dtype)

Expand All @@ -48,7 +51,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float

c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")

if c == 1: # Match PIL behaviour
return image
Expand Down Expand Up @@ -82,7 +85,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->

c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point()
if c == 3:
grayscale_image = _rgb_to_gray(image, cast=False)
Expand Down Expand Up @@ -121,7 +124,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
if image.numel() == 0 or height <= 2 or width <= 2:
return image

bound = _FT._max_value(image.dtype)
bound = _max_value(image.dtype)
fp = image.is_floating_point()
shape = image.shape

Expand Down Expand Up @@ -248,7 +251,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten

c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")

if c == 1: # Match PIL behaviour
return image
Expand Down Expand Up @@ -350,7 +353,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:


def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _FT._max_value(image.dtype):
if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")

return torch.where(image >= threshold, invert_image_tensor(image), image)
Expand All @@ -375,13 +378,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3]
if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")

if image.numel() == 0:
# exit earlier on empty images
return image

bound = _FT._max_value(image.dtype)
bound = _max_value(image.dtype)
fp = image.is_floating_point()
float_image = image if fp else image.to(torch.float32)

Expand Down
16 changes: 12 additions & 4 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad

from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional import (
_compute_resized_output_size as __compute_resized_output_size,
_get_perspective_coeffs,
Expand All @@ -17,10 +17,15 @@
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import _pad_symmetric

from ._meta import convert_format_bounding_box, get_spatial_size_image_pil

horizontal_flip_image_tensor = _FT.hflip

def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1)


horizontal_flip_image_pil = _FP.hflip


Expand Down Expand Up @@ -58,7 +63,10 @@ def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
return horizontal_flip_image_pil(inpt)


vertical_flip_image_tensor = _FT.vflip
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2)


vertical_flip_image_pil = _FP.vflip


Expand Down Expand Up @@ -975,7 +983,7 @@ def _pad_with_scalar_fill(
if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)
image = _pad_symmetric(image, torch_padding)

new_height, new_width = image.shape[-2:]

Expand Down
11 changes: 6 additions & 5 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value


def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
Expand Down Expand Up @@ -193,7 +194,7 @@ def clamp_bounding_box(

def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
if not torch.all(alpha == _max_value(alpha.dtype)):
raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported."
)
Expand All @@ -204,7 +205,7 @@ def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> tor
if alpha is None:
shape = list(image.shape)
shape[-3] = 1
alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device)
alpha = torch.full(shape, _max_value(image.dtype), dtype=image.dtype, device=image.device)
return torch.cat((image, alpha), dim=-3)


Expand Down Expand Up @@ -363,14 +364,14 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_FT._max_value(dtype))
max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).mul_(1.0 / _FT._max_value(image.dtype))
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))

# int to int
num_value_bits_input = _num_value_bits(image.dtype)
Expand Down