Skip to content

Commit 8fceb0b

Browse files
committed
Some cleanups
1 parent cbd3d9b commit 8fceb0b

File tree

2 files changed

+74
-90
lines changed

2 files changed

+74
-90
lines changed

test/test_datasets.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,53 +2517,43 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
25172517
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
25182518
DATASET_CLASS = datasets.StanfordCars
25192519
REQUIRED_PACKAGES = ("scipy",)
2520-
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
2520+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
25212521

2522-
def _inject_fake_data(self, tmpdir, config):
2522+
def inject_fake_data(self, tmpdir, config):
25232523
import scipy.io as io
25242524
from numpy.core.records import fromarrays
25252525

2526-
train = config["train"]
2527-
num_examples = 5
2528-
root_folder = tmpdir
2526+
num_examples = {"train": 5, "test": 7}[config["split"]]
2527+
num_classes = 3
2528+
base_folder = pathlib.Path(tmpdir) / "stanford_cars"
25292529

2530-
class_name = np.random.randint(0, 100, num_examples, dtype=np.uint8)
2531-
bbox_x1 = np.random.randint(0, 100, num_examples, dtype=np.uint8)
2532-
bbox_x2 = np.random.randint(0, 100, num_examples, dtype=np.uint8)
2530+
devkit = base_folder / "devkit"
2531+
devkit.mkdir(parents=True)
25332532

2534-
bbox_y1 = np.random.randint(0, 100, num_examples, dtype=np.uint8)
2535-
bb1ox_y2 = np.random.randint(0, 100, num_examples, dtype=np.uint8)
2536-
fname = [f"{i:5d}.jpg" for i in range(num_examples)]
2533+
if config["split"] == "train":
2534+
images_folder_name = "cars_train"
2535+
annotations_mat_path = str(devkit / "cars_train_annos.mat")
2536+
else:
2537+
images_folder_name = "cars_test"
2538+
annotations_mat_path = str(base_folder / "cars_test_annos_withlabels.mat")
25372539

