30
30
31
31
def _save_response_content (
32
32
content : Iterator [bytes ],
33
- destination : str ,
33
+ destination : Union [ str , pathlib . Path ] ,
34
34
length : Optional [int ] = None ,
35
35
) -> None :
36
36
with open (destination , "wb" ) as fh , tqdm (total = length ) as pbar :
@@ -43,12 +43,12 @@ def _save_response_content(
43
43
pbar .update (len (chunk ))
44
44
45
45
46
- def _urlretrieve (url : str , filename : str , chunk_size : int = 1024 * 32 ) -> None :
46
+ def _urlretrieve (url : str , filename : Union [ str , pathlib . Path ] , chunk_size : int = 1024 * 32 ) -> None :
47
47
with urllib .request .urlopen (urllib .request .Request (url , headers = {"User-Agent" : USER_AGENT })) as response :
48
48
_save_response_content (iter (lambda : response .read (chunk_size ), b"" ), filename , length = response .length )
49
49
50
50
51
- def calculate_md5 (fpath : str , chunk_size : int = 1024 * 1024 ) -> str :
51
+ def calculate_md5 (fpath : Union [ str , pathlib . Path ] , chunk_size : int = 1024 * 1024 ) -> str :
52
52
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
53
53
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
54
54
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
@@ -62,11 +62,11 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
62
62
return md5 .hexdigest ()
63
63
64
64
65
- def check_md5 (fpath : str , md5 : str , ** kwargs : Any ) -> bool :
65
+ def check_md5 (fpath : Union [ str , pathlib . Path ] , md5 : str , ** kwargs : Any ) -> bool :
66
66
return md5 == calculate_md5 (fpath , ** kwargs )
67
67
68
68
69
- def check_integrity (fpath : str , md5 : Optional [str ] = None ) -> bool :
69
+ def check_integrity (fpath : Union [ str , pathlib . Path ] , md5 : Optional [str ] = None ) -> bool :
70
70
if not os .path .isfile (fpath ):
71
71
return False
72
72
if md5 is None :
@@ -106,7 +106,7 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
106
106
def download_url (
107
107
url : str ,
108
108
root : Union [str , pathlib .Path ],
109
- filename : Optional [str ] = None ,
109
+ filename : Optional [Union [ str , pathlib . Path ] ] = None ,
110
110
md5 : Optional [str ] = None ,
111
111
max_redirect_hops : int = 3 ,
112
112
) -> None :
@@ -159,7 +159,7 @@ def download_url(
159
159
raise RuntimeError ("File not found or corrupted." )
160
160
161
161
162
- def list_dir (root : str , prefix : bool = False ) -> List [str ]:
162
+ def list_dir (root : Union [ str , pathlib . Path ] , prefix : bool = False ) -> List [str ]:
163
163
"""List all directories at a given root
164
164
165
165
Args:
@@ -174,7 +174,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]:
174
174
return directories
175
175
176
176
177
- def list_files (root : str , suffix : str , prefix : bool = False ) -> List [str ]:
177
+ def list_files (root : Union [ str , pathlib . Path ] , suffix : str , prefix : bool = False ) -> List [str ]:
178
178
"""List all files ending with a suffix at a given root
179
179
180
180
Args:
@@ -208,7 +208,10 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple
208
208
209
209
210
210
def download_file_from_google_drive (
211
- file_id : str , root : Union [str , pathlib .Path ], filename : Optional [str ] = None , md5 : Optional [str ] = None
211
+ file_id : str ,
212
+ root : Union [str , pathlib .Path ],
213
+ filename : Optional [Union [str , pathlib .Path ]] = None ,
214
+ md5 : Optional [str ] = None ,
212
215
):
213
216
"""Download a Google Drive file from and place it in root.
214
217
@@ -278,7 +281,9 @@ def download_file_from_google_drive(
278
281
)
279
282
280
283
281
- def _extract_tar (from_path : str , to_path : str , compression : Optional [str ]) -> None :
284
+ def _extract_tar (
285
+ from_path : Union [str , pathlib .Path ], to_path : Union [str , pathlib .Path ], compression : Optional [str ]
286
+ ) -> None :
282
287
with tarfile .open (from_path , f"r:{ compression [1 :]} " if compression else "r" ) as tar :
283
288
tar .extractall (to_path )
284
289
@@ -289,14 +294,16 @@ def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> No
289
294
}
290
295
291
296
292
- def _extract_zip (from_path : str , to_path : str , compression : Optional [str ]) -> None :
297
+ def _extract_zip (
298
+ from_path : Union [str , pathlib .Path ], to_path : Union [str , pathlib .Path ], compression : Optional [str ]
299
+ ) -> None :
293
300
with zipfile .ZipFile (
294
301
from_path , "r" , compression = _ZIP_COMPRESSION_MAP [compression ] if compression else zipfile .ZIP_STORED
295
302
) as zip :
296
303
zip .extractall (to_path )
297
304
298
305
299
- _ARCHIVE_EXTRACTORS : Dict [str , Callable [[str , str , Optional [str ]], None ]] = {
306
+ _ARCHIVE_EXTRACTORS : Dict [str , Callable [[Union [ str , pathlib . Path ], Union [ str , pathlib . Path ] , Optional [str ]], None ]] = {
300
307
".tar" : _extract_tar ,
301
308
".zip" : _extract_zip ,
302
309
}
@@ -312,7 +319,7 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
312
319
}
313
320
314
321
315
- def _detect_file_type (file : str ) -> Tuple [str , Optional [str ], Optional [str ]]:
322
+ def _detect_file_type (file : Union [ str , pathlib . Path ] ) -> Tuple [str , Optional [str ], Optional [str ]]:
316
323
"""Detect the archive type and/or compression of a file.
317
324
318
325
Args:
@@ -355,7 +362,11 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
355
362
raise RuntimeError (f"Unknown compression or archive type: '{ suffix } '.\n Known suffixes are: '{ valid_suffixes } '." )
356
363
357
364
358
- def _decompress (from_path : str , to_path : Optional [str ] = None , remove_finished : bool = False ) -> str :
365
+ def _decompress (
366
+ from_path : Union [str , pathlib .Path ],
367
+ to_path : Optional [Union [str , pathlib .Path ]] = None ,
368
+ remove_finished : bool = False ,
369
+ ) -> pathlib .Path :
359
370
r"""Decompress a file.
360
371
361
372
The compression is automatically detected from the file name.
@@ -373,7 +384,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
373
384
raise RuntimeError (f"Couldn't detect a compression from suffix { suffix } ." )
374
385
375
386
if to_path is None :
376
- to_path = from_path .replace (suffix , archive_type if archive_type is not None else "" )
387
+ to_path = pathlib . Path ( os . fspath ( from_path ) .replace (suffix , archive_type if archive_type is not None else "" ) )
377
388
378
389
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
379
390
compressed_file_opener = _COMPRESSED_FILE_OPENERS [compression ]
@@ -384,10 +395,14 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
384
395
if remove_finished :
385
396
os .remove (from_path )
386
397
387
- return to_path
398
+ return pathlib . Path ( to_path )
388
399
389
400
390
- def extract_archive (from_path : str , to_path : Optional [str ] = None , remove_finished : bool = False ) -> str :
401
+ def extract_archive (
402
+ from_path : Union [str , pathlib .Path ],
403
+ to_path : Optional [Union [str , pathlib .Path ]] = None ,
404
+ remove_finished : bool = False ,
405
+ ) -> Union [str , pathlib .Path ]:
391
406
"""Extract an archive.
392
407
393
408
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
@@ -402,16 +417,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
402
417
Returns:
403
418
(str): Path to the directory the file was extracted to.
404
419
"""
420
+
421
+ def path_or_str (ret_path : pathlib .Path ) -> Union [str , pathlib .Path ]:
422
+ if isinstance (from_path , str ):
423
+ return os .fspath (ret_path )
424
+ else :
425
+ return ret_path
426
+
405
427
if to_path is None :
406
428
to_path = os .path .dirname (from_path )
407
429
408
430
suffix , archive_type , compression = _detect_file_type (from_path )
409
431
if not archive_type :
410
- return _decompress (
432
+ ret_path = _decompress (
411
433
from_path ,
412
434
os .path .join (to_path , os .path .basename (from_path ).replace (suffix , "" )),
413
435
remove_finished = remove_finished ,
414
436
)
437
+ return path_or_str (ret_path )
415
438
416
439
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
417
440
extractor = _ARCHIVE_EXTRACTORS [archive_type ]
@@ -420,14 +443,14 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
420
443
if remove_finished :
421
444
os .remove (from_path )
422
445
423
- return to_path
446
+ return path_or_str ( pathlib . Path ( to_path ))
424
447
425
448
426
449
def download_and_extract_archive (
427
450
url : str ,
428
- download_root : str ,
429
- extract_root : Optional [str ] = None ,
430
- filename : Optional [str ] = None ,
451
+ download_root : Union [ str , pathlib . Path ] ,
452
+ extract_root : Optional [Union [ str , pathlib . Path ] ] = None ,
453
+ filename : Optional [Union [ str , pathlib . Path ] ] = None ,
431
454
md5 : Optional [str ] = None ,
432
455
remove_finished : bool = False ,
433
456
) -> None :
@@ -479,7 +502,7 @@ def verify_str_arg(
479
502
return value
480
503
481
504
482
- def _read_pfm (file_name : str , slice_channels : int = 2 ) -> np .ndarray :
505
+ def _read_pfm (file_name : Union [ str , pathlib . Path ] , slice_channels : int = 2 ) -> np .ndarray :
483
506
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
484
507
485
508
Args:
0 commit comments