diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index cc8568154ed..3c5c51f612d 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1044,9 +1044,9 @@ def fer2013(info, root, config): return num_samples -# @register_mock -def gtsrb(info, root, config): - num_examples_per_class = 5 if config.split == "train" else 3 +@register_mock(configs=combinations_grid(split=("train", "test"))) +def gtsrb(root, config): + num_examples_per_class = 5 if config["split"] == "train" else 3 classes = ("00000", "00042", "00012") num_examples = num_examples_per_class * len(classes) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index c08d8947292..fa29f3be780 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -1,11 +1,9 @@ import pathlib -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, OnlineResource, HttpResource, ) @@ -17,15 +15,31 @@ ) from torchvision.prototype.features import Label, BoundingBox, EncodedImage +from .._api import register_dataset, register_info + +NAME = "gtsrb" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict( + categories=[f"{label:05d}" for label in range(43)], + ) -class GTSRB(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "gtsrb", - homepage="https://benchmark.ini.rub.de", - categories=[f"{label:05d}" for label in range(43)], - valid_options=dict(split=("train", "test")), - ) + +@register_dataset(NAME) +class GTSRB(Dataset2): + """GTSRB Dataset + + homepage="https://benchmark.ini.rub.de" + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" _URLS = { @@ -39,10 +53,10 @@ def _make_info(self) -> DatasetInfo: "test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - rsrcs: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUMS[config.split])] + def _resources(self) -> List[OnlineResource]: + rsrcs: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUMS[self._split])] - if config.split == "test": + if self._split == "test": rsrcs.append( HttpResource( self._URLS["test_ground_truth"], @@ -74,14 +88,12 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[ return { "path": path, "image": EncodedImage.from_file(buffer), - "label": Label(label, categories=self.categories), + "label": Label(label, categories=self._categories), "bounding_box": bounding_box, } - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: - if config.split == "train": + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + if self._split == "train": images_dp, ann_dp = Demultiplexer( resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) @@ -98,3 +110,6 @@ def _make_datapipe( dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 26_640 if self._split == "train" else 12_630