Skip to content

Commit 1de7a74

Browse files
authored
Added pathlib support to datasets/utils.py (#8200)
1 parent a00a72b commit 1de7a74

File tree

2 files changed

+75
-29
lines changed

2 files changed

+75
-29
lines changed

test/test_datasets_utils.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,11 @@ def test_get_redirect_url_max_hops_exceeded(self, mocker):
5858
assert mock.call_count == 1
5959
assert mock.call_args[0][0].full_url == url
6060

61-
def test_check_md5(self):
61+
@pytest.mark.parametrize("use_pathlib", (True, False))
62+
def test_check_md5(self, use_pathlib):
6263
fpath = TEST_FILE
64+
if use_pathlib:
65+
fpath = pathlib.Path(fpath)
6366
correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
6467
false_md5 = ""
6568
assert utils.check_md5(fpath, correct_md5)
@@ -116,7 +119,8 @@ def test_detect_file_type_incompatible(self, file):
116119
utils._detect_file_type(file)
117120

118121
@pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
119-
def test_decompress(self, extension, tmpdir):
122+
@pytest.mark.parametrize("use_pathlib", (True, False))
123+
def test_decompress(self, extension, tmpdir, use_pathlib):
120124
def create_compressed(root, content="this is the content"):
121125
file = os.path.join(root, "file")
122126
compressed = f"{file}{extension}"
@@ -128,6 +132,8 @@ def create_compressed(root, content="this is the content"):
128132
return compressed, file, content
129133

130134
compressed, file, content = create_compressed(tmpdir)
135+
if use_pathlib:
136+
compressed = pathlib.Path(compressed)
131137

132138
utils._decompress(compressed)
133139

@@ -140,7 +146,8 @@ def test_decompress_no_compression(self):
140146
with pytest.raises(RuntimeError):
141147
utils._decompress("foo.tar")
142148

143-
def test_decompress_remove_finished(self, tmpdir):
149+
@pytest.mark.parametrize("use_pathlib", (True, False))
150+
def test_decompress_remove_finished(self, tmpdir, use_pathlib):
144151
def create_compressed(root, content="this is the content"):
145152
file = os.path.join(root, "file")
146153
compressed = f"{file}.gz"
@@ -151,10 +158,20 @@ def create_compressed(root, content="this is the content"):
151158
return compressed, file, content
152159

153160
compressed, file, content = create_compressed(tmpdir)
161+
print(f"{type(compressed)=}")
162+
if use_pathlib:
163+
compressed = pathlib.Path(compressed)
164+
tmpdir = pathlib.Path(tmpdir)
154165

155-
utils.extract_archive(compressed, tmpdir, remove_finished=True)
166+
extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True)
156167

157168
assert not os.path.exists(compressed)
169+
if use_pathlib:
170+
assert isinstance(extracted_dir, pathlib.Path)
171+
assert isinstance(compressed, pathlib.Path)
172+
else:
173+
assert isinstance(extracted_dir, str)
174+
assert isinstance(compressed, str)
158175

159176
@pytest.mark.parametrize("extension", [".gz", ".xz"])
160177
@pytest.mark.parametrize("remove_finished", [True, False])
@@ -167,7 +184,8 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m
167184

168185
mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
169186

170-
def test_extract_zip(self, tmpdir):
187+
@pytest.mark.parametrize("use_pathlib", (True, False))
188+
def test_extract_zip(self, tmpdir, use_pathlib):
171189
def create_archive(root, content="this is the content"):
172190
file = os.path.join(root, "dst.txt")
173191
archive = os.path.join(root, "archive.zip")
@@ -177,6 +195,8 @@ def create_archive(root, content="this is the content"):
177195

178196
return archive, file, content
179197

198+
if use_pathlib:
199+
tmpdir = pathlib.Path(tmpdir)
180200
archive, file, content = create_archive(tmpdir)
181201

182202
utils.extract_archive(archive, tmpdir)
@@ -189,7 +209,8 @@ def create_archive(root, content="this is the content"):
189209
@pytest.mark.parametrize(
190210
"extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
191211
)
192-
def test_extract_tar(self, extension, mode, tmpdir):
212+
@pytest.mark.parametrize("use_pathlib", (True, False))
213+
def test_extract_tar(self, extension, mode, tmpdir, use_pathlib):
193214
def create_archive(root, extension, mode, content="this is the content"):
194215
src = os.path.join(root, "src.txt")
195216
dst = os.path.join(root, "dst.txt")
@@ -203,6 +224,8 @@ def create_archive(root, extension, mode, content="this is the content"):
203224

204225
return archive, dst, content
205226

227+
if use_pathlib:
228+
tmpdir = pathlib.Path(tmpdir)
206229
archive, file, content = create_archive(tmpdir, extension, mode)
207230

208231
utils.extract_archive(archive, tmpdir)

