11
11
import urllib
12
12
import urllib .error
13
13
import urllib .request
14
+ import warnings
14
15
import zipfile
15
16
from typing import Any , Callable , List , Iterable , Optional , TypeVar , Dict , IO , Tuple , Iterator
16
17
from urllib .parse import urlparse
24
25
_is_remote_location_available ,
25
26
)
26
27
27
-
28
28
USER_AGENT = "pytorch/vision"
29
29
30
30
31
- def _urlretrieve (url : str , filename : str , chunk_size : int = 1024 ) -> None :
32
- with open (filename , "wb" ) as fh :
33
- with urllib .request .urlopen (urllib .request .Request (url , headers = {"User-Agent" : USER_AGENT })) as response :
34
- with tqdm (total = response .length ) as pbar :
35
- for chunk in iter (lambda : response .read (chunk_size ), "" ):
36
- if not chunk :
37
- break
38
- pbar .update (chunk_size )
39
- fh .write (chunk )
31
+ def _save_response_content (
32
+ content : Iterator [bytes ],
33
+ destination : str ,
34
+ length : Optional [int ] = None ,
35
+ ) -> None :
36
+ with open (destination , "wb" ) as fh , tqdm (total = length ) as pbar :
37
+ for chunk in content :
38
+ # filter out keep-alive new chunks
39
+ if not chunk :
40
+ continue
41
+
42
+ fh .write (chunk )
43
+ pbar .update (len (chunk ))
44
+
45
+
46
+ def _urlretrieve (url : str , filename : str , chunk_size : int = 1024 * 32 ) -> None :
47
+ with urllib .request .urlopen (urllib .request .Request (url , headers = {"User-Agent" : USER_AGENT })) as response :
48
+ _save_response_content (iter (lambda : response .read (chunk_size ), b"" ), filename , length = response .length )
40
49
41
50
42
51
def gen_bar_updater () -> Callable [[int , int , int ], None ]:
52
+ warnings .warn ("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15." )
43
53
pbar = tqdm (total = None )
44
54
45
55
def bar_update (count , block_size , total_size ):
@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
184
194
return files
185
195
186
196
187
- def _quota_exceeded (first_chunk : bytes ) -> bool :
197
+ def _extract_gdrive_api_response (response , chunk_size : int = 32 * 1024 ) -> Tuple [bytes , Iterator [bytes ]]:
198
+ content = response .iter_content (chunk_size )
199
+ first_chunk = None
200
+ # filter out keep-alive new chunks
201
+ while not first_chunk :
202
+ first_chunk = next (content )
203
+ content = itertools .chain ([first_chunk ], content )
204
+
188
205
try :
189
- return "Google Drive - Quota exceeded" in first_chunk .decode ()
206
+ match = re .search ("<title>Google Drive - (?P<api_response>.+?)</title>" , first_chunk .decode ())
207
+ api_response = match ["api_response" ] if match is not None else None
190
208
except UnicodeDecodeError :
191
- return False
209
+ api_response = None
210
+ return api_response , content
192
211
193
212
194
213
def download_file_from_google_drive (file_id : str , root : str , filename : Optional [str ] = None , md5 : Optional [str ] = None ):
@@ -202,70 +221,41 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
202
221
"""
203
222
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
204
223
205
- url = "https://docs.google.com/uc?export=download"
206
-
207
224
root = os .path .expanduser (root )
208
225
if not filename :
209
226
filename = file_id
210
227
fpath = os .path .join (root , filename )
211
228
212
229
os .makedirs (root , exist_ok = True )
213
230
214
- if os .path .isfile (fpath ) and check_integrity (fpath , md5 ):
215
- print ("Using downloaded and verified file: " + fpath )
216
- else :
217
- session = requests .Session ()
218
-
219
- response = session .get (url , params = {"id" : file_id }, stream = True )
220
- token = _get_confirm_token (response )
221
-
222
- if token :
223
- params = {"id" : file_id , "confirm" : token }
224
- response = session .get (url , params = params , stream = True )
225
-
226
- # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
227
- # with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
228
- # Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
229
- # the first_chunk of the payload
230
- response_content_generator = response .iter_content (32768 )
231
- first_chunk = None
232
- while not first_chunk : # filter out keep-alive new chunks
233
- first_chunk = next (response_content_generator )
234
-
235
- if _quota_exceeded (first_chunk ):
236
- msg = (
237
- f"The daily quota of the file { filename } is exceeded and it "
238
- f"can't be downloaded. This is a limitation of Google Drive "
239
- f"and can only be overcome by trying again later."
240
- )
241
- raise RuntimeError (msg )
242
-
243
- _save_response_content (itertools .chain ((first_chunk ,), response_content_generator ), fpath )
244
- response .close ()
231
+ if check_integrity (fpath , md5 ):
232
+ print (f"Using downloaded { 'and verified ' if md5 else '' } file: { fpath } " )
245
233
234
+ url = "https://drive.google.com/uc"
235
+ params = dict (id = file_id , export = "download" )
236
+ with requests .Session () as session :
237
+ response = session .get (url , params = params , stream = True )
246
238
247
- def _get_confirm_token (response : requests .models .Response ) -> Optional [str ]:
248
- for key , value in response .cookies .items ():
249
- if key .startswith ("download_warning" ):
250
- return value
239
+ for key , value in response .cookies .items ():
240
+ if key .startswith ("download_warning" ):
241
+ token = value
242
+ break
243
+ else :
244
+ api_response , content = _extract_gdrive_api_response (response )
245
+ token = "t" if api_response == "Virus scan warning" else None
251
246
252
- return None
247
+ if token is not None :
248
+ response = session .get (url , params = dict (params , confirm = token ), stream = True )
249
+ api_response , content = _extract_gdrive_api_response (response )
253
250
251
+ if api_response == "Quota exceeded" :
252
+ raise RuntimeError (
253
+ f"The daily quota of the file { filename } is exceeded and it "
254
+ f"can't be downloaded. This is a limitation of Google Drive "
255
+ f"and can only be overcome by trying again later."
256
+ )
254
257
255
- def _save_response_content (
256
- response_gen : Iterator [bytes ],
257
- destination : str ,
258
- ) -> None :
259
- with open (destination , "wb" ) as f :
260
- pbar = tqdm (total = None )
261
- progress = 0
262
-
263
- for chunk in response_gen :
264
- if chunk : # filter out keep-alive new chunks
265
- f .write (chunk )
266
- progress += len (chunk )
267
- pbar .update (progress - pbar .n )
268
- pbar .close ()
258
+ _save_response_content (content , fpath )
269
259
270
260
271
261
def _extract_tar (from_path : str , to_path : str , compression : Optional [str ]) -> None :
0 commit comments