Skip to content

RFC: transforms #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7003c1e
initial draft
pmeier Jun 25, 2021
734214f
rewrite proposal to use custom feature classes
pmeier Jul 5, 2021
96f6ffc
fix wording
pmeier Jul 5, 2021
17a80b7
misc test fixes
pmeier Jul 12, 2021
6b1e327
make features self-contained
pmeier Aug 17, 2021
dcfcddb
cleanup
pmeier Aug 17, 2021
c4f578e
add compose
pmeier Aug 18, 2021
8fd8d97
add auto register and better random support
pmeier Aug 18, 2021
9611403
add automatic functional dispatch
pmeier Aug 19, 2021
3e411a9
fix functional interface
pmeier Aug 19, 2021
1c0b75e
update proposal
pmeier Aug 20, 2021
dd66bf1
introduce from_tensor creation method
pmeier Aug 20, 2021
d2ffc2d
factor out recursive apply
pmeier Aug 23, 2021
15f2f01
improve auto registering
pmeier Sep 1, 2021
c240308
add READMe on how to implement a new transform
pmeier Sep 1, 2021
1e1596c
add implementation for MixUp
pmeier Sep 1, 2021
14ffb8c
add support for creating transforms from callables and other transforms
pmeier Sep 2, 2021
d5c0f45
add better support for containers
pmeier Sep 2, 2021
c7dde4e
add documentation
pmeier Sep 2, 2021
012ba3b
update RFC document
pmeier Sep 3, 2021
373fa13
improve query
pmeier Sep 3, 2021
b3a990f
add example how feature transform requirements can look like
pmeier Sep 3, 2021
3b1a45a
add minimal implementation for segmentation training
pmeier Sep 7, 2021
8901f6a
refactorget_params to be able to return other parameters based on type
pmeier Sep 7, 2021
6ede737
simplify transform wrapping
pmeier Sep 7, 2021
82485a3
add RandomOrder container
pmeier Sep 7, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ sections= ["FUTURE", "STDLIB", "THIRDPARTY", "PYTORCH", "FIRSTPARTY", "LOCALFOLD

skip = [
"torchvision/datasets/__init__.py",
"torchvision/transforms/__init__.py",
]

[tool.black]
Expand All @@ -36,9 +37,3 @@ skip = [

line-length = 120
target-version = ["py36"]
exclude = '''
/(
\.git
| __pycache__
)/
'''
15 changes: 15 additions & 0 deletions references/segmentation/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from transforms import get_transform

import torch

from torchvision.features import Image, Segmentation

image = Image(torch.rand(3, 480, 640))
seg = Segmentation(torch.randint(0, 256, size=image.shape, dtype=torch.uint8))
sample = image, seg

transform = get_transform(train=True)
train_image, train_seg = transform(sample)

transform = get_transform(train=False)
eval_image, eval_seg = transform(sample)
30 changes: 30 additions & 0 deletions references/segmentation/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Sequence

from torchvision import transforms as T


def get_transform(
*,
train: bool,
base_size: int = 520,
crop_size: int = 480,
horizontal_flip_probability: float = 0.5,
mean: Sequence[float] = (0.485, 0.456, 0.406),
std: Sequence[float] = (0.229, 0.224, 0.225),
):

if train:
min_size = base_size // 2
max_size = base_size * 2
transforms = [T.RandomResize(min_size, max_size)]

if horizontal_flip_probability > 0:
transforms.append(T.RandomApply(T.HorizontalFlip(), p=horizontal_flip_probability))

transforms.append(T.RandomCrop(crop_size))

augmentation = T.Compose(*transforms)
else:
augmentation = T.Resize(base_size)

return T.Compose(augmentation, T.Normalize(mean, std))
1 change: 1 addition & 0 deletions torchvision/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ._bunch import *
from ._query import *
from ._resource import *
62 changes: 62 additions & 0 deletions torchvision/datasets/utils/_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import collections.abc
from typing import Any, Callable, Iterator, Optional, Set, Tuple, TypeVar, Union

import torch

from torchvision.features import BoundingBox, Image

T = TypeVar("T")

__all__ = ["Query"]


class Query:
def __init__(self, sample: Any) -> None:
self.sample = sample

@staticmethod
def _query_recursively(sample: Any, fn: Callable[[Any], Optional[T]]) -> Iterator[T]:
if isinstance(sample, (collections.abc.Sequence, collections.abc.Mapping)):
for item in sample.values() if isinstance(sample, collections.abc.Mapping) else sample:
yield from Query._query_recursively(item, fn)
else:
result = fn(sample)
if result is not None:
yield result

def query(self, fn: Callable[[Any], Optional[T]], *, unique: bool = True) -> Union[T, Set[T]]:
results = set(self._query_recursively(self.sample, fn))
if not results:
raise RuntimeError("Query turned up empty.")

if not unique:
return results

if len(results) > 1:
raise RuntimeError(f"Found more than one result: {sorted(results)}")

return results.pop()

def image_size(self) -> Optional[Tuple[int, int]]:
def fn(sample: Any) -> Optional[Tuple[int, int]]:
if not isinstance(sample, torch.Tensor):
return None
elif type(sample) is torch.Tensor:
return sample.shape[-2:]
elif isinstance(sample, (Image, BoundingBox)):
return sample.image_size
else:
return None

return self.query(fn)

def batch_size(self) -> Optional[int]:
def fn(sample: Any) -> Optional[int]:
if not isinstance(sample, torch.Tensor):
return None
elif isinstance(sample, Image):
return sample.batch_size
else:
return None

return self.query(fn)
3 changes: 3 additions & 0 deletions torchvision/features/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._bounding_box import *
from ._core import *
from ._image import *
163 changes: 163 additions & 0 deletions torchvision/features/_bounding_box.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import enum
from typing import Any, Optional, Tuple, Union

import torch

from ._core import TensorFeature

__all__ = ["BoundingBox", "BoundingBoxFormat"]


class BoundingBoxFormat(enum.Enum):
XYXY = "XYXY"
XYWH = "XYWH"
CXCYWH = "CXCYWH"


class BoundingBox(TensorFeature):
formats = BoundingBoxFormat

@staticmethod
def _parse_format(format: Union[str, BoundingBoxFormat]) -> BoundingBoxFormat:
if isinstance(format, str):
format = format.upper()
return BoundingBox.formats(format)

def __init__(
self,
data: Any = None,
*,
image_size: Tuple[int, int],
format: Union[str, BoundingBoxFormat],
):
super().__init__()
self._image_size = image_size
self._format = self._parse_format(format)

self._convert_to_xyxy = {
self.formats.XYWH: self._xywh_to_xyxy,
self.formats.CXCYWH: self._cxcywh_to_xyxy,
}
self._convert_from_xyxy = {
self.formats.XYWH: self._xyxy_to_xywh,
self.formats.CXCYWH: self._xyxy_to_cxcywh,
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For when we will be discussing specifics of the data classes, it might be good to add a __torch_function__ that forbids applying operations of two bounding boxes if they have different image_size, while allowing BoundingBox + Tensor to work

def __new__(
cls,
data: Any = None,
*,
image_size: Tuple[int, int],
format: Union[str, BoundingBoxFormat],
):
# Since torch.Tensor defines both __new__ and __init__, we also need to do that since we change the signature
return super().__new__(cls, data)

@classmethod
def from_tensor(
cls,
tensor: torch.Tensor,
*,
like: Optional["BoundingBox"] = None,
image_size: Optional[Tuple[int, int]] = None,
format: Optional[Union[str, BoundingBoxFormat]] = None,
) -> "BoundingBox":
params = cls._parse_from_tensor_args(like=like, image_size=image_size, format=format)

format = params.get("format") or "xyxy"

image_size = params.get("image_size")
if image_size is None:
# TODO: compute minimum image size needed to hold this bounding box depending on format
image_size = (0, 0)

return cls(tensor, image_size=image_size, format=format)

@property
def image_size(self) -> Tuple[int, int]:
return self._image_size

@property
def format(self) -> BoundingBoxFormat:
return self._format

@classmethod
def from_parts(
cls,
a,
b,
c,
d,
*,
format: Union[str, BoundingBoxFormat],
like: Optional["BoundingBox"] = None,
image_size: Optional[Tuple[int, int]] = None,
) -> "BoundingBox":
parts = torch.broadcast_tensors(
*[part if isinstance(part, torch.Tensor) else torch.as_tensor(part) for part in (a, b, c, d)]
)
return cls.from_tensor(torch.stack(parts, dim=-1), like=like, image_size=image_size, format=format)

def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return self.unbind(-1)

def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! One other thing to think about is if there would be a minimum common API across all our data types that we might want to enforce?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, I don't see anything. Happy to hear ideas though. If something comes up later, that should be retro-fittable.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a TensorFeature.from_tensor method, that should make it easy to create a new feature from just a tensor.

format = self._parse_format(format)
# FIXME: cloning does not preserve custom attributes such as image_size or format
# bounding_box = self.clone()
bounding_box = self

if format == self.format:
return bounding_box

if self.format != self.formats.XYXY:
bounding_box = self._convert_to_xyxy[self.format](bounding_box)

if format != self.formats.XYXY:
bounding_box = self._convert_from_xyxy[format](bounding_box)

return bounding_box

@staticmethod
def _xywh_to_xyxy(input: "BoundingBox") -> "BoundingBox":
x, y, w, h = input.to_parts()

x1 = x
y1 = y
x2 = x + w
y2 = y + h

return BoundingBox.from_parts(x1, y1, x2, y2, like=input, format="xyxy")

@staticmethod
def _xyxy_to_xywh(input: "BoundingBox") -> "BoundingBox":
x1, y1, x2, y2 = input.to_parts()

x = x1
y = y1
w = x2 - x1
h = y2 - y1

return BoundingBox.from_parts(x, y, w, h, format="xywh")

@staticmethod
def _cxcywh_to_xyxy(input: "BoundingBox") -> "BoundingBox":
cx, cy, w, h = input.to_parts()

x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h

return BoundingBox.from_parts(x1, y1, x2, y2, like=input, format="xyxy")

@staticmethod
def _xyxy_to_cxcywh(input: "BoundingBox") -> "BoundingBox":
x1, y1, x2, y2 = input.to_parts()

cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1

return BoundingBox.from_parts(cx, cy, w, h, like=input, format="cxcywh")
33 changes: 33 additions & 0 deletions torchvision/features/_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, Dict, Optional, Type, TypeVar

import torch

__all__ = ["Feature", "TensorFeature"]

TF = TypeVar("TF", bound="TensorFeature")


# A Feature might not necessarily be a Tensor. Think text.
class Feature:
pass


class TensorFeature(torch.Tensor, Feature):
def __new__(cls, data: Any = None):
if data is None:
data = torch.tensor([])
requires_grad = False
return torch.Tensor._make_subclass(cls, data, requires_grad)

@staticmethod
def _parse_from_tensor_args(*, like: Optional["TensorFeature"], **attrs: Any) -> Dict[str, Any]:
if not attrs:
raise ValueError()

params = {name: getattr(like, name) for name in attrs.keys()} if like is not None else {}
params.update({name: value for name, value in attrs.items() if value is not None})
return params

@classmethod
def from_tensor(cls: Type[TF], tensor: torch.Tensor, *, like: Optional[TF] = None) -> TF:
return cls(tensor)
33 changes: 33 additions & 0 deletions torchvision/features/_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Tuple

from ._core import TensorFeature

__all__ = ["Image", "Segmentation"]


class Image(TensorFeature):
@property
def image_size(self) -> Tuple[int, int]:
return self.shape[-2:]

@property
def batch_size(self) -> int:
return self.shape[0] if self.ndim == 4 else 0

def batch(self) -> "Image":
if self.batch_size > 0:
return self

return Image.from_tensor(self.unsqueeze(0), like=self)

def unbatch(self) -> "Image":
if self.batch_size == 0:
return self
elif self.batch_size == 1:
return Image.from_tensor(self.squeeze(0), like=self)
else:
raise RuntimeError("Cannot unbatch an image tensor if batch contains more than one image.")


class Segmentation(Image):
pass
6 changes: 6 additions & 0 deletions torchvision/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._transform import *

from ._augmentation import *
from ._container import *
from ._geometry import *
from ._misc import *
Loading