diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 4acc1d53b4d..1c94e522d65 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -1,22 +1,26 @@ from .caltech import Caltech101, Caltech256 -from .celeba import CelebA + +# from .celeba import CelebA from .cifar import Cifar10, Cifar100 -from .clevr import CLEVR -from .coco import Coco -from .country211 import Country211 -from .cub200 import CUB200 -from .dtd import DTD -from .eurosat import EuroSAT + +# from .clevr import CLEVR +# from .coco import Coco +# from .country211 import Country211 +# from .cub200 import CUB200 +# from .dtd import DTD +# from .eurosat import EuroSAT from .fer2013 import FER2013 -from .food101 import Food101 -from .gtsrb import GTSRB + +# from .food101 import Food101 +# from .gtsrb import GTSRB from .imagenet import ImageNet -from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST -from .oxford_iiit_pet import OxfordIIITPet -from .pcam import PCAM -from .sbd import SBD -from .semeion import SEMEION -from .stanford_cars import StanfordCars -from .svhn import SVHN -from .usps import USPS -from .voc import VOC + +# from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST +# from .oxford_iiit_pet import OxfordIIITPet +# from .pcam import PCAM +# from .sbd import SBD +# from .semeion import SEMEION +# from .stanford_cars import StanfordCars +# from .svhn import SVHN +# from .usps import USPS +# from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 3a9a8f08d41..cebe441f55b 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -9,7 +9,7 @@ Filter, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -49,13 +49,13 @@ def __init__( ) def _resources(self) -> List[OnlineResource]: - images = GDriveResource( + images = OnlineResource.from_gdrive( "137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp", file_name="101_ObjectCategories.tar.gz", sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", preprocess="decompress", ) - anns = GDriveResource( + anns = OnlineResource.from_gdrive( "175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m", file_name="Annotations.tar", sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8", @@ -173,7 +173,7 @@ def __init__( def _resources(self) -> List[OnlineResource]: return [ - GDriveResource( + OnlineResource.from_gdrive( "1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK", file_name="256_ObjectCategories.tar", sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e", diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 514938d6e5f..4f68a30ee00 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -10,7 +10,7 @@ Filter, Mapper, ) -from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, path_comparator, @@ -58,7 +58,7 @@ def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]: def _resources(self) -> List[OnlineResource]: return [ - HttpResource( + OnlineResource.from_http( f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}", sha256=self._SHA256, ) diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index c1a914c6f63..4e84425a0fb 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -3,11 +3,7 @@ import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser -from torchvision.prototype.datasets.utils import ( - Dataset, - OnlineResource, - KaggleDownloadResource, -) +from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, @@ -43,8 +39,8 @@ def __init__( "test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3", } - def _resources(self) -> List[OnlineResource]: - archive = KaggleDownloadResource( + def _resources(self) -> List[ManualDownloadResource]: + archive = ManualDownloadResource.from_kaggle( "https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", file_name=f"{self._split}.csv.zip", sha256=self._CHECKSUMS[self._split], diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 1307757cef6..c81a698e341 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -14,11 +14,7 @@ TarArchiveLoader, Enumerator, ) -from torchvision.prototype.datasets.utils import ( - OnlineResource, - ManualDownloadResource, - Dataset, -) +from torchvision.prototype.datasets.utils import ManualDownloadResource, Dataset from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, getitem, @@ -41,11 +37,6 @@ def _info() -> Dict[str, Any]: return dict(categories=categories, wnids=wnids) -class ImageNetResource(ManualDownloadResource): - def __init__(self, **kwargs: Any) -> None: - super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) - - class ImageNetDemux(enum.IntEnum): META = 0 LABEL = 1 @@ -80,16 +71,22 @@ def __init__( "test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", } - def _resources(self) -> List[OnlineResource]: - name = "test_v10102019" if self._split == "test" else self._split - images = ImageNetResource( - file_name=f"ILSVRC2012_img_{name}.tar", - sha256=self._IMAGES_CHECKSUMS[name], + def _imagenet_resource(self, *, file_name: str, sha256: str) -> ManualDownloadResource: + return ManualDownloadResource( + "https://image-net.org/", + instructions="Register on https://image-net.org/ and follow the instructions there.", + file_name=file_name, + sha256=sha256, ) - resources: List[OnlineResource] = [images] + + def _resources(self) -> List[ManualDownloadResource]: + name = "test_v10102019" if self._split == "test" else self._split + images = self._imagenet_resource(file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name]) + + resources = [images] if self._split == "val": - devkit = ImageNetResource( + devkit = self._imagenet_resource( file_name="ILSVRC2012_devkit_t12.tar.gz", sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", ) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 94c5907b47d..3f83dfc0129 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,3 +1,3 @@ from . import _internal # usort: skip from ._dataset import Dataset -from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource +from ._resource import OnlineResource, ManualDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 528d0a0f25f..853cf9fbf94 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -42,7 +42,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: yield from self._dp @abc.abstractmethod - def _resources(self) -> List[OnlineResource]: + def _resources(self) -> Sequence[OnlineResource]: pass @abc.abstractmethod diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 3c9b95cb498..37a0e64ee20 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import abc -import hashlib -import itertools +import functools import pathlib -from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set +from typing import Optional, Tuple, Callable, BinaryIO, Any, Union, NoReturn, Set +from typing import TypeVar, Iterator from urllib.parse import urlparse +from torch.hub import tqdm from torchdata.datapipes.iter import ( IterableWrapper, FileLister, @@ -13,27 +16,42 @@ ZipArchiveLoader, TarArchiveLoader, RarArchiveLoader, + OnlineReader, + HashChecker, + StreamReader, + Saver, + Forker, + Zipper, + Mapper, ) -from torchvision.datasets.utils import ( - download_url, - _detect_file_type, - extract_archive, - _decompress, - download_file_from_google_drive, - _get_redirect_url, - _get_google_drive_file_id, -) +from torchvision.datasets.utils import _detect_file_type, extract_archive, _decompress from typing_extensions import Literal +D = TypeVar("D") + + +class ProgressBar(IterDataPipe[D]): + def __init__(self, datapipe: IterDataPipe[D]) -> None: + self.datapipe = datapipe + + def __iter__(self) -> Iterator[D]: + with tqdm() as progress_bar: + for data in self.datapipe: + _, chunk = data + progress_bar.update(len(chunk)) + yield data + class OnlineResource(abc.ABC): def __init__( self, + url: str, *, file_name: str, sha256: Optional[str] = None, preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None, ) -> None: + self.url = url self.file_name = file_name self.sha256 = sha256 @@ -57,7 +75,58 @@ def _extract(file: pathlib.Path) -> None: def _decompress(file: pathlib.Path) -> None: _decompress(str(file), remove_finished=True) - def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: + @classmethod + def from_http(cls, url: str, *, file_name: Optional[str] = None, **kwargs: Any) -> OnlineResource: + return cls(url, file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs) + + @classmethod + def from_gdrive(cls, id: str, **kwargs: Any) -> OnlineResource: + return cls(f"https://drive.google.com/uc?export=download&id={id}", **kwargs) + + def _filepath_fn(self, root: pathlib.Path, file_name: str) -> str: + return str(root / file_name) + + def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path: + root = pathlib.Path(root).expanduser().resolve() + root.mkdir(parents=True, exist_ok=True) + + filepath_fn = functools.partial(self._filepath_fn, root) + file = pathlib.Path(filepath_fn(self.file_name)) + + if file.exists(): + return file + + dp = IterableWrapper([self.url]) + dp = OnlineReader(dp) + # FIXME: this currently only works for GDrive + # See https://github.com/pytorch/data/issues/451 for details + dp = Mapper(dp, filepath_fn, input_col=0) + dp = StreamReader(dp, chunk=32 * 1024 * 1024) + dp: IterDataPipe[Tuple[str, bytes]] = ProgressBar(dp) + + check_hash = self.sha256 and not skip_integrity_check + if check_hash: + # We can get away with a buffer_size of 1 since both datapipes are iterated at the same time. See the + # comment in the check_hash branch below for details. + dp, hash_checker_fork = Forker(dp, 2, buffer_size=1) + # FIXME: HashChecker does not work with chunks + # See https://github.com/pytorch/data/issues/452 for details + hash_checker_fork = HashChecker(hash_checker_fork, {str(file): self.sha256}, hash_type="sha256") + + dp = Saver(dp, mode="wb") + + if check_hash: + # This makes sure that both forks are iterated at the same time for two reasons: + # 1. Forker caches the items. Iterating separately would mean we load the full data into memory. + # 2. The first iteration would trigger the progress bar. Thus, if we for example at first only perform the + # hash check, the progress bar is finished and the whole storing on disk part is not captured. + dp = Zipper(dp, hash_checker_fork) + + list(dp) + + return file + + def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, BinaryIO]]: if path.is_dir(): return FileOpener(FileLister(str(path), recursive=True), mode="rb") @@ -77,7 +146,7 @@ def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: def _guess_archive_loader( self, path: pathlib.Path - ) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]: + ) -> Optional[Callable[[IterDataPipe[Tuple[str, BinaryIO]]], IterDataPipe[Tuple[str, BinaryIO]]]]: try: _, archive_type, _ = _detect_file_type(path.name) except RuntimeError: @@ -86,7 +155,7 @@ def _guess_archive_loader( def load( self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False - ) -> IterDataPipe[Tuple[str, IO]]: + ) -> IterDataPipe[Tuple[str, BinaryIO]]: root = pathlib.Path(root) path = root / self.file_name @@ -122,108 +191,22 @@ def find_candidates() -> Set[pathlib.Path]: # priority that we want for the best I/O performance. return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes))) - @abc.abstractmethod - def _download(self, root: pathlib.Path) -> None: - pass - - def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path: - root = pathlib.Path(root) - self._download(root) - path = root / self.file_name - if self.sha256 and not skip_integrity_check: - self._check_sha256(path) - return path - - def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None: - hash = hashlib.sha256() - with open(path, "rb") as file: - for chunk in iter(lambda: file.read(chunk_size), b""): - hash.update(chunk) - sha256 = hash.hexdigest() - if sha256 != self.sha256: - raise RuntimeError( - f"After the download, the SHA256 checksum of {path} didn't match the expected one: " - f"{sha256} != {self.sha256}" - ) - - -class HttpResource(OnlineResource): - def __init__( - self, url: str, *, file_name: Optional[str] = None, mirrors: Sequence[str] = (), **kwargs: Any - ) -> None: - super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs) - self.url = url - self.mirrors = mirrors - self._resolved = False - - def resolve(self) -> OnlineResource: - if self._resolved: - return self - - redirect_url = _get_redirect_url(self.url) - if redirect_url == self.url: - self._resolved = True - return self - - meta = { - attr.lstrip("_"): getattr(self, attr) - for attr in ( - "file_name", - "sha256", - "_preprocess", - ) - } - - gdrive_id = _get_google_drive_file_id(redirect_url) - if gdrive_id: - return GDriveResource(gdrive_id, **meta) - - http_resource = HttpResource(redirect_url, **meta) - http_resource._resolved = True - return http_resource - - def _download(self, root: pathlib.Path) -> None: - if not self._resolved: - return self.resolve()._download(root) - - for url in itertools.chain((self.url,), self.mirrors): - - try: - download_url(url, str(root), filename=self.file_name, md5=None) - # TODO: make this more precise - except Exception: - continue - - return - else: - # TODO: make this more informative - raise RuntimeError("Download failed!") - - -class GDriveResource(OnlineResource): - def __init__(self, id: str, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.id = id - - def _download(self, root: pathlib.Path) -> None: - download_file_from_google_drive(self.id, root=str(root), filename=self.file_name, md5=None) - class ManualDownloadResource(OnlineResource): - def __init__(self, instructions: str, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.instructions = instructions + def __init__(self, url: str, *, instructions: str, **kwargs: Any) -> None: + super().__init__(url, **kwargs) + self._instructions = instructions - def _download(self, root: pathlib.Path) -> NoReturn: + def download(self, root: Union[str, pathlib.Path], **_: Any) -> NoReturn: + root = pathlib.Path(root) raise RuntimeError( f"The file {self.file_name} cannot be downloaded automatically. " f"Please follow the instructions below and place it in {root}\n\n" - f"{self.instructions}" + f"{self._instructions}" ) - -class KaggleDownloadResource(ManualDownloadResource): - def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None: + @classmethod + def from_kaggle(cls, challenge_url: str, *, file_name: str, **kwargs: Any) -> ManualDownloadResource: instructions = "\n".join( ( "1. Register and login at https://www.kaggle.com", @@ -233,4 +216,4 @@ def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None f"5. Select {file_name} in the 'Data Explorer' and click the download button", ) ) - super().__init__(instructions, file_name=file_name, **kwargs) + return cls(challenge_url, instructions=instructions, file_name=file_name, **kwargs)