-
Notifications
You must be signed in to change notification settings - Fork 0
RFC: transforms #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
7003c1e
734214f
96f6ffc
17a80b7
6b1e327
dcfcddb
c4f578e
8fd8d97
9611403
3e411a9
1c0b75e
dd66bf1
d2ffc2d
15f2f01
c240308
1e1596c
14ffb8c
d5c0f45
c7dde4e
012ba3b
373fa13
b3a990f
3b1a45a
8901f6a
6ede737
82485a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from transforms import get_transform | ||
|
||
import torch | ||
|
||
from torchvision.features import Image, Segmentation | ||
|
||
image = Image(torch.rand(3, 480, 640)) | ||
seg = Segmentation(torch.randint(0, 256, size=image.shape, dtype=torch.uint8)) | ||
sample = image, seg | ||
|
||
transform = get_transform(train=True) | ||
train_image, train_seg = transform(sample) | ||
|
||
transform = get_transform(train=False) | ||
eval_image, eval_seg = transform(sample) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Sequence | ||
|
||
from torchvision import transforms as T | ||
|
||
|
||
def get_transform( | ||
*, | ||
train: bool, | ||
base_size: int = 520, | ||
crop_size: int = 480, | ||
horizontal_flip_probability: float = 0.5, | ||
mean: Sequence[float] = (0.485, 0.456, 0.406), | ||
std: Sequence[float] = (0.229, 0.224, 0.225), | ||
): | ||
|
||
if train: | ||
min_size = base_size // 2 | ||
max_size = base_size * 2 | ||
transforms = [T.RandomResize(min_size, max_size)] | ||
|
||
if horizontal_flip_probability > 0: | ||
transforms.append(T.RandomApply(T.HorizontalFlip(), p=horizontal_flip_probability)) | ||
|
||
transforms.append(T.RandomCrop(crop_size)) | ||
|
||
augmentation = T.Compose(*transforms) | ||
else: | ||
augmentation = T.Resize(base_size) | ||
|
||
return T.Compose(augmentation, T.Normalize(mean, std)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from ._bunch import * | ||
from ._query import * | ||
from ._resource import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import collections.abc | ||
from typing import Any, Callable, Iterator, Optional, Set, Tuple, TypeVar, Union | ||
|
||
import torch | ||
|
||
from torchvision.features import BoundingBox, Image | ||
|
||
T = TypeVar("T") | ||
|
||
__all__ = ["Query"] | ||
|
||
|
||
class Query: | ||
def __init__(self, sample: Any) -> None: | ||
self.sample = sample | ||
|
||
@staticmethod | ||
def _query_recursively(sample: Any, fn: Callable[[Any], Optional[T]]) -> Iterator[T]: | ||
if isinstance(sample, (collections.abc.Sequence, collections.abc.Mapping)): | ||
for item in sample.values() if isinstance(sample, collections.abc.Mapping) else sample: | ||
yield from Query._query_recursively(item, fn) | ||
else: | ||
result = fn(sample) | ||
if result is not None: | ||
yield result | ||
|
||
def query(self, fn: Callable[[Any], Optional[T]], *, unique: bool = True) -> Union[T, Set[T]]: | ||
results = set(self._query_recursively(self.sample, fn)) | ||
if not results: | ||
raise RuntimeError("Query turned up empty.") | ||
|
||
if not unique: | ||
return results | ||
|
||
if len(results) > 1: | ||
raise RuntimeError(f"Found more than one result: {sorted(results)}") | ||
|
||
return results.pop() | ||
|
||
def image_size(self) -> Optional[Tuple[int, int]]: | ||
def fn(sample: Any) -> Optional[Tuple[int, int]]: | ||
if not isinstance(sample, torch.Tensor): | ||
return None | ||
elif type(sample) is torch.Tensor: | ||
return sample.shape[-2:] | ||
elif isinstance(sample, (Image, BoundingBox)): | ||
return sample.image_size | ||
else: | ||
return None | ||
|
||
return self.query(fn) | ||
|
||
def batch_size(self) -> Optional[int]: | ||
def fn(sample: Any) -> Optional[int]: | ||
if not isinstance(sample, torch.Tensor): | ||
return None | ||
elif isinstance(sample, Image): | ||
return sample.batch_size | ||
else: | ||
return None | ||
|
||
return self.query(fn) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ._bounding_box import * | ||
from ._core import * | ||
from ._image import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import enum | ||
from typing import Any, Optional, Tuple, Union | ||
|
||
import torch | ||
|
||
from ._core import TensorFeature | ||
|
||
__all__ = ["BoundingBox", "BoundingBoxFormat"] | ||
|
||
|
||
class BoundingBoxFormat(enum.Enum): | ||
XYXY = "XYXY" | ||
XYWH = "XYWH" | ||
CXCYWH = "CXCYWH" | ||
|
||
|
||
class BoundingBox(TensorFeature): | ||
formats = BoundingBoxFormat | ||
|
||
@staticmethod | ||
def _parse_format(format: Union[str, BoundingBoxFormat]) -> BoundingBoxFormat: | ||
if isinstance(format, str): | ||
format = format.upper() | ||
return BoundingBox.formats(format) | ||
|
||
def __init__( | ||
self, | ||
data: Any = None, | ||
*, | ||
image_size: Tuple[int, int], | ||
format: Union[str, BoundingBoxFormat], | ||
): | ||
super().__init__() | ||
self._image_size = image_size | ||
self._format = self._parse_format(format) | ||
|
||
self._convert_to_xyxy = { | ||
self.formats.XYWH: self._xywh_to_xyxy, | ||
self.formats.CXCYWH: self._cxcywh_to_xyxy, | ||
} | ||
self._convert_from_xyxy = { | ||
self.formats.XYWH: self._xyxy_to_xywh, | ||
self.formats.CXCYWH: self._xyxy_to_cxcywh, | ||
} | ||
|
||
def __new__( | ||
cls, | ||
data: Any = None, | ||
*, | ||
image_size: Tuple[int, int], | ||
format: Union[str, BoundingBoxFormat], | ||
): | ||
# Since torch.Tensor defines both __new__ and __init__, we also need to do that since we change the signature | ||
return super().__new__(cls, data) | ||
|
||
@classmethod | ||
def from_tensor( | ||
cls, | ||
tensor: torch.Tensor, | ||
*, | ||
like: Optional["BoundingBox"] = None, | ||
image_size: Optional[Tuple[int, int]] = None, | ||
format: Optional[Union[str, BoundingBoxFormat]] = None, | ||
) -> "BoundingBox": | ||
params = cls._parse_from_tensor_args(like=like, image_size=image_size, format=format) | ||
|
||
format = params.get("format") or "xyxy" | ||
|
||
image_size = params.get("image_size") | ||
if image_size is None: | ||
# TODO: compute minimum image size needed to hold this bounding box depending on format | ||
image_size = (0, 0) | ||
|
||
return cls(tensor, image_size=image_size, format=format) | ||
|
||
@property | ||
def image_size(self) -> Tuple[int, int]: | ||
return self._image_size | ||
|
||
@property | ||
def format(self) -> BoundingBoxFormat: | ||
return self._format | ||
|
||
@classmethod | ||
def from_parts( | ||
cls, | ||
a, | ||
b, | ||
c, | ||
d, | ||
*, | ||
format: Union[str, BoundingBoxFormat], | ||
like: Optional["BoundingBox"] = None, | ||
image_size: Optional[Tuple[int, int]] = None, | ||
) -> "BoundingBox": | ||
parts = torch.broadcast_tensors( | ||
*[part if isinstance(part, torch.Tensor) else torch.as_tensor(part) for part in (a, b, c, d)] | ||
) | ||
return cls.from_tensor(torch.stack(parts, dim=-1), like=like, image_size=image_size, format=format) | ||
|
||
def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
return self.unbind(-1) | ||
|
||
def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! One other thing to think about is if there would be a minimum common API across all our data types that we might want to enforce? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, I don't see anything. Happy to hear ideas though. If something comes up later, that should be retro-fittable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added a |
||
format = self._parse_format(format) | ||
# FIXME: cloning does not preserve custom attributes such as image_size or format | ||
# bounding_box = self.clone() | ||
bounding_box = self | ||
|
||
if format == self.format: | ||
return bounding_box | ||
|
||
if self.format != self.formats.XYXY: | ||
bounding_box = self._convert_to_xyxy[self.format](bounding_box) | ||
|
||
if format != self.formats.XYXY: | ||
bounding_box = self._convert_from_xyxy[format](bounding_box) | ||
|
||
return bounding_box | ||
|
||
@staticmethod | ||
def _xywh_to_xyxy(input: "BoundingBox") -> "BoundingBox": | ||
x, y, w, h = input.to_parts() | ||
|
||
x1 = x | ||
y1 = y | ||
x2 = x + w | ||
y2 = y + h | ||
|
||
return BoundingBox.from_parts(x1, y1, x2, y2, like=input, format="xyxy") | ||
|
||
@staticmethod | ||
def _xyxy_to_xywh(input: "BoundingBox") -> "BoundingBox": | ||
x1, y1, x2, y2 = input.to_parts() | ||
|
||
x = x1 | ||
y = y1 | ||
w = x2 - x1 | ||
h = y2 - y1 | ||
|
||
return BoundingBox.from_parts(x, y, w, h, format="xywh") | ||
|
||
@staticmethod | ||
def _cxcywh_to_xyxy(input: "BoundingBox") -> "BoundingBox": | ||
cx, cy, w, h = input.to_parts() | ||
|
||
x1 = cx - 0.5 * w | ||
y1 = cy - 0.5 * h | ||
x2 = cx + 0.5 * w | ||
y2 = cy + 0.5 * h | ||
|
||
return BoundingBox.from_parts(x1, y1, x2, y2, like=input, format="xyxy") | ||
|
||
@staticmethod | ||
def _xyxy_to_cxcywh(input: "BoundingBox") -> "BoundingBox": | ||
x1, y1, x2, y2 = input.to_parts() | ||
|
||
cx = (x1 + x2) / 2 | ||
cy = (y1 + y2) / 2 | ||
w = x2 - x1 | ||
h = y2 - y1 | ||
|
||
return BoundingBox.from_parts(cx, cy, w, h, like=input, format="cxcywh") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Any, Dict, Optional, Type, TypeVar | ||
|
||
import torch | ||
|
||
__all__ = ["Feature", "TensorFeature"] | ||
|
||
TF = TypeVar("TF", bound="TensorFeature") | ||
|
||
|
||
# A Feature might not necessarily be a Tensor. Think text. | ||
class Feature: | ||
pass | ||
|
||
|
||
class TensorFeature(torch.Tensor, Feature): | ||
def __new__(cls, data: Any = None): | ||
if data is None: | ||
data = torch.tensor([]) | ||
requires_grad = False | ||
return torch.Tensor._make_subclass(cls, data, requires_grad) | ||
|
||
@staticmethod | ||
def _parse_from_tensor_args(*, like: Optional["TensorFeature"], **attrs: Any) -> Dict[str, Any]: | ||
if not attrs: | ||
raise ValueError() | ||
|
||
params = {name: getattr(like, name) for name in attrs.keys()} if like is not None else {} | ||
params.update({name: value for name, value in attrs.items() if value is not None}) | ||
return params | ||
|
||
@classmethod | ||
def from_tensor(cls: Type[TF], tensor: torch.Tensor, *, like: Optional[TF] = None) -> TF: | ||
return cls(tensor) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Tuple | ||
|
||
from ._core import TensorFeature | ||
|
||
__all__ = ["Image", "Segmentation"] | ||
|
||
|
||
class Image(TensorFeature): | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@property | ||
def image_size(self) -> Tuple[int, int]: | ||
return self.shape[-2:] | ||
|
||
@property | ||
def batch_size(self) -> int: | ||
return self.shape[0] if self.ndim == 4 else 0 | ||
|
||
def batch(self) -> "Image": | ||
if self.batch_size > 0: | ||
return self | ||
|
||
return Image.from_tensor(self.unsqueeze(0), like=self) | ||
|
||
def unbatch(self) -> "Image": | ||
if self.batch_size == 0: | ||
return self | ||
elif self.batch_size == 1: | ||
return Image.from_tensor(self.squeeze(0), like=self) | ||
else: | ||
raise RuntimeError("Cannot unbatch an image tensor if batch contains more than one image.") | ||
|
||
|
||
class Segmentation(Image): | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from ._transform import * | ||
|
||
from ._augmentation import * | ||
from ._container import * | ||
from ._geometry import * | ||
from ._misc import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For when we will be discussing specifics of the data classes, it might be good to add a
__torch_function__
that forbids applying operations of two bounding boxes if they have differentimage_size
, while allowingBoundingBox
+Tensor
to work