diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index b1c95844574..8790b1638f9 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -1,11 +1,15 @@ +import gzip +import pathlib import sys import numpy as np import pytest import torch -from datasets_utils import make_fake_flo_file +from datasets_utils import make_fake_flo_file, make_tar +from torchdata.datapipes.iter import FileOpener, TarArchiveLoader from torchvision.datasets._optical_flow import _read_flo as read_flo_ref -from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset +from torchvision.datasets.utils import _decompress +from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource from torchvision.prototype.datasets.utils._internal import read_flo, fromfile @@ -48,6 +52,183 @@ def test_read_flo(tmpdir): torch.testing.assert_close(actual, expected) +class TestOnlineResource: + class DummyResource(OnlineResource): + def __init__(self, download_fn=None, **kwargs): + super().__init__(**kwargs) + self._download_fn = download_fn + + def _download(self, root): + if self._download_fn is None: + raise pytest.UsageError( + "`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`." + ) + + return self._download_fn(self, root) + + def _make_file(self, root, *, content, name="file.txt"): + file = root / name + with open(file, "w") as fh: + fh.write(content) + + return file + + def _make_folder(self, root, *, name="folder"): + folder = root / name + subfolder = folder / "subfolder" + subfolder.mkdir(parents=True) + + files = {} + for idx, root in enumerate([folder, folder, subfolder]): + content = f"sentinel{idx}" + file = self._make_file(root, name=f"file{idx}.txt", content=content) + files[str(file)] = content + + return folder, files + + def _make_tar(self, root, *, name="archive.tar", remove=True): + folder, files = self._make_folder(root, name=name.split(".")[0]) + archive = make_tar(root, name, folder, remove=remove) + files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()} + return archive, files + + def test_load_file(self, tmp_path): + content = "sentinel" + file = self._make_file(tmp_path, content=content) + + resource = self.DummyResource(file_name=file.name) + + dp = resource.load(tmp_path) + assert isinstance(dp, FileOpener) + + data = list(dp) + assert len(data) == 1 + + path, buffer = data[0] + assert path == str(file) + assert buffer.read().decode() == content + + def test_load_folder(self, tmp_path): + folder, files = self._make_folder(tmp_path) + + resource = self.DummyResource(file_name=folder.name) + + dp = resource.load(tmp_path) + assert isinstance(dp, FileOpener) + assert {path: buffer.read().decode() for path, buffer in dp} == files + + def test_load_archive(self, tmp_path): + archive, files = self._make_tar(tmp_path) + + resource = self.DummyResource(file_name=archive.name) + + dp = resource.load(tmp_path) + assert isinstance(dp, TarArchiveLoader) + assert {path: buffer.read().decode() for path, buffer in dp} == files + + def test_priority_decompressed_gt_raw(self, tmp_path): + # We don't need to actually compress here. Adding the suffix is sufficient + self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz") + file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt") + + resource = self.DummyResource(file_name=file.name) + + dp = resource.load(tmp_path) + path, buffer = next(iter(dp)) + + assert path == str(file) + assert buffer.read().decode() == "decompressed_sentinel" + + def test_priority_extracted_gt_decompressed(self, tmp_path): + archive, _ = self._make_tar(tmp_path, remove=False) + + resource = self.DummyResource(file_name=archive.name) + + dp = resource.load(tmp_path) + # If the archive had been selected, this would be a `TarArchiveReader` + assert isinstance(dp, FileOpener) + + def test_download(self, tmp_path): + download_fn_was_called = False + + def download_fn(resource, root): + nonlocal download_fn_was_called + download_fn_was_called = True + + return self._make_file(root, content="_", name=resource.file_name) + + resource = self.DummyResource( + file_name="file.txt", + download_fn=download_fn, + ) + + resource.load(tmp_path) + + assert download_fn_was_called, "`download_fn()` was never called" + + # This tests the `"decompress"` literal as well as a custom callable + @pytest.mark.parametrize( + "preprocess", + [ + "decompress", + lambda path: _decompress(str(path), remove_finished=True), + ], + ) + def test_preprocess_decompress(self, tmp_path, preprocess): + file_name = "file.txt.gz" + content = "sentinel" + + def download_fn(resource, root): + file = root / resource.file_name + with gzip.open(file, "wb") as fh: + fh.write(content.encode()) + return file + + resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn) + + dp = resource.load(tmp_path) + data = list(dp) + assert len(data) == 1 + + path, buffer = data[0] + assert path == str(tmp_path / file_name).replace(".gz", "") + assert buffer.read().decode() == content + + def test_preprocess_extract(self, tmp_path): + files = None + + def download_fn(resource, root): + nonlocal files + archive, files = self._make_tar(root, name=resource.file_name) + return archive + + resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn) + + dp = resource.load(tmp_path) + assert files is not None, "`download_fn()` was never called" + assert isinstance(dp, FileOpener) + + actual = {path: buffer.read().decode() for path, buffer in dp} + expected = { + path.replace(resource.file_name, resource.file_name.split(".")[0]): content + for path, content in files.items() + } + assert actual == expected + + def test_preprocess_only_after_download(self, tmp_path): + file = self._make_file(tmp_path, content="_") + + def preprocess(path): + raise AssertionError("`preprocess` was called although the file was already present.") + + resource = self.DummyResource( + file_name=file.name, + preprocess=preprocess, + ) + + resource.load(tmp_path) + + class TestHttpResource: def test_resolve_to_http(self, mocker): file_name = "data.tar" diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 507428a98d3..3c9b95cb498 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -2,7 +2,7 @@ import hashlib import itertools import pathlib -from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn +from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set from urllib.parse import urlparse from torchdata.datapipes.iter import ( @@ -32,7 +32,7 @@ def __init__( *, file_name: str, sha256: Optional[str] = None, - preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None, + preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None, ) -> None: self.file_name = file_name self.sha256 = sha256 @@ -50,14 +50,12 @@ def __init__( self._preprocess = preprocess @staticmethod - def _extract(file: pathlib.Path) -> pathlib.Path: - return pathlib.Path( - extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False) - ) + def _extract(file: pathlib.Path) -> None: + extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False) @staticmethod - def _decompress(file: pathlib.Path) -> pathlib.Path: - return pathlib.Path(_decompress(str(file), remove_finished=True)) + def _decompress(file: pathlib.Path) -> None: + _decompress(str(file), remove_finished=True) def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: if path.is_dir(): @@ -91,32 +89,38 @@ def load( ) -> IterDataPipe[Tuple[str, IO]]: root = pathlib.Path(root) path = root / self.file_name + # Instead of the raw file, there might also be files with fewer suffixes after decompression or directories - # with no suffixes at all. + # with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which + # is not sufficient for files with multiple suffixes, e.g. foo.tar.gz. stem = path.name.replace("".join(path.suffixes), "") - # In a first step, we check for a folder with the same stem as the raw file. If it exists, we use it since - # extracted files give the best I/O performance. Note that OnlineResource._extract() makes sure that an archive - # is always extracted in a folder with the corresponding file name. - folder_candidate = path.parent / stem - if folder_candidate.exists() and folder_candidate.is_dir(): - return self._loader(folder_candidate) - - # If there is no folder, we look for all files that share the same stem as the raw file, but might have a - # different suffix. - file_candidates = {file for file in path.parent.glob(stem + ".*")} - # If we don't find anything, we download the raw file. - if not file_candidates: - file_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)} - # If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps. - if file_candidates == {path}: + def find_candidates() -> Set[pathlib.Path]: + # Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder + # candidate simultaneously, that would also pick up other files that share the same prefix. For example, the + # test split of the stanford-cars dataset uses the files + # - cars_test.tgz + # - cars_test_annos_withlabels.mat + # Globbing for `"cars_test*"` picks up both. + candidates = {file for file in path.parent.glob(f"{stem}.*")} + folder_candidate = path.parent / stem + if folder_candidate.exists(): + candidates.add(folder_candidate) + + return candidates + + candidates = find_candidates() + + if not candidates: + self.download(root, skip_integrity_check=skip_integrity_check) if self._preprocess is not None: - path = self._preprocess(path) - # Otherwise, we use the path with the fewest suffixes. This gives us the decompressed > raw priority that we - # want for the best I/O performance. - else: - path = min(file_candidates, key=lambda path: len(path.suffixes)) - return self._loader(path) + self._preprocess(path) + candidates = find_candidates() + + # We use the path with the fewest suffixes. This gives us the + # extracted > decompressed > raw + # priority that we want for the best I/O performance. + return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes))) @abc.abstractmethod def _download(self, root: pathlib.Path) -> None: