Skip to content

[prototype] Minor speed and nit optimizations on Transform Classes #6837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def test__transform(self, padding, fill, padding_mode, mocker):
inpt = mocker.MagicMock(spec=features.Image)
_ = transform(inpt)

fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
if isinstance(padding, tuple):
padding = list(padding)
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
Expand All @@ -405,14 +405,14 @@ def test__transform_image_mask(self, fill, mocker):
_ = transform(inpt)

if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
]
else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
Expand Down Expand Up @@ -466,7 +466,7 @@ def test__transform(self, fill, side_range, mocker):
torch.rand(1) # random apply changes random state
params = transform._get_params([inpt])

fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill)

@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
Expand All @@ -485,14 +485,14 @@ def test__transform_image_mask(self, fill, mocker):
params = transform._get_params(inpt)

if isinstance(fill, int):
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=fill),
]
else:
fill_img = transforms.functional._geometry._convert_fill_arg(fill[type(image)])
fill_mask = transforms.functional._geometry._convert_fill_arg(fill[type(mask)])
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, **params, fill=fill_img),
mocker.call(mask, **params, fill=fill_mask),
Expand Down Expand Up @@ -556,7 +556,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
torch.manual_seed(12)
params = transform._get_params(inpt)

fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)

@pytest.mark.parametrize("angle", [34, -87])
Expand Down Expand Up @@ -694,7 +694,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
torch.manual_seed(12)
params = transform._get_params([inpt])

fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)


Expand Down Expand Up @@ -939,7 +939,7 @@ def test__transform(self, distortion_scale, mocker):
torch.rand(1) # random apply changes random state
params = transform._get_params([inpt])

fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)


Expand Down Expand Up @@ -1009,7 +1009,7 @@ def test__transform(self, alpha, sigma, mocker):
transform._get_params = mocker.MagicMock()
_ = transform(inpt)
params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill)
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)


Expand Down Expand Up @@ -1632,7 +1632,7 @@ def test__transform(self, mocker, needs):
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
fill_sentinel = transforms.functional._geometry._convert_fill_arg(fill_sentinel)
fill_sentinel = transforms._utils._convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()
Expand Down
2 changes: 0 additions & 2 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,8 +983,6 @@ def _transform(self, inpt, params):
return inpt

fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)

return F.pad(inpt, padding=params["padding"], fill=fill)


Expand Down
34 changes: 14 additions & 20 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import PIL.Image
import torch
Expand All @@ -11,9 +11,6 @@

from ._utils import _isinstance, _setup_fill_arg

K = TypeVar("K")
V = TypeVar("V")


class _AutoAugmentBase(Transform):
def __init__(
Expand All @@ -26,7 +23,7 @@ def __init__(
self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)

def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
keys = tuple(dct.keys())
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
Expand Down Expand Up @@ -71,10 +68,9 @@ def _apply_image_or_video_transform(
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Dict[Type, features.FillType],
fill: Dict[Type, features.FillTypeJIT],
) -> Union[features.ImageType, features.VideoType]:
fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_)

if transform_id == "Identity":
return image
Expand Down Expand Up @@ -170,9 +166,7 @@ class AutoAugment(_AutoAugmentBase):
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
Expand Down Expand Up @@ -327,9 +321,7 @@ class RandAugment(_AutoAugmentBase):
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
Expand Down Expand Up @@ -383,9 +375,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
"Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)))
.round()
.int(),
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
Expand Down Expand Up @@ -430,9 +420,7 @@ class AugMix(_AutoAugmentBase):
"TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
"Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (
lambda num_bins, height, width: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)))
.round()
.int(),
lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
Expand Down Expand Up @@ -517,7 +505,13 @@ def forward(self, *inputs: Any) -> Any:
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix.add_(
# The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()`
# Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`.
# TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840
combined_weights[:, i].reshape(batch_dims)
* aug
)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)

