Skip to content

simplify OnlineResource.load #5990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 17, 2022
185 changes: 183 additions & 2 deletions test/test_prototype_datasets_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import gzip
import pathlib
import sys

import numpy as np
import pytest
import torch
from datasets_utils import make_fake_flo_file
from datasets_utils import make_fake_flo_file, make_tar
from torchdata.datapipes.iter import FileOpener, TarArchiveLoader
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
from torchvision.datasets.utils import _decompress
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile


Expand Down Expand Up @@ -48,6 +52,183 @@ def test_read_flo(tmpdir):
torch.testing.assert_close(actual, expected)


class TestOnlineResource:
class DummyResource(OnlineResource):
def __init__(self, download_fn=None, **kwargs):
super().__init__(**kwargs)
self._download_fn = download_fn

def _download(self, root):
if self._download_fn is None:
raise pytest.UsageError(
"`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`."
)

return self._download_fn(self, root)

def _make_file(self, root, *, content, name="file.txt"):
file = root / name
with open(file, "w") as fh:
fh.write(content)

return file

def _make_folder(self, root, *, name="folder"):
folder = root / name
subfolder = folder / "subfolder"
subfolder.mkdir(parents=True)

files = {}
for idx, root in enumerate([folder, folder, subfolder]):
content = f"sentinel{idx}"
file = self._make_file(root, name=f"file{idx}.txt", content=content)
files[str(file)] = content

return folder, files

def _make_tar(self, root, *, name="archive.tar", remove=True):
folder, files = self._make_folder(root, name=name.split(".")[0])
archive = make_tar(root, name, folder, remove=remove)
files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()}
return archive, files

def test_load_file(self, tmp_path):
content = "sentinel"
file = self._make_file(tmp_path, content=content)

resource = self.DummyResource(file_name=file.name)

dp = resource.load(tmp_path)
assert isinstance(dp, FileOpener)

data = list(dp)
assert len(data) == 1

path, buffer = data[0]
assert path == str(file)
assert buffer.read().decode() == content

def test_load_folder(self, tmp_path):
folder, files = self._make_folder(tmp_path)

resource = self.DummyResource(file_name=folder.name)

dp = resource.load(tmp_path)
assert isinstance(dp, FileOpener)
assert {path: buffer.read().decode() for path, buffer in dp} == files

def test_load_archive(self, tmp_path):
archive, files = self._make_tar(tmp_path)

resource = self.DummyResource(file_name=archive.name)

dp = resource.load(tmp_path)
assert isinstance(dp, TarArchiveLoader)
assert {path: buffer.read().decode() for path, buffer in dp} == files

def test_priority_decompressed_gt_raw(self, tmp_path):
# We don't need to actually compress here. Adding the suffix is sufficient
self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz")
file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt")

resource = self.DummyResource(file_name=file.name)

dp = resource.load(tmp_path)
path, buffer = next(iter(dp))

assert path == str(file)
assert buffer.read().decode() == "decompressed_sentinel"

def test_priority_extracted_gt_decompressed(self, tmp_path):
archive, _ = self._make_tar(tmp_path, remove=False)

resource = self.DummyResource(file_name=archive.name)

dp = resource.load(tmp_path)
# If the archive had been selected, this would be a `TarArchiveReader`
assert isinstance(dp, FileOpener)

def test_download(self, tmp_path):
download_fn_was_called = False

def download_fn(resource, root):
nonlocal download_fn_was_called
download_fn_was_called = True

return self._make_file(root, content="_", name=resource.file_name)

resource = self.DummyResource(
file_name="file.txt",
download_fn=download_fn,
)

resource.load(tmp_path)

assert download_fn_was_called, "`download_fn()` was never called"

# This tests the `"decompress"` literal as well as a custom callable
@pytest.mark.parametrize(
"preprocess",
[
"decompress",
lambda path: _decompress(str(path), remove_finished=True),
],
)
def test_preprocess_decompress(self, tmp_path, preprocess):
file_name = "file.txt.gz"
content = "sentinel"

def download_fn(resource, root):
file = root / resource.file_name
with gzip.open(file, "wb") as fh:
fh.write(content.encode())
return file

resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn)

dp = resource.load(tmp_path)
data = list(dp)
assert len(data) == 1

path, buffer = data[0]
assert path == str(tmp_path / file_name).replace(".gz", "")
assert buffer.read().decode() == content

def test_preprocess_extract(self, tmp_path):
files = None

def download_fn(resource, root):
nonlocal files
archive, files = self._make_tar(root, name=resource.file_name)
return archive

resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn)

dp = resource.load(tmp_path)
assert files is not None, "`download_fn()` was never called"
assert isinstance(dp, FileOpener)

actual = {path: buffer.read().decode() for path, buffer in dp}
expected = {
path.replace(resource.file_name, resource.file_name.split(".")[0]): content
for path, content in files.items()
}
assert actual == expected
Comment on lines +197 to +216
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fairly complex TBH. The nonlocal logic, the fact that _make_tar returns a dict of files, etc.

