From 0e8ed00154d3a1fab0d3955f2225d37e0a3bed90 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 10 Mar 2021 16:23:00 +0100 Subject: [PATCH 1/2] fix test_extract_(zip|tar|tar_xz|gzip) on windows --- test/test_datasets_utils.py | 127 ++++++++++++++++++++++-------------- 1 file changed, 79 insertions(+), 48 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index f0edbaba08f..244ae521d72 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -1,6 +1,4 @@ import os -import sys -import tempfile import torchvision.datasets.utils as utils import unittest import unittest.mock @@ -102,62 +100,95 @@ def test_download_url_dispatch_download_from_google_drive(self, mock): mock.assert_called_once_with(id, root, filename, md5) - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_zip(self): + def create_archive(root, content="this is the content"): + file = os.path.join(root, "dst.txt") + archive = os.path.join(root, f"archive.zip") + + with zipfile.ZipFile(archive, "w") as zf: + zf.writestr(os.path.basename(file), content) + + return archive, file, content + with get_tmp_dir() as temp_dir: - with tempfile.NamedTemporaryFile(suffix='.zip') as f: - with zipfile.ZipFile(f, 'w') as zf: - zf.writestr('file.tst', 'this is the content') - utils.extract_archive(f.name, temp_dir) - self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) - with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: - data = nf.read() - self.assertEqual(data, 'this is the content') - - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') + archive, file, content = create_archive(temp_dir) + + utils.extract_archive(archive, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + def test_extract_tar(self): + def create_archive(root, ext, mode, content="this is the content"): + src = os.path.join(root, "src.txt") + dst = os.path.join(root, "dst.txt") + archive = os.path.join(root, f"archive{ext}") + + with open(src, "w") as fh: + fh.write(content) + + with tarfile.open(archive, mode=mode) as fh: + fh.add(src, arcname=os.path.basename(dst)) + + return archive, dst, content + for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']): with get_tmp_dir() as temp_dir: - with tempfile.NamedTemporaryFile() as bf: - bf.write("this is the content".encode()) - bf.seek(0) - with tempfile.NamedTemporaryFile(suffix=ext) as f: - with tarfile.open(f.name, mode=mode) as zf: - zf.add(bf.name, arcname='file.tst') - utils.extract_archive(f.name, temp_dir) - self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) - with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: - data = nf.read() - self.assertEqual(data, 'this is the content') - - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') + archive, file, content = create_archive(temp_dir, ext, mode) + + utils.extract_archive(archive, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + def test_extract_tar_xz(self): + def create_archive(root, ext, mode, content="this is the content"): + src = os.path.join(root, "src.txt") + dst = os.path.join(root, "dst.txt") + archive = os.path.join(root, f"archive{ext}") + + with open(src, "w") as fh: + fh.write(content) + + with tarfile.open(archive, mode=mode) as fh: + fh.add(src, arcname=os.path.basename(dst)) + + return archive, dst, content + for ext, mode in zip(['.tar.xz'], ['w:xz']): with get_tmp_dir() as temp_dir: - with tempfile.NamedTemporaryFile() as bf: - bf.write("this is the content".encode()) - bf.seek(0) - with tempfile.NamedTemporaryFile(suffix=ext) as f: - with tarfile.open(f.name, mode=mode) as zf: - zf.add(bf.name, arcname='file.tst') - utils.extract_archive(f.name, temp_dir) - self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) - with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: - data = nf.read() - self.assertEqual(data, 'this is the content') - - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') + archive, file, content = create_archive(temp_dir, ext, mode) + + utils.extract_archive(archive, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + def test_extract_gzip(self): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}.gz" + + with gzip.GzipFile(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + with get_tmp_dir() as temp_dir: - with tempfile.NamedTemporaryFile(suffix='.gz') as f: - with gzip.GzipFile(f.name, 'wb') as zf: - zf.write('this is the content'.encode()) - utils.extract_archive(f.name, temp_dir) - f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0]) - self.assertTrue(os.path.exists(f_name)) - with open(os.path.join(f_name), 'r') as nf: - data = nf.read() - self.assertEqual(data, 'this is the content') + compressed, file, content = create_compressed(temp_dir) + + utils.extract_archive(compressed, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) def test_verify_str_arg(self): self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",))) From f5590e19af739ba6bb8c53c47c4e9a23dbe281d3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 10 Mar 2021 16:37:18 +0100 Subject: [PATCH 2/2] lint --- test/test_datasets_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 244ae521d72..b1a8e1eda0f 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -103,7 +103,7 @@ def test_download_url_dispatch_download_from_google_drive(self, mock): def test_extract_zip(self): def create_archive(root, content="this is the content"): file = os.path.join(root, "dst.txt") - archive = os.path.join(root, f"archive.zip") + archive = os.path.join(root, "archive.zip") with zipfile.ZipFile(archive, "w") as zf: zf.writestr(os.path.basename(file), content)