Skip to content

Commit 3d6b42c

Browse files
NicolasHugfmassa
authored andcommitted
[fbsync] cleanup prototype datasets (#4471)
Summary: * cleanup image folder * make shuffling mandatory * rename parameter in home() function * don't show builtin list * make categories optional in dataset info * use pseudo-infinite buffer size for shuffler Reviewed By: datumbox Differential Revision: D31268046 fbshipit-source-id: e7d66ecdc1c8250cabb8385d116daea32ef0899b Co-authored-by: Francisco Massa <[email protected]>
1 parent 18fdaac commit 3d6b42c

File tree

7 files changed

+24
-24
lines changed

7 files changed

+24
-24
lines changed

torchvision/prototype/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
from . import decoder, utils
33

44
# Load this last, since some parts depend on the above being loaded first
5-
from ._api import register, list, info, load
5+
from ._api import register, _list as list, info, load
66
from ._folder import from_data_folder, from_image_folder

torchvision/prototype/datasets/_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def register(dataset: Dataset) -> None:
1717
DATASETS[dataset.name] = dataset
1818

1919

20-
def list() -> List[str]:
20+
# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list'
21+
def _list() -> List[str]:
2122
return sorted(DATASETS.keys())
2223

2324

@@ -45,7 +46,6 @@ def info(name: str) -> DatasetInfo:
4546
def load(
4647
name: str,
4748
*,
48-
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
4949
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
5050
split: str = "train",
5151
**options: Any,
@@ -55,4 +55,4 @@ def load(
5555
config = dataset.info.make_config(split=split, **options)
5656
root = home() / name
5757

58-
return dataset.to_datapipe(root, config=config, shuffler=shuffler, decoder=decoder)
58+
return dataset.to_datapipe(root, config=config, decoder=decoder)

torchvision/prototype/datasets/_folder.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@
1010
from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter
1111

1212
from torchvision.prototype.datasets.decoder import pil
13+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
1314

1415

1516
__all__ = ["from_data_folder", "from_image_folder"]
1617

17-
# pseudo-infinite buffer size until a true infinite buffer is supported
18-
INFINITE = 1_000_000_000
19-
2018

2119
def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
2220
rel_path = pathlib.Path(path).relative_to(root)
@@ -45,7 +43,6 @@ def _collate_and_decode_data(
4543
def from_data_folder(
4644
root: Union[str, pathlib.Path],
4745
*,
48-
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = lambda dp: Shuffler(dp, buffer_size=INFINITE),
4946
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
5047
valid_extensions: Optional[Collection[str]] = None,
5148
recursive: bool = True,
@@ -55,8 +52,7 @@ def from_data_folder(
5552
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
5653
dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks)
5754
dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
58-
if shuffler:
59-
dp = shuffler(dp)
55+
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
6056
dp = FileLoader(dp)
6157
return (
6258
Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)),

torchvision/prototype/datasets/_home.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision"
88

99

10-
def home(home: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
10+
def home(root: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
1111
global HOME
12-
if home is not None:
13-
HOME = pathlib.Path(home).expanduser().resolve()
12+
if root is not None:
13+
HOME = pathlib.Path(root).expanduser().resolve()
1414
return HOME
1515

16-
home = os.getenv("TORCHVISION_DATASETS_HOME")
17-
if home is not None:
18-
return pathlib.Path(home)
16+
root = os.getenv("TORCHVISION_DATASETS_HOME")
17+
if root is not None:
18+
return pathlib.Path(root)
1919

2020
return HOME
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import io
22

3-
import numpy as np
43
import PIL.Image
54
import torch
65

6+
from torchvision.transforms.functional import pil_to_tensor
7+
78
__all__ = ["pil"]
89

910

1011
def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor:
11-
image = PIL.Image.open(file).convert(mode.upper())
12-
return torch.from_numpy(np.array(image, copy=True)).permute((2, 0, 1))
12+
return pil_to_tensor(PIL.Image.open(file).convert(mode.upper()))

torchvision/prototype/datasets/utils/_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,17 @@ def __init__(
9898
self,
9999
name: str,
100100
*,
101-
categories: Union[int, Sequence[str], str, pathlib.Path],
101+
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
102102
citation: Optional[str] = None,
103103
homepage: Optional[str] = None,
104104
license: Optional[str] = None,
105105
valid_options: Optional[Dict[str, Sequence]] = None,
106106
) -> None:
107107
self.name = name.lower()
108108

109-
if isinstance(categories, int):
109+
if categories is None:
110+
categories = []
111+
elif isinstance(categories, int):
110112
categories = [str(label) for label in range(categories)]
111113
elif isinstance(categories, (str, pathlib.Path)):
112114
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
@@ -198,7 +200,6 @@ def _make_datapipe(
198200
resource_dps: List[IterDataPipe],
199201
*,
200202
config: DatasetConfig,
201-
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]],
202203
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
203204
) -> IterDataPipe[Dict[str, Any]]:
204205
pass
@@ -208,7 +209,6 @@ def to_datapipe(
208209
root: Union[str, pathlib.Path],
209210
*,
210211
config: Optional[DatasetConfig] = None,
211-
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
212212
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
213213
) -> IterDataPipe[Dict[str, Any]]:
214214
if not config:
@@ -217,4 +217,4 @@ def to_datapipe(
217217
resource_dps = [
218218
resource.to_datapipe(root) for resource in self.resources(config)
219219
]
220-
return self._make_datapipe(resource_dps, config=config, shuffler=shuffler, decoder=decoder)
220+
return self._make_datapipe(resource_dps, config=config, decoder=decoder)

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44

55

66
__all__ = [
7+
"INFINITE_BUFFER_SIZE",
78
"sequence_to_str",
89
"add_suggestion",
910
]
1011

12+
# pseudo-infinite until a true infinite buffer is supported by all datapipes
13+
INFINITE_BUFFER_SIZE = 1_000_000_000
14+
1115

1216
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
1317
if len(seq) == 1:

0 commit comments

Comments
 (0)