1
- import os
2
- import os .path
1
+ import pathlib
3
2
from typing import Callable , Optional , Any , Tuple
4
3
5
4
from PIL import Image
6
5
7
- from .utils import download_and_extract_archive , download_url
6
+ from .utils import download_and_extract_archive , download_url , verify_str_arg
8
7
from .vision import VisionDataset
9
8
10
9
11
10
class StanfordCars (VisionDataset ):
12
11
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
13
12
14
- .. warning::
13
+ The Cars dataset contains 16,185 images of 196 classes of cars. The data is
14
+ split into 8,144 training images and 8,041 testing images, where each class
15
+ has been split roughly in a 50-50 split
16
+
17
+ .. note::
15
18
16
19
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
17
20
18
21
Args:
19
22
root (string): Root directory of dataset
20
- train (bool , optional):If True, creates dataset from training set, otherwise creates from test set
23
+ split (string , optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
21
24
transform (callable, optional): A function/transform that takes in an PIL image
22
25
and returns a transformed version. E.g, ``transforms.RandomCrop``
23
26
target_transform (callable, optional): A function/transform that takes in the
@@ -26,30 +29,10 @@ class StanfordCars(VisionDataset):
26
29
puts it in root directory. If dataset is already downloaded, it is not
27
30
downloaded again."""
28
31
29
- urls = (
30
- "https://ai.stanford.edu/~jkrause/car196/cars_test.tgz" ,
31
- "https://ai.stanford.edu/~jkrause/car196/cars_train.tgz" ,
32
- ) # test and train image urls
33
-
34
- md5s = (
35
- "4ce7ebf6a94d07f1952d94dd34c4d501" ,
36
- "065e5b463ae28d29e77c1b4b166cfe61" ,
37
- ) # md5checksum for test and train data
38
-
39
- annot_urls = (
40
- "https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat" ,
41
- "https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz" ,
42
- ) # annotations and labels for test and train
43
-
44
- annot_md5s = (
45
- "b0a2b23655a3edd16d84508592a98d10" ,
46
- "c3b158d763b6e2245038c8ad08e45376" ,
47
- ) # md5 checksum for annotations
48
-
49
32
def __init__ (
50
33
self ,
51
34
root : str ,
52
- train : bool = True ,
35
+ split : str = "train" ,
53
36
transform : Optional [Callable ] = None ,
54
37
target_transform : Optional [Callable ] = None ,
55
38
download : bool = False ,
@@ -62,7 +45,16 @@ def __init__(
62
45
63
46
super ().__init__ (root , transform = transform , target_transform = target_transform )
64
47
65
- self .train = train
48
+ self ._split = verify_str_arg (split , "split" , ("train" , "test" ))
49
+ self ._base_folder = pathlib .Path (root ) / "stanford_cars"
50
+ devkit = self ._base_folder / "devkit"
51
+
52
+ if self ._split == "train" :
53
+ self ._annotations_mat_path = devkit / "cars_train_annos.mat"
54
+ self ._images_base_path = self ._base_folder / "cars_train"
55
+ else :
56
+ self ._annotations_mat_path = self ._base_folder / "cars_test_annos_withlabels.mat"
57
+ self ._images_base_path = self ._base_folder / "cars_test"
66
58
67
59
if download :
68
60
self .download ()
@@ -72,22 +64,13 @@ def __init__(
72
64
73
65
self ._samples = [
74
66
(
75
- os .path .join (self .root , f"cars_{ 'train' if self .train else 'test' } " , annotation ["fname" ]),
76
- annotation ["class" ] - 1 ,
77
- # Beware stanford cars target mapping starts from 1
67
+ str (self ._images_base_path / annotation ["fname" ]),
68
+ annotation ["class" ] - 1 , # Original target mapping starts from 1, hence -1
78
69
)
79
- for annotation in sio .loadmat (
80
- os .path .join (
81
- self .root ,
82
- * ["devkit" , "cars_train_annos.mat" ] if self .train else ["cars_test_annos_withlabels.mat" ],
83
- ),
84
- squeeze_me = True ,
85
- )["annotations" ]
70
+ for annotation in sio .loadmat (self ._annotations_mat_path , squeeze_me = True )["annotations" ]
86
71
]
87
72
88
- self .classes = sio .loadmat (os .path .join (self .root , "devkit" , "cars_meta.mat" ), squeeze_me = True )[
89
- "class_names"
90
- ].tolist ()
73
+ self .classes = sio .loadmat (str (devkit / "cars_meta.mat" ), squeeze_me = True )["class_names" ].tolist ()
91
74
self .class_to_idx = {cls : i for i , cls in enumerate (self .classes )}
92
75
93
76
def __len__ (self ) -> int :
@@ -108,20 +91,31 @@ def download(self) -> None:
108
91
if self ._check_exists ():
109
92
return
110
93
111
- download_and_extract_archive (url = self .urls [self .train ], download_root = self .root , md5 = self .md5s [self .train ])
112
- download_and_extract_archive (url = self .annot_urls [1 ], download_root = self .root , md5 = self .annot_md5s [1 ])
113
- if not self .train :
94
+ download_and_extract_archive (
95
+ url = "https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz" ,
96
+ download_root = self ._base_folder ,
97
+ md5 = "c3b158d763b6e2245038c8ad08e45376" ,
98
+ )
99
+ if self ._split == "train" :
100
+ download_and_extract_archive (
101
+ url = "https://ai.stanford.edu/~jkrause/car196/cars_train.tgz" ,
102
+ download_root = self ._base_folder ,
103
+ md5 = "065e5b463ae28d29e77c1b4b166cfe61" ,
104
+ )
105
+ else :
106
+ download_and_extract_archive (
107
+ url = "https://ai.stanford.edu/~jkrause/car196/cars_test.tgz" ,
108
+ download_root = self ._base_folder ,
109
+ md5 = "4ce7ebf6a94d07f1952d94dd34c4d501" ,
110
+ )
114
111
download_url (
115
- url = self . annot_urls [ 0 ] ,
116
- root = self .root ,
117
- md5 = self . annot_md5s [ 0 ] ,
112
+ url = "https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat" ,
113
+ root = self ._base_folder ,
114
+ md5 = "b0a2b23655a3edd16d84508592a98d10" ,
118
115
)
119
116
120
117
def _check_exists (self ) -> bool :
121
- return (
122
- os .path .exists (os .path .join (self .root , f"cars_{ 'train' if self .train else 'test' } " ))
123
- and os .path .isdir (os .path .join (self .root , f"cars_{ 'train' if self .train else 'test' } " ))
124
- and os .path .exists (os .path .join (self .root , "devkit" , "cars_meta.mat" ))
125
- if self .train
126
- else os .path .exists (os .path .join (self .root , "cars_test_annos_withlabels.mat" ))
127
- )
118
+ if not (self ._base_folder / "devkit" ).is_dir ():
119
+ return False
120
+
121
+ return self ._annotations_mat_path .exists () and self ._images_base_path .is_dir ()
0 commit comments