Skip to content

cleanup prototype datasets #4471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import decoder, utils

# Load this last, since some parts depend on the above being loaded first
from ._api import register, list, info, load
from ._api import register, _list as list, info, load
from ._folder import from_data_folder, from_image_folder
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def register(dataset: Dataset) -> None:
DATASETS[dataset.name] = dataset


def list() -> List[str]:
# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list'
def _list() -> List[str]:
return sorted(DATASETS.keys())


Expand Down Expand Up @@ -45,7 +46,6 @@ def info(name: str) -> DatasetInfo:
def load(
name: str,
*,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = pil,
split: str = "train",
**options: Any,
Expand All @@ -55,4 +55,4 @@ def load(
config = dataset.info.make_config(split=split, **options)
root = home() / name

return dataset.to_datapipe(root, config=config, shuffler=shuffler, decoder=decoder)
return dataset.to_datapipe(root, config=config, decoder=decoder)
8 changes: 2 additions & 6 deletions torchvision/prototype/datasets/_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter

from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE


__all__ = ["from_data_folder", "from_image_folder"]

# pseudo-infinite buffer size until a true infinite buffer is supported
INFINITE = 1_000_000_000


def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
rel_path = pathlib.Path(path).relative_to(root)
Expand Down Expand Up @@ -45,7 +43,6 @@ def _collate_and_decode_data(
def from_data_folder(
root: Union[str, pathlib.Path],
*,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = lambda dp: Shuffler(dp, buffer_size=INFINITE),
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
valid_extensions: Optional[Collection[str]] = None,
recursive: bool = True,
Expand All @@ -55,8 +52,7 @@ def from_data_folder(
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp: IterDataPipe = FileLister(str(root), recursive=recursive, masks=masks)
dp = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
if shuffler:
dp = shuffler(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = FileLoader(dp)
return (
Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)),
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datasets/_home.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision"


def home(home: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
def home(root: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
global HOME
if home is not None:
HOME = pathlib.Path(home).expanduser().resolve()
if root is not None:
HOME = pathlib.Path(root).expanduser().resolve()
return HOME

home = os.getenv("TORCHVISION_DATASETS_HOME")
if home is not None:
return pathlib.Path(home)
root = os.getenv("TORCHVISION_DATASETS_HOME")
if root is not None:
return pathlib.Path(root)

return HOME
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/decoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import io

import numpy as np
import PIL.Image
import torch

from torchvision.transforms.functional import pil_to_tensor

__all__ = ["pil"]


def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor:
image = PIL.Image.open(file).convert(mode.upper())
return torch.from_numpy(np.array(image, copy=True)).permute((2, 0, 1))
return pil_to_tensor(PIL.Image.open(file).convert(mode.upper()))
10 changes: 5 additions & 5 deletions torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,17 @@ def __init__(
self,
name: str,
*,
categories: Union[int, Sequence[str], str, pathlib.Path],
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None,
homepage: Optional[str] = None,
license: Optional[str] = None,
valid_options: Optional[Dict[str, Sequence]] = None,
) -> None:
self.name = name.lower()

if isinstance(categories, int):
if categories is None:
categories = []
elif isinstance(categories, int):
categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)):
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
Expand Down Expand Up @@ -198,7 +200,6 @@ def _make_datapipe(
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]],
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
pass
Expand All @@ -208,7 +209,6 @@ def to_datapipe(
root: Union[str, pathlib.Path],
*,
config: Optional[DatasetConfig] = None,
shuffler: Optional[Callable[[IterDataPipe], IterDataPipe]] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
) -> IterDataPipe[Dict[str, Any]]:
if not config:
Expand All @@ -217,4 +217,4 @@ def to_datapipe(
resource_dps = [
resource.to_datapipe(root) for resource in self.resources(config)
]
return self._make_datapipe(resource_dps, config=config, shuffler=shuffler, decoder=decoder)
return self._make_datapipe(resource_dps, config=config, decoder=decoder)
4 changes: 4 additions & 0 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@


__all__ = [
"INFINITE_BUFFER_SIZE",
"sequence_to_str",
"add_suggestion",
]

# pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000


def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
Expand Down