Skip to content

Commit d9a99ff

Browse files
vincentqbfmassa
authored andcommitted
[fbsync] Separate extraction and decompression logic in datasets.utils.extract_archive (#3443)
Summary: * generalize extract_archive * [test] re-enable extraction tests on windows * add tests for detect_file_type * add error messages to detect_file_type * Revert "[test] re-enable extraction tests on windows" This reverts commit 7fafebb. * add utility functions for better mock call checking * add tests for decompress * simplify logic by using pathlib * lint * Apply suggestions from code review * make decompress private * remove unnecessary checks * add error message * fix mocking * add remaining tests * lint Reviewed By: fmassa Differential Revision: D27128004 fbshipit-source-id: 73f7d8a43eca5dbc9c7e63d8b1ff6e0859915d92 Co-authored-by: Francisco Massa <[email protected]> Co-authored-by: Francisco Massa <[email protected]>
1 parent abffef6 commit d9a99ff

File tree

3 files changed

+254
-54
lines changed

3 files changed

+254
-54
lines changed

test/common_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import warnings
1111
import __main__
1212
import random
13+
import inspect
1314

1415
from numbers import Number
1516
from torch._six import string_classes
@@ -401,3 +402,20 @@ def disable_console_output():
401402
stack.enter_context(contextlib.redirect_stdout(devnull))
402403
stack.enter_context(contextlib.redirect_stderr(devnull))
403404
yield
405+
406+
407+
def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
408+
callable_or_arg_name = callable_or_arg_names[0]
409+
if callable(callable_or_arg_name):
410+
argspec = inspect.getfullargspec(callable_or_arg_name)
411+
arg_names = argspec.args
412+
if isinstance(callable_or_arg_name, type):
413+
# remove self
414+
arg_names.pop(0)
415+
else:
416+
arg_names = callable_or_arg_names
417+
418+
args, kwargs = call_args
419+
kwargs_only = kwargs.copy()
420+
kwargs_only.update(dict(zip(arg_names, args)))
421+
return kwargs_only

test/test_datasets_utils.py

Lines changed: 111 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import warnings
99
from torch._utils_internal import get_file_path_2
1010
from urllib.error import URLError
11+
import itertools
12+
import lzma
1113

12-
from common_utils import get_tmp_dir
14+
from common_utils import get_tmp_dir, call_args_to_kwargs_only
1315

1416

1517
TEST_FILE = get_file_path_2(
@@ -100,6 +102,114 @@ def test_download_url_dispatch_download_from_google_drive(self, mock):
100102

101103
mock.assert_called_once_with(id, root, filename, md5)
102104

105+
def test_detect_file_type(self):
106+
for file, expected in [
107+
("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
108+
("foo.tar", (".tar", ".tar", None)),
109+
("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
110+
("foo.tgz", (".tgz", ".tar", ".gz")),
111+
("foo.gz", (".gz", None, ".gz")),
112+
("foo.zip", (".zip", ".zip", None)),
113+
("foo.xz", (".xz", None, ".xz")),
114+
]:
115+
with self.subTest(file=file):
116+
self.assertSequenceEqual(utils._detect_file_type(file), expected)
117+
118+
def test_detect_file_type_no_ext(self):
119+
with self.assertRaises(RuntimeError):
120+
utils._detect_file_type("foo")
121+
122+
def test_detect_file_type_to_many_exts(self):
123+
with self.assertRaises(RuntimeError):
124+
utils._detect_file_type("foo.bar.tar.gz")
125+
126+
def test_detect_file_type_unknown_archive_type(self):
127+
with self.assertRaises(RuntimeError):
128+
utils._detect_file_type("foo.bar.gz")
129+
130+
def test_detect_file_type_unknown_compression(self):
131+
with self.assertRaises(RuntimeError):
132+
utils._detect_file_type("foo.tar.baz")
133+
134+
def test_detect_file_type_unknown_partial_ext(self):
135+
with self.assertRaises(RuntimeError):
136+
utils._detect_file_type("foo.bar")
137+
138+
def test_decompress_gzip(self):
139+
def create_compressed(root, content="this is the content"):
140+
file = os.path.join(root, "file")
141+
compressed = f"{file}.gz"
142+
143+
with gzip.open(compressed, "wb") as fh:
144+
fh.write(content.encode())
145+
146+
return compressed, file, content
147+
148+
with get_tmp_dir() as temp_dir:
149+
compressed, file, content = create_compressed(temp_dir)
150+
151+
utils._decompress(compressed)
152+
153+
self.assertTrue(os.path.exists(file))
154+
155+
with open(file, "r") as fh:
156+
self.assertEqual(fh.read(), content)
157+
158+
def test_decompress_lzma(self):
159+
def create_compressed(root, content="this is the content"):
160+
file = os.path.join(root, "file")
161+
compressed = f"{file}.xz"
162+
163+
with lzma.open(compressed, "wb") as fh:
164+
fh.write(content.encode())
165+
166+
return compressed, file, content
167+
168+
with get_tmp_dir() as temp_dir:
169+
compressed, file, content = create_compressed(temp_dir)
170+
171+
utils.extract_archive(compressed, temp_dir)
172+
173+
self.assertTrue(os.path.exists(file))
174+
175+
with open(file, "r") as fh:
176+
self.assertEqual(fh.read(), content)
177+
178+
def test_decompress_no_compression(self):
179+
with self.assertRaises(RuntimeError):
180+
utils._decompress("foo.tar")
181+
182+
def test_decompress_remove_finished(self):
183+
def create_compressed(root, content="this is the content"):
184+
file = os.path.join(root, "file")
185+
compressed = f"{file}.gz"
186+
187+
with gzip.open(compressed, "wb") as fh:
188+
fh.write(content.encode())
189+
190+
return compressed, file, content
191+
192+
with get_tmp_dir() as temp_dir:
193+
compressed, file, content = create_compressed(temp_dir)
194+
195+
utils.extract_archive(compressed, temp_dir, remove_finished=True)
196+
197+
self.assertFalse(os.path.exists(compressed))
198+
199+
def test_extract_archive_defer_to_decompress(self):
200+
filename = "foo"
201+
for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)):
202+
with self.subTest(ext=ext, remove_finished=remove_finished):
203+
with unittest.mock.patch("torchvision.datasets.utils._decompress") as mock:
204+
file = f"{filename}{ext}"
205+
utils.extract_archive(file, remove_finished=remove_finished)
206+
207+
mock.assert_called_once()
208+
self.assertEqual(
209+
call_args_to_kwargs_only(mock.call_args, utils._decompress),
210+
dict(from_path=file, to_path=filename, remove_finished=remove_finished),
211+
)
212+
103213
def test_extract_zip(self):
104214
def create_archive(root, content="this is the content"):
105215
file = os.path.join(root, "dst.txt")
@@ -170,26 +280,6 @@ def create_archive(root, ext, mode, content="this is the content"):
170280
with open(file, "r") as fh:
171281
self.assertEqual(fh.read(), content)
172282

173-
def test_extract_gzip(self):
174-
def create_compressed(root, content="this is the content"):
175-
file = os.path.join(root, "file")
176-
compressed = f"{file}.gz"
177-
178-
with gzip.GzipFile(compressed, "wb") as fh:
179-
fh.write(content.encode())
180-
181-
return compressed, file, content
182-
183-
with get_tmp_dir() as temp_dir:
184-
compressed, file, content = create_compressed(temp_dir)
185-
186-
utils.extract_archive(compressed, temp_dir)
187-
188-
self.assertTrue(os.path.exists(file))
189-
190-
with open(file, "r") as fh:
191-
self.assertEqual(fh.read(), content)
192-
193283
def test_verify_str_arg(self):
194284
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
195285
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")

torchvision/datasets/utils.py

Lines changed: 125 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import gzip
55
import re
66
import tarfile
7-
from typing import Any, Callable, List, Iterable, Optional, TypeVar
7+
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple
88
from urllib.parse import urlparse
99
import zipfile
10+
import lzma
11+
import contextlib
1012
import urllib
1113
import urllib.request
1214
import urllib.error
15+
import pathlib
1316

1417
import torch
1518
from torch.utils.model_zoo import tqdm
@@ -242,56 +245,145 @@ def _save_response_content(
242245
pbar.close()
243246

244247

245-
def _is_tarxz(filename: str) -> bool:
246-
return filename.endswith(".tar.xz")
248+
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
249+
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
250+
tar.extractall(to_path)
247251

248252

249-
def _is_tar(filename: str) -> bool:
250-
return filename.endswith(".tar")
253+
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
254+
".xz": zipfile.ZIP_LZMA,
255+
}
251256

252257

253-
def _is_targz(filename: str) -> bool:
254-
return filename.endswith(".tar.gz")
258+
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
259+
with zipfile.ZipFile(
260+
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
261+
) as zip:
262+
zip.extractall(to_path)
255263

256264

257-
def _is_tgz(filename: str) -> bool:
258-
return filename.endswith(".tgz")
265+
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
266+
".tar": _extract_tar,
267+
".zip": _extract_zip,
268+
}
269+
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {".gz": gzip.open, ".xz": lzma.open}
270+
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {".tgz": (".tar", ".gz")}
259271

260272

261-
def _is_gzip(filename: str) -> bool:
262-
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
273+
def _verify_archive_type(archive_type: str) -> None:
274+
if archive_type not in _ARCHIVE_EXTRACTORS.keys():
275+
valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys())
276+
raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.")
263277

