1
1
import pathlib
2
- from typing import Any , Dict , List , Tuple , Iterator , BinaryIO
2
+ from typing import Any , Dict , List , Tuple , Iterator , BinaryIO , Union
3
3
4
4
from torchdata .datapipes .iter import Filter , IterDataPipe , Mapper , Zipper
5
- from torchvision .prototype .datasets .utils import Dataset , DatasetConfig , DatasetInfo , HttpResource , OnlineResource
6
- from torchvision .prototype .datasets .utils ._internal import hint_sharding , hint_shuffling , path_comparator , read_mat
5
+ from torchvision .prototype .datasets .utils import Dataset2 , DatasetInfo , HttpResource , OnlineResource
6
+ from torchvision .prototype .datasets .utils ._internal import (
7
+ hint_sharding ,
8
+ hint_shuffling ,
9
+ path_comparator ,
10
+ read_mat ,
11
+ BUILTIN_DIR ,
12
+ )
7
13
from torchvision .prototype .features import BoundingBox , EncodedImage , Label
8
14
15
+ from .._api import register_dataset , register_info
16
+
9
17
10
18
class StanfordCarsLabelReader (IterDataPipe [Tuple [int , int , int , int , int , str ]]):
11
19
def __init__ (self , datapipe : IterDataPipe [Dict [str , Any ]]) -> None :
@@ -18,16 +26,33 @@ def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]:
18
26
yield tuple (ann ) # type: ignore[misc]
19
27
20
28
21
- class StanfordCars (Dataset ):
22
- def _make_info (self ) -> DatasetInfo :
23
- return DatasetInfo (
24
- name = "stanford-cars" ,
25
- homepage = "https://ai.stanford.edu/~jkrause/cars/car_dataset.html" ,
26
- dependencies = ("scipy" ,),
27
- valid_options = dict (
28
- split = ("test" , "train" ),
29
- ),
30
- )
29
+ NAME = "stanford-cars"
30
+
31
+
32
+ @register_info (NAME )
33
+ def _info () -> Dict [str , Any ]:
34
+ categories = DatasetInfo .read_categories_file (BUILTIN_DIR / f"{ NAME } .categories" )
35
+ categories = [c [0 ] for c in categories ]
36
+ return dict (categories = categories )
37
+
38
+
39
+ @register_dataset (NAME )
40
+ class StanfordCars (Dataset2 ):
41
+ """Stanford Cars dataset.
42
+ homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html",
43
+ dependencies=scipy
44
+ """
45
+
46
+ def __init__ (
47
+ self ,
48
+ root : Union [str , pathlib .Path ],
49
+ * ,
50
+ split : str = "train" ,
51
+ skip_integrity_check : bool = False ,
52
+ ) -> None :
53
+ self ._split = self ._verify_str_arg (split , "split" , {"train" , "test" })
54
+ self ._categories = _info ()["categories" ]
55
+ super ().__init__ (root , skip_integrity_check = skip_integrity_check , dependencies = ("scipy" ,))
31
56
32
57
_URL_ROOT = "https://ai.stanford.edu/~jkrause/"
33
58
_URLS = {
@@ -44,9 +69,9 @@ def _make_info(self) -> DatasetInfo:
44
69
"car_devkit" : "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288" ,
45
70
}
46
71
47
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
48
- resources : List [OnlineResource ] = [HttpResource (self ._URLS [config . split ], sha256 = self ._CHECKSUM [config . split ])]
49
- if config . split == "train" :
72
+ def _resources (self ) -> List [OnlineResource ]:
73
+ resources : List [OnlineResource ] = [HttpResource (self ._URLS [self . _split ], sha256 = self ._CHECKSUM [self . _split ])]
74
+ if self . _split == "train" :
50
75
resources .append (HttpResource (url = self ._URLS ["car_devkit" ], sha256 = self ._CHECKSUM ["car_devkit" ]))
51
76
52
77
else :
@@ -65,32 +90,29 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int,
65
90
return dict (
66
91
path = path ,
67
92
image = image ,
68
- label = Label (target [4 ] - 1 , categories = self .categories ),
93
+ label = Label (target [4 ] - 1 , categories = self ._categories ),
69
94
bounding_box = BoundingBox (target [:4 ], format = "xyxy" , image_size = image .image_size ),
70
95
)
71
96
72
- def _make_datapipe (
73
- self ,
74
- resource_dps : List [IterDataPipe ],
75
- * ,
76
- config : DatasetConfig ,
77
- ) -> IterDataPipe [Dict [str , Any ]]:
97
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
78
98
79
99
images_dp , targets_dp = resource_dps
80
- if config . split == "train" :
100
+ if self . _split == "train" :
81
101
targets_dp = Filter (targets_dp , path_comparator ("name" , "cars_train_annos.mat" ))
82
102
targets_dp = StanfordCarsLabelReader (targets_dp )
83
103
dp = Zipper (images_dp , targets_dp )
84
104
dp = hint_shuffling (dp )
85
105
dp = hint_sharding (dp )
86
106
return Mapper (dp , self ._prepare_sample )
87
107
88
- def _generate_categories (self , root : pathlib .Path ) -> List [str ]:
89
- config = self .info .make_config (split = "train" )
90
- resources = self .resources (config )
108
+ def _generate_categories (self ) -> List [str ]:
109
+ resources = self ._resources ()
91
110
92
- devkit_dp = resources [1 ].load (root )
111
+ devkit_dp = resources [1 ].load (self . _root )
93
112
meta_dp = Filter (devkit_dp , path_comparator ("name" , "cars_meta.mat" ))
94
113
_ , meta_file = next (iter (meta_dp ))
95
114
96
115
return list (read_mat (meta_file , squeeze_me = True )["class_names" ])
116
+
117
+ def __len__ (self ) -> int :
118
+ return 8_144 if self ._split == "train" else 8_041
0 commit comments