Skip to content

Commit 2d91c26

Browse files
committed
add API for new style datasets (#4473)
* add API for new style datasets * cleanup Co-authored-by: Francisco Massa <[email protected]> [ghstack-poisoned]
1 parent ab321fc commit 2d91c26

File tree

8 files changed

+403
-2
lines changed

8 files changed

+403
-2
lines changed
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1-
from . import decoder
1+
from ._home import home
2+
from . import decoder, utils
3+
4+
# Load this last, since some parts depend on the above being loaded first
5+
from ._api import register, list, info, load
26
from ._folder import from_data_folder, from_image_folder
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import io
2+
from typing import Any, Callable, Dict, List, Optional
3+
4+
import torch
5+
from torch.utils.data import IterDataPipe
6+
7+
from torchvision.prototype.datasets import home
8+
from torchvision.prototype.datasets.decoder import pil
9+
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
10+
from torchvision.prototype.datasets.utils._internal import add_suggestion
11+
12+
13+
DATASETS: Dict[str, Dataset] = {}
14+
15+
16+
def register(dataset: Dataset) -> None:
17+
DATASETS[dataset.name] = dataset
18+
19+
20+
def list() -> List[str]:
21+
return sorted(DATASETS.keys())
22+
23+
24+
def find(name: str) -> Dataset:
25+
name = name.lower()
26+
try:
27+
return DATASETS[name]
28+
except KeyError as error:
29+
raise ValueError(
30+
add_suggestion(
31+
f"Unknown dataset '{name}'.",
32+
word=name,
33+
possibilities=DATASETS.keys(),
34+
alternative_hint=lambda _: (
35+
"You can use torchvision.datasets.list() to get a list of all available datasets."
36+
),
37+
)
38+
) from error
39+
40+
41+
def info(name: str) -> DatasetInfo:
42+
return find(name).info
43+
44+
45+
def load(
46+
name: str,
47+
*,
48+
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
49+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
50+
split: str = "train",
51+
**options: Any,
52+
) -> IterDataPipe[Dict[str, Any]]:
53+
dataset = find(name)
54+
55+
config = dataset.info.make_config(split=split, **options)
56+
root = home() / name
57+
58+
return dataset.to_datapipe(root, config=config, shuffler=shuffler, decoder=decoder)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
import pathlib
3+
from typing import Optional, Union
4+
5+
from torch.hub import _get_torch_home
6+
7+
HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision"
8+
9+
10+
def home(home: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
11+
global HOME
12+
if home is not None:
13+
HOME = pathlib.Path(home).expanduser().resolve()
14+
return HOME
15+
16+
home = os.getenv("TORCHVISION_DATASETS_HOME")
17+
if home is not None:
18+
return pathlib.Path(home)
19+
20+
return HOME

torchvision/prototype/datasets/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
__all__ = ["pil"]
88

99

10-
def pil(file: io.IOBase, mode="RGB") -> torch.Tensor:
10+
def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor:
1111
image = PIL.Image.open(file).convert(mode.upper())
1212
return torch.from_numpy(np.array(image, copy=True)).permute((2, 0, 1))
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import _internal
2+
from ._dataset import DatasetConfig, DatasetInfo, Dataset
3+
from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import abc
2+
import io
3+
import os
4+
import pathlib
5+
import textwrap
6+
from collections import Mapping
7+
from typing import (
8+
Any,
9+
Callable,
10+
Dict,
11+
List,
12+
Optional,
13+
Sequence,
14+
Union,
15+
NoReturn,
16+
Iterable,
17+
Tuple,
18+
)
19+
20+
import torch
21+
from torch.utils.data import IterDataPipe
22+
23+
from torchvision.prototype.datasets.utils._internal import (
24+
add_suggestion,
25+
sequence_to_str,
26+
)
27+
from ._resource import OnlineResource
28+
29+
30+
def make_repr(name: str, items: Iterable[Tuple[str, Any]]):
31+
def to_str(sep: str) -> str:
32+
return sep.join([f"{key}={value}" for key, value in items])
33+
34+
prefix = f"{name}("
35+
postfix = ")"
36+
body = to_str(", ")
37+
38+
line_length = int(os.environ.get("COLUMNS", 80))
39+
body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length
40+
multiline_body = len(str(body).splitlines()) > 1
41+
if not (body_too_long or multiline_body):
42+
return prefix + body + postfix
43+
44+
body = textwrap.indent(to_str(",\n"), " " * 2)
45+
return f"{prefix}\n{body}\n{postfix}"
46+
47+
48+
class DatasetConfig(Mapping):
49+
def __init__(self, *args, **kwargs):
50+
data = dict(*args, **kwargs)
51+
self.__dict__["__data__"] = data
52+
self.__dict__["__final_hash__"] = hash(tuple(data.items()))
53+
54+
def __getitem__(self, name: str) -> Any:
55+
return self.__dict__["__data__"][name]
56+
57+
def __iter__(self):
58+
return iter(self.__dict__["__data__"].keys())
59+
60+
def __len__(self):
61+
return len(self.__dict__["__data__"])
62+
63+
def __getattr__(self, name: str) -> Any:
64+
try:
65+
return self[name]
66+
except KeyError as error:
67+
raise AttributeError(
68+
f"'{type(self).__name__}' object has no attribute '{name}'"
69+
) from error
70+
71+
def __setitem__(self, key: Any, value: Any) -> NoReturn:
72+
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
73+
74+
def __setattr__(self, key: Any, value: Any) -> NoReturn:
75+
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
76+
77+
def __delitem__(self, key: Any) -> NoReturn:
78+
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
79+
80+
def __delattr__(self, item: Any) -> NoReturn:
81+
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
82+
83+
def __hash__(self) -> int:
84+
return self.__dict__["__final_hash__"]
85+
86+
def __eq__(self, other: Any) -> bool:
87+
if not isinstance(other, DatasetConfig):
88+
return NotImplemented
89+
90+
return hash(self) == hash(other)
91+
92+
def __repr__(self) -> str:
93+
return make_repr(type(self).__name__, self.items())
94+
95+
96+
class DatasetInfo:
97+
def __init__(
98+
self,
99+
name: str,
100+
*,
101+
categories: Union[int, Sequence[str], str, pathlib.Path],
102+
citation: Optional[str] = None,
103+
homepage: Optional[str] = None,
104+
license: Optional[str] = None,
105+
valid_options: Optional[Dict[str, Sequence]] = None,
106+
) -> None:
107+
self.name = name.lower()
108+
109+
if isinstance(categories, int):
110+
categories = [str(label) for label in range(categories)]
111+
elif isinstance(categories, (str, pathlib.Path)):
112+
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
113+
categories = fh.readlines()
114+
self.categories = categories
115+
116+
self.citation = citation
117+
self.homepage = homepage
118+
self.license = license
119+
120+
valid_split: Dict[str, Sequence] = dict(split=["train"])
121+
if valid_options is None:
122+
valid_options = valid_split
123+
elif "split" not in valid_options:
124+
valid_options.update(valid_split)
125+
elif "train" not in valid_options["split"]:
126+
raise ValueError(
127+
f"'train' has to be a valid argument for option 'split', "
128+
f"but found only {sequence_to_str(valid_options['split'], separate_last='and ')}."
129+
)
130+
self._valid_options: Dict[str, Sequence] = valid_options
131+
132+
@property
133+
def default_config(self) -> DatasetConfig:
134+
return DatasetConfig(
135+
{name: valid_args[0] for name, valid_args in self._valid_options.items()}
136+
)
137+
138+
def make_config(self, **options: Any) -> DatasetConfig:
139+
for name, arg in options.items():
140+
if name not in self._valid_options:
141+
raise ValueError(
142+
add_suggestion(
143+
f"Unknown option '{name}' of dataset {self.name}.",
144+
word=name,
145+
possibilities=sorted(self._valid_options.keys()),
146+
)
147+
)
148+
149+
valid_args = self._valid_options[name]
150+
151+
if arg not in valid_args:
152+
raise ValueError(
153+
add_suggestion(
154+
f"Invalid argument '{arg}' for option '{name}' of dataset {self.name}.",
155+
word=arg,
156+
possibilities=valid_args,
157+
)
158+
)
159+
160+
return DatasetConfig(self.default_config, **options)
161+
162+
def __repr__(self) -> str:
163+
items = [("name", self.name)]
164+
for key in ("citation", "homepage", "license"):
165+
value = getattr(self, key)
166+
if value is not None:
167+
items.append((key, value))
168+
items.extend(
169+
sorted(
170+
(key, sequence_to_str(value))
171+
for key, value in self._valid_options.items()
172+
)
173+
)
174+
return make_repr(type(self).__name__, items)
175+
176+
177+
class Dataset(abc.ABC):
178+
@property
179+
@abc.abstractmethod
180+
def info(self) -> DatasetInfo:
181+
pass
182+
183+
@property
184+
def name(self) -> str:
185+
return self.info.name
186+
187+
@property
188+
def default_config(self) -> DatasetConfig:
189+
return self.info.default_config
190+
191+
@abc.abstractmethod
192+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
193+
pass
194+
195+
@abc.abstractmethod
196+
def _make_datapipe(
197+
self,
198+
resource_dps: List[IterDataPipe],
199+
*,
200+
config: DatasetConfig,
201+
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]],
202+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
203+
) -> IterDataPipe[Dict[str, Any]]:
204+
pass
205+
206+
def to_datapipe(
207+
self,
208+
root: Union[str, pathlib.Path],
209+
*,
210+
config: Optional[DatasetConfig] = None,
211+
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
212+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
213+
) -> IterDataPipe[Dict[str, Any]]:
214+
if not config:
215+
config = self.info.default_config
216+
217+
resource_dps = [
218+
resource.to_datapipe(root) for resource in self.resources(config)
219+
]
220+
return self._make_datapipe(resource_dps, config=config, shuffler=shuffler, decoder=decoder)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import collections.abc
2+
import difflib
3+
from typing import Collection, Sequence, Callable
4+
5+
6+
__all__ = [
7+
"sequence_to_str",
8+
"add_suggestion",
9+
]
10+
11+
12+
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
13+
if len(seq) == 1:
14+
return f"'{seq[0]}'"
15+
16+
return (
17+
f"""'{"', '".join([str(item) for item in seq[:-1]])}', """
18+
f"""{separate_last}'{seq[-1]}'."""
19+
)
20+
21+
22+
def add_suggestion(
23+
msg: str,
24+
*,
25+
word: str,
26+
possibilities: Collection[str],
27+
close_match_hint: Callable[
28+
[str], str
29+
] = lambda close_match: f"Did you mean '{close_match}'?",
30+
alternative_hint: Callable[
31+
[Sequence[str]], str
32+
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
33+
) -> str:
34+
if not isinstance(possibilities, collections.abc.Sequence):
35+
possibilities = sorted(possibilities)
36+
suggestions = difflib.get_close_matches(word, possibilities, 1)
37+
hint = (
38+
close_match_hint(suggestions[0])
39+
if suggestions
40+
else alternative_hint(possibilities)
41+
)
42+
return f"{msg.strip()} {hint}"

0 commit comments

Comments
 (0)