IIUC the goal of this check is to make sure the file gets properly extracted. Surely there are simpler ways to assert that?

Copy link
Collaborator Author

@pmeier pmeier May 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course there are other ways, but I'm not sure they are easier. We can't create the archive upfront in tmp_path, because in that case we would never trigger the preprocessing. One option would be to create the data upfront in a temporary directory and move it to tmp_path inside the download function similar to what we are doing with the actual resource loading in our dataset tests (#6010).

Given that you were ok with nonlocal in #5990 (comment), I don't think it will be much simpler. You choose.


def test_preprocess_only_after_download(self, tmp_path):
file = self._make_file(tmp_path, content="_")

def preprocess(path):
raise AssertionError("`preprocess` was called although the file was already present.")

resource = self.DummyResource(
file_name=file.name,
preprocess=preprocess,
)

resource.load(tmp_path)


class TestHttpResource:
def test_resolve_to_http(self, mocker):
file_name = "data.tar"
Expand Down
64 changes: 34 additions & 30 deletions torchvision/prototype/datasets/utils/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import itertools
import pathlib
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set
from urllib.parse import urlparse

from torchdata.datapipes.iter import (
Expand Down Expand Up @@ -32,7 +32,7 @@ def __init__(
*,
file_name: str,
sha256: Optional[str] = None,
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None,
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own education, why did we need to specify None in the annotation here? I assume Optional[] would have been enough - is it just to be more explicit about it?

Copy link
Collaborator Author

@pmeier pmeier May 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It becomes clearer when we let black explode the annotation:

Suggested change
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None,
preprocess: Optional[
Union[
Literal["decompress", "extract"],
Callable[[pathlib.Path], None],
]
] = None,

I only changed the return type of the callable from pathlib.Path to None since we refactored the loading logic that the return is no longer needed.

Still, in general you are right. Optional[Foo] is equivalent to Union[None, Foo]. Plus, Optional[Union[Foo, Bar]] can be flattened to Union[None, Foo, Bar]. It is just my personal preference to be explicit about Optional in case it actually means an optional value. In case None is just another valid value to pass, I prefer to merge it into a Union.

) -> None:
self.file_name = file_name
self.sha256 = sha256
Expand All @@ -50,14 +50,12 @@ def __init__(
self._preprocess = preprocess

@staticmethod
def _extract(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
)
def _extract(file: pathlib.Path) -> None:
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)

@staticmethod
def _decompress(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(_decompress(str(file), remove_finished=True))
def _decompress(file: pathlib.Path) -> None:
_decompress(str(file), remove_finished=True)

def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir():
Expand Down Expand Up @@ -91,32 +89,38 @@ def load(
) -> IterDataPipe[Tuple[str, IO]]:
root = pathlib.Path(root)
path = root / self.file_name

# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
# with no suffixes at all.
# with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which
# is not sufficient for files with multiple suffixes, e.g. foo.tar.gz.
stem = path.name.replace("".join(path.suffixes), "")

# In a first step, we check for a folder with the same stem as the raw file. If it exists, we use it since
# extracted files give the best I/O performance. Note that OnlineResource._extract() makes sure that an archive
# is always extracted in a folder with the corresponding file name.
folder_candidate = path.parent / stem
if folder_candidate.exists() and folder_candidate.is_dir():
return self._loader(folder_candidate)

# If there is no folder, we look for all files that share the same stem as the raw file, but might have a
# different suffix.
file_candidates = {file for file in path.parent.glob(stem + ".*")}
# If we don't find anything, we download the raw file.
if not file_candidates:
file_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
if file_candidates == {path}:
def find_candidates() -> Set[pathlib.Path]:
# Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder
# candidate simultaneously, that would also pick up other files that share the same prefix. For example, the
# test split of the stanford-cars dataset uses the files
# - cars_test.tgz
# - cars_test_annos_withlabels.mat
# Globbing for `"cars_test*"` picks up both.
candidates = {file for file in path.parent.glob(f"{stem}.*")}
folder_candidate = path.parent / stem
if folder_candidate.exists():
candidates.add(folder_candidate)

return candidates

candidates = find_candidates()

if not candidates:
self.download(root, skip_integrity_check=skip_integrity_check)
if self._preprocess is not None:
path = self._preprocess(path)
# Otherwise, we use the path with the fewest suffixes. This gives us the decompressed > raw priority that we
# want for the best I/O performance.
else:
path = min(file_candidates, key=lambda path: len(path.suffixes))
return self._loader(path)
self._preprocess(path)
candidates = find_candidates()

# We use the path with the fewest suffixes. This gives us the
# extracted > decompressed > raw
# priority that we want for the best I/O performance.
return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes)))

@abc.abstractmethod
def _download(self, root: pathlib.Path) -> None:
Expand Down