|
| 1 | +import abc |
| 2 | +import io |
| 3 | +import os |
| 4 | +import pathlib |
| 5 | +import textwrap |
| 6 | +from collections import Mapping |
| 7 | +from typing import ( |
| 8 | + Any, |
| 9 | + Callable, |
| 10 | + Dict, |
| 11 | + List, |
| 12 | + Optional, |
| 13 | + Sequence, |
| 14 | + Union, |
| 15 | + NoReturn, |
| 16 | + Iterable, |
| 17 | + Tuple, |
| 18 | +) |
| 19 | + |
| 20 | +import torch |
| 21 | +from torch.utils.data import IterDataPipe |
| 22 | + |
| 23 | +from torchvision.prototype.datasets.utils._internal import ( |
| 24 | + add_suggestion, |
| 25 | + sequence_to_str, |
| 26 | +) |
| 27 | +from ._resource import OnlineResource |
| 28 | + |
| 29 | + |
| 30 | +def make_repr(name: str, items: Iterable[Tuple[str, Any]]): |
| 31 | + def to_str(sep: str) -> str: |
| 32 | + return sep.join([f"{key}={value}" for key, value in items]) |
| 33 | + |
| 34 | + prefix = f"{name}(" |
| 35 | + postfix = ")" |
| 36 | + body = to_str(", ") |
| 37 | + |
| 38 | + line_length = int(os.environ.get("COLUMNS", 80)) |
| 39 | + body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length |
| 40 | + multiline_body = len(str(body).splitlines()) > 1 |
| 41 | + if not (body_too_long or multiline_body): |
| 42 | + return prefix + body + postfix |
| 43 | + |
| 44 | + body = textwrap.indent(to_str(",\n"), " " * 2) |
| 45 | + return f"{prefix}\n{body}\n{postfix}" |
| 46 | + |
| 47 | + |
| 48 | +class DatasetConfig(Mapping): |
| 49 | + def __init__(self, *args, **kwargs): |
| 50 | + data = dict(*args, **kwargs) |
| 51 | + self.__dict__["__data__"] = data |
| 52 | + self.__dict__["__final_hash__"] = hash(tuple(data.items())) |
| 53 | + |
| 54 | + def __getitem__(self, name: str) -> Any: |
| 55 | + return self.__dict__["__data__"][name] |
| 56 | + |
| 57 | + def __iter__(self): |
| 58 | + return iter(self.__dict__["__data__"].keys()) |
| 59 | + |
| 60 | + def __len__(self): |
| 61 | + return len(self.__dict__["__data__"]) |
| 62 | + |
| 63 | + def __getattr__(self, name: str) -> Any: |
| 64 | + try: |
| 65 | + return self[name] |
| 66 | + except KeyError as error: |
| 67 | + raise AttributeError( |
| 68 | + f"'{type(self).__name__}' object has no attribute '{name}'" |
| 69 | + ) from error |
| 70 | + |
| 71 | + def __setitem__(self, key: Any, value: Any) -> NoReturn: |
| 72 | + raise RuntimeError(f"'{type(self).__name__}' object is immutable") |
| 73 | + |
| 74 | + def __setattr__(self, key: Any, value: Any) -> NoReturn: |
| 75 | + raise RuntimeError(f"'{type(self).__name__}' object is immutable") |
| 76 | + |
| 77 | + def __delitem__(self, key: Any) -> NoReturn: |
| 78 | + raise RuntimeError(f"'{type(self).__name__}' object is immutable") |
| 79 | + |
| 80 | + def __delattr__(self, item: Any) -> NoReturn: |
| 81 | + raise RuntimeError(f"'{type(self).__name__}' object is immutable") |
| 82 | + |
| 83 | + def __hash__(self) -> int: |
| 84 | + return self.__dict__["__final_hash__"] |
| 85 | + |
| 86 | + def __eq__(self, other: Any) -> bool: |
| 87 | + if not isinstance(other, DatasetConfig): |
| 88 | + return NotImplemented |
| 89 | + |
| 90 | + return hash(self) == hash(other) |
| 91 | + |
| 92 | + def __repr__(self) -> str: |
| 93 | + return make_repr(type(self).__name__, self.items()) |
| 94 | + |
| 95 | + |
| 96 | +class DatasetInfo: |
| 97 | + def __init__( |
| 98 | + self, |
| 99 | + name: str, |
| 100 | + *, |
| 101 | + categories: Union[int, Sequence[str], str, pathlib.Path], |
| 102 | + citation: Optional[str] = None, |
| 103 | + homepage: Optional[str] = None, |
| 104 | + license: Optional[str] = None, |
| 105 | + valid_options: Optional[Dict[str, Sequence]] = None, |
| 106 | + ) -> None: |
| 107 | + self.name = name.lower() |
| 108 | + |
| 109 | + if isinstance(categories, int): |
| 110 | + categories = [str(label) for label in range(categories)] |
| 111 | + elif isinstance(categories, (str, pathlib.Path)): |
| 112 | + with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: |
| 113 | + categories = fh.readlines() |
| 114 | + self.categories = categories |
| 115 | + |
| 116 | + self.citation = citation |
| 117 | + self.homepage = homepage |
| 118 | + self.license = license |
| 119 | + |
| 120 | + valid_split: Dict[str, Sequence] = dict(split=["train"]) |
| 121 | + if valid_options is None: |
| 122 | + valid_options = valid_split |
| 123 | + elif "split" not in valid_options: |
| 124 | + valid_options.update(valid_split) |
| 125 | + elif "train" not in valid_options["split"]: |
| 126 | + raise ValueError( |
| 127 | + f"'train' has to be a valid argument for option 'split', " |
| 128 | + f"but found only {sequence_to_str(valid_options['split'], separate_last='and ')}." |
| 129 | + ) |
| 130 | + self._valid_options: Dict[str, Sequence] = valid_options |
| 131 | + |
| 132 | + @property |
| 133 | + def default_config(self) -> DatasetConfig: |
| 134 | + return DatasetConfig( |
| 135 | + {name: valid_args[0] for name, valid_args in self._valid_options.items()} |
| 136 | + ) |
| 137 | + |
| 138 | + def make_config(self, **options: Any) -> DatasetConfig: |
| 139 | + for name, arg in options.items(): |
| 140 | + if name not in self._valid_options: |
| 141 | + raise ValueError( |
| 142 | + add_suggestion( |
| 143 | + f"Unknown option '{name}' of dataset {self.name}.", |
| 144 | + word=name, |
| 145 | + possibilities=sorted(self._valid_options.keys()), |
| 146 | + ) |
| 147 | + ) |
| 148 | + |
| 149 | + valid_args = self._valid_options[name] |
| 150 | + |
| 151 | + if arg not in valid_args: |
| 152 | + raise ValueError( |
| 153 | + add_suggestion( |
| 154 | + f"Invalid argument '{arg}' for option '{name}' of dataset {self.name}.", |
| 155 | + word=arg, |
| 156 | + possibilities=valid_args, |
| 157 | + ) |
| 158 | + ) |
| 159 | + |
| 160 | + return DatasetConfig(self.default_config, **options) |
| 161 | + |
| 162 | + def __repr__(self) -> str: |
| 163 | + items = [("name", self.name)] |
| 164 | + for key in ("citation", "homepage", "license"): |
| 165 | + value = getattr(self, key) |
| 166 | + if value is not None: |
| 167 | + items.append((key, value)) |
| 168 | + items.extend( |
| 169 | + sorted( |
| 170 | + (key, sequence_to_str(value)) |
| 171 | + for key, value in self._valid_options.items() |
| 172 | + ) |
| 173 | + ) |
| 174 | + return make_repr(type(self).__name__, items) |
| 175 | + |
| 176 | + |
| 177 | +class Dataset(abc.ABC): |
| 178 | + @property |
| 179 | + @abc.abstractmethod |
| 180 | + def info(self) -> DatasetInfo: |
| 181 | + pass |
| 182 | + |
| 183 | + @property |
| 184 | + def name(self) -> str: |
| 185 | + return self.info.name |
| 186 | + |
| 187 | + @property |
| 188 | + def default_config(self) -> DatasetConfig: |
| 189 | + return self.info.default_config |
| 190 | + |
| 191 | + @abc.abstractmethod |
| 192 | + def resources(self, config: DatasetConfig) -> List[OnlineResource]: |
| 193 | + pass |
| 194 | + |
| 195 | + @abc.abstractmethod |
| 196 | + def _make_datapipe( |
| 197 | + self, |
| 198 | + resource_dps: List[IterDataPipe], |
| 199 | + *, |
| 200 | + config: DatasetConfig, |
| 201 | + shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]], |
| 202 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], |
| 203 | + ) -> IterDataPipe[Dict[str, Any]]: |
| 204 | + pass |
| 205 | + |
| 206 | + def to_datapipe( |
| 207 | + self, |
| 208 | + root: Union[str, pathlib.Path], |
| 209 | + *, |
| 210 | + config: Optional[DatasetConfig] = None, |
| 211 | + shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None, |
| 212 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, |
| 213 | + ) -> IterDataPipe[Dict[str, Any]]: |
| 214 | + if not config: |
| 215 | + config = self.info.default_config |
| 216 | + |
| 217 | + resource_dps = [ |
| 218 | + resource.to_datapipe(root) for resource in self.resources(config) |
| 219 | + ] |
| 220 | + return self._make_datapipe(resource_dps, config=config, shuffler=shuffler, decoder=decoder) |
0 commit comments