diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 0210a4dacec..a244cf5d28a 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1386,8 +1386,8 @@ def cub200(info, root, config): return num_samples_map[config.split] -# @register_mock -def eurosat(info, root, config): +@register_mock(configs=[dict()]) +def eurosat(root, config): data_folder = root / "2750" data_folder.mkdir(parents=True) diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py index 336f35de968..00d6a04f320 100644 --- a/torchvision/prototype/datasets/_builtin/eurosat.py +++ b/torchvision/prototype/datasets/_builtin/eurosat.py @@ -1,31 +1,44 @@ import pathlib -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import EncodedImage, Label +from .._api import register_dataset, register_info -class EuroSAT(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "eurosat", - homepage="https://github.com/phelber/eurosat", - categories=( - "AnnualCrop", - "Forest", - "HerbaceousVegetation", - "Highway", - "Industrial," "Pasture", - "PermanentCrop", - "Residential", - "River", - "SeaLake", - ), +NAME = "eurosat" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict( + categories=( + "AnnualCrop", + "Forest", + "HerbaceousVegetation", + "Highway", + "Industrial," "Pasture", + "PermanentCrop", + "Residential", + "River", + "SeaLake", ) + ) + - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class EuroSAT(Dataset2): + """EuroSAT Dataset. + homepage="https://github.com/phelber/eurosat", + """ + + def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: return [ HttpResource( "https://madm.dfki.de/files/sentinel/EuroSAT.zip", @@ -37,15 +50,16 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).parent.name return dict( - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 27_000