Skip to content

Commit 2d03f02

Browse files
committed
use StreamWrapper
1 parent db52e8e commit 2d03f02

File tree

1 file changed

+58
-21
lines changed

1 file changed

+58
-21
lines changed

torchvision/prototype/datasets/utils/_resource.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import abc
4+
import functools
45
import pathlib
56
from typing import Optional, Tuple, Callable, BinaryIO, Any, Union, NoReturn, Set
7+
from typing import TypeVar, Iterator
68
from urllib.parse import urlparse
79

810
from torch.hub import tqdm
@@ -16,10 +18,29 @@
1618
RarArchiveLoader,
1719
OnlineReader,
1820
HashChecker,
21+
StreamReader,
22+
Saver,
23+
Forker,
24+
Zipper,
25+
Mapper,
1926
)
2027
from torchvision.datasets.utils import _detect_file_type, extract_archive, _decompress
2128
from typing_extensions import Literal
2229

30+
D = TypeVar("D")
31+
32+
33+
class ProgressBar(IterDataPipe[D]):
34+
def __init__(self, datapipe: IterDataPipe[D]) -> None:
35+
self.datapipe = datapipe
36+
37+
def __iter__(self) -> Iterator[D]:
38+
with tqdm() as progress_bar:
39+
for data in self.datapipe:
40+
_, chunk = data
41+
progress_bar.update(len(chunk))
42+
yield data
43+
2344

2445
class OnlineResource(abc.ABC):
2546
def __init__(
@@ -62,30 +83,46 @@ def from_http(cls, url: str, *, file_name: Optional[str] = None, **kwargs: Any)
6283
def from_gdrive(cls, id: str, **kwargs: Any) -> OnlineResource:
6384
return cls(f"https://drive.google.com/uc?export=download&id={id}", **kwargs)
6485

86+
def _filepath_fn(self, root: pathlib.Path, file_name: str) -> str:
87+
return str(root / file_name)
88+
6589
def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path:
6690
root = pathlib.Path(root).expanduser().resolve()
6791
root.mkdir(parents=True, exist_ok=True)
68-
file = root / self.file_name
69-
70-
if not file.exists():
71-
dp = IterableWrapper([self.url])
72-
dp = OnlineReader(dp)
73-
stream = list(dp)[0][1]
74-
75-
with open(file, "wb") as fh, tqdm() as progress_bar:
76-
for chunk in iter(lambda: stream.read(1024 * 1024), b""): # type: ignore[no-any-return]
77-
# filter out keep-alive new chunks
78-
if not chunk:
79-
continue
80-
81-
fh.write(chunk)
82-
progress_bar.update(len(chunk))
83-
84-
if self.sha256 and not skip_integrity_check:
85-
dp = IterableWrapper([str(file)])
86-
dp = FileOpener(dp, mode="rb")
87-
dp = HashChecker(dp, {str(file): self.sha256}, hash_type="sha256")
88-
list(dp)
92+
93+
filepath_fn = functools.partial(self._filepath_fn, root)
94+
file = pathlib.Path(filepath_fn(self.file_name))
95+
96+
if file.exists():
97+
return file
98+
99+
dp = IterableWrapper([self.url])
100+
dp = OnlineReader(dp)
101+
# FIXME: this currently only works for GDrive
102+
# See https://github.com/pytorch/data/issues/451 for details
103+
dp = Mapper(dp, filepath_fn, input_col=0)
104+
dp = StreamReader(dp, chunk=32 * 1024 * 1024)
105+
dp: IterDataPipe[Tuple[str, bytes]] = ProgressBar(dp)
106+
107+
check_hash = self.sha256 and not skip_integrity_check
108+
if check_hash:
109+
# We can get away with a buffer_size of 1 since both datapipes are iterated at the same time. See the
110+
# comment in the check_hash branch below for details.
111+
dp, hash_checker_fork = Forker(dp, 2, buffer_size=1)
112+
# FIXME: HashChecker does not work with chunks
113+
# See https://github.com/pytorch/data/issues/452 for details
114+
hash_checker_fork = HashChecker(hash_checker_fork, {str(file): self.sha256}, hash_type="sha256")
115+
116+
dp = Saver(dp, mode="wb")
117+
118+
if check_hash:
119+
# This makes sure that both forks are iterated at the same time for two reasons:
120+
# 1. Forker caches the items. Iterating separately would mean we load the full data into memory.
121+
# 2. The first iteration would trigger the progress bar. Thus, if we for example at first only perform the
122+
# hash check, the progress bar is finished and the whole storing on disk part is not captured.
123+
dp = Zipper(dp, hash_checker_fork)
124+
125+
list(dp)
89126

90127
return file
91128

0 commit comments

Comments
 (0)