2538-
rec_array = fromarrays(
2539-
[bbox_x1, bbox_y1, bbox_x2, bb1ox_y2, class_name, fname],
2540-
names=["bbox_x1", "bbox_y1", "bbox_x2", "bbox_y2", "class", "fname"],
2540+
datasets_utils.create_image_folder(
2541+
root=base_folder,
2542+
name=images_folder_name,
2543+
file_name_fn=lambda image_index: f"{image_index:5d}.jpg",
2544+
num_examples=num_examples,
25412545
)
2542-
devkit = os.path.join(root_folder, "devkit")
2543-
os.makedirs(devkit)
2544-
2545-
random_class_names = ["Tesla Model S Sedan 2012"] * 196
2546-
2547-
io.savemat(os.path.join(devkit, "cars_meta.mat"), {"class_names": random_class_names})
25482546

2549-
if train:
2550-
datasets_utils.create_image_folder(
2551-
root=root_folder,
2552-
name="cars_train",
2553-
file_name_fn=lambda image_index: f"{image_index:5d}.jpg",
2554-
num_examples=num_examples,
2555-
)
2556-
2557-
io.savemat(f"{devkit}/cars_train_annos.mat", {"annotations": rec_array})
2558-
else:
2547+
classes = np.random.randint(1, num_classes + 1, num_examples, dtype=np.uint8)
2548+
fnames = [f"{i:5d}.jpg" for i in range(num_examples)]
2549+
rec_array = fromarrays(
2550+
[classes, fnames],
2551+
names=["class", "fname"],
2552+
)
2553+
io.savemat(annotations_mat_path, {"annotations": rec_array})
25592554

2560-
datasets_utils.create_image_folder(
2561-
root=root_folder,
2562-
name="cars_test",
2563-
file_name_fn=lambda image_index: f"{image_index:5d}.jpg",
2564-
num_examples=num_examples,
2565-
)
2566-
io.savemat(f"{root_folder}/cars_test_annos_withlabels.mat", {"annotations": rec_array})
2555+
random_class_names = ["random_name"] * num_classes
2556+
io.savemat(str(devkit / "cars_meta.mat"), {"class_names": random_class_names})
25672557

25682558
return num_examples
25692559

torchvision/datasets/stanford_cars.py

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
1-
import os
2-
import os.path
1+
import pathlib
32
from typing import Callable, Optional, Any, Tuple
43

54
from PIL import Image
65

7-
from .utils import download_and_extract_archive, download_url
6+
from .utils import download_and_extract_archive, download_url, verify_str_arg
87
from .vision import VisionDataset
98

109

1110
class StanfordCars(VisionDataset):
1211
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
1312
14-
.. warning::
13+
The Cars dataset contains 16,185 images of 196 classes of cars. The data is
14+
split into 8,144 training images and 8,041 testing images, where each class
15+
has been split roughly in a 50-50 split
16+
17+
.. note::
1518
1619
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
1720
1821
Args:
1922
root (string): Root directory of dataset
20-
train (bool, optional):If True, creates dataset from training set, otherwise creates from test set
23+
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
2124
transform (callable, optional): A function/transform that takes in an PIL image
2225
and returns a transformed version. E.g, ``transforms.RandomCrop``
2326
target_transform (callable, optional): A function/transform that takes in the
@@ -26,30 +29,10 @@ class StanfordCars(VisionDataset):
2629
puts it in root directory. If dataset is already downloaded, it is not
2730
downloaded again."""
2831

29-
urls = (
30-
"https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
31-
"https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
32-
) # test and train image urls
33-
34-
md5s = (
35-
"4ce7ebf6a94d07f1952d94dd34c4d501",
36-
"065e5b463ae28d29e77c1b4b166cfe61",
37-
) # md5checksum for test and train data
38-
39-
annot_urls = (
40-
"https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
41-
"https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
42-
) # annotations and labels for test and train
43-
44-
annot_md5s = (
45-
"b0a2b23655a3edd16d84508592a98d10",
46-
"c3b158d763b6e2245038c8ad08e45376",
47-
) # md5 checksum for annotations
48-
4932
def __init__(
5033
self,
5134
root: str,
52-
train: bool = True,
35+
split: str = "train",
5336
transform: Optional[Callable] = None,
5437
target_transform: Optional[Callable] = None,
5538
download: bool = False,
@@ -62,7 +45,16 @@ def __init__(
6245

6346
super().__init__(root, transform=transform, target_transform=target_transform)
6447

65-
self.train = train
48+
self._split = verify_str_arg(split, "split", ("train", "test"))
49+
self._base_folder = pathlib.Path(root) / "stanford_cars"
50+
devkit = self._base_folder / "devkit"
51+
52+
if self._split == "train":
53+
self._annotations_mat_path = devkit / "cars_train_annos.mat"
54+
self._images_base_path = self._base_folder / "cars_train"
55+
else:
56+
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
57+
self._images_base_path = self._base_folder / "cars_test"
6658

6759
if download:
6860
self.download()
@@ -72,22 +64,13 @@ def __init__(
7264

7365
self._samples = [
7466
(
75-
os.path.join(self.root, f"cars_{'train' if self.train else 'test'}", annotation["fname"]),
76-
annotation["class"] - 1,
77-
# Beware stanford cars target mapping starts from 1
67+
str(self._images_base_path / annotation["fname"]),
68+
annotation["class"] - 1, # Original target mapping starts from 1, hence -1
7869
)
79-
for annotation in sio.loadmat(
80-
os.path.join(
81-
self.root,
82-
*["devkit", "cars_train_annos.mat"] if self.train else ["cars_test_annos_withlabels.mat"],
83-
),
84-
squeeze_me=True,
85-
)["annotations"]
70+
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
8671
]
8772

88-
self.classes = sio.loadmat(os.path.join(self.root, "devkit", "cars_meta.mat"), squeeze_me=True)[
89-
"class_names"
90-
].tolist()
73+
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
9174
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
9275

9376
def __len__(self) -> int:
@@ -108,20 +91,31 @@ def download(self) -> None:
10891
if self._check_exists():
10992
return
11093

111-
download_and_extract_archive(url=self.urls[self.train], download_root=self.root, md5=self.md5s[self.train])
112-
download_and_extract_archive(url=self.annot_urls[1], download_root=self.root, md5=self.annot_md5s[1])
113-
if not self.train:
94+
download_and_extract_archive(
95+
url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
96+
download_root=self._base_folder,
97+
md5="c3b158d763b6e2245038c8ad08e45376",
98+
)
99+
if self._split == "train":
100+
download_and_extract_archive(
101+
url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
102+
download_root=self._base_folder,
103+
md5="065e5b463ae28d29e77c1b4b166cfe61",
104+
)
105+
else:
106+
download_and_extract_archive(
107+
url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
108+
download_root=self._base_folder,
109+
md5="4ce7ebf6a94d07f1952d94dd34c4d501",
110+
)
114111
download_url(
115-
url=self.annot_urls[0],
116-
root=self.root,
117-
md5=self.annot_md5s[0],
112+
url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
113+
root=self._base_folder,
114+
md5="b0a2b23655a3edd16d84508592a98d10",
118115
)
119116

120117
def _check_exists(self) -> bool:
121-
return (
122-
os.path.exists(os.path.join(self.root, f"cars_{'train' if self.train else 'test'}"))
123-
and os.path.isdir(os.path.join(self.root, f"cars_{'train' if self.train else 'test'}"))
124-
and os.path.exists(os.path.join(self.root, "devkit", "cars_meta.mat"))
125-
if self.train
126-
else os.path.exists(os.path.join(self.root, "cars_test_annos_withlabels.mat"))
127-
)
118+
if not (self._base_folder / "devkit").is_dir():
119+
return False
120+
121+
return self._annotations_mat_path.exists() and self._images_base_path.is_dir()

0 commit comments

Comments
 (0)