|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import abc
|
| 4 | +import functools |
4 | 5 | import pathlib
|
5 | 6 | from typing import Optional, Tuple, Callable, BinaryIO, Any, Union, NoReturn, Set
|
| 7 | +from typing import TypeVar, Iterator |
6 | 8 | from urllib.parse import urlparse
|
7 | 9 |
|
8 | 10 | from torch.hub import tqdm
|
|
16 | 18 | RarArchiveLoader,
|
17 | 19 | OnlineReader,
|
18 | 20 | HashChecker,
|
| 21 | + StreamReader, |
| 22 | + Saver, |
| 23 | + Forker, |
| 24 | + Zipper, |
| 25 | + Mapper, |
19 | 26 | )
|
20 | 27 | from torchvision.datasets.utils import _detect_file_type, extract_archive, _decompress
|
21 | 28 | from typing_extensions import Literal
|
22 | 29 |
|
| 30 | +D = TypeVar("D") |
| 31 | + |
| 32 | + |
| 33 | +class ProgressBar(IterDataPipe[D]): |
| 34 | + def __init__(self, datapipe: IterDataPipe[D]) -> None: |
| 35 | + self.datapipe = datapipe |
| 36 | + |
| 37 | + def __iter__(self) -> Iterator[D]: |
| 38 | + with tqdm() as progress_bar: |
| 39 | + for data in self.datapipe: |
| 40 | + _, chunk = data |
| 41 | + progress_bar.update(len(chunk)) |
| 42 | + yield data |
| 43 | + |
23 | 44 |
|
24 | 45 | class OnlineResource(abc.ABC):
|
25 | 46 | def __init__(
|
@@ -62,30 +83,46 @@ def from_http(cls, url: str, *, file_name: Optional[str] = None, **kwargs: Any)
|
62 | 83 | def from_gdrive(cls, id: str, **kwargs: Any) -> OnlineResource:
|
63 | 84 | return cls(f"https://drive.google.com/uc?export=download&id={id}", **kwargs)
|
64 | 85 |
|
| 86 | + def _filepath_fn(self, root: pathlib.Path, file_name: str) -> str: |
| 87 | + return str(root / file_name) |
| 88 | + |
65 | 89 | def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path:
|
66 | 90 | root = pathlib.Path(root).expanduser().resolve()
|
67 | 91 | root.mkdir(parents=True, exist_ok=True)
|
68 |
| - file = root / self.file_name |
69 |
| - |
70 |
| - if not file.exists(): |
71 |
| - dp = IterableWrapper([self.url]) |
72 |
| - dp = OnlineReader(dp) |
73 |
| - stream = list(dp)[0][1] |
74 |
| - |
75 |
| - with open(file, "wb") as fh, tqdm() as progress_bar: |
76 |
| - for chunk in iter(lambda: stream.read(1024 * 1024), b""): # type: ignore[no-any-return] |
77 |
| - # filter out keep-alive new chunks |
78 |
| - if not chunk: |
79 |
| - continue |
80 |
| - |
81 |
| - fh.write(chunk) |
82 |
| - progress_bar.update(len(chunk)) |
83 |
| - |
84 |
| - if self.sha256 and not skip_integrity_check: |
85 |
| - dp = IterableWrapper([str(file)]) |
86 |
| - dp = FileOpener(dp, mode="rb") |
87 |
| - dp = HashChecker(dp, {str(file): self.sha256}, hash_type="sha256") |
88 |
| - list(dp) |
| 92 | + |
| 93 | + filepath_fn = functools.partial(self._filepath_fn, root) |
| 94 | + file = pathlib.Path(filepath_fn(self.file_name)) |
| 95 | + |
| 96 | + if file.exists(): |
| 97 | + return file |
| 98 | + |
| 99 | + dp = IterableWrapper([self.url]) |
| 100 | + dp = OnlineReader(dp) |
| 101 | + # FIXME: this currently only works for GDrive |
| 102 | + # See https://github.com/pytorch/data/issues/451 for details |
| 103 | + dp = Mapper(dp, filepath_fn, input_col=0) |
| 104 | + dp = StreamReader(dp, chunk=32 * 1024 * 1024) |
| 105 | + dp: IterDataPipe[Tuple[str, bytes]] = ProgressBar(dp) |
| 106 | + |
| 107 | + check_hash = self.sha256 and not skip_integrity_check |
| 108 | + if check_hash: |
| 109 | + # We can get away with a buffer_size of 1 since both datapipes are iterated at the same time. See the |
| 110 | + # comment in the check_hash branch below for details. |
| 111 | + dp, hash_checker_fork = Forker(dp, 2, buffer_size=1) |
| 112 | + # FIXME: HashChecker does not work with chunks |
| 113 | + # See https://github.com/pytorch/data/issues/452 for details |
| 114 | + hash_checker_fork = HashChecker(hash_checker_fork, {str(file): self.sha256}, hash_type="sha256") |
| 115 | + |
| 116 | + dp = Saver(dp, mode="wb") |
| 117 | + |
| 118 | + if check_hash: |
| 119 | + # This makes sure that both forks are iterated at the same time for two reasons: |
| 120 | + # 1. Forker caches the items. Iterating separately would mean we load the full data into memory. |
| 121 | + # 2. The first iteration would trigger the progress bar. Thus, if we for example at first only perform the |
| 122 | + # hash check, the progress bar is finished and the whole storing on disk part is not captured. |
| 123 | + dp = Zipper(dp, hash_checker_fork) |
| 124 | + |
| 125 | + list(dp) |
89 | 126 |
|
90 | 127 | return file
|
91 | 128 |
|
|
0 commit comments