Skip to content

Commit 8194b17

Browse files
authored
Migrate EuroSAT prototype dataset (#5760)
1 parent ebe9006 commit 8194b17

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

test/builtin_dataset_mocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,8 +1393,8 @@ def cub200(info, root, config):
13931393
return num_samples_map[config.split]
13941394

13951395

1396-
# @register_mock
1397-
def eurosat(info, root, config):
1396+
@register_mock(configs=[dict()])
1397+
def eurosat(root, config):
13981398
data_folder = root / "2750"
13991399
data_folder.mkdir(parents=True)
14001400

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,44 @@
11
import pathlib
2-
from typing import Any, Dict, List, Tuple
2+
from typing import Any, Dict, List, Tuple, Union
33

44
from torchdata.datapipes.iter import IterDataPipe, Mapper
5-
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
5+
from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource
66
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
77
from torchvision.prototype.features import EncodedImage, Label
88

9+
from .._api import register_dataset, register_info
910

10-
class EuroSAT(Dataset):
11-
def _make_info(self) -> DatasetInfo:
12-
return DatasetInfo(
13-
"eurosat",
14-
homepage="https://github.com/phelber/eurosat",
15-
categories=(
16-
"AnnualCrop",
17-
"Forest",
18-
"HerbaceousVegetation",
19-
"Highway",
20-
"Industrial," "Pasture",
21-
"PermanentCrop",
22-
"Residential",
23-
"River",
24-
"SeaLake",
25-
),
11+
NAME = "eurosat"
12+
13+
14+
@register_info(NAME)
15+
def _info() -> Dict[str, Any]:
16+
return dict(
17+
categories=(
18+
"AnnualCrop",
19+
"Forest",
20+
"HerbaceousVegetation",
21+
"Highway",
22+
"Industrial," "Pasture",
23+
"PermanentCrop",
24+
"Residential",
25+
"River",
26+
"SeaLake",
2627
)
28+
)
29+
2730

28-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
31+
@register_dataset(NAME)
32+
class EuroSAT(Dataset2):
33+
"""EuroSAT Dataset.
34+
homepage="https://github.com/phelber/eurosat",
35+
"""
36+
37+
def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
38+
self._categories = _info()["categories"]
39+
super().__init__(root, skip_integrity_check=skip_integrity_check)
40+
41+
def _resources(self) -> List[OnlineResource]:
2942
return [
3043
HttpResource(
3144
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
@@ -37,15 +50,16 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
3750
path, buffer = data
3851
category = pathlib.Path(path).parent.name
3952
return dict(
40-
label=Label.from_category(category, categories=self.categories),
53+
label=Label.from_category(category, categories=self._categories),
4154
path=path,
4255
image=EncodedImage.from_file(buffer),
4356
)
4457

45-
def _make_datapipe(
46-
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
47-
) -> IterDataPipe[Dict[str, Any]]:
58+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
4859
dp = resource_dps[0]
4960
dp = hint_shuffling(dp)
5061
dp = hint_sharding(dp)
5162
return Mapper(dp, self._prepare_sample)
63+
64+
def __len__(self) -> int:
65+
return 27_000

0 commit comments

Comments
 (0)