From dc411451a8125e53b6805287038b53f83739fae3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 5 Apr 2022 16:26:43 +0100 Subject: [PATCH 1/6] Migrate GTSRB prototype dataset --- test/builtin_dataset_mocks.py | 6 +-- .../prototype/datasets/_builtin/gtsrb.py | 47 +++++++++++-------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index c4f51463e34..1ede75de076 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1030,9 +1030,9 @@ def fer2013(info, root, config): return num_samples -# @register_mock -def gtsrb(info, root, config): - num_examples_per_class = 5 if config.split == "train" else 3 +@register_mock(configs=combinations_grid(split=("train", "test"))) +def gtsrb(root, config): + num_examples_per_class = 5 if config["split"] == "train" else 3 classes = ("00000", "00042", "00012") num_examples = num_examples_per_class * len(classes) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index c08d8947292..8d01afce928 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -1,11 +1,9 @@ import pathlib -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, OnlineResource, HttpResource, ) @@ -16,17 +14,26 @@ INFINITE_BUFFER_SIZE, ) from torchvision.prototype.features import Label, BoundingBox, EncodedImage +from .._api import register_dataset, register_info +NAME = "gtsrb" -class GTSRB(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "gtsrb", - homepage="https://benchmark.ini.rub.de", - categories=[f"{label:05d}" for label in range(43)], - valid_options=dict(split=("train", "test")), +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict( + categories=[f"{label:05d}" for label in range(43)], ) +@register_dataset(NAME) +class GTSRB(Dataset2): + """GTSRB Dataset + + homepage="https://benchmark.ini.rub.de" + """ + def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + super().__init__(root, skip_integrity_check=skip_integrity_check) + _URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" _URLS = { "train": f"{_URL_ROOT}GTSRB-Training_fixed.zip", @@ -39,10 +46,10 @@ def _make_info(self) -> DatasetInfo: "test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - rsrcs: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUMS[config.split])] + def _resources(self) -> List[OnlineResource]: + rsrcs: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUMS[self._split])] - if config.split == "test": + if self._split == "test": rsrcs.append( HttpResource( self._URLS["test_ground_truth"], @@ -74,14 +81,12 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[ return { "path": path, "image": EncodedImage.from_file(buffer), - "label": Label(label, categories=self.categories), + "label": Label(label), "bounding_box": bounding_box, } - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: - if config.split == "train": + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + if self._split == "train": images_dp, ann_dp = Demultiplexer( resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) @@ -98,3 +103,7 @@ def _make_datapipe( dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self): + # TODO: Implement len + return 4 \ No newline at end of file From 1f0d84404f6a6d69edd92e17f2f569a586f54dcb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 5 Apr 2022 16:29:32 +0100 Subject: [PATCH 2/6] ufmt --- torchvision/prototype/datasets/_builtin/gtsrb.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index 8d01afce928..c3b44f2e424 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -14,15 +14,18 @@ INFINITE_BUFFER_SIZE, ) from torchvision.prototype.features import Label, BoundingBox, EncodedImage + from .._api import register_dataset, register_info NAME = "gtsrb" + @register_info(NAME) def _info() -> Dict[str, Any]: return dict( categories=[f"{label:05d}" for label in range(43)], - ) + ) + @register_dataset(NAME) class GTSRB(Dataset2): @@ -30,7 +33,10 @@ class GTSRB(Dataset2): homepage="https://benchmark.ini.rub.de" """ - def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False) -> None: + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: self._split = self._verify_str_arg(split, "split", {"train", "test"}) super().__init__(root, skip_integrity_check=skip_integrity_check) @@ -106,4 +112,4 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, def __len__(self): # TODO: Implement len - return 4 \ No newline at end of file + return 4 From 71a9fc673522e41a5c80dd8e9d4ad784beda1e35 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 10:28:03 +0100 Subject: [PATCH 3/6] Address comments --- torchvision/prototype/datasets/_builtin/gtsrb.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index c3b44f2e424..02399b67bfa 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -38,6 +38,7 @@ def __init__( self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False ) -> None: self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" @@ -87,7 +88,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[ return { "path": path, "image": EncodedImage.from_file(buffer), - "label": Label(label), + "label": Label(label, categories=self._categories), "bounding_box": bounding_box, } @@ -111,5 +112,7 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, return Mapper(dp, self._prepare_sample) def __len__(self): - # TODO: Implement len - return 4 + return { + "train": 26_640, + "test": 12_630, + }[self._split] From 39ef6b211945899bab15f005b68c36668f85868d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 10:31:05 +0100 Subject: [PATCH 4/6] Apparently mypy doesn't know that __len__ returns ints. How cute. --- torchvision/prototype/datasets/_builtin/gtsrb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index 02399b67bfa..bd98e3b7c47 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -111,7 +111,7 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, return Mapper(dp, self._prepare_sample) - def __len__(self): + def __len__(self) -> int: return { "train": 26_640, "test": 12_630, From 22d57a65e98c2f11d9794646fc3a352fb156705e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 10:35:51 +0100 Subject: [PATCH 5/6] why is the CI not triggered?? From b50f7ba4efdd5de17dc0b9c8cdab6e2eb6945cc5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:03:28 +0100 Subject: [PATCH 6/6] Update torchvision/prototype/datasets/_builtin/gtsrb.py Co-authored-by: Philip Meier --- torchvision/prototype/datasets/_builtin/gtsrb.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index bd98e3b7c47..fa29f3be780 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -112,7 +112,4 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, return Mapper(dp, self._prepare_sample) def __len__(self) -> int: - return { - "train": 26_640, - "test": 12_630, - }[self._split] + return 26_640 if self._split == "train" else 12_630