264278

265-
def _is_zip(filename: str) -> bool:
266-
return filename.endswith(".zip")
279+
def _verify_compression(compression: str) -> None:
280+
if compression not in _COMPRESSED_FILE_OPENERS.keys():
281+
valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys())
282+
raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.")
267283

268284

269-
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
285+
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
286+
path = pathlib.Path(file)
287+
suffix = path.suffix
288+
suffixes = pathlib.Path(file).suffixes
289+
if not suffixes:
290+
raise RuntimeError(
291+
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
292+
)
293+
elif len(suffixes) > 2:
294+
raise RuntimeError(
295+
"Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead."
296+
)
297+
elif len(suffixes) == 2:
298+
# if we have exactly two suffixes we assume the first one is the archive type and the second on is the
299+
# compression
300+
archive_type, compression = suffixes
301+
_verify_archive_type(archive_type)
302+
_verify_compression(compression)
303+
return "".join(suffixes), archive_type, compression
304+
305+
# check if the suffix is a known alias
306+
with contextlib.suppress(KeyError):
307+
return (suffix, *_FILE_TYPE_ALIASES[suffix])
308+
309+
# check if the suffix is an archive type
310+
with contextlib.suppress(RuntimeError):
311+
_verify_archive_type(suffix)
312+
return suffix, suffix, None
313+
314+
# check if the suffix is a compression
315+
with contextlib.suppress(RuntimeError):
316+
_verify_compression(suffix)
317+
return suffix, None, suffix
318+
319+
raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.")
320+
321+
322+
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
323+
r"""Decompress a file.
324+
325+
The compression is automatically detected from the file name.
326+
327+
Args:
328+
from_path (str): Path to the file to be decompressed.
329+
to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
330+
remove_finished (bool): If ``True``, remove the file after the extraction.
331+
332+
Returns:
333+
(str): Path to the decompressed file.
334+
"""
335+
suffix, archive_type, compression = _detect_file_type(from_path)
336+
if not compression:
337+
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
338+
270339
if to_path is None:
271-
to_path = os.path.dirname(from_path)
340+
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
272341

