1
1
import pathlib
2
- from typing import Any , Dict , List , Tuple
2
+ from typing import Any , Dict , List , Tuple , Union
3
3
4
4
from torchdata .datapipes .iter import IterDataPipe , Mapper
5
- from torchvision .prototype .datasets .utils import Dataset , DatasetConfig , DatasetInfo , HttpResource , OnlineResource
5
+ from torchvision .prototype .datasets .utils import Dataset2 , HttpResource , OnlineResource
6
6
from torchvision .prototype .datasets .utils ._internal import hint_sharding , hint_shuffling
7
7
from torchvision .prototype .features import EncodedImage , Label
8
8
9
+ from .._api import register_dataset , register_info
9
10
10
- class EuroSAT ( Dataset ):
11
- def _make_info ( self ) -> DatasetInfo :
12
- return DatasetInfo (
13
- "eurosat" ,
14
- homepage = "https://github.com/phelber/eurosat" ,
15
- categories = (
16
- "AnnualCrop" ,
17
- "Forest " ,
18
- "HerbaceousVegetation " ,
19
- "Highway " ,
20
- "Industrial," "Pasture " ,
21
- "PermanentCrop " ,
22
- "Residential " ,
23
- "River " ,
24
- "SeaLake " ,
25
- ) ,
11
+ NAME = "eurosat"
12
+
13
+
14
+ @ register_info ( NAME )
15
+ def _info () -> Dict [ str , Any ]:
16
+ return dict (
17
+ categories = (
18
+ "AnnualCrop " ,
19
+ "Forest " ,
20
+ "HerbaceousVegetation " ,
21
+ "Highway " ,
22
+ "Industrial," "Pasture " ,
23
+ "PermanentCrop " ,
24
+ "Residential " ,
25
+ "River " ,
26
+ "SeaLake" ,
26
27
)
28
+ )
29
+
27
30
28
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
31
+ @register_dataset (NAME )
32
+ class EuroSAT (Dataset2 ):
33
+ """EuroSAT Dataset.
34
+ homepage="https://github.com/phelber/eurosat",
35
+ """
36
+
37
+ def __init__ (self , root : Union [str , pathlib .Path ], * , skip_integrity_check : bool = False ) -> None :
38
+ self ._categories = _info ()["categories" ]
39
+ super ().__init__ (root , skip_integrity_check = skip_integrity_check )
40
+
41
+ def _resources (self ) -> List [OnlineResource ]:
29
42
return [
30
43
HttpResource (
31
44
"https://madm.dfki.de/files/sentinel/EuroSAT.zip" ,
@@ -37,15 +50,16 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
37
50
path , buffer = data
38
51
category = pathlib .Path (path ).parent .name
39
52
return dict (
40
- label = Label .from_category (category , categories = self .categories ),
53
+ label = Label .from_category (category , categories = self ._categories ),
41
54
path = path ,
42
55
image = EncodedImage .from_file (buffer ),
43
56
)
44
57
45
- def _make_datapipe (
46
- self , resource_dps : List [IterDataPipe ], * , config : DatasetConfig
47
- ) -> IterDataPipe [Dict [str , Any ]]:
58
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
48
59
dp = resource_dps [0 ]
49
60
dp = hint_shuffling (dp )
50
61
dp = hint_sharding (dp )
51
62
return Mapper (dp , self ._prepare_sample )
63
+
64
+ def __len__ (self ) -> int :
65
+ return 27_000
0 commit comments