Skip to content

fix test_extract_(zip|tar|tar_xz|gzip) on windows #3542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 11, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 79 additions & 48 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import sys
import tempfile
import torchvision.datasets.utils as utils
import unittest
import unittest.mock
Expand Down Expand Up @@ -102,62 +100,95 @@ def test_download_url_dispatch_download_from_google_drive(self, mock):

mock.assert_called_once_with(id, root, filename, md5)

@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_zip(self):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")

with zipfile.ZipFile(archive, "w") as zf:
zf.writestr(os.path.basename(file), content)

return archive, file, content

with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')

@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
archive, file, content = create_archive(temp_dir)

utils.extract_archive(archive, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_extract_tar(self):
def create_archive(root, ext, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}")

with open(src, "w") as fh:
fh.write(content)

with tarfile.open(archive, mode=mode) as fh:
fh.add(src, arcname=os.path.basename(dst))

return archive, dst, content

for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')

@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
archive, file, content = create_archive(temp_dir, ext, mode)

utils.extract_archive(archive, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_extract_tar_xz(self):
def create_archive(root, ext, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}")

with open(src, "w") as fh:
fh.write(content)

with tarfile.open(archive, mode=mode) as fh:
fh.add(src, arcname=os.path.basename(dst))

return archive, dst, content

for ext, mode in zip(['.tar.xz'], ['w:xz']):
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')

@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
archive, file, content = create_archive(temp_dir, ext, mode)

utils.extract_archive(archive, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_extract_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"

with gzip.GzipFile(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode())
utils.extract_archive(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
self.assertTrue(os.path.exists(f_name))
with open(os.path.join(f_name), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
compressed, file, content = create_compressed(temp_dir)

utils.extract_archive(compressed, temp_dir)

self.assertTrue(os.path.exists(file))

with open(file, "r") as fh:
self.assertEqual(fh.read(), content)

def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
Expand Down