273-
if _is_tar(from_path):
274-
with tarfile.open(from_path, 'r') as tar:
275-
tar.extractall(path=to_path)
276-
elif _is_targz(from_path) or _is_tgz(from_path):
277-
with tarfile.open(from_path, 'r:gz') as tar:
278-
tar.extractall(path=to_path)
279-
elif _is_tarxz(from_path):
280-
with tarfile.open(from_path, 'r:xz') as tar:
281-
tar.extractall(path=to_path)
282-
elif _is_gzip(from_path):
283-
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
284-
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
285-
out_f.write(zip_f.read())
286-
elif _is_zip(from_path):
287-
with zipfile.ZipFile(from_path, 'r') as z:
288-
z.extractall(to_path)
289-
else:
290-
raise ValueError("Extraction of {} not supported".format(from_path))
342+
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
343+
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
344+
345+
with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
346+
wfh.write(rfh.read())
291347

292348
if remove_finished:
293349
os.remove(from_path)
294350

351+
return to_path
352+
353+
354+
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
355+
"""Extract an archive.
356+
357+
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
358+
but not an archive the call is dispatched to :func:`decompress`.
359+
360+
Args:
361+
from_path (str): Path to the file to be extracted.
362+
to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
363+
used.
364+
remove_finished (bool): If ``True``, remove the file after the extraction.
365+
366+
Returns:
367+
(str): Path to the directory the file was extracted to.
368+
"""
369+
if to_path is None:
370+
to_path = os.path.dirname(from_path)
371+
372+
suffix, archive_type, compression = _detect_file_type(from_path)
373+
if not archive_type:
374+
return _decompress(
375+
from_path,
376+
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
377+
remove_finished=remove_finished,
378+
)
379+
380+
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
381+
extractor = _ARCHIVE_EXTRACTORS[archive_type]
382+
383+
extractor(from_path, to_path, compression)
384+
385+
return to_path
386+
295387

296388
def download_and_extract_archive(
297389
url: str,

0 commit comments

Comments
 (0)