From 0cf10a9ae27b32872939efd8069e07ae51dc45d1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 22 Oct 2021 11:38:45 +0200 Subject: [PATCH 01/10] add prototype features --- torchvision/prototype/__init__.py | 1 + .../prototype/datasets/_builtin/mnist.py | 12 +- torchvision/prototype/features.py | 104 ++++++++++++++++++ 3 files changed, 111 insertions(+), 6 deletions(-) create mode 100644 torchvision/prototype/features.py diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py index 23a2912b00b..d4aa1883e78 100644 --- a/torchvision/prototype/__init__.py +++ b/torchvision/prototype/__init__.py @@ -1,3 +1,4 @@ from . import datasets +from . import features from . import models from . import transforms diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index af22199ce39..d2aaa2968a1 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -31,6 +31,7 @@ Decompressor, INFINITE_BUFFER_SIZE, ) +from torchvision.prototype.features import Image, Label __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] @@ -112,17 +113,16 @@ def _collate_and_decode( ) -> Dict[str, Any]: image_array, label_array = data - image: Union[torch.Tensor, io.BytesIO] + image: Union[Image, io.IOBase] if decoder is raw: - image = torch.from_numpy(image_array) + image = Image(torch.from_numpy(image_array)) else: image_buffer = image_buffer_from_array(image_array) - image = decoder(image_buffer) if decoder else image_buffer + image = Image(decoder(image_buffer)) if decoder else image_buffer - label = torch.tensor(label_array, dtype=torch.int64) - category = self.info.categories[int(label)] + label = Label(torch.tensor(label_array, dtype=torch.int64), category=self.info.categories[int(label_array)]) - return dict(image=image, label=label, category=category) + return dict(image=image, label=label) def _make_datapipe( self, diff --git a/torchvision/prototype/features.py b/torchvision/prototype/features.py new file mode 100644 index 00000000000..0d9d51b6c24 --- /dev/null +++ b/torchvision/prototype/features.py @@ -0,0 +1,104 @@ +import enum +from typing import Any, Sequence, Mapping, Optional, Tuple, Type, Callable, TypeVar, Union, cast + +import torch +from torch._C import DisableTorchFunction, _TensorBase +from torchvision.prototype.datasets.utils._internal import FrozenMapping + +__all__ = ["Feature", "ColorSpace", "Image", "Label"] + +T = TypeVar("T", bound="Feature") + + +class Feature(torch.Tensor): + _meta_data: FrozenMapping[str, Any] + + def __new__( + cls: Type[T], + data: Any = None, + *, + like: Optional[T] = None, + meta_data: FrozenMapping[str, Any] = FrozenMapping(), + ) -> T: + if data is None: + data = torch.empty(0) + requires_grad = False + self = cast(T, torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad)) + + _meta_data = dict(like._meta_data) if like is not None else dict() + _meta_data.update(meta_data) + self._meta_data = FrozenMapping(_meta_data) + + for name in self._meta_data: + setattr(cls, name, property(lambda self: self._meta_data[name])) + + return self + + @classmethod + def __torch_function__( + cls, + func: Callable[..., torch.Tensor], + types: Tuple[Type[torch.Tensor], ...], + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, + ) -> torch.Tensor: + with DisableTorchFunction(): + output = func(*args, **(kwargs or dict())) + if func is not torch.Tensor.clone: + return output + + return cls(output, like=args[0]) + + +class ColorSpace(enum.Enum): + OTHER = 0 + GRAYSCALE = 1 + RGB = 3 + + +class Image(Feature): + color_space: ColorSpace + + def __new__( + cls, + data: Any = None, + *, + like: Optional["Image"] = None, + color_space: Optional[Union[str, ColorSpace]] = None, + ) -> "Image": + if color_space is None: + color_space = cls.guess_color_space(data) if data is not None else ColorSpace.OTHER + elif isinstance(color_space, str): + color_space = ColorSpace[color_space.upper()] + + meta_data: FrozenMapping[str, Any] = FrozenMapping(color_space=color_space) + + return Feature.__new__(cls, data, like=like, meta_data=meta_data) + + @staticmethod + def guess_color_space(image: torch.Tensor) -> ColorSpace: + if image.ndim < 2: + return ColorSpace.OTHER + elif image.ndim == 2: + return ColorSpace.GRAYSCALE + + num_channels = image.shape[-3] + if num_channels == 1: + return ColorSpace.GRAYSCALE + elif num_channels == 3: + return ColorSpace.RGB + else: + return ColorSpace.OTHER + + +class Label(Feature): + category: Optional[str] + + def __new__( + cls, + data: Any = None, + *, + like: Optional["Label"] = None, + category: Optional[str] = None, + ) -> "Label": + return Feature.__new__(cls, data, like=like, meta_data=FrozenMapping(category=category)) From f421c071389eaaff02fe0b02b17e355f761f1112 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 25 Oct 2021 08:06:11 +0200 Subject: [PATCH 02/10] add some JIT tests --- test/test_prototype_features.py | 63 +++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 test/test_prototype_features.py diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py new file mode 100644 index 00000000000..d58dfb1ad7a --- /dev/null +++ b/test/test_prototype_features.py @@ -0,0 +1,63 @@ +import functools + +import pytest +import torch +from torch import jit +from torch.testing import make_tensor as _make_tensor, assert_close +from torchvision.prototype import features + +make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32) + + +class TestJIT: + FEATURE_TYPES = { + feature_type + for name, feature_type in features.__dict__.items() + if not name.startswith("_") + and isinstance(feature_type, type) + and issubclass(feature_type, features.Feature) + and feature_type is not features.Feature + } + feature_types = pytest.mark.parametrize( + "feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__ + ) + + @feature_types + def test_identity(self, feature_type): + def identity(input): + return input + + identity.__annotations__ = {"input": feature_type, "return": feature_type} + + scripted_fn = jit.script(identity) + input = feature_type(make_tensor(())) + output = scripted_fn(input) + + assert output is input + + @feature_types + def test_torch_function(self, feature_type): + def any_operation_except_clone(input) -> torch.Tensor: + return input + 0 + + any_operation_except_clone.__annotations__["input"] = feature_type + + scripted_fn = jit.script(any_operation_except_clone) + input = feature_type(make_tensor(())) + output = scripted_fn(input) + + assert type(output) is torch.Tensor + + @feature_types + def test_clone(self, feature_type): + def clone(input): + return input.clone() + + clone.__annotations__ = {"input": feature_type, "return": feature_type} + + scripted_fn = jit.script(clone) + input = feature_type(make_tensor(())) + output = scripted_fn(input) + + assert type(output) is feature_type + assert_close(output, input) From a96667fc013d4623ce7ec27adf94d592d251e585 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 25 Oct 2021 08:35:28 +0200 Subject: [PATCH 03/10] refactor input data handling --- .../prototype/datasets/_builtin/mnist.py | 6 ++--- torchvision/prototype/datasets/decoder.py | 6 ++--- torchvision/prototype/features.py | 24 +++++++++++++------ 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index d2aaa2968a1..75f9e935040 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -115,12 +115,12 @@ def _collate_and_decode( image: Union[Image, io.IOBase] if decoder is raw: - image = Image(torch.from_numpy(image_array)) + image = Image(image_array) else: image_buffer = image_buffer_from_array(image_array) - image = Image(decoder(image_buffer)) if decoder else image_buffer + image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] - label = Label(torch.tensor(label_array, dtype=torch.int64), category=self.info.categories[int(label_array)]) + label = Label(label_array, dtype=torch.int64, category=self.info.categories[int(label_array)]) return dict(image=image, label=label) diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py index 3c8fb824223..8e12df6f568 100644 --- a/torchvision/prototype/datasets/decoder.py +++ b/torchvision/prototype/datasets/decoder.py @@ -1,8 +1,8 @@ import io -from typing import cast import PIL.Image import torch +from torchvision.prototype import features from torchvision.transforms.functional import pil_to_tensor __all__ = ["raw", "pil"] @@ -12,5 +12,5 @@ def raw(buffer: io.IOBase) -> torch.Tensor: raise RuntimeError("This is just a sentinel and should never be called.") -def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor: - return cast(torch.Tensor, pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper()))) +def pil(buffer: io.IOBase, mode: str = "RGB") -> features.Image: + return features.Image(pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper()))) diff --git a/torchvision/prototype/features.py b/torchvision/prototype/features.py index 0d9d51b6c24..3794d05bd2c 100644 --- a/torchvision/prototype/features.py +++ b/torchvision/prototype/features.py @@ -15,13 +15,17 @@ class Feature(torch.Tensor): def __new__( cls: Type[T], - data: Any = None, + data: Any, *, like: Optional[T] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, meta_data: FrozenMapping[str, Any] = FrozenMapping(), ) -> T: - if data is None: - data = torch.empty(0) + if like is not None: + dtype = dtype or like.dtype + device = device or like.device + data = torch.as_tensor(data, dtype=dtype, device=device) requires_grad = False self = cast(T, torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad)) @@ -61,9 +65,11 @@ class Image(Feature): def __new__( cls, - data: Any = None, + data: Any, *, like: Optional["Image"] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, color_space: Optional[Union[str, ColorSpace]] = None, ) -> "Image": if color_space is None: @@ -73,7 +79,7 @@ def __new__( meta_data: FrozenMapping[str, Any] = FrozenMapping(color_space=color_space) - return Feature.__new__(cls, data, like=like, meta_data=meta_data) + return Feature.__new__(cls, data, like=like, dtype=dtype, device=device, meta_data=meta_data) @staticmethod def guess_color_space(image: torch.Tensor) -> ColorSpace: @@ -96,9 +102,13 @@ class Label(Feature): def __new__( cls, - data: Any = None, + data: Any, *, like: Optional["Label"] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, category: Optional[str] = None, ) -> "Label": - return Feature.__new__(cls, data, like=like, meta_data=FrozenMapping(category=category)) + return Feature.__new__( + cls, data, like=like, dtype=dtype, device=device, meta_data=FrozenMapping(category=category) + ) From 565b9ba9fc16254af94fe504f0a55b1375f0676b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 25 Oct 2021 13:37:06 +0200 Subject: [PATCH 04/10] refactor tests --- test/test_prototype_features.py | 101 ++++++++++++++++++++---------- torchvision/prototype/features.py | 5 ++ 2 files changed, 73 insertions(+), 33 deletions(-) diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py index d58dfb1ad7a..2204e283118 100644 --- a/test/test_prototype_features.py +++ b/test/test_prototype_features.py @@ -2,62 +2,97 @@ import pytest import torch -from torch import jit from torch.testing import make_tensor as _make_tensor, assert_close from torchvision.prototype import features +from torchvision.prototype.datasets.utils._internal import FrozenMapping make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32) -class TestJIT: - FEATURE_TYPES = { - feature_type - for name, feature_type in features.__dict__.items() - if not name.startswith("_") - and isinstance(feature_type, type) - and issubclass(feature_type, features.Feature) - and feature_type is not features.Feature - } +class TestCommon: + FEATURE_TYPES, NON_DEFAULT_META_DATA = zip( + *( + (features.Image, FrozenMapping(color_space=features.ColorSpace._SENTINEL)), + (features.Label, FrozenMapping(category="category")), + ) + ) feature_types = pytest.mark.parametrize( "feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__ ) + feature_types_with_non_default_meta_data = pytest.mark.parametrize( + ("feature_type", "meta_data"), + [ + pytest.param(feature_type, FrozenMapping(meta_data), id=feature_type.__name__) + for feature_type, meta_data in zip(FEATURE_TYPES, NON_DEFAULT_META_DATA) + ], + ) - @feature_types - def test_identity(self, feature_type): - def identity(input): - return input + def test_consistency(self): + builtin_feature_types = { + feature_type + for name, feature_type in features.__dict__.items() + if not name.startswith("_") + and isinstance(feature_type, type) + and issubclass(feature_type, features.Feature) + and feature_type is not features.Feature + } + untested_feature_types = builtin_feature_types - set(self.FEATURE_TYPES) + if untested_feature_types: + raise AssertionError("FIXME") + + jit_fns = pytest.mark.parametrize( + "jit_fn", + [ + pytest.param(lambda fn, example_inputs: fn, id="no_jit"), + pytest.param(lambda fn, example_inputs: torch.jit.trace(fn, example_inputs), id="torch.jit.trace"), + pytest.param(lambda fn, example_inputs: torch.jit.script(fn), id="torch.jit.script"), + ], + ) - identity.__annotations__ = {"input": feature_type, "return": feature_type} + @jit_fns + @feature_types + def test_torch_function(self, jit_fn, feature_type): + def fn(input): + return input + 1 - scripted_fn = jit.script(identity) input = feature_type(make_tensor(())) - output = scripted_fn(input) - - assert output is input - - @feature_types - def test_torch_function(self, feature_type): - def any_operation_except_clone(input) -> torch.Tensor: - return input + 0 - any_operation_except_clone.__annotations__["input"] = feature_type + fn.__annotations__ = {"input": feature_type, "return": torch.Tensor} + fn = jit_fn(fn, input) - scripted_fn = jit.script(any_operation_except_clone) - input = feature_type(make_tensor(())) - output = scripted_fn(input) + output = fn(input) assert type(output) is torch.Tensor + assert_close(output, input + 1) + @jit_fns @feature_types - def test_clone(self, feature_type): - def clone(input): + def test_clone(self, jit_fn, feature_type): + def fn(input): return input.clone() - clone.__annotations__ = {"input": feature_type, "return": feature_type} + fn.__annotations__ = {"input": feature_type, "return": feature_type} + fn = jit_fn(fn) - scripted_fn = jit.script(clone) input = feature_type(make_tensor(())) - output = scripted_fn(input) + output = fn(input) assert type(output) is feature_type assert_close(output, input) + assert output._meta_data == input._meta_data + + @feature_types_with_non_default_meta_data + def test_serialization(self, tmpdir, feature_type, meta_data): + feature = feature_type(make_tensor(()), **meta_data) + file = tmpdir / "test_serialization.pt" + + torch.save(feature, str(file)) + loaded_feature = torch.load(str(file)) + + assert isinstance(loaded_feature, feature_type) + assert_close(loaded_feature, feature) + assert loaded_feature._meta_data == meta_data + + @feature_types + def test_repr(self, feature_type): + assert feature_type.__name__ in repr(feature_type(make_tensor(()))) diff --git a/torchvision/prototype/features.py b/torchvision/prototype/features.py index 3794d05bd2c..e7455770248 100644 --- a/torchvision/prototype/features.py +++ b/torchvision/prototype/features.py @@ -53,8 +53,13 @@ def __torch_function__( return cls(output, like=args[0]) + def __repr__(self) -> str: + return super().__repr__().replace("tensor", type(self).__name__) + class ColorSpace(enum.Enum): + # this is just for test purposes + _SENTINEL = -1 OTHER = 0 GRAYSCALE = 1 RGB = 3 From f2ac961525ab0d56f3e5d372fe6e76c2df546c2d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 25 Oct 2021 13:58:04 +0200 Subject: [PATCH 05/10] cleanup tests --- test/test_prototype_features.py | 55 +++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py index 2204e283118..5a1c7e09e0b 100644 --- a/test/test_prototype_features.py +++ b/test/test_prototype_features.py @@ -19,14 +19,29 @@ class TestCommon: feature_types = pytest.mark.parametrize( "feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__ ) - feature_types_with_non_default_meta_data = pytest.mark.parametrize( - ("feature_type", "meta_data"), + features = pytest.mark.parametrize( + "feature", [ - pytest.param(feature_type, FrozenMapping(meta_data), id=feature_type.__name__) + pytest.param(feature_type(make_tensor(()), **meta_data), id=feature_type.__name__) for feature_type, meta_data in zip(FEATURE_TYPES, NON_DEFAULT_META_DATA) ], ) + jit_fns = pytest.mark.parametrize( + "jit_fn", + [ + pytest.param(lambda fn, example_inputs: fn, id="no_jit"), + pytest.param( + lambda fn, example_inputs: torch.jit.trace(fn, example_inputs, check_trace=False), id="torch.jit.trace" + ), + pytest.param(lambda fn, example_inputs: torch.jit.script(fn), id="torch.jit.script"), + ], + ) + + @pytest.fixture + def data(self): + return make_tensor(()) + def test_consistency(self): builtin_feature_types = { feature_type @@ -40,14 +55,10 @@ def test_consistency(self): if untested_feature_types: raise AssertionError("FIXME") - jit_fns = pytest.mark.parametrize( - "jit_fn", - [ - pytest.param(lambda fn, example_inputs: fn, id="no_jit"), - pytest.param(lambda fn, example_inputs: torch.jit.trace(fn, example_inputs), id="torch.jit.trace"), - pytest.param(lambda fn, example_inputs: torch.jit.script(fn), id="torch.jit.script"), - ], - ) + @features + def test_meta_data_attribute_access(self, feature): + for name, value in feature._meta_data.items(): + assert getattr(feature, name) == feature._meta_data[name] @jit_fns @feature_types @@ -55,11 +66,10 @@ def test_torch_function(self, jit_fn, feature_type): def fn(input): return input + 1 + fn.__annotations__ = {"input": feature_type, "return": torch.Tensor} input = feature_type(make_tensor(())) - fn.__annotations__ = {"input": feature_type, "return": torch.Tensor} fn = jit_fn(fn, input) - output = fn(input) assert type(output) is torch.Tensor @@ -72,27 +82,26 @@ def fn(input): return input.clone() fn.__annotations__ = {"input": feature_type, "return": feature_type} - fn = jit_fn(fn) - input = feature_type(make_tensor(())) + + fn = jit_fn(fn, input) output = fn(input) assert type(output) is feature_type assert_close(output, input) assert output._meta_data == input._meta_data - @feature_types_with_non_default_meta_data - def test_serialization(self, tmpdir, feature_type, meta_data): - feature = feature_type(make_tensor(()), **meta_data) + @features + def test_serialization(self, tmpdir, feature): file = tmpdir / "test_serialization.pt" torch.save(feature, str(file)) loaded_feature = torch.load(str(file)) - assert isinstance(loaded_feature, feature_type) + assert isinstance(loaded_feature, type(feature)) assert_close(loaded_feature, feature) - assert loaded_feature._meta_data == meta_data + assert loaded_feature._meta_data == feature._meta_data - @feature_types - def test_repr(self, feature_type): - assert feature_type.__name__ in repr(feature_type(make_tensor(()))) + @features + def test_repr(self, feature): + assert type(feature).__name__ in repr(feature) From 1d2ee59289037176fb60c7bd32894c1fbfb6c15b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 3 Nov 2021 08:39:58 +0100 Subject: [PATCH 06/10] add BoundingBox feature --- test/test_prototype_features.py | 133 ++++++++++++---- torchvision/prototype/__init__.py | 1 + torchvision/prototype/datasets/_api.py | 2 +- .../prototype/datasets/utils/_dataset.py | 5 +- .../prototype/datasets/utils/_internal.py | 29 ---- torchvision/prototype/features.py | 119 -------------- torchvision/prototype/features/__init__.py | 4 + .../prototype/features/_bounding_box.py | 149 ++++++++++++++++++ torchvision/prototype/features/_feature.py | 85 ++++++++++ torchvision/prototype/features/_image.py | 41 +++++ torchvision/prototype/features/_label.py | 15 ++ torchvision/prototype/utils/__init__.py | 1 + torchvision/prototype/utils/_internal.py | 39 +++++ 13 files changed, 435 insertions(+), 188 deletions(-) delete mode 100644 torchvision/prototype/features.py create mode 100644 torchvision/prototype/features/__init__.py create mode 100644 torchvision/prototype/features/_bounding_box.py create mode 100644 torchvision/prototype/features/_feature.py create mode 100644 torchvision/prototype/features/_image.py create mode 100644 torchvision/prototype/features/_label.py create mode 100644 torchvision/prototype/utils/__init__.py create mode 100644 torchvision/prototype/utils/_internal.py diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py index 5a1c7e09e0b..abd3610e061 100644 --- a/test/test_prototype_features.py +++ b/test/test_prototype_features.py @@ -1,10 +1,12 @@ import functools +import itertools import pytest import torch from torch.testing import make_tensor as _make_tensor, assert_close from torchvision.prototype import features -from torchvision.prototype.datasets.utils._internal import FrozenMapping +from torchvision.prototype.utils._internal import sequence_to_str + make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32) @@ -12,8 +14,9 @@ class TestCommon: FEATURE_TYPES, NON_DEFAULT_META_DATA = zip( *( - (features.Image, FrozenMapping(color_space=features.ColorSpace._SENTINEL)), - (features.Label, FrozenMapping(category="category")), + (features.Image, dict(color_space=features.ColorSpace._SENTINEL)), + (features.Label, dict(category="category")), + (features.BoundingBox, dict(format=features.BoundingBoxFormat._SENTINEL, image_size=(-1, -1))), ) ) feature_types = pytest.mark.parametrize( @@ -27,65 +30,45 @@ class TestCommon: ], ) - jit_fns = pytest.mark.parametrize( - "jit_fn", - [ - pytest.param(lambda fn, example_inputs: fn, id="no_jit"), - pytest.param( - lambda fn, example_inputs: torch.jit.trace(fn, example_inputs, check_trace=False), id="torch.jit.trace" - ), - pytest.param(lambda fn, example_inputs: torch.jit.script(fn), id="torch.jit.script"), - ], - ) - @pytest.fixture def data(self): return make_tensor(()) def test_consistency(self): builtin_feature_types = { - feature_type + name for name, feature_type in features.__dict__.items() if not name.startswith("_") and isinstance(feature_type, type) and issubclass(feature_type, features.Feature) and feature_type is not features.Feature } - untested_feature_types = builtin_feature_types - set(self.FEATURE_TYPES) + untested_feature_types = builtin_feature_types - {feature_type.__name__ for feature_type in self.FEATURE_TYPES} if untested_feature_types: - raise AssertionError("FIXME") + raise AssertionError( + f"The feature(s) {sequence_to_str(sorted(untested_feature_types), separate_last='and ')} " + f"is/are exposed at `torchvision.prototype.features`, but is/are not tested by `TestCommon`. " + f"Please add it/them to `TestCommon.FEATURE_TYPES`." + ) @features def test_meta_data_attribute_access(self, feature): for name, value in feature._meta_data.items(): assert getattr(feature, name) == feature._meta_data[name] - @jit_fns @feature_types - def test_torch_function(self, jit_fn, feature_type): - def fn(input): - return input + 1 - - fn.__annotations__ = {"input": feature_type, "return": torch.Tensor} + def test_torch_function(self, feature_type): input = feature_type(make_tensor(())) - - fn = jit_fn(fn, input) - output = fn(input) + # This can be any Tensor operation besides clone + output = input + 1 assert type(output) is torch.Tensor assert_close(output, input + 1) - @jit_fns @feature_types - def test_clone(self, jit_fn, feature_type): - def fn(input): - return input.clone() - - fn.__annotations__ = {"input": feature_type, "return": feature_type} + def test_clone(self, feature_type): input = feature_type(make_tensor(())) - - fn = jit_fn(fn, input) - output = fn(input) + output = input.clone() assert type(output) is feature_type assert_close(output, input) @@ -105,3 +88,83 @@ def test_serialization(self, tmpdir, feature): @features def test_repr(self, feature): assert type(feature).__name__ in repr(feature) + + +class TestBoundingBox: + def make_bounding_box(self, *, format, image_size=(10, 10)): + if isinstance(format, str): + format = features.BoundingBoxFormat[format] + + height, width = image_size + + if format == features.BoundingBoxFormat.XYXY: + x1 = torch.randint(0, width // 2, ()) + y1 = torch.randint(0, height // 2, ()) + x2 = torch.randint(int(x1) + 1, width - int(x1), ()) + x1 + y2 = torch.randint(int(y1) + 1, height - int(y1), ()) + y1 + parts = (x1, y1, x2, y2) + elif format == features.BoundingBoxFormat.XYWH: + x = torch.randint(0, width // 2, ()) + y = torch.randint(0, height // 2, ()) + w = torch.randint(1, width - int(x), ()) + h = torch.randint(1, height - int(y), ()) + parts = (x, y, w, h) + elif format == features.BoundingBoxFormat.CXCYWH: + cx = torch.randint(1, width - 1, ()) + cy = torch.randint(1, height - 1, ()) + w = torch.randint(1, min(int(cx), width - int(cx)), ()) + h = torch.randint(1, min(int(cy), height - int(cy)), ()) + parts = (cx, cy, w, h) + + return features.BoundingBox.from_parts(*parts, format=format, image_size=image_size) + + @pytest.mark.parametrize(("format", "intermediate_format"), itertools.permutations(("xyxy", "xywh"), 2)) + def test_cycle_consistency(self, format, intermediate_format): + input = self.make_bounding_box(format=format) + a = input.convert(intermediate_format) + output = a.convert(format) + assert_close(input, output) + + +class TestJit: + def test_bounding_box(self): + def resize(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox: + old_height, old_width = input.image_size + new_height, new_width = size + + height_scale = new_height / old_height + width_scale = new_width / old_width + + old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts() + + new_x1 = old_x1 * width_scale + new_y1 = old_y1 * height_scale + + new_x2 = old_x2 * width_scale + new_y2 = old_y2 * height_scale + + return features.BoundingBox.from_parts( + new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=tuple(size.tolist()) + ) + + def horizontal_flip(input: features.BoundingBox) -> features.BoundingBox: + x, y, w, h = input.convert("xywh").to_parts() + x = input.image_size[1] - (x + w) + return features.BoundingBox.from_parts(x, y, w, h, like=input, format="xywh") + + def compose(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox: + return horizontal_flip(resize(input, size)).convert("xyxy") + + image_size = (8, 6) + input = features.BoundingBox([2, 4, 2, 4], format="cxcywh", image_size=image_size) + size = torch.tensor((4, 12)) + expected = features.BoundingBox([6, 1, 10, 3], format="xyxy", image_size=image_size) + + actual_eager = compose(input, size) + assert_close(actual_eager, expected) + + sample_inputs = (features.BoundingBox(torch.zeros((4,)), image_size=(10, 10)), torch.tensor((20, 5))) + actual_jit = torch.jit.trace(compose, sample_inputs, check_trace=False)(input, size) + print(actual_jit) + print(expected) + assert_close(actual_jit, expected) diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py index d4aa1883e78..e1be6c81f59 100644 --- a/torchvision/prototype/__init__.py +++ b/torchvision/prototype/__init__.py @@ -2,3 +2,4 @@ from . import features from . import models from . import transforms +from . import utils diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 8d6796b2c32..502e2ae5b60 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -6,7 +6,7 @@ from torchvision.prototype.datasets import home from torchvision.prototype.datasets.decoder import raw, pil from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType -from torchvision.prototype.datasets.utils._internal import add_suggestion +from torchvision.prototype.utils._internal import add_suggestion from . import _builtin diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index fa1b5e6478b..9f61c614e71 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -7,10 +7,7 @@ import torch from torch.utils.data import IterDataPipe -from torchvision.prototype.datasets.utils._internal import ( - add_suggestion, - sequence_to_str, -) +from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str from ._internal import FrozenBunch, make_repr from ._resource import OnlineResource diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index fdb6a17c46f..360cf7cdca6 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,6 +1,4 @@ -import collections.abc import csv -import difflib import enum import gzip import io @@ -10,7 +8,6 @@ import pathlib import textwrap from typing import ( - Collection, Sequence, Callable, Union, @@ -33,8 +30,6 @@ __all__ = [ "INFINITE_BUFFER_SIZE", "BUILTIN_DIR", - "sequence_to_str", - "add_suggestion", "make_repr", "FrozenMapping", "FrozenBunch", @@ -59,30 +54,6 @@ BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin" -def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: - if len(seq) == 1: - return f"'{seq[0]}'" - - return f"""'{"', '".join([str(item) for item in seq[:-1]])}', """ f"""{separate_last}'{seq[-1]}'.""" - - -def add_suggestion( - msg: str, - *, - word: str, - possibilities: Collection[str], - close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?", - alternative_hint: Callable[ - [Sequence[str]], str - ] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.", -) -> str: - if not isinstance(possibilities, collections.abc.Sequence): - possibilities = sorted(possibilities) - suggestions = difflib.get_close_matches(word, possibilities, 1) - hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities) - return f"{msg.strip()} {hint}" - - def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str: def to_str(sep: str) -> str: return sep.join([f"{key}={value}" for key, value in items]) diff --git a/torchvision/prototype/features.py b/torchvision/prototype/features.py deleted file mode 100644 index e7455770248..00000000000 --- a/torchvision/prototype/features.py +++ /dev/null @@ -1,119 +0,0 @@ -import enum -from typing import Any, Sequence, Mapping, Optional, Tuple, Type, Callable, TypeVar, Union, cast - -import torch -from torch._C import DisableTorchFunction, _TensorBase -from torchvision.prototype.datasets.utils._internal import FrozenMapping - -__all__ = ["Feature", "ColorSpace", "Image", "Label"] - -T = TypeVar("T", bound="Feature") - - -class Feature(torch.Tensor): - _meta_data: FrozenMapping[str, Any] - - def __new__( - cls: Type[T], - data: Any, - *, - like: Optional[T] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - meta_data: FrozenMapping[str, Any] = FrozenMapping(), - ) -> T: - if like is not None: - dtype = dtype or like.dtype - device = device or like.device - data = torch.as_tensor(data, dtype=dtype, device=device) - requires_grad = False - self = cast(T, torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad)) - - _meta_data = dict(like._meta_data) if like is not None else dict() - _meta_data.update(meta_data) - self._meta_data = FrozenMapping(_meta_data) - - for name in self._meta_data: - setattr(cls, name, property(lambda self: self._meta_data[name])) - - return self - - @classmethod - def __torch_function__( - cls, - func: Callable[..., torch.Tensor], - types: Tuple[Type[torch.Tensor], ...], - args: Sequence[Any] = (), - kwargs: Optional[Mapping[str, Any]] = None, - ) -> torch.Tensor: - with DisableTorchFunction(): - output = func(*args, **(kwargs or dict())) - if func is not torch.Tensor.clone: - return output - - return cls(output, like=args[0]) - - def __repr__(self) -> str: - return super().__repr__().replace("tensor", type(self).__name__) - - -class ColorSpace(enum.Enum): - # this is just for test purposes - _SENTINEL = -1 - OTHER = 0 - GRAYSCALE = 1 - RGB = 3 - - -class Image(Feature): - color_space: ColorSpace - - def __new__( - cls, - data: Any, - *, - like: Optional["Image"] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - color_space: Optional[Union[str, ColorSpace]] = None, - ) -> "Image": - if color_space is None: - color_space = cls.guess_color_space(data) if data is not None else ColorSpace.OTHER - elif isinstance(color_space, str): - color_space = ColorSpace[color_space.upper()] - - meta_data: FrozenMapping[str, Any] = FrozenMapping(color_space=color_space) - - return Feature.__new__(cls, data, like=like, dtype=dtype, device=device, meta_data=meta_data) - - @staticmethod - def guess_color_space(image: torch.Tensor) -> ColorSpace: - if image.ndim < 2: - return ColorSpace.OTHER - elif image.ndim == 2: - return ColorSpace.GRAYSCALE - - num_channels = image.shape[-3] - if num_channels == 1: - return ColorSpace.GRAYSCALE - elif num_channels == 3: - return ColorSpace.RGB - else: - return ColorSpace.OTHER - - -class Label(Feature): - category: Optional[str] - - def __new__( - cls, - data: Any, - *, - like: Optional["Label"] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - category: Optional[str] = None, - ) -> "Label": - return Feature.__new__( - cls, data, like=like, dtype=dtype, device=device, meta_data=FrozenMapping(category=category) - ) diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py new file mode 100644 index 00000000000..dfb455d8f5a --- /dev/null +++ b/torchvision/prototype/features/__init__.py @@ -0,0 +1,4 @@ +from ._bounding_box import BoundingBoxFormat, BoundingBox +from ._feature import Feature +from ._image import Image, ColorSpace +from ._label import Label diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py new file mode 100644 index 00000000000..a8e13544b83 --- /dev/null +++ b/torchvision/prototype/features/_bounding_box.py @@ -0,0 +1,149 @@ +import enum +import functools +from typing import Callable +from typing import Dict, Any, Optional +from typing import Union, Tuple + +import torch +from torchvision.prototype.utils._internal import StrEnum + +from ._feature import Feature, DEFAULT + + +class BoundingBoxFormat(StrEnum): + # this is just for test purposes + _SENTINEL = -1 + XYXY = enum.auto() + XYWH = enum.auto() + CXCYWH = enum.auto() + + +def to_parts(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return input.unbind(dim=-1) + + +def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor: + return torch.stack((a, b, c, d), dim=-1) + + +# FIXME: kwargs and name +def foo(part_converter: Callable[[Tuple[torch.Tensor, ...]], Tuple[torch.Tensor, ...]]): + def wrapper(input: torch.Tensor) -> torch.Tensor: + return from_parts(*part_converter(*to_parts(input))) + + return wrapper + + +@foo +def xywh_to_xyxy( + x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x1 = x + y1 = y + x2 = x + w + y2 = y + h + return x1, y1, x2, y2 + + +@foo +def xyxy_to_xywh( + x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = x1 + y = y1 + w = x2 - x1 + h = y2 - y1 + return x, y, w, h + + +@foo +def cxcywh_to_xyxy( + cx: torch.Tensor, cy: torch.Tensor, w: torch.Tensor, h: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x1 = cx - 0.5 * w + y1 = cy - 0.5 * h + x2 = cx + 0.5 * w + y2 = cy + 0.5 * h + return x1, y1, x2, y2 + + +@foo +def xyxy_to_cxcywh( + x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + return cx, cy, w, h + + +class BoundingBox(Feature): + formats = BoundingBoxFormat + format: BoundingBoxFormat + image_size: Tuple[int, int] + + @classmethod + def _parse_meta_data( + cls, + format: Union[str, BoundingBoxFormat] = DEFAULT, + image_size: Optional[Tuple[int, int]] = DEFAULT, + ) -> Dict[str, Tuple[Any, Any]]: + if isinstance(format, str): + format = BoundingBoxFormat[format] + return dict( + format=(format, BoundingBoxFormat.XYXY), + image_size=(image_size, functools.partial(cls._guess_image_size, format=format)), + ) + + _TO_XYXY_MAP = { + BoundingBoxFormat.XYWH: xywh_to_xyxy, + BoundingBoxFormat.CXCYWH: cxcywh_to_xyxy, + } + _FROM_XYXY_MAP = { + BoundingBoxFormat.XYWH: xyxy_to_xywh, + BoundingBoxFormat.CXCYWH: xyxy_to_cxcywh, + } + + @classmethod + def _guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> Tuple[int, int]: + if format not in (BoundingBoxFormat.XYWH, BoundingBoxFormat.CXCYWH): + if format != BoundingBoxFormat.XYXY: + data = cls._TO_XYXY_MAP[format](data) + data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data) + *_, w, h = to_parts(data) + return int(h.ceil()), int(w.ceil()) + + @classmethod + def from_parts( + cls, + a, + b, + c, + d, + *, + like: Optional["BoundingBox"] = None, + format: Union[str, BoundingBoxFormat] = DEFAULT, + image_size: Optional[Tuple[int, int]] = DEFAULT, + ) -> "BoundingBox": + return cls(from_parts(a, b, c, d), like=like, image_size=image_size, format=format) + + def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return to_parts(self) + + def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": + if isinstance(format, str): + format = BoundingBoxFormat[format] + + if format == self.format: + return self.clone() + + data = self + + if self.format != BoundingBoxFormat.XYXY: + data = self._TO_XYXY_MAP[self.format](data) + + if format != BoundingBoxFormat.XYXY: + data = self._FROM_XYXY_MAP[format](data) + + return BoundingBox(data, like=self, format=format) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py new file mode 100644 index 00000000000..e3090e4d857 --- /dev/null +++ b/torchvision/prototype/features/_feature.py @@ -0,0 +1,85 @@ +from typing import Tuple +from typing import cast, TypeVar, Set, Dict, Any, Callable, Optional, Mapping, Type, Sequence + +import torch +from torch._C import _TensorBase, DisableTorchFunction +from torchvision.prototype.utils._internal import add_suggestion + + +F = TypeVar("F", bound="Feature") + + +DEFAULT = object() + + +class Feature(torch.Tensor): + _META_ATTRS: Set[str] + _meta_data: Dict[str, Any] + + def __init_subclass__(cls): + if not hasattr(cls, "_META_ATTRS"): + cls._META_ATTRS = { + attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_") + } + + for attr in cls._META_ATTRS: + if not hasattr(cls, attr): + setattr(cls, attr, property(lambda self, attr=attr: self._meta_data[attr])) + + def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs): + unknown_meta_attrs = kwargs.keys() - cls._META_ATTRS + if unknown_meta_attrs: + unknown_meta_attr = sorted(unknown_meta_attrs)[0] + raise TypeError( + add_suggestion( + f"{cls.__name__}() got unexpected keyword '{unknown_meta_attr}'.", + word=unknown_meta_attr, + possibilities=cls._META_ATTRS, + ) + ) + + if like is not None: + dtype = dtype or like.dtype + device = device or like.device + data = cls._to_tensor(data, dtype=dtype, device=device) + requires_grad = False + self = cast(F, torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad)) + + meta_data = dict() + for attr, (explicit, fallback) in cls._parse_meta_data(**kwargs).items(): + if explicit is not DEFAULT: + value = explicit + elif like is not None: + value = getattr(like, attr) + else: + value = fallback(data) if callable(fallback) else fallback + meta_data[attr] = value + self._meta_data = meta_data + + return self + + @classmethod + def _to_tensor(cls, data, *, dtype, device): + return torch.as_tensor(data, dtype=dtype, device=device) + + @classmethod + def _parse_meta_data(cls) -> Dict[str, Tuple[Any, Any]]: + return dict() + + @classmethod + def __torch_function__( + cls, + func: Callable[..., torch.Tensor], + types: Tuple[Type[torch.Tensor], ...], + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, + ) -> torch.Tensor: + with DisableTorchFunction(): + output = func(*args, **(kwargs or dict())) + if func is not torch.Tensor.clone: + return output + + return cls(output, like=args[0]) + + def __repr__(self): + return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py new file mode 100644 index 00000000000..48be7bd0de2 --- /dev/null +++ b/torchvision/prototype/features/_image.py @@ -0,0 +1,41 @@ +from typing import Dict, Any +from typing import Union, Tuple + +import torch +from torchvision.prototype.utils._internal import StrEnum + +from ._feature import Feature, DEFAULT + + +class ColorSpace(StrEnum): + # this is just for test purposes + _SENTINEL = -1 + OTHER = 0 + GRAYSCALE = 1 + RGB = 3 + + +class Image(Feature): + color_spaces = ColorSpace + color_space: ColorSpace + + @classmethod + def _parse_meta_data(cls, color_space: Union[str, ColorSpace] = DEFAULT) -> Dict[str, Tuple[Any, Any]]: + if isinstance(color_space, str): + color_space = ColorSpace[color_space] + return dict(color_space=(color_space, cls.guess_color_space)) + + @staticmethod + def guess_color_space(data: torch.Tensor) -> ColorSpace: + if data.ndim < 2: + return ColorSpace.OTHER + elif data.ndim == 2: + return ColorSpace.GRAYSCALE + + num_channels = data.shape[-3] + if num_channels == 1: + return ColorSpace.GRAYSCALE + elif num_channels == 3: + return ColorSpace.RGB + else: + return ColorSpace.OTHER diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py new file mode 100644 index 00000000000..cb23d46fa15 --- /dev/null +++ b/torchvision/prototype/features/_label.py @@ -0,0 +1,15 @@ +from typing import Dict, Any, Optional +from typing import Tuple + +from ._feature import Feature, DEFAULT + + +class Label(Feature): + category: Optional[str] + + @classmethod + def _parse_meta_data( + cls, + category: Optional[str] = DEFAULT, + ) -> Dict[str, Tuple[Any, Any]]: + return dict(category=(category, None)) diff --git a/torchvision/prototype/utils/__init__.py b/torchvision/prototype/utils/__init__.py new file mode 100644 index 00000000000..e85a582b483 --- /dev/null +++ b/torchvision/prototype/utils/__init__.py @@ -0,0 +1 @@ +from . import _internal diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py new file mode 100644 index 00000000000..d4408fd413c --- /dev/null +++ b/torchvision/prototype/utils/_internal.py @@ -0,0 +1,39 @@ +import collections.abc +import difflib +import enum +from typing import Sequence, Collection, Callable + +__all__ = ["StrEnum", "sequence_to_str", "add_suggestion"] + + +class StrEnumMeta(enum.EnumMeta): + def __getitem__(self, item): + return super().__getitem__(item.upper() if isinstance(item, str) else item) + + +class StrEnum(enum.Enum, metaclass=StrEnumMeta): + pass + + +def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: + if len(seq) == 1: + return f"'{seq[0]}'" + + return f"""'{"', '".join([str(item) for item in seq[:-1]])}', """ f"""{separate_last}'{seq[-1]}'.""" + + +def add_suggestion( + msg: str, + *, + word: str, + possibilities: Collection[str], + close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?", + alternative_hint: Callable[ + [Sequence[str]], str + ] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.", +) -> str: + if not isinstance(possibilities, collections.abc.Sequence): + possibilities = sorted(possibilities) + suggestions = difflib.get_close_matches(word, possibilities, 1) + hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities) + return f"{msg.strip()} {hint}" From 7e775fa2baa57f812fa4f3c6398ff2e5a6432c5b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 3 Nov 2021 10:16:32 +0100 Subject: [PATCH 07/10] mypy --- .../prototype/datasets/_builtin/mnist.py | 3 +-- .../prototype/features/_bounding_box.py | 23 +++++++++++-------- torchvision/prototype/features/_feature.py | 5 ++-- torchvision/prototype/features/_image.py | 8 ++++--- torchvision/prototype/features/_label.py | 5 ++-- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index e3ba703ecd0..bdd71f86df5 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -6,7 +6,7 @@ import pathlib import string import sys -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast import torch from torchdata.datapipes.iter import ( @@ -119,7 +119,6 @@ def _collate_and_decode( ) -> Dict[str, Any]: image, label = data - image: Union[Image, io.IOBase] if decoder is raw: image = Image(image) else: diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index a8e13544b83..be86a8f28c8 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -1,8 +1,6 @@ import enum import functools -from typing import Callable -from typing import Dict, Any, Optional -from typing import Union, Tuple +from typing import Callable, Union, Tuple, Dict, Any, Optional, cast import torch from torchvision.prototype.utils._internal import StrEnum @@ -19,7 +17,7 @@ class BoundingBoxFormat(StrEnum): def to_parts(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return input.unbind(dim=-1) + return input.unbind(dim=-1) # type: ignore[return-value] def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor: @@ -27,7 +25,12 @@ def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tenso # FIXME: kwargs and name -def foo(part_converter: Callable[[Tuple[torch.Tensor, ...]], Tuple[torch.Tensor, ...]]): +def foo( + part_converter: Callable[ + [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ] +): def wrapper(input: torch.Tensor) -> torch.Tensor: return from_parts(*part_converter(*to_parts(input))) @@ -86,8 +89,8 @@ class BoundingBox(Feature): @classmethod def _parse_meta_data( cls, - format: Union[str, BoundingBoxFormat] = DEFAULT, - image_size: Optional[Tuple[int, int]] = DEFAULT, + format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment] + image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment] ) -> Dict[str, Tuple[Any, Any]]: if isinstance(format, str): format = BoundingBoxFormat[format] @@ -123,8 +126,8 @@ def from_parts( d, *, like: Optional["BoundingBox"] = None, - format: Union[str, BoundingBoxFormat] = DEFAULT, - image_size: Optional[Tuple[int, int]] = DEFAULT, + format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment] + image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment] ) -> "BoundingBox": return cls(from_parts(a, b, c, d), like=like, image_size=image_size, format=format) @@ -136,7 +139,7 @@ def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": format = BoundingBoxFormat[format] if format == self.format: - return self.clone() + return cast(BoundingBox, self.clone()) data = self diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index e3090e4d857..9202e1fa986 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,5 +1,4 @@ -from typing import Tuple -from typing import cast, TypeVar, Set, Dict, Any, Callable, Optional, Mapping, Type, Sequence +from typing import Tuple, cast, TypeVar, Set, Dict, Any, Callable, Optional, Mapping, Type, Sequence import torch from torch._C import _TensorBase, DisableTorchFunction @@ -43,7 +42,7 @@ def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs): device = device or like.device data = cls._to_tensor(data, dtype=dtype, device=device) requires_grad = False - self = cast(F, torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad)) + self = torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad) meta_data = dict() for attr, (explicit, fallback) in cls._parse_meta_data(**kwargs).items(): diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 48be7bd0de2..a8eab249997 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -1,5 +1,4 @@ -from typing import Dict, Any -from typing import Union, Tuple +from typing import Dict, Any, Union, Tuple import torch from torchvision.prototype.utils._internal import StrEnum @@ -20,7 +19,10 @@ class Image(Feature): color_space: ColorSpace @classmethod - def _parse_meta_data(cls, color_space: Union[str, ColorSpace] = DEFAULT) -> Dict[str, Tuple[Any, Any]]: + def _parse_meta_data( + cls, + color_space: Union[str, ColorSpace] = DEFAULT, # type: ignore[assignment] + ) -> Dict[str, Tuple[Any, Any]]: if isinstance(color_space, str): color_space = ColorSpace[color_space] return dict(color_space=(color_space, cls.guess_color_space)) diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index cb23d46fa15..ebdc6bbbc26 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -1,5 +1,4 @@ -from typing import Dict, Any, Optional -from typing import Tuple +from typing import Dict, Any, Optional, Tuple from ._feature import Feature, DEFAULT @@ -10,6 +9,6 @@ class Label(Feature): @classmethod def _parse_meta_data( cls, - category: Optional[str] = DEFAULT, + category: Optional[str] = DEFAULT, # type: ignore[assignment] ) -> Dict[str, Tuple[Any, Any]]: return dict(category=(category, None)) From 36ac55b0997f4982949d8f4a92a831bd93c12473 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 4 Nov 2021 10:06:53 +0100 Subject: [PATCH 08/10] xfail torchscript tests for now --- test/test_prototype_features.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py index abd3610e061..592cfb86f7c 100644 --- a/test/test_prototype_features.py +++ b/test/test_prototype_features.py @@ -126,6 +126,9 @@ def test_cycle_consistency(self, format, intermediate_format): assert_close(input, output) +# For now, tensor subclasses with additional meta data do not work with torchscript. +# See https://github.com/pytorch/vision/pull/4721#discussion_r741676037. +@pytest.mark.xfail class TestJit: def test_bounding_box(self): def resize(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox: @@ -165,6 +168,4 @@ def compose(input: features.BoundingBox, size: torch.Tensor) -> features.Boundin sample_inputs = (features.BoundingBox(torch.zeros((4,)), image_size=(10, 10)), torch.tensor((20, 5))) actual_jit = torch.jit.trace(compose, sample_inputs, check_trace=False)(input, size) - print(actual_jit) - print(expected) assert_close(actual_jit, expected) From 42b6925792fe79d39261923ec36f0af68775c0ff Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 4 Nov 2021 10:27:34 +0100 Subject: [PATCH 09/10] cleanup --- test/test_prototype_features.py | 82 ++++++++++--------- .../prototype/features/_bounding_box.py | 18 ++-- 2 files changed, 54 insertions(+), 46 deletions(-) diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py index 592cfb86f7c..e4a178e3594 100644 --- a/test/test_prototype_features.py +++ b/test/test_prototype_features.py @@ -11,6 +11,46 @@ make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32) +def make_bounding_box(*, format="xyxy", image_size=(10, 10)): + if isinstance(format, str): + format = features.BoundingBoxFormat[format] + + height, width = image_size + + if format == features.BoundingBoxFormat.XYXY: + x1 = torch.randint(0, width // 2, ()) + y1 = torch.randint(0, height // 2, ()) + x2 = torch.randint(int(x1) + 1, width - int(x1), ()) + x1 + y2 = torch.randint(int(y1) + 1, height - int(y1), ()) + y1 + parts = (x1, y1, x2, y2) + elif format == features.BoundingBoxFormat.XYWH: + x = torch.randint(0, width // 2, ()) + y = torch.randint(0, height // 2, ()) + w = torch.randint(1, width - int(x), ()) + h = torch.randint(1, height - int(y), ()) + parts = (x, y, w, h) + elif format == features.BoundingBoxFormat.CXCYWH: + cx = torch.randint(1, width - 1, ()) + cy = torch.randint(1, height - 1, ()) + w = torch.randint(1, min(int(cx), width - int(cx)), ()) + h = torch.randint(1, min(int(cy), height - int(cy)), ()) + parts = (cx, cy, w, h) + else: # format == features.BoundingBoxFormat._SENTINEL: + parts = make_tensor((4,)).unbind() + + return features.BoundingBox.from_parts(*parts, format=format, image_size=image_size) + + +MAKE_DATA_MAP = { + features.BoundingBox: make_bounding_box, +} + + +def make_feature(feature_type, **meta_data): + maker = MAKE_DATA_MAP.get(feature_type, lambda **meta_data: feature_type(make_tensor(()), **meta_data)) + return maker(**meta_data) + + class TestCommon: FEATURE_TYPES, NON_DEFAULT_META_DATA = zip( *( @@ -25,15 +65,11 @@ class TestCommon: features = pytest.mark.parametrize( "feature", [ - pytest.param(feature_type(make_tensor(()), **meta_data), id=feature_type.__name__) + pytest.param(make_feature(feature_type, **meta_data), id=feature_type.__name__) for feature_type, meta_data in zip(FEATURE_TYPES, NON_DEFAULT_META_DATA) ], ) - @pytest.fixture - def data(self): - return make_tensor(()) - def test_consistency(self): builtin_feature_types = { name @@ -58,7 +94,7 @@ def test_meta_data_attribute_access(self, feature): @feature_types def test_torch_function(self, feature_type): - input = feature_type(make_tensor(())) + input = make_feature(feature_type) # This can be any Tensor operation besides clone output = input + 1 @@ -67,7 +103,7 @@ def test_torch_function(self, feature_type): @feature_types def test_clone(self, feature_type): - input = feature_type(make_tensor(())) + input = make_feature(feature_type) output = input.clone() assert type(output) is feature_type @@ -91,38 +127,10 @@ def test_repr(self, feature): class TestBoundingBox: - def make_bounding_box(self, *, format, image_size=(10, 10)): - if isinstance(format, str): - format = features.BoundingBoxFormat[format] - - height, width = image_size - - if format == features.BoundingBoxFormat.XYXY: - x1 = torch.randint(0, width // 2, ()) - y1 = torch.randint(0, height // 2, ()) - x2 = torch.randint(int(x1) + 1, width - int(x1), ()) + x1 - y2 = torch.randint(int(y1) + 1, height - int(y1), ()) + y1 - parts = (x1, y1, x2, y2) - elif format == features.BoundingBoxFormat.XYWH: - x = torch.randint(0, width // 2, ()) - y = torch.randint(0, height // 2, ()) - w = torch.randint(1, width - int(x), ()) - h = torch.randint(1, height - int(y), ()) - parts = (x, y, w, h) - elif format == features.BoundingBoxFormat.CXCYWH: - cx = torch.randint(1, width - 1, ()) - cy = torch.randint(1, height - 1, ()) - w = torch.randint(1, min(int(cx), width - int(cx)), ()) - h = torch.randint(1, min(int(cy), height - int(cy)), ()) - parts = (cx, cy, w, h) - - return features.BoundingBox.from_parts(*parts, format=format, image_size=image_size) - @pytest.mark.parametrize(("format", "intermediate_format"), itertools.permutations(("xyxy", "xywh"), 2)) def test_cycle_consistency(self, format, intermediate_format): - input = self.make_bounding_box(format=format) - a = input.convert(intermediate_format) - output = a.convert(format) + input = make_bounding_box(format=format) + output = input.convert(intermediate_format).convert(format) assert_close(input, output) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index be86a8f28c8..c3638e41c10 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -24,8 +24,7 @@ def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tenso return torch.stack((a, b, c, d), dim=-1) -# FIXME: kwargs and name -def foo( +def format_converter_wrapper( part_converter: Callable[ [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], @@ -37,7 +36,7 @@ def wrapper(input: torch.Tensor) -> torch.Tensor: return wrapper -@foo +@format_converter_wrapper def xywh_to_xyxy( x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -48,7 +47,7 @@ def xywh_to_xyxy( return x1, y1, x2, y2 -@foo +@format_converter_wrapper def xyxy_to_xywh( x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -59,7 +58,7 @@ def xyxy_to_xywh( return x, y, w, h -@foo +@format_converter_wrapper def cxcywh_to_xyxy( cx: torch.Tensor, cy: torch.Tensor, w: torch.Tensor, h: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -70,7 +69,7 @@ def cxcywh_to_xyxy( return x1, y1, x2, y2 -@foo +@format_converter_wrapper def xyxy_to_cxcywh( x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -94,9 +93,10 @@ def _parse_meta_data( ) -> Dict[str, Tuple[Any, Any]]: if isinstance(format, str): format = BoundingBoxFormat[format] + format_fallback = BoundingBoxFormat.XYXY return dict( - format=(format, BoundingBoxFormat.XYXY), - image_size=(image_size, functools.partial(cls._guess_image_size, format=format)), + format=(format, format_fallback), + image_size=(image_size, functools.partial(cls.guess_image_size, format=format_fallback)), ) _TO_XYXY_MAP = { @@ -109,7 +109,7 @@ def _parse_meta_data( } @classmethod - def _guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> Tuple[int, int]: + def guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> Tuple[int, int]: if format not in (BoundingBoxFormat.XYWH, BoundingBoxFormat.CXCYWH): if format != BoundingBoxFormat.XYXY: data = cls._TO_XYXY_MAP[format](data) From 21dd6ea6d19272b741e9a4e7534867f8dd3a33fc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 4 Nov 2021 10:42:18 +0100 Subject: [PATCH 10/10] fix imports --- test/builtin_dataset_mocks.py | 2 +- test/test_prototype_builtin_datasets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 1dd25998678..c6dfbff35cf 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -15,7 +15,7 @@ from torchdata.datapipes.iter import IterDataPipe from torchvision.prototype import datasets from torchvision.prototype.datasets._api import find -from torchvision.prototype.datasets.utils._internal import add_suggestion +from torchvision.prototype.utils._internal import add_suggestion make_tensor = functools.partial(_make_tensor, device="cpu") diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 24a79ef4e46..24120a73856 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -5,7 +5,7 @@ import pytest from torchdata.datapipes.iter import IterDataPipe from torchvision.prototype import datasets -from torchvision.prototype.datasets.utils._internal import sequence_to_str +from torchvision.prototype.utils._internal import sequence_to_str _loaders = []