Skip to content

add prototype features #4721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import add_suggestion
from torchvision.prototype.utils._internal import add_suggestion

make_tensor = functools.partial(_make_tensor, device="cpu")

Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import sequence_to_str
from torchvision.prototype.utils._internal import sequence_to_str


_loaders = []
Expand Down
179 changes: 179 additions & 0 deletions test/test_prototype_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import functools
import itertools

import pytest
import torch
from torch.testing import make_tensor as _make_tensor, assert_close
from torchvision.prototype import features
from torchvision.prototype.utils._internal import sequence_to_str


make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32)


def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]

height, width = image_size

if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, ())
y1 = torch.randint(0, height // 2, ())
x2 = torch.randint(int(x1) + 1, width - int(x1), ()) + x1
y2 = torch.randint(int(y1) + 1, height - int(y1), ()) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, ())
y = torch.randint(0, height // 2, ())
w = torch.randint(1, width - int(x), ())
h = torch.randint(1, height - int(y), ())
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = torch.randint(1, min(int(cx), width - int(cx)), ())
h = torch.randint(1, min(int(cy), height - int(cy)), ())
parts = (cx, cy, w, h)
else: # format == features.BoundingBoxFormat._SENTINEL:
parts = make_tensor((4,)).unbind()

return features.BoundingBox.from_parts(*parts, format=format, image_size=image_size)


MAKE_DATA_MAP = {
features.BoundingBox: make_bounding_box,
}


def make_feature(feature_type, **meta_data):
maker = MAKE_DATA_MAP.get(feature_type, lambda **meta_data: feature_type(make_tensor(()), **meta_data))
return maker(**meta_data)


class TestCommon:
FEATURE_TYPES, NON_DEFAULT_META_DATA = zip(
*(
(features.Image, dict(color_space=features.ColorSpace._SENTINEL)),
(features.Label, dict(category="category")),
(features.BoundingBox, dict(format=features.BoundingBoxFormat._SENTINEL, image_size=(-1, -1))),
)
)
feature_types = pytest.mark.parametrize(
"feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__
)
features = pytest.mark.parametrize(
"feature",
[
pytest.param(make_feature(feature_type, **meta_data), id=feature_type.__name__)
for feature_type, meta_data in zip(FEATURE_TYPES, NON_DEFAULT_META_DATA)
],
)

def test_consistency(self):
builtin_feature_types = {
name
for name, feature_type in features.__dict__.items()
if not name.startswith("_")
and isinstance(feature_type, type)
and issubclass(feature_type, features.Feature)
and feature_type is not features.Feature
}
untested_feature_types = builtin_feature_types - {feature_type.__name__ for feature_type in self.FEATURE_TYPES}
if untested_feature_types:
raise AssertionError(
f"The feature(s) {sequence_to_str(sorted(untested_feature_types), separate_last='and ')} "
f"is/are exposed at `torchvision.prototype.features`, but is/are not tested by `TestCommon`. "
f"Please add it/them to `TestCommon.FEATURE_TYPES`."
)

@features
def test_meta_data_attribute_access(self, feature):
for name, value in feature._meta_data.items():
assert getattr(feature, name) == feature._meta_data[name]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for serialization of the feature? Like torch.save / torch.load, to make sure that the extra metadata added is not lost on the way?

@feature_types
def test_torch_function(self, feature_type):
input = make_feature(feature_type)
# This can be any Tensor operation besides clone
output = input + 1

assert type(output) is torch.Tensor
assert_close(output, input + 1)

@feature_types
def test_clone(self, feature_type):
input = make_feature(feature_type)
output = input.clone()

assert type(output) is feature_type
assert_close(output, input)
assert output._meta_data == input._meta_data

@features
def test_serialization(self, tmpdir, feature):
file = tmpdir / "test_serialization.pt"

torch.save(feature, str(file))
loaded_feature = torch.load(str(file))

assert isinstance(loaded_feature, type(feature))
assert_close(loaded_feature, feature)
assert loaded_feature._meta_data == feature._meta_data

@features
def test_repr(self, feature):
assert type(feature).__name__ in repr(feature)


class TestBoundingBox:
@pytest.mark.parametrize(("format", "intermediate_format"), itertools.permutations(("xyxy", "xywh"), 2))
def test_cycle_consistency(self, format, intermediate_format):
input = make_bounding_box(format=format)
output = input.convert(intermediate_format).convert(format)
assert_close(input, output)


# For now, tensor subclasses with additional meta data do not work with torchscript.
# See https://github.com/pytorch/vision/pull/4721#discussion_r741676037.
@pytest.mark.xfail
class TestJit:
def test_bounding_box(self):
def resize(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox:
old_height, old_width = input.image_size
new_height, new_width = size

height_scale = new_height / old_height
width_scale = new_width / old_width

old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts()

new_x1 = old_x1 * width_scale
new_y1 = old_y1 * height_scale

new_x2 = old_x2 * width_scale
new_y2 = old_y2 * height_scale

return features.BoundingBox.from_parts(
new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=tuple(size.tolist())
)

def horizontal_flip(input: features.BoundingBox) -> features.BoundingBox:
x, y, w, h = input.convert("xywh").to_parts()
x = input.image_size[1] - (x + w)
return features.BoundingBox.from_parts(x, y, w, h, like=input, format="xywh")

def compose(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox:
return horizontal_flip(resize(input, size)).convert("xyxy")

image_size = (8, 6)
input = features.BoundingBox([2, 4, 2, 4], format="cxcywh", image_size=image_size)
size = torch.tensor((4, 12))
expected = features.BoundingBox([6, 1, 10, 3], format="xyxy", image_size=image_size)

actual_eager = compose(input, size)
assert_close(actual_eager, expected)

sample_inputs = (features.BoundingBox(torch.zeros((4,)), image_size=(10, 10)), torch.tensor((20, 5)))
actual_jit = torch.jit.trace(compose, sample_inputs, check_trace=False)(input, size)
assert_close(actual_jit, expected)
2 changes: 2 additions & 0 deletions torchvision/prototype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from . import datasets
from . import features
from . import models
from . import transforms
from . import utils
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import raw, pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
from torchvision.prototype.datasets.utils._internal import add_suggestion
from torchvision.prototype.utils._internal import add_suggestion

from . import _builtin

Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Decompressor,
INFINITE_BUFFER_SIZE,
)
from torchvision.prototype.features import Image, Label


__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
Expand Down Expand Up @@ -126,15 +127,14 @@ def _collate_and_decode(
image, label = data

if decoder is raw:
image = image.unsqueeze(0)
image = Image(image)
else:
image_buffer = image_buffer_from_array(image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]

category = self.info.categories[int(label)]
label = label.to(torch.int64)
label = Label(label, dtype=torch.int64, category=self.info.categories[int(label)])

return dict(image=image, category=category, label=label)
return dict(image=image, label=label)

def _make_datapipe(
self,
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/decoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import io
from typing import cast

import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms.functional import pil_to_tensor

__all__ = ["raw", "pil"]
Expand All @@ -12,5 +12,5 @@ def raw(buffer: io.IOBase) -> torch.Tensor:
raise RuntimeError("This is just a sentinel and should never be called.")


def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor:
return cast(torch.Tensor, pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper())))
def pil(buffer: io.IOBase) -> features.Image:
return features.Image(pil_to_tensor(PIL.Image.open(buffer)))
5 changes: 1 addition & 4 deletions torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets.utils._internal import (
add_suggestion,
sequence_to_str,
)
from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str

from .._home import use_sharded_dataset
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe
Expand Down
32 changes: 0 additions & 32 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import collections.abc
import csv
import difflib
import enum
import gzip
import io
Expand All @@ -11,7 +9,6 @@
import pickle
import textwrap
from typing import (
Collection,
Sequence,
Callable,
Union,
Expand Down Expand Up @@ -41,8 +38,6 @@
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"sequence_to_str",
"add_suggestion",
"make_repr",
"FrozenMapping",
"FrozenBunch",
Expand All @@ -67,33 +62,6 @@
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"


def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
return f"'{seq[0]}'"

return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'."""


def add_suggestion(
msg: str,
*,
word: str,
possibilities: Collection[str],
close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[
[Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
) -> str:
if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
if not hint:
return msg

return f"{msg.strip()} {hint}"


def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/features/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ._bounding_box import BoundingBoxFormat, BoundingBox
from ._feature import Feature
from ._image import Image, ColorSpace
from ._label import Label
Loading