Skip to content

Commit 8cc6d52

Browse files
authored
add prototype dataset for CelebA (#4514)
* add prototype dataset for CelebA * fix code format * fix mypy * hardcode fmtparams * fix mypy * replace KeyZipper with Zipper for annotations
1 parent 493d301 commit 8cc6d52

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .caltech import Caltech101, Caltech256
2+
from .celeba import CelebA
23
from .cifar import Cifar10, Cifar100
34
from .sbd import SBD
45
from .voc import VOC
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import csv
2+
import io
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Mapping, Union
4+
5+
import torch
6+
from torch.utils.data import IterDataPipe
7+
from torch.utils.data.datapipes.iter import (
8+
Mapper,
9+
Shuffler,
10+
Filter,
11+
ZipArchiveReader,
12+
Zipper,
13+
)
14+
from torchdata.datapipes.iter import KeyZipper
15+
from torchvision.prototype.datasets.utils import (
16+
Dataset,
17+
DatasetConfig,
18+
DatasetInfo,
19+
GDriveResource,
20+
OnlineResource,
21+
DatasetType,
22+
)
23+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor
24+
25+
26+
class CelebACSVParser(IterDataPipe):
27+
def __init__(
28+
self,
29+
datapipe,
30+
*,
31+
has_header,
32+
):
33+
self.datapipe = datapipe
34+
self.has_header = has_header
35+
self._fmtparams = dict(delimiter=" ", skipinitialspace=True)
36+
37+
def __iter__(self):
38+
for _, file in self.datapipe:
39+
file = (line.decode() for line in file)
40+
41+
if self.has_header:
42+
# The first row is skipped, because it only contains the number of samples
43+
next(file)
44+
45+
# Empty field names are filtered out, because some files have an extr white space after the header
46+
# line, which is recognized as extra column
47+
fieldnames = [name for name in next(csv.reader([next(file)], **self._fmtparams)) if name]
48+
# Some files do not include a label for the image ID column
49+
if fieldnames[0] != "image_id":
50+
fieldnames.insert(0, "image_id")
51+
52+
for line in csv.DictReader(file, fieldnames=fieldnames, **self._fmtparams):
53+
yield line.pop("image_id"), line
54+
else:
55+
for line in csv.reader(file, **self._fmtparams):
56+
yield line[0], line[1:]
57+
58+
59+
class CelebA(Dataset):
60+
@property
61+
def info(self) -> DatasetInfo:
62+
return DatasetInfo(
63+
"celeba",
64+
type=DatasetType.IMAGE,
65+
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
66+
)
67+
68+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
69+
splits = GDriveResource(
70+
"0B7EVK8r0v71pY0NSMzRuSXJEVkk",
71+
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
72+
file_name="list_eval_partition.txt",
73+
)
74+
images = GDriveResource(
75+
"0B7EVK8r0v71pZjFTYXZWM3FlRnM",
76+
sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74",
77+
file_name="img_align_celeba.zip",
78+
)
79+
identities = GDriveResource(
80+
"1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS",
81+
sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0",
82+
file_name="identity_CelebA.txt",
83+
)
84+
attributes = GDriveResource(
85+
"0B7EVK8r0v71pblRyaVFSWGxPY0U",
86+
sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0",
87+
file_name="list_attr_celeba.txt",
88+
)
89+
bboxes = GDriveResource(
90+
"0B7EVK8r0v71pbThiMVRxWXZ4dU0",
91+
sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b",
92+
file_name="list_bbox_celeba.txt",
93+
)
94+
landmarks = GDriveResource(
95+
"0B7EVK8r0v71pd0FJY3Blby1HUTQ",
96+
sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b",
97+
file_name="list_landmarks_align_celeba.txt",
98+
)
99+
return [splits, images, identities, attributes, bboxes, landmarks]
100+
101+
_SPLIT_ID_TO_NAME = {
102+
"0": "train",
103+
"1": "valid",
104+
"2": "test",
105+
}
106+
107+
def _filter_split(self, data: Tuple[str, str], *, split):
108+
_, split_id = data
109+
return self._SPLIT_ID_TO_NAME[split_id[0]] == split
110+
111+
def _collate_anns(
112+
self, data: Tuple[Tuple[str, Union[List[str], Mapping[str, str]]], ...]
113+
) -> Tuple[str, Dict[str, Union[List[str], Mapping[str, str]]]]:
114+
(image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data
115+
return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks)
116+
117+
def _collate_and_decode_sample(
118+
self,
119+
data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, io.IOBase]], Tuple[str, Dict[str, Any]]],
120+
*,
121+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
122+
) -> Dict[str, Any]:
123+
split_and_image_data, ann_data = data
124+
_, _, image_data = split_and_image_data
125+
path, buffer = image_data
126+
_, ann = ann_data
127+
128+
image = decoder(buffer) if decoder else buffer
129+
130+
identity = torch.tensor(int(ann["identity"][0]))
131+
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
132+
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
133+
landmarks = {
134+
landmark: torch.tensor((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"])))
135+
for landmark in {key[:-2] for key in ann["landmarks"].keys()}
136+
}
137+
138+
return dict(
139+
path=path,
140+
image=image,
141+
identity=identity,
142+
attributes=attributes,
143+
bbox=bbox,
144+
landmarks=landmarks,
145+
)
146+
147+
def _make_datapipe(
148+
self,
149+
resource_dps: List[IterDataPipe],
150+
*,
151+
config: DatasetConfig,
152+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
153+
) -> IterDataPipe[Dict[str, Any]]:
154+
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps
155+
156+
splits_dp = CelebACSVParser(splits_dp, has_header=False)
157+
splits_dp: IterDataPipe = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
158+
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
159+
160+
images_dp = ZipArchiveReader(images_dp)
161+
162+
anns_dp: IterDataPipe = Zipper(
163+
*[
164+
CelebACSVParser(dp, has_header=has_header)
165+
for dp, has_header in (
166+
(identities_dp, False),
167+
(attributes_dp, True),
168+
(bboxes_dp, True),
169+
(landmarks_dp, True),
170+
)
171+
]
172+
)
173+
anns_dp: IterDataPipe = Mapper(anns_dp, self._collate_anns)
174+
175+
dp = KeyZipper(
176+
splits_dp,
177+
images_dp,
178+
key_fn=getitem(0),
179+
ref_key_fn=path_accessor("name"),
180+
buffer_size=INFINITE_BUFFER_SIZE,
181+
keep_key=True,
182+
)
183+
dp = KeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
184+
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

0 commit comments

Comments
 (0)