Skip to content

Commit 9ea341a

Browse files
authored
Migrate Stanford Cars prototype dataset (#5767)
* Migrate Stanford Cars prototype dataset * Address comments
1 parent ccfcaa5 commit 9ea341a

File tree

3 files changed

+59
-35
lines changed

3 files changed

+59
-35
lines changed

test/builtin_dataset_mocks.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,18 +1473,19 @@ def pcam(root, config):
14731473
return num_images
14741474

14751475

1476-
# @register_mock
1477-
def stanford_cars(info, root, config):
1476+
@register_mock(name="stanford-cars", configs=combinations_grid(split=("train", "test")))
1477+
def stanford_cars(root, config):
14781478
import scipy.io as io
14791479
from numpy.core.records import fromarrays
14801480

1481-
num_samples = {"train": 5, "test": 7}[config["split"]]
1481+
split = config["split"]
1482+
num_samples = {"train": 5, "test": 7}[split]
14821483
num_categories = 3
14831484

14841485
devkit = root / "devkit"
14851486
devkit.mkdir(parents=True)
14861487

1487-
if config["split"] == "train":
1488+
if split == "train":
14881489
images_folder_name = "cars_train"
14891490
annotations_mat_path = devkit / "cars_train_annos.mat"
14901491
else:
@@ -1498,7 +1499,7 @@ def stanford_cars(info, root, config):
14981499
num_examples=num_samples,
14991500
)
15001501

1501-
make_tar(root, f"cars_{config.split}.tgz", images_folder_name)
1502+
make_tar(root, f"cars_{split}.tgz", images_folder_name)
15021503
bbox = np.random.randint(1, 200, num_samples, dtype=np.uint8)
15031504
classes = np.random.randint(1, num_categories + 1, num_samples, dtype=np.uint8)
15041505
fnames = [f"{i:5d}.jpg" for i in range(num_samples)]
@@ -1508,7 +1509,7 @@ def stanford_cars(info, root, config):
15081509
)
15091510

15101511
io.savemat(annotations_mat_path, {"annotations": rec_array})
1511-
if config.split == "train":
1512+
if split == "train":
15121513
make_tar(root, "car_devkit.tgz", devkit, compression="gz")
15131514

15141515
return num_samples

torchvision/prototype/datasets/_builtin/dtd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ def _info() -> Dict[str, Any]:
4141
@register_dataset(NAME)
4242
class DTD(Dataset2):
4343
"""DTD Dataset.
44-
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
44+
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
4545
"""
46+
4647
def __init__(
4748
self,
4849
root: Union[str, pathlib.Path],
Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
import pathlib
2-
from typing import Any, Dict, List, Tuple, Iterator, BinaryIO
2+
from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union
33

44
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper
5-
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
6-
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, path_comparator, read_mat
5+
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
6+
from torchvision.prototype.datasets.utils._internal import (
7+
hint_sharding,
8+
hint_shuffling,
9+
path_comparator,
10+
read_mat,
11+
BUILTIN_DIR,
12+
)
713
from torchvision.prototype.features import BoundingBox, EncodedImage, Label
814

15+
from .._api import register_dataset, register_info
16+
917

1018
class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]):
1119
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None:
@@ -18,16 +26,33 @@ def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]:
1826
yield tuple(ann) # type: ignore[misc]
1927

2028

21-
class StanfordCars(Dataset):
22-
def _make_info(self) -> DatasetInfo:
23-
return DatasetInfo(
24-
name="stanford-cars",
25-
homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html",
26-
dependencies=("scipy",),
27-
valid_options=dict(
28-
split=("test", "train"),
29-
),
30-
)
29+
NAME = "stanford-cars"
30+
31+
32+
@register_info(NAME)
33+
def _info() -> Dict[str, Any]:
34+
categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")
35+
categories = [c[0] for c in categories]
36+
return dict(categories=categories)
37+
38+
39+
@register_dataset(NAME)
40+
class StanfordCars(Dataset2):
41+
"""Stanford Cars dataset.
42+
homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html",
43+
dependencies=scipy
44+
"""
45+
46+
def __init__(
47+
self,
48+
root: Union[str, pathlib.Path],
49+
*,
50+
split: str = "train",
51+
skip_integrity_check: bool = False,
52+
) -> None:
53+
self._split = self._verify_str_arg(split, "split", {"train", "test"})
54+
self._categories = _info()["categories"]
55+
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",))
3156

3257
_URL_ROOT = "https://ai.stanford.edu/~jkrause/"
3358
_URLS = {
@@ -44,9 +69,9 @@ def _make_info(self) -> DatasetInfo:
4469
"car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288",
4570
}
4671

47-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
48-
resources: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUM[config.split])]
49-
if config.split == "train":
72+
def _resources(self) -> List[OnlineResource]:
73+
resources: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUM[self._split])]
74+
if self._split == "train":
5075
resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"]))
5176

5277
else:
@@ -65,32 +90,29 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int,
6590
return dict(
6691
path=path,
6792
image=image,
68-
label=Label(target[4] - 1, categories=self.categories),
93+
label=Label(target[4] - 1, categories=self._categories),
6994
bounding_box=BoundingBox(target[:4], format="xyxy", image_size=image.image_size),
7095
)
7196

72-
def _make_datapipe(
73-
self,
74-
resource_dps: List[IterDataPipe],
75-
*,
76-
config: DatasetConfig,
77-
) -> IterDataPipe[Dict[str, Any]]:
97+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
7898

7999
images_dp, targets_dp = resource_dps
80-
if config.split == "train":
100+
if self._split == "train":
81101
targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat"))
82102
targets_dp = StanfordCarsLabelReader(targets_dp)
83103
dp = Zipper(images_dp, targets_dp)
84104
dp = hint_shuffling(dp)
85105
dp = hint_sharding(dp)
86106
return Mapper(dp, self._prepare_sample)
87107

88-
def _generate_categories(self, root: pathlib.Path) -> List[str]:
89-
config = self.info.make_config(split="train")
90-
resources = self.resources(config)
108+
def _generate_categories(self) -> List[str]:
109+
resources = self._resources()
91110

92-
devkit_dp = resources[1].load(root)
111+
devkit_dp = resources[1].load(self._root)
93112
meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat"))
94113
_, meta_file = next(iter(meta_dp))
95114

96115
return list(read_mat(meta_file, squeeze_me=True)["class_names"])
116+
117+
def __len__(self) -> int:
118+
return 8_144 if self._split == "train" else 8_041

0 commit comments

Comments
 (0)