Skip to content

Commit 96f2c0d

Browse files
pmeierNicolasHug
andauthored
support confirming no virus scan on GDrive download (#5645)
* support confirming no virus scan on GDrive download * put gen_bar_updater back * Update torchvision/datasets/utils.py Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent b7c59a0 commit 96f2c0d

File tree

1 file changed

+55
-65
lines changed

1 file changed

+55
-65
lines changed

torchvision/datasets/utils.py

Lines changed: 55 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import urllib
1212
import urllib.error
1313
import urllib.request
14+
import warnings
1415
import zipfile
1516
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
1617
from urllib.parse import urlparse
@@ -24,22 +25,31 @@
2425
_is_remote_location_available,
2526
)
2627

27-
2828
USER_AGENT = "pytorch/vision"
2929

3030

31-
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
32-
with open(filename, "wb") as fh:
33-
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
34-
with tqdm(total=response.length) as pbar:
35-
for chunk in iter(lambda: response.read(chunk_size), ""):
36-
if not chunk:
37-
break
38-
pbar.update(chunk_size)
39-
fh.write(chunk)
31+
def _save_response_content(
32+
content: Iterator[bytes],
33+
destination: str,
34+
length: Optional[int] = None,
35+
) -> None:
36+
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
37+
for chunk in content:
38+
# filter out keep-alive new chunks
39+
if not chunk:
40+
continue
41+
42+
fh.write(chunk)
43+
pbar.update(len(chunk))
44+
45+
46+
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
47+
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
48+
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
4049

4150

4251
def gen_bar_updater() -> Callable[[int, int, int], None]:
52+
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
4353
pbar = tqdm(total=None)
4454

4555
def bar_update(count, block_size, total_size):
@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
184194
return files
185195

186196

187-
def _quota_exceeded(first_chunk: bytes) -> bool:
197+
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
198+
content = response.iter_content(chunk_size)
199+
first_chunk = None
200+
# filter out keep-alive new chunks
201+
while not first_chunk:
202+
first_chunk = next(content)
203+
content = itertools.chain([first_chunk], content)
204+
188205
try:
189-
return "Google Drive - Quota exceeded" in first_chunk.decode()
206+
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
207+
api_response = match["api_response"] if match is not None else None
190208
except UnicodeDecodeError:
191-
return False
209+
api_response = None
210+
return api_response, content
192211

193212

194213
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
@@ -202,70 +221,41 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
202221
"""
203222
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
204223

205-
url = "https://docs.google.com/uc?export=download"
206-
207224
root = os.path.expanduser(root)
208225
if not filename:
209226
filename = file_id
210227
fpath = os.path.join(root, filename)
211228

212229
os.makedirs(root, exist_ok=True)
213230

214-
if os.path.isfile(fpath) and check_integrity(fpath, md5):
215-
print("Using downloaded and verified file: " + fpath)
216-
else:
217-
session = requests.Session()
218-
219-
response = session.get(url, params={"id": file_id}, stream=True)
220-
token = _get_confirm_token(response)
221-
222-
if token:
223-
params = {"id": file_id, "confirm": token}
224-
response = session.get(url, params=params, stream=True)
225-
226-
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
227-
# with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
228-
# Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
229-
# the first_chunk of the payload
230-
response_content_generator = response.iter_content(32768)
231-
first_chunk = None
232-
while not first_chunk: # filter out keep-alive new chunks
233-
first_chunk = next(response_content_generator)
234-
235-
if _quota_exceeded(first_chunk):
236-
msg = (
237-
f"The daily quota of the file {filename} is exceeded and it "
238-
f"can't be downloaded. This is a limitation of Google Drive "
239-
f"and can only be overcome by trying again later."
240-
)
241-
raise RuntimeError(msg)
242-
243-
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
244-
response.close()
231+
if check_integrity(fpath, md5):
232+
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
245233

234+
url = "https://drive.google.com/uc"
235+
params = dict(id=file_id, export="download")
236+
with requests.Session() as session:
237+
response = session.get(url, params=params, stream=True)
246238

247-
def _get_confirm_token(response: requests.models.Response) -> Optional[str]:
248-
for key, value in response.cookies.items():
249-
if key.startswith("download_warning"):
250-
return value
239+
for key, value in response.cookies.items():
240+
if key.startswith("download_warning"):
241+
token = value
242+
break
243+
else:
244+
api_response, content = _extract_gdrive_api_response(response)
245+
token = "t" if api_response == "Virus scan warning" else None
251246

252-
return None
247+
if token is not None:
248+
response = session.get(url, params=dict(params, confirm=token), stream=True)
249+
api_response, content = _extract_gdrive_api_response(response)
253250

251+
if api_response == "Quota exceeded":
252+
raise RuntimeError(
253+
f"The daily quota of the file {filename} is exceeded and it "
254+
f"can't be downloaded. This is a limitation of Google Drive "
255+
f"and can only be overcome by trying again later."
256+
)
254257

255-
def _save_response_content(
256-
response_gen: Iterator[bytes],
257-
destination: str,
258-
) -> None:
259-
with open(destination, "wb") as f:
260-
pbar = tqdm(total=None)
261-
progress = 0
262-
263-
for chunk in response_gen:
264-
if chunk: # filter out keep-alive new chunks
265-
f.write(chunk)
266-
progress += len(chunk)
267-
pbar.update(progress - pbar.n)
268-
pbar.close()
258+
_save_response_content(content, fpath)
269259

270260

271261
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:

0 commit comments

Comments
 (0)