|
1 | 1 | import csv
|
2 |
| -import functools |
3 |
| -from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO |
| 2 | +import pathlib |
| 3 | +from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO, Union |
4 | 4 |
|
5 | 5 | from torchdata.datapipes.iter import (
|
6 | 6 | IterDataPipe,
|
|
10 | 10 | IterKeyZipper,
|
11 | 11 | )
|
12 | 12 | from torchvision.prototype.datasets.utils import (
|
13 |
| - Dataset, |
14 |
| - DatasetConfig, |
15 |
| - DatasetInfo, |
| 13 | + Dataset2, |
16 | 14 | GDriveResource,
|
17 | 15 | OnlineResource,
|
18 | 16 | )
|
|
25 | 23 | )
|
26 | 24 | from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox
|
27 | 25 |
|
| 26 | +from .._api import register_dataset, register_info |
28 | 27 |
|
29 | 28 | csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
|
30 | 29 |
|
@@ -60,15 +59,32 @@ def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
|
60 | 59 | yield line.pop("image_id"), line
|
61 | 60 |
|
62 | 61 |
|
63 |
| -class CelebA(Dataset): |
64 |
| - def _make_info(self) -> DatasetInfo: |
65 |
| - return DatasetInfo( |
66 |
| - "celeba", |
67 |
| - homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html", |
68 |
| - valid_options=dict(split=("train", "val", "test")), |
69 |
| - ) |
| 62 | +NAME = "celeba" |
| 63 | + |
| 64 | + |
| 65 | +@register_info(NAME) |
| 66 | +def _info() -> Dict[str, Any]: |
| 67 | + return dict() |
| 68 | + |
70 | 69 |
|
71 |
| - def resources(self, config: DatasetConfig) -> List[OnlineResource]: |
| 70 | +@register_dataset(NAME) |
| 71 | +class CelebA(Dataset2): |
| 72 | + """ |
| 73 | + - **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html |
| 74 | + """ |
| 75 | + |
| 76 | + def __init__( |
| 77 | + self, |
| 78 | + root: Union[str, pathlib.Path], |
| 79 | + *, |
| 80 | + split: str = "train", |
| 81 | + skip_integrity_check: bool = False, |
| 82 | + ) -> None: |
| 83 | + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) |
| 84 | + |
| 85 | + super().__init__(root, skip_integrity_check=skip_integrity_check) |
| 86 | + |
| 87 | + def _resources(self) -> List[OnlineResource]: |
72 | 88 | splits = GDriveResource(
|
73 | 89 | "0B7EVK8r0v71pY0NSMzRuSXJEVkk",
|
74 | 90 | sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
|
@@ -101,14 +117,13 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
|
101 | 117 | )
|
102 | 118 | return [splits, images, identities, attributes, bounding_boxes, landmarks]
|
103 | 119 |
|
104 |
| - _SPLIT_ID_TO_NAME = { |
105 |
| - "0": "train", |
106 |
| - "1": "val", |
107 |
| - "2": "test", |
108 |
| - } |
109 |
| - |
110 |
| - def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: |
111 |
| - return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split |
| 120 | + def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool: |
| 121 | + split_id = { |
| 122 | + "train": "0", |
| 123 | + "val": "1", |
| 124 | + "test": "2", |
| 125 | + }[self._split] |
| 126 | + return data[1]["split_id"] == split_id |
112 | 127 |
|
113 | 128 | def _prepare_sample(
|
114 | 129 | self,
|
@@ -145,16 +160,11 @@ def _prepare_sample(
|
145 | 160 | },
|
146 | 161 | )
|
147 | 162 |
|
148 |
| - def _make_datapipe( |
149 |
| - self, |
150 |
| - resource_dps: List[IterDataPipe], |
151 |
| - *, |
152 |
| - config: DatasetConfig, |
153 |
| - ) -> IterDataPipe[Dict[str, Any]]: |
| 163 | + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: |
154 | 164 | splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps
|
155 | 165 |
|
156 | 166 | splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
|
157 |
| - splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) |
| 167 | + splits_dp = Filter(splits_dp, self._filter_split) |
158 | 168 | splits_dp = hint_shuffling(splits_dp)
|
159 | 169 | splits_dp = hint_sharding(splits_dp)
|
160 | 170 |
|
@@ -186,3 +196,10 @@ def _make_datapipe(
|
186 | 196 | buffer_size=INFINITE_BUFFER_SIZE,
|
187 | 197 | )
|
188 | 198 | return Mapper(dp, self._prepare_sample)
|
| 199 | + |
| 200 | + def __len__(self) -> int: |
| 201 | + return { |
| 202 | + "train": 162_770, |
| 203 | + "val": 19_867, |
| 204 | + "test": 19_962, |
| 205 | + }[self._split] |
0 commit comments