From 4a999d13c6f8a4400a9b4a81e990589983c21e40 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 23 Mar 2022 17:53:15 +0100 Subject: [PATCH 1/3] fix HttpResource.resolve() with preprocess set --- test/test_prototype_datasets_utils.py | 56 +++++++++++++++++++ .../prototype/datasets/_builtin/caltech.py | 2 +- .../prototype/datasets/_builtin/cub200.py | 10 ++-- .../prototype/datasets/_builtin/dtd.py | 2 +- .../datasets/_builtin/oxford_iiit_pet.py | 4 +- .../prototype/datasets/_builtin/pcam.py | 2 +- .../prototype/datasets/utils/_resource.py | 17 +++--- 7 files changed, 73 insertions(+), 20 deletions(-) diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index 7207299c9d4..bd857abf02f 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -5,6 +5,7 @@ import torch from datasets_utils import make_fake_flo_file from torchvision.datasets._optical_flow import _read_flo as read_flo_ref +from torchvision.prototype.datasets.utils import HttpResource, GDriveResource from torchvision.prototype.datasets.utils._internal import read_flo, fromfile @@ -45,3 +46,58 @@ def test_read_flo(tmpdir): expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False)) torch.testing.assert_close(actual, expected) + + +class TestHttpResource: + def test_resolve_to_http(self, mocker): + file_name = "data.tar" + original_url = f"http://downloads.pytorch.org/{file_name}" + + redirected_url = original_url.replace("http", "https") + + sha256_sentinel = "sha256_sentinel" + + def preprocess_sentinel(path): + return path + + original_resource = HttpResource( + original_url, + sha256=sha256_sentinel, + preprocess=preprocess_sentinel, + ) + + mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url) + redirected_resource = original_resource.resolve() + + assert isinstance(redirected_resource, HttpResource) + assert redirected_resource.url == redirected_url + assert redirected_resource.file_name == file_name + assert redirected_resource.sha256 == sha256_sentinel + assert redirected_resource._preprocess is preprocess_sentinel + + def test_resolve_to_gdrive(self, mocker): + file_name = "data.tar" + original_url = f"http://downloads.pytorch.org/{file_name}" + + id_sentinel = "id-sentinel" + redirected_url = f"https://drive.google.com/file/d/{id_sentinel}/view" + + sha256_sentinel = "sha256_sentinel" + + def preprocess_sentinel(path): + return path + + original_resource = HttpResource( + original_url, + sha256=sha256_sentinel, + preprocess=preprocess_sentinel, + ) + + mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url) + redirected_resource = original_resource.resolve() + + assert isinstance(redirected_resource, GDriveResource) + assert redirected_resource.id == id_sentinel + assert redirected_resource.file_name == file_name + assert redirected_resource.sha256 == sha256_sentinel + assert redirected_resource._preprocess is preprocess_sentinel diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 1a052860ebf..331ae4d6f11 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -32,7 +32,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: images = HttpResource( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", - decompress=True, + preprocess="decompress", ) anns = HttpResource( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index ae34b48d191..9a5a8274cbf 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -51,29 +51,29 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: archive = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz", sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081", - decompress=True, + preprocess="decompress", ) segmentations = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/segmentations.tgz", sha256="dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f", - decompress=True, + preprocess="decompress", ) return [archive, segmentations] else: # config.year == "2010" split = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz", sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428", - decompress=True, + preprocess="decompress", ) images = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz", sha256="2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e", - decompress=True, + preprocess="decompress", ) anns = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200/annotations.tgz", sha256="c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1", - decompress=True, + preprocess="decompress", ) return [split, images, anns] diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index 171861454d4..aba2359884a 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -49,7 +49,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: archive = HttpResource( "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", - decompress=True, + preprocess="decompress", ) return [archive] diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 1780b8829f4..237c22a7ced 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -40,12 +40,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: images = HttpResource( "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d", - decompress=True, + preprocess="decompress", ) anns = HttpResource( "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", sha256="52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91", - decompress=True, + preprocess="decompress", ) return [images, anns] diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 988ff8d4138..ab95391b466 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -91,7 +91,7 @@ def _make_info(self) -> DatasetInfo: def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [ # = [images resource, targets resource] - GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, decompress=True) + GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress") for file_name, gdrive_id, sha256 in self._RESOURCES[config.split] ] diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index b2ae175c551..4e2214053f6 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -23,6 +23,7 @@ _get_redirect_url, _get_google_drive_file_id, ) +from typing_extensions import Literal class OnlineResource(abc.ABC): @@ -31,19 +32,16 @@ def __init__( *, file_name: str, sha256: Optional[str] = None, - decompress: bool = False, - extract: bool = False, + preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None, ) -> None: self.file_name = file_name self.sha256 = sha256 - self._preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] - if extract: - self._preprocess = self._extract - elif decompress: - self._preprocess = self._decompress - else: - self._preprocess = None + if preprocess == "decompress": + preprocess = self._decompress + elif preprocess == "extract": + preprocess = self._extract + self._preprocess = preprocess @staticmethod def _extract(file: pathlib.Path) -> pathlib.Path: @@ -163,7 +161,6 @@ def resolve(self) -> OnlineResource: "file_name", "sha256", "_preprocess", - "_loader", ) } From 47dae97158913351ef3594c911cfd2d47e9e4dcf Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 23 Mar 2022 18:10:31 +0100 Subject: [PATCH 2/3] fix README --- torchvision/prototype/datasets/_builtin/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/README.md b/torchvision/prototype/datasets/_builtin/README.md index b0280071027..83ba6eac149 100644 --- a/torchvision/prototype/datasets/_builtin/README.md +++ b/torchvision/prototype/datasets/_builtin/README.md @@ -231,7 +231,7 @@ To generate the `$NAME.categories` file, run `python -m torchvision.prototype.da ### What if a resource file forms an I/O bottleneck? In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if -the performance hit becomes significant, the archives can still be decompressed or extracted. To do this, the -`decompress: bool` and `extract: bool` flags can be used for every `OnlineResource` individually. For more complex -cases, each resource also accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw file and should -return `pathlib.Path` of the preprocessed file or folder. +the performance hit becomes significant, the archives can still be preprocessed. `OnlineResource` accepts the +`preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be +preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also +accepts `"decompress"` and `"extract"` to handle these common scenarios. From 6e4f4ae4b117b67ea854ba27b751f16f76997556 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 23 Mar 2022 19:24:57 +0100 Subject: [PATCH 3/3] add safe guard for invalid str inputs --- torchvision/prototype/datasets/utils/_resource.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 4e2214053f6..e88a6966390 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -37,10 +37,16 @@ def __init__( self.file_name = file_name self.sha256 = sha256 - if preprocess == "decompress": - preprocess = self._decompress - elif preprocess == "extract": - preprocess = self._extract + if isinstance(preprocess, str): + if preprocess == "decompress": + preprocess = self._decompress + elif preprocess == "extract": + preprocess = self._extract + else: + raise ValueError( + f"Only `'decompress'` or `'extract'` are valid if `preprocess` is passed as string," + f"but got {preprocess} instead." + ) self._preprocess = preprocess @staticmethod