Skip to content

Commit f948d79

Browse files
pmeierNicolasHug
andauthored
add CLEVR dataset (#5130)
* add prototype dataset * add old-style dataset * appease mypy * simplify prototype scenes * Update torchvision/datasets/clevr.py Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 7120024 commit f948d79

File tree

6 files changed

+241
-2
lines changed

6 files changed

+241
-2
lines changed

test/test_datasets.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2325,5 +2325,37 @@ def inject_fake_data(self, tmpdir: str, config):
23252325
return total_number_of_examples
23262326

23272327

2328+
class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase):
2329+
DATASET_CLASS = datasets.CLEVRClassification
2330+
FEATURE_TYPES = (PIL.Image.Image, (int, type(None)))
2331+
2332+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
2333+
2334+
def inject_fake_data(self, tmpdir, config):
2335+
data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0"
2336+
2337+
images_folder = data_folder / "images"
2338+
image_files = datasets_utils.create_image_folder(
2339+
images_folder, config["split"], lambda idx: f"CLEVR_{config['split']}_{idx:06d}.png", num_examples=5
2340+
)
2341+
2342+
scenes_folder = data_folder / "scenes"
2343+
scenes_folder.mkdir()
2344+
if config["split"] != "test":
2345+
with open(scenes_folder / f"CLEVR_{config['split']}_scenes.json", "w") as file:
2346+
json.dump(
2347+
dict(
2348+
info=dict(),
2349+
scenes=[
2350+
dict(image_filename=image_file.name, objects=[dict()] * int(torch.randint(10, ())))
2351+
for image_file in image_files
2352+
],
2353+
),
2354+
file,
2355+
)
2356+
2357+
return len(image_files)
2358+
2359+
23282360
if __name__ == "__main__":
23292361
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .celeba import CelebA
44
from .cifar import CIFAR10, CIFAR100
55
from .cityscapes import Cityscapes
6+
from .clevr import CLEVRClassification
67
from .coco import CocoCaptions, CocoDetection
78
from .dtd import DTD
89
from .fakedata import FakeData
@@ -85,4 +86,5 @@
8586
"DTD",
8687
"FER2013",
8788
"GTSRB",
89+
"CLEVRClassification",
8890
)

torchvision/datasets/clevr.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import json
2+
import pathlib
3+
from typing import Any, Callable, Optional, Tuple, List
4+
from urllib.parse import urlparse
5+
6+
from PIL import Image
7+
8+
from .utils import download_and_extract_archive, verify_str_arg
9+
from .vision import VisionDataset
10+
11+
12+
class CLEVRClassification(VisionDataset):
13+
"""`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_ classification dataset.
14+
15+
The number of objects in a scene are used as label.
16+
17+
Args:
18+
root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
19+
set to True.
20+
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
21+
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
22+
version. E.g, ``transforms.RandomCrop``
23+
target_transform (callable, optional): A function/transform that takes in them target and transforms it.
24+
download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
25+
dataset is already downloaded, it is not downloaded again.
26+
"""
27+
28+
_URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
29+
_MD5 = "b11922020e72d0cd9154779b2d3d07d2"
30+
31+
def __init__(
32+
self,
33+
root: str,
34+
split: str = "train",
35+
transform: Optional[Callable] = None,
36+
target_transform: Optional[Callable] = None,
37+
download: bool = True,
38+
) -> None:
39+
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
40+
super().__init__(root, transform=transform, target_transform=target_transform)
41+
self._base_folder = pathlib.Path(self.root) / "clevr"
42+
self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
43+
44+
if download:
45+
self._download()
46+
47+
if not self._check_exists():
48+
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
49+
50+
self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))
51+
52+
self._labels: List[Optional[int]]
53+
if self._split != "test":
54+
with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
55+
content = json.load(file)
56+
num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
57+
self._labels = [num_objects[image_file.name] for image_file in self._image_files]
58+
else:
59+
self._labels = [None] * len(self._image_files)
60+
61+
def __len__(self) -> int:
62+
return len(self._image_files)
63+
64+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
65+
image_file = self._image_files[idx]
66+
label = self._labels[idx]
67+
68+
image = Image.open(image_file).convert("RGB")
69+
70+
if self.transform:
71+
image = self.transform(image)
72+
73+
if self.target_transform:
74+
label = self.target_transform(label)
75+
76+
return image, label
77+
78+
def _check_exists(self) -> bool:
79+
return self._data_folder.exists() and self._data_folder.is_dir()
80+
81+
def _download(self) -> None:
82+
if self._check_exists():
83+
return
84+
85+
download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)
86+
87+
def extra_repr(self) -> str:
88+
return f"split={self._split}"

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .caltech import Caltech101, Caltech256
22
from .celeba import CelebA
33
from .cifar import Cifar10, Cifar100
4+
from .clevr import CLEVR
45
from .coco import Coco
56
from .dtd import DTD
67
from .fer2013 import FER2013
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import functools
2+
import io
3+
import pathlib
4+
from typing import Any, Callable, Dict, List, Optional, Tuple
5+
6+
import torch
7+
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher
8+
from torchvision.prototype.datasets.utils import (
9+
Dataset,
10+
DatasetConfig,
11+
DatasetInfo,
12+
HttpResource,
13+
OnlineResource,
14+
DatasetType,
15+
)
16+
from torchvision.prototype.datasets.utils._internal import (
17+
INFINITE_BUFFER_SIZE,
18+
hint_sharding,
19+
hint_shuffling,
20+
path_comparator,
21+
path_accessor,
22+
getitem,
23+
)
24+
from torchvision.prototype.features import Label
25+
26+
27+
class CLEVR(Dataset):
28+
def _make_info(self) -> DatasetInfo:
29+
return DatasetInfo(
30+
"clevr",
31+
type=DatasetType.IMAGE,
32+
homepage="https://cs.stanford.edu/people/jcjohns/clevr/",
33+
valid_options=dict(split=("train", "val", "test")),
34+
)
35+
36+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
37+
archive = HttpResource(
38+
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip",
39+
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1",
40+
)
41+
return [archive]
42+
43+
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
44+
path = pathlib.Path(data[0])
45+
if path.parents[1].name == "images":
46+
return 0
47+
elif path.parent.name == "scenes":
48+
return 1
49+
else:
50+
return None
51+
52+
def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool:
53+
key, _ = data
54+
return key == "scenes"
55+
56+
def _add_empty_anns(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[str, io.IOBase], None]:
57+
return data, None
58+
59+
def _collate_and_decode_sample(
60+
self,
61+
data: Tuple[Tuple[str, io.IOBase], Optional[Dict[str, Any]]],
62+
*,
63+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
64+
) -> Dict[str, Any]:
65+
image_data, scenes_data = data
66+
path, buffer = image_data
67+
68+
return dict(
69+
path=path,
70+
image=decoder(buffer) if decoder else buffer,
71+
label=Label(len(scenes_data["objects"])) if scenes_data else None,
72+
)
73+
74+
def _make_datapipe(
75+
self,
76+
resource_dps: List[IterDataPipe],
77+
*,
78+
config: DatasetConfig,
79+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
80+
) -> IterDataPipe[Dict[str, Any]]:
81+
archive_dp = resource_dps[0]
82+
images_dp, scenes_dp = Demultiplexer(
83+
archive_dp,
84+
2,
85+
self._classify_archive,
86+
drop_none=True,
87+
buffer_size=INFINITE_BUFFER_SIZE,
88+
)
89+
90+
images_dp = Filter(images_dp, path_comparator("parent.name", config.split))
91+
images_dp = hint_sharding(images_dp)
92+
images_dp = hint_shuffling(images_dp)
93+
94+
if config.split != "test":
95+
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json"))
96+
scenes_dp = JsonParser(scenes_dp)
97+
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes"))
98+
scenes_dp = UnBatcher(scenes_dp)
99+
100+
dp = IterKeyZipper(
101+
images_dp,
102+
scenes_dp,
103+
key_fn=path_accessor("name"),
104+
ref_key_fn=getitem("image_filename"),
105+
buffer_size=INFINITE_BUFFER_SIZE,
106+
)
107+
else:
108+
dp = Mapper(images_dp, self._add_empty_anns)
109+
110+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __iter__(self) -> Iterator[Tuple[int, D]]:
108108
yield from enumerate(self.datapipe, self.start)
109109

110110

111-
def _getitem_closure(obj: Any, *, items: Tuple[Any, ...]) -> Any:
111+
def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any:
112112
for item in items:
113113
obj = obj[item]
114114
return obj
@@ -118,8 +118,14 @@ def getitem(*items: Any) -> Callable[[Any], Any]:
118118
return functools.partial(_getitem_closure, items=items)
119119

120120

121+
def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
122+
for attr in attrs:
123+
obj = getattr(obj, attr)
124+
return obj
125+
126+
121127
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D:
122-
return cast(D, getattr(path, name))
128+
return cast(D, _getattr_closure(path, attrs=name.split(".")))
123129

124130

125131
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:

0 commit comments

Comments
 (0)