Skip to content

Commit 2612c4c

Browse files
authored
migrate CelebA prototype dataset (#5750)
* migrate CelebA prototype dataset * inline split_id
1 parent 217616b commit 2612c4c

File tree

2 files changed

+48
-31
lines changed

2 files changed

+48
-31
lines changed

test/builtin_dataset_mocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -905,9 +905,9 @@ def generate(cls, root):
905905
return num_samples_map
906906

907907

908-
# @register_mock
909-
def celeba(info, root, config):
910-
return CelebAMockData.generate(root)[config.split]
908+
@register_mock(configs=combinations_grid(split=("train", "val", "test")))
909+
def celeba(root, config):
910+
return CelebAMockData.generate(root)[config["split"]]
911911

912912

913913
@register_mock(configs=combinations_grid(split=("train", "val", "test")))

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import csv
2-
import functools
3-
from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO
2+
import pathlib
3+
from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO, Union
44

55
from torchdata.datapipes.iter import (
66
IterDataPipe,
@@ -10,9 +10,7 @@
1010
IterKeyZipper,
1111
)
1212
from torchvision.prototype.datasets.utils import (
13-
Dataset,
14-
DatasetConfig,
15-
DatasetInfo,
13+
Dataset2,
1614
GDriveResource,
1715
OnlineResource,
1816
)
@@ -25,6 +23,7 @@
2523
)
2624
from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox
2725

26+
from .._api import register_dataset, register_info
2827

2928
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
3029

@@ -60,15 +59,32 @@ def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
6059
yield line.pop("image_id"), line
6160

6261

63-
class CelebA(Dataset):
64-
def _make_info(self) -> DatasetInfo:
65-
return DatasetInfo(
66-
"celeba",
67-
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
68-
valid_options=dict(split=("train", "val", "test")),
69-
)
62+
NAME = "celeba"
63+
64+
65+
@register_info(NAME)
66+
def _info() -> Dict[str, Any]:
67+
return dict()
68+
7069

71-
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
70+
@register_dataset(NAME)
71+
class CelebA(Dataset2):
72+
"""
73+
- **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
74+
"""
75+
76+
def __init__(
77+
self,
78+
root: Union[str, pathlib.Path],
79+
*,
80+
split: str = "train",
81+
skip_integrity_check: bool = False,
82+
) -> None:
83+
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
84+
85+
super().__init__(root, skip_integrity_check=skip_integrity_check)
86+
87+
def _resources(self) -> List[OnlineResource]:
7288
splits = GDriveResource(
7389
"0B7EVK8r0v71pY0NSMzRuSXJEVkk",
7490
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
@@ -101,14 +117,13 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
101117
)
102118
return [splits, images, identities, attributes, bounding_boxes, landmarks]
103119

104-
_SPLIT_ID_TO_NAME = {
105-
"0": "train",
106-
"1": "val",
107-
"2": "test",
108-
}
109-
110-
def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool:
111-
return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split
120+
def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool:
121+
split_id = {
122+
"train": "0",
123+
"val": "1",
124+
"test": "2",
125+
}[self._split]
126+
return data[1]["split_id"] == split_id
112127

113128
def _prepare_sample(
114129
self,
@@ -145,16 +160,11 @@ def _prepare_sample(
145160
},
146161
)
147162

148-
def _make_datapipe(
149-
self,
150-
resource_dps: List[IterDataPipe],
151-
*,
152-
config: DatasetConfig,
153-
) -> IterDataPipe[Dict[str, Any]]:
163+
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
154164
splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps
155165

156166
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
157-
splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split))
167+
splits_dp = Filter(splits_dp, self._filter_split)
158168
splits_dp = hint_shuffling(splits_dp)
159169
splits_dp = hint_sharding(splits_dp)
160170

@@ -186,3 +196,10 @@ def _make_datapipe(
186196
buffer_size=INFINITE_BUFFER_SIZE,
187197
)
188198
return Mapper(dp, self._prepare_sample)
199+
200+
def __len__(self) -> int:
201+
return {
202+
"train": 162_770,
203+
"val": 19_867,
204+
"test": 19_962,
205+
}[self._split]

0 commit comments

Comments
 (0)