if isinstance(orig_image_or_video, (features.Image, features.Video)):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _check_input(

@staticmethod
def _generate_value(left: float, right: float) -> float:
return float(torch.distributions.Uniform(left, right).sample())
return torch.empty(1).uniform_(left, right).item()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to this random generator we get a performance boost on GPU. Moreover this option is JIT-scriptable (if on the future we decide to add support) and doesn't require to constantly initialize a distribution object as before:

[--------- ColorJitter cpu torch.float32 ---------]
                     |   old random  |   new random
1 threads: ----------------------------------------
      (3, 400, 400)  |       17      |       17    
6 threads: ----------------------------------------
      (3, 400, 400)  |       21      |       21    

Times are in milliseconds (ms).

[--------- ColorJitter cuda torch.float32 --------]
                     |   old random  |   new random
1 threads: ----------------------------------------
      (3, 400, 400)  |      1090     |      883    
6 threads: ----------------------------------------
      (3, 400, 400)  |      1090     |      882    

Times are in microseconds (us).


def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4)
Expand Down
48 changes: 19 additions & 29 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,16 @@ def __init__(
_check_padding_arg(padding)
_check_padding_mode_arg(padding_mode)

# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
self.padding = padding
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]

# This cast does Sequence[int] -> List[int] and is required to make mypy happy
padding = self.padding
if not isinstance(padding, int):
padding = list(padding)

fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)


class RandomZoomOut(_RandomApplyTransform):
Expand Down Expand Up @@ -274,7 +270,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.pad(inpt, **params, fill=fill)


Expand All @@ -300,12 +295,11 @@ def __init__(
self.center = center

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
return dict(angle=angle)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.rotate(
inpt,
**params,
Expand Down Expand Up @@ -358,7 +352,7 @@ def __init__(
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)

angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
if self.translate is not None:
max_dx = float(self.translate[0] * width)
max_dy = float(self.translate[1] * height)
Expand All @@ -369,22 +363,21 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
translate = (0, 0)

if self.scale is not None:
scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item())
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
else:
scale = 1.0

shear_x = shear_y = 0.0
if self.shear is not None:
shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item())
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
if len(self.shear) == 4:
shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item())
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()

shear = (shear_x, shear_y)
return dict(angle=angle, translate=translate, scale=scale, shear=shear)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.affine(
inpt,
**params,
Expand Down Expand Up @@ -478,8 +471,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)

inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)

if params["needs_crop"]:
Expand Down Expand Up @@ -512,21 +503,23 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

half_height = height // 2
half_width = width // 2
bound_height = int(distortion_scale * half_height) + 1
bound_width = int(distortion_scale * half_width) + 1
topleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
topright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(0, bound_height, size=(1,))),
]
botright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
int(torch.randint(width - bound_width, width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
botleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
int(torch.randint(0, bound_width, size=(1,))),
int(torch.randint(height - bound_height, height, size=(1,))),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
Expand All @@ -535,7 +528,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.perspective(
inpt,
**params,
Expand Down Expand Up @@ -584,7 +576,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
return F.elastic(
inpt,
**params,
Expand Down Expand Up @@ -855,7 +846,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)

return inpt
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, cast, Dict, Optional, Union
from typing import Any, Dict, Optional, Union

import numpy as np
import PIL.Image
Expand All @@ -13,7 +13,7 @@ class DecodeImage(Transform):
_transformed_types = (features.EncodedImage,)

def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> features.Image:
return cast(features.Image, F.decode_image_with_pil(inpt))
return F.decode_image_with_pil(inpt) # type: ignore[no-any-return]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be here, because it seems

@torch.jit.unused
def decode_image_with_pil(encoded_image: torch.Tensor) -> features.Image:

doesn't "forward" the type annotations 🙄

Copy link
Contributor Author

@datumbox datumbox Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all other places we took the decision to silence with ignore rather than cast, do we really need the cast here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nono, I was just explaining why we need the ignore for future me that is looking confused at the blame why we introduced it in the first place.



class LabelToOneHot(Transform):
Expand All @@ -27,7 +27,7 @@ def _transform(self, inpt: features.Label, params: Dict[str, Any]) -> features.O
num_categories = self.num_categories
if num_categories == -1 and inpt.categories is not None:
num_categories = len(inpt.categories)
output = one_hot(inpt, num_classes=num_categories)
output = one_hot(inpt.as_subclass(torch.Tensor), num_classes=num_categories)
return features.OneHotLabel(output, categories=inpt.categories)

def extra_repr(self) -> str:
Expand All @@ -50,7 +50,7 @@ class ToImageTensor(Transform):
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> features.Image:
return cast(features.Image, F.to_image_tensor(inpt))
return F.to_image_tensor(inpt) # type: ignore[no-any-return]


class ToImagePIL(Transform):
Expand Down
Loading