-
Notifications
You must be signed in to change notification settings - Fork 7.1k
simplify OnlineResource.load #5990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cb774b7
5db94a8
d627479
99c2daf
65198d1
232c6a9
89df201
a836a83
e8ca146
c0ecb46
a703d44
82af340
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,15 @@ | ||
import gzip | ||
import pathlib | ||
import sys | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from datasets_utils import make_fake_flo_file | ||
from datasets_utils import make_fake_flo_file, make_tar | ||
from torchdata.datapipes.iter import FileOpener, TarArchiveLoader | ||
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref | ||
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset | ||
from torchvision.datasets.utils import _decompress | ||
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource | ||
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile | ||
|
||
|
||
|
@@ -48,6 +52,183 @@ def test_read_flo(tmpdir): | |
torch.testing.assert_close(actual, expected) | ||
|
||
|
||
class TestOnlineResource: | ||
class DummyResource(OnlineResource): | ||
def __init__(self, download_fn=None, **kwargs): | ||
super().__init__(**kwargs) | ||
self._download_fn = download_fn | ||
|
||
def _download(self, root): | ||
if self._download_fn is None: | ||
raise pytest.UsageError( | ||
"`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`." | ||
) | ||
|
||
return self._download_fn(self, root) | ||
|
||
def _make_file(self, root, *, content, name="file.txt"): | ||
file = root / name | ||
with open(file, "w") as fh: | ||
fh.write(content) | ||
|
||
return file | ||
|
||
def _make_folder(self, root, *, name="folder"): | ||
folder = root / name | ||
subfolder = folder / "subfolder" | ||
subfolder.mkdir(parents=True) | ||
|
||
files = {} | ||
for idx, root in enumerate([folder, folder, subfolder]): | ||
content = f"sentinel{idx}" | ||
file = self._make_file(root, name=f"file{idx}.txt", content=content) | ||
files[str(file)] = content | ||
|
||
return folder, files | ||
|
||
def _make_tar(self, root, *, name="archive.tar", remove=True): | ||
folder, files = self._make_folder(root, name=name.split(".")[0]) | ||
archive = make_tar(root, name, folder, remove=remove) | ||
files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()} | ||
return archive, files | ||
|
||
def test_load_file(self, tmp_path): | ||
content = "sentinel" | ||
file = self._make_file(tmp_path, content=content) | ||
|
||
resource = self.DummyResource(file_name=file.name) | ||
|
||
dp = resource.load(tmp_path) | ||
assert isinstance(dp, FileOpener) | ||
|
||
data = list(dp) | ||
assert len(data) == 1 | ||
|
||
path, buffer = data[0] | ||
assert path == str(file) | ||
assert buffer.read().decode() == content | ||
|
||
def test_load_folder(self, tmp_path): | ||
folder, files = self._make_folder(tmp_path) | ||
|
||
resource = self.DummyResource(file_name=folder.name) | ||
|
||
dp = resource.load(tmp_path) | ||
assert isinstance(dp, FileOpener) | ||
assert {path: buffer.read().decode() for path, buffer in dp} == files | ||
|
||
def test_load_archive(self, tmp_path): | ||
archive, files = self._make_tar(tmp_path) | ||
|
||
resource = self.DummyResource(file_name=archive.name) | ||
|
||
dp = resource.load(tmp_path) | ||
assert isinstance(dp, TarArchiveLoader) | ||
assert {path: buffer.read().decode() for path, buffer in dp} == files | ||
|
||
def test_priority_decompressed_gt_raw(self, tmp_path): | ||
# We don't need to actually compress here. Adding the suffix is sufficient | ||
self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz") | ||
file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt") | ||
|
||
resource = self.DummyResource(file_name=file.name) | ||
|
||
dp = resource.load(tmp_path) | ||
path, buffer = next(iter(dp)) | ||
|
||
assert path == str(file) | ||
assert buffer.read().decode() == "decompressed_sentinel" | ||
|
||
def test_priority_extracted_gt_decompressed(self, tmp_path): | ||
archive, _ = self._make_tar(tmp_path, remove=False) | ||
|
||
resource = self.DummyResource(file_name=archive.name) | ||
|
||
dp = resource.load(tmp_path) | ||
# If the archive had been selected, this would be a `TarArchiveReader` | ||
assert isinstance(dp, FileOpener) | ||
|
||
def test_download(self, tmp_path): | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
download_fn_was_called = False | ||
|
||
def download_fn(resource, root): | ||
nonlocal download_fn_was_called | ||
download_fn_was_called = True | ||
|
||
return self._make_file(root, content="_", name=resource.file_name) | ||
|
||
resource = self.DummyResource( | ||
file_name="file.txt", | ||
download_fn=download_fn, | ||
) | ||
|
||
resource.load(tmp_path) | ||
|
||
assert download_fn_was_called, "`download_fn()` was never called" | ||
|
||
# This tests the `"decompress"` literal as well as a custom callable | ||
@pytest.mark.parametrize( | ||
"preprocess", | ||
[ | ||
"decompress", | ||
lambda path: _decompress(str(path), remove_finished=True), | ||
], | ||
) | ||
def test_preprocess_decompress(self, tmp_path, preprocess): | ||
file_name = "file.txt.gz" | ||
content = "sentinel" | ||
|
||
def download_fn(resource, root): | ||
file = root / resource.file_name | ||
with gzip.open(file, "wb") as fh: | ||
fh.write(content.encode()) | ||
return file | ||
|
||
resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn) | ||
|
||
dp = resource.load(tmp_path) | ||
data = list(dp) | ||
assert len(data) == 1 | ||
|
||
path, buffer = data[0] | ||
assert path == str(tmp_path / file_name).replace(".gz", "") | ||
assert buffer.read().decode() == content | ||
|
||
def test_preprocess_extract(self, tmp_path): | ||
files = None | ||
|
||
def download_fn(resource, root): | ||
nonlocal files | ||
archive, files = self._make_tar(root, name=resource.file_name) | ||
return archive | ||
|
||
resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn) | ||
|
||
dp = resource.load(tmp_path) | ||
assert files is not None, "`download_fn()` was never called" | ||
assert isinstance(dp, FileOpener) | ||
|
||
actual = {path: buffer.read().decode() for path, buffer in dp} | ||
expected = { | ||
path.replace(resource.file_name, resource.file_name.split(".")[0]): content | ||
for path, content in files.items() | ||
} | ||
assert actual == expected | ||
Comment on lines
+197
to
+216
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fairly complex TBH. The IIUC the goal of this check is to make sure the file gets properly extracted. Surely there are simpler ways to assert that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course there are other ways, but I'm not sure they are easier. We can't create the archive upfront in Given that you were ok with |
||
|
||
def test_preprocess_only_after_download(self, tmp_path): | ||
file = self._make_file(tmp_path, content="_") | ||
|
||
def preprocess(path): | ||
raise AssertionError("`preprocess` was called although the file was already present.") | ||
|
||
resource = self.DummyResource( | ||
file_name=file.name, | ||
preprocess=preprocess, | ||
) | ||
|
||
resource.load(tmp_path) | ||
|
||
|
||
class TestHttpResource: | ||
def test_resolve_to_http(self, mocker): | ||
file_name = "data.tar" | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,7 +2,7 @@ | |||||||||||||||
import hashlib | ||||||||||||||||
import itertools | ||||||||||||||||
import pathlib | ||||||||||||||||
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn | ||||||||||||||||
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set | ||||||||||||||||
from urllib.parse import urlparse | ||||||||||||||||
|
||||||||||||||||
from torchdata.datapipes.iter import ( | ||||||||||||||||
|
@@ -32,7 +32,7 @@ def __init__( | |||||||||||||||
*, | ||||||||||||||||
file_name: str, | ||||||||||||||||
sha256: Optional[str] = None, | ||||||||||||||||
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None, | ||||||||||||||||
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None, | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my own education, why did we need to specify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It becomes clearer when we let
Suggested change
I only changed the return type of the callable from Still, in general you are right. |
||||||||||||||||
) -> None: | ||||||||||||||||
self.file_name = file_name | ||||||||||||||||
self.sha256 = sha256 | ||||||||||||||||
|
@@ -50,14 +50,12 @@ def __init__( | |||||||||||||||
self._preprocess = preprocess | ||||||||||||||||
|
||||||||||||||||
@staticmethod | ||||||||||||||||
def _extract(file: pathlib.Path) -> pathlib.Path: | ||||||||||||||||
return pathlib.Path( | ||||||||||||||||
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False) | ||||||||||||||||
) | ||||||||||||||||
def _extract(file: pathlib.Path) -> None: | ||||||||||||||||
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False) | ||||||||||||||||
|
||||||||||||||||
@staticmethod | ||||||||||||||||
def _decompress(file: pathlib.Path) -> pathlib.Path: | ||||||||||||||||
return pathlib.Path(_decompress(str(file), remove_finished=True)) | ||||||||||||||||
def _decompress(file: pathlib.Path) -> None: | ||||||||||||||||
_decompress(str(file), remove_finished=True) | ||||||||||||||||
|
||||||||||||||||
def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: | ||||||||||||||||
if path.is_dir(): | ||||||||||||||||
|
@@ -91,32 +89,38 @@ def load( | |||||||||||||||
) -> IterDataPipe[Tuple[str, IO]]: | ||||||||||||||||
root = pathlib.Path(root) | ||||||||||||||||
path = root / self.file_name | ||||||||||||||||
|
||||||||||||||||
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories | ||||||||||||||||
# with no suffixes at all. | ||||||||||||||||
# with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which | ||||||||||||||||
# is not sufficient for files with multiple suffixes, e.g. foo.tar.gz. | ||||||||||||||||
stem = path.name.replace("".join(path.suffixes), "") | ||||||||||||||||
|
||||||||||||||||
# In a first step, we check for a folder with the same stem as the raw file. If it exists, we use it since | ||||||||||||||||
# extracted files give the best I/O performance. Note that OnlineResource._extract() makes sure that an archive | ||||||||||||||||
# is always extracted in a folder with the corresponding file name. | ||||||||||||||||
folder_candidate = path.parent / stem | ||||||||||||||||
if folder_candidate.exists() and folder_candidate.is_dir(): | ||||||||||||||||
return self._loader(folder_candidate) | ||||||||||||||||
|
||||||||||||||||
# If there is no folder, we look for all files that share the same stem as the raw file, but might have a | ||||||||||||||||
# different suffix. | ||||||||||||||||
file_candidates = {file for file in path.parent.glob(stem + ".*")} | ||||||||||||||||
# If we don't find anything, we download the raw file. | ||||||||||||||||
if not file_candidates: | ||||||||||||||||
file_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)} | ||||||||||||||||
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps. | ||||||||||||||||
if file_candidates == {path}: | ||||||||||||||||
def find_candidates() -> Set[pathlib.Path]: | ||||||||||||||||
# Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder | ||||||||||||||||
# candidate simultaneously, that would also pick up other files that share the same prefix. For example, the | ||||||||||||||||
# test split of the stanford-cars dataset uses the files | ||||||||||||||||
# - cars_test.tgz | ||||||||||||||||
# - cars_test_annos_withlabels.mat | ||||||||||||||||
# Globbing for `"cars_test*"` picks up both. | ||||||||||||||||
candidates = {file for file in path.parent.glob(f"{stem}.*")} | ||||||||||||||||
folder_candidate = path.parent / stem | ||||||||||||||||
if folder_candidate.exists(): | ||||||||||||||||
candidates.add(folder_candidate) | ||||||||||||||||
|
||||||||||||||||
return candidates | ||||||||||||||||
|
||||||||||||||||
candidates = find_candidates() | ||||||||||||||||
|
||||||||||||||||
if not candidates: | ||||||||||||||||
self.download(root, skip_integrity_check=skip_integrity_check) | ||||||||||||||||
if self._preprocess is not None: | ||||||||||||||||
path = self._preprocess(path) | ||||||||||||||||
# Otherwise, we use the path with the fewest suffixes. This gives us the decompressed > raw priority that we | ||||||||||||||||
# want for the best I/O performance. | ||||||||||||||||
else: | ||||||||||||||||
path = min(file_candidates, key=lambda path: len(path.suffixes)) | ||||||||||||||||
return self._loader(path) | ||||||||||||||||
self._preprocess(path) | ||||||||||||||||
candidates = find_candidates() | ||||||||||||||||
|
||||||||||||||||
# We use the path with the fewest suffixes. This gives us the | ||||||||||||||||
# extracted > decompressed > raw | ||||||||||||||||
# priority that we want for the best I/O performance. | ||||||||||||||||
return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes))) | ||||||||||||||||
|
||||||||||||||||
@abc.abstractmethod | ||||||||||||||||
def _download(self, root: pathlib.Path) -> None: | ||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.