diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index f0edbaba08f..b1a8e1eda0f 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -1,6 +1,4 @@ import os -import sys -import tempfile import torchvision.datasets.utils as utils import unittest import unittest.mock @@ -102,62 +100,95 @@ def test_download_url_dispatch_download_from_google_drive(self, mock): mock.assert_called_once_with(id, root, filename, md5) - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_zip(self): + def create_archive(root, content="this is the content"): + file = os.path.join(root, "dst.txt") + archive = os.path.join(root, "archive.zip") + + with zipfile.ZipFile(archive, "w") as zf: + zf.writestr(os.path.basename(file), content) + + return archive, file, content + with get_tmp_dir() as temp_dir: - with tempfile.NamedTemporaryFile(suffix='.zip') as f: - with zipfile.ZipFile(f, 'w') as zf: - zf.writestr('file.tst', 'this is the content') - utils.extract_archive(f.name, temp_dir) - self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) - with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: - data = nf.read() - self.assertEqual(data, 'this is the content') - - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') + archive, file, content = create_archive(temp_dir) + + utils.extract_archive(archive, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + def test_extract_tar(self): + def create_archive(root, ext, mode, content="this is the content"): + src = os.path.join(root, "src.txt") + dst = os.path.join(root, "dst.txt") + archive = os.path.join(root, f"archive{ext}") + + with open(src, "w") as fh: + fh.write(content) + + with tarfile.open(archive, mode=mode) as fh: + fh.add(src, arcname=os.path.basename(dst)) + + return archive, dst, content + for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']): with get_tmp_dir() as temp_dir: - with tempfile.NamedTemporaryFile() as bf: - bf.write("this is the content".encode()) - bf.seek(0) - with tempfile.NamedTemporaryFile(suffix=ext) as f: - with tarfile.open(f.name, mode=mode) as zf: - zf.add(bf.name, arcname='file.tst') - utils.extract_archive(f.name, temp_dir) - self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) - with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: - data = nf.read() - self.assertEqual(data, 'this is the content') - - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') + archive, file, content = create_archive(temp_dir, ext, mode) + + utils.extract_archive(archive, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + def test_extract_tar_xz(self): + def create_archive(root, ext, mode, content="this is the content"): + src = os.path.join(root, "src.txt") + dst = os.path.join(root, "dst.txt") + archive = os.path.join(root, f"archive{ext}") + + with open(src, "w") as fh: + fh.write(content) + + with tarfile.open(archive, mode=mode) as fh: + fh.add(src, arcname=os.path.basename(dst)) + + return archive, dst, content + for ext, mode in zip(['.tar.xz'], ['w:xz']): with get_tmp_dir() as temp_dir: - with tempfile.NamedTemporaryFile() as bf: - bf.write("this is the content".encode()) - bf.seek(0) - with tempfile.NamedTemporaryFile(suffix=ext) as f: - with tarfile.open(f.name, mode=mode) as zf: - zf.add(bf.name, arcname='file.tst') - utils.extract_archive(f.name, temp_dir) - self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) - with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: - data = nf.read() - self.assertEqual(data, 'this is the content') - - @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') + archive, file, content = create_archive(temp_dir, ext, mode) + + utils.extract_archive(archive, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) + def test_extract_gzip(self): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}.gz" + + with gzip.GzipFile(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + with get_tmp_dir() as temp_dir: - with tempfile.NamedTemporaryFile(suffix='.gz') as f: - with gzip.GzipFile(f.name, 'wb') as zf: - zf.write('this is the content'.encode()) - utils.extract_archive(f.name, temp_dir) - f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0]) - self.assertTrue(os.path.exists(f_name)) - with open(os.path.join(f_name), 'r') as nf: - data = nf.read() - self.assertEqual(data, 'this is the content') + compressed, file, content = create_compressed(temp_dir) + + utils.extract_archive(compressed, temp_dir) + + self.assertTrue(os.path.exists(file)) + + with open(file, "r") as fh: + self.assertEqual(fh.read(), content) def test_verify_str_arg(self): self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))