torchvision/datasets/utils.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
def _save_response_content(
3232
content: Iterator[bytes],
33-
destination: str,
33+
destination: Union[str, pathlib.Path],
3434
length: Optional[int] = None,
3535
) -> None:
3636
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
@@ -43,12 +43,12 @@ def _save_response_content(
4343
pbar.update(len(chunk))
4444

4545

46-
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
46+
def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
4747
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
4848
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
4949

5050

51-
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
51+
def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
5252
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
5353
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
5454
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
@@ -62,11 +62,11 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
6262
return md5.hexdigest()
6363

6464

65-
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
65+
def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
6666
return md5 == calculate_md5(fpath, **kwargs)
6767

6868

69-
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
69+
def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
7070
if not os.path.isfile(fpath):
7171
return False
7272
if md5 is None:
@@ -106,7 +106,7 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
106106
def download_url(
107107
url: str,
108108
root: Union[str, pathlib.Path],
109-
filename: Optional[str] = None,
109+
filename: Optional[Union[str, pathlib.Path]] = None,
110110
md5: Optional[str] = None,
111111
max_redirect_hops: int = 3,
112112
) -> None:
@@ -159,7 +159,7 @@ def download_url(
159159
raise RuntimeError("File not found or corrupted.")
160160

161161

162-
def list_dir(root: str, prefix: bool = False) -> List[str]:
162+
def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]:
163163
"""List all directories at a given root
164164
165165
Args:
@@ -174,7 +174,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]:
174174
return directories
175175

176176

177-
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
177+
def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]:
178178
"""List all files ending with a suffix at a given root
179179
180180
Args:
@@ -208,7 +208,10 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple
208208

209209

210210
def download_file_from_google_drive(
211-
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
211+
file_id: str,
212+
root: Union[str, pathlib.Path],
213+
filename: Optional[Union[str, pathlib.Path]] = None,
214+
md5: Optional[str] = None,
212215
):
213216
"""Download a Google Drive file from and place it in root.
214217
@@ -278,7 +281,9 @@ def download_file_from_google_drive(
278281
)
279282

280283

281-
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
284+
def _extract_tar(
285+
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
286+
) -> None:
282287
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
283288
tar.extractall(to_path)
284289

@@ -289,14 +294,16 @@ def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> No
289294
}
290295

291296

292-
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
297+
def _extract_zip(
298+
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
299+
) -> None:
293300
with zipfile.ZipFile(
294301
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
295302
) as zip:
296303
zip.extractall(to_path)
297304

298305

299-
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
306+
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
300307
".tar": _extract_tar,
301308
".zip": _extract_zip,
302309
}
@@ -312,7 +319,7 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
312319
}
313320

314321

315-
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
322+
def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]:
316323
"""Detect the archive type and/or compression of a file.
317324
318325
Args:
@@ -355,7 +362,11 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
355362
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
356363

357364

358-
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
365+
def _decompress(
366+
from_path: Union[str, pathlib.Path],
367+
to_path: Optional[Union[str, pathlib.Path]] = None,
368+
remove_finished: bool = False,
369+
) -> pathlib.Path:
359370
r"""Decompress a file.
360371
361372
The compression is automatically detected from the file name.
@@ -373,7 +384,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
373384
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
374385

375386
if to_path is None:
376-
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
387+
to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))
377388

378389
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
379390
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
@@ -384,10 +395,14 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
384395
if remove_finished:
385396
os.remove(from_path)
386397

387-
return to_path
398+
return pathlib.Path(to_path)
388399

389400

390-
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
401+
def extract_archive(
402+
from_path: Union[str, pathlib.Path],
403+
to_path: Optional[Union[str, pathlib.Path]] = None,
404+
remove_finished: bool = False,
405+
) -> Union[str, pathlib.Path]:
391406
"""Extract an archive.
392407
393408
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
@@ -402,16 +417,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
402417
Returns:
403418
(str): Path to the directory the file was extracted to.
404419
"""
420+
421+
def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
422+
if isinstance(from_path, str):
423+
return os.fspath(ret_path)
424+
else:
425+
return ret_path
426+
405427
if to_path is None:
406428
to_path = os.path.dirname(from_path)
407429

408430
suffix, archive_type, compression = _detect_file_type(from_path)
409431
if not archive_type:
410-
return _decompress(
432+
ret_path = _decompress(
411433
from_path,
412434
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
413435
remove_finished=remove_finished,
414436
)
437+
return path_or_str(ret_path)
415438

416439
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
417440
extractor = _ARCHIVE_EXTRACTORS[archive_type]
@@ -420,14 +443,14 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
420443
if remove_finished:
421444
os.remove(from_path)
422445

423-
return to_path
446+
return path_or_str(pathlib.Path(to_path))
424447

425448

426449
def download_and_extract_archive(
427450
url: str,
428-
download_root: str,
429-
extract_root: Optional[str] = None,
430-
filename: Optional[str] = None,
451+
download_root: Union[str, pathlib.Path],
452+
extract_root: Optional[Union[str, pathlib.Path]] = None,
453+
filename: Optional[Union[str, pathlib.Path]] = None,
431454
md5: Optional[str] = None,
432455
remove_finished: bool = False,
433456
) -> None:
@@ -479,7 +502,7 @@ def verify_str_arg(
479502
return value
480503

481504

482-
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
505+
def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
483506
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
484507
485508
Args:

0 commit comments

Comments
 (0)