From 6a73c83485cd5e98b90f903e9932350fb4ab961f Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 19 Mar 2019 12:28:47 +0100 Subject: [PATCH 1/7] WIP --- test/test_datasets.py | 30 ++++++++++++++++++++++++++++++ torchvision/datasets/mnist.py | 17 +++++++++-------- torchvision/datasets/utils.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 test/test_datasets.py diff --git a/test/test_datasets.py b/test/test_datasets.py new file mode 100644 index 00000000000..eeac7cc00cf --- /dev/null +++ b/test/test_datasets.py @@ -0,0 +1,30 @@ +import torchvision + +import unittest +from unittest import mock +from unittest.mock import patch + + + +def download_url(url, root, filename=None, md5=None): + print("Downloaded {} to {} with filename {} and md5 {}".format(url, root, filename, md5)) + + + +class Tester(unittest.TestCase): + #@mock.patch('torchvision.datasets.utils.download_url') + @mock.patch('torchvision.datasets.mnist.download_url') + @mock.patch('torchvision.datasets.mnist.MNIST._check_exists') + @mock.patch('torchvision.datasets.mnist.read_image_file') + @mock.patch('torchvision.datasets.mnist.read_label_file') + def test_mnist(self, check_fn, download_fn): + dataset = torchvision.datasets.MNIST('.') + + + + + +if __name__ == '__main__': + #unittest.main() + with patch('torchvision.datasets.utils.download_url', side_effect=download_url) as f: + torchvision.datasets.utils.download_url('http://google.com', '.') diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index e1a277d2de7..197fe7234f8 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -9,6 +9,7 @@ import torch import codecs from .utils import download_url, makedir_exist_ok +from .utils import download_and_extract, extract_file class MNIST(VisionDataset): @@ -141,9 +142,7 @@ def download(self): # download files for url in self.urls: filename = url.rpartition('/')[2] - file_path = os.path.join(self.raw_folder, filename) - download_url(url, root=self.raw_folder, filename=filename, md5=None) - self.extract_gzip(gzip_path=file_path, remove_finished=True) + download_and_extract(url, root=self.raw_folder, filename=filename, md5=None) # process and save as torch files print('Processing...') @@ -273,16 +272,18 @@ def download(self): # download files filename = self.url.rpartition('/')[2] file_path = os.path.join(self.raw_folder, filename) - download_url(self.url, root=self.raw_folder, filename=filename, md5=None) + #download_url(self.url, root=self.raw_folder, filename=filename, md5=None) print('Extracting zip archive') - with zipfile.ZipFile(file_path) as zip_f: - zip_f.extractall(self.raw_folder) - os.unlink(file_path) + #with zipfile.ZipFile(file_path) as zip_f: + # zip_f.extractall(self.raw_folder) + #os.unlink(file_path) + download_and_extract(self.url, root=self.raw_folder, filename="kmnist.zip", md5=None) gzip_folder = os.path.join(self.raw_folder, 'gzip') for gzip_file in os.listdir(gzip_folder): if gzip_file.endswith('.gz'): - self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file)) + extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder) + # self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file)) # process and save as torch files for split in self.splits: diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index f0011603f57..c58ade6ffbf 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -189,3 +189,32 @@ def _save_response_content(response, destination, chunk_size=32768): progress += len(chunk) pbar.update(progress - pbar.n) pbar.close() + + +import tarfile +import zipfile +import gzip +def extract_file(from_path, to_path, remove_finished=False): + # TODO make it more robust wrt tar.gz + if from_path.endswith(".tar"): + with tarfile.open(from_path, 'r:') as tar: + tar.extractall(path=to_path) + elif from_path.endswith(".tar.gz"): + with tarfile.open(from_path, 'r:gz') as tar: + tar.extractall(path=to_path) + elif from_path.endswith(".gz"): + to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif from_path.endswith("zip"): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError("Not supported") + + if remove_finished: + os.unlink(from_path) + +def download_and_extract(url, root, filename, md5=None): + download_url(url, root, filename, md5) + extract_file(os.path.join(root, filename), root) From bce0176f107ccb73f7c7b752706baf30b6a617f8 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 28 May 2019 16:34:39 +0200 Subject: [PATCH 2/7] WIP: minor improvements --- test/test_datasets.py | 30 ++++++++++++++++++++---------- torchvision/datasets/mnist.py | 21 +++------------------ torchvision/datasets/utils.py | 25 +++++++++++++++++++------ 3 files changed, 42 insertions(+), 34 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index eeac7cc00cf..dfa84b129d1 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1,4 +1,6 @@ +import PIL import torchvision +import tempfile import unittest from unittest import mock @@ -13,18 +15,26 @@ def download_url(url, root, filename=None, md5=None): class Tester(unittest.TestCase): #@mock.patch('torchvision.datasets.utils.download_url') - @mock.patch('torchvision.datasets.mnist.download_url') - @mock.patch('torchvision.datasets.mnist.MNIST._check_exists') - @mock.patch('torchvision.datasets.mnist.read_image_file') - @mock.patch('torchvision.datasets.mnist.read_label_file') - def test_mnist(self, check_fn, download_fn): - dataset = torchvision.datasets.MNIST('.') + #@mock.patch('torchvision.datasets.mnist.download_url') + #@mock.patch('torchvision.datasets.mnist.MNIST._check_exists') + #@mock.patch('torchvision.datasets.mnist.read_image_file') + #@mock.patch('torchvision.datasets.mnist.read_label_file') + #def test_mnist(self, check_fn, download_fn): + # dataset = torchvision.datasets.MNIST('.') + + + def test_mnist(self): + tmp_dir = tempfile.gettempdir() + dataset = torchvision.datasets.MNIST(tmp_dir, download=True) + self.assertEqual(len(dataset), 60000) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) - if __name__ == '__main__': - #unittest.main() - with patch('torchvision.datasets.utils.download_url', side_effect=download_url) as f: - torchvision.datasets.utils.download_url('http://google.com', '.') + unittest.main() + # with patch('torchvision.datasets.utils.download_url', side_effect=download_url) as f: + # torchvision.datasets.utils.download_url('http://google.com', '.') diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 197fe7234f8..7afc07be4c9 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -8,8 +8,7 @@ import numpy as np import torch import codecs -from .utils import download_url, makedir_exist_ok -from .utils import download_and_extract, extract_file +from .utils import download_and_extract, extract_file, makedir_exist_ok class MNIST(VisionDataset): @@ -121,15 +120,6 @@ def _check_exists(self): os.path.exists(os.path.join(self.processed_folder, self.test_file))) - @staticmethod - def extract_gzip(gzip_path, remove_finished=False): - print('Extracting {}'.format(gzip_path)) - with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \ - gzip.GzipFile(gzip_path) as zip_f: - out_f.write(zip_f.read()) - if remove_finished: - os.unlink(gzip_path) - def download(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" @@ -142,7 +132,7 @@ def download(self): # download files for url in self.urls: filename = url.rpartition('/')[2] - download_and_extract(url, root=self.raw_folder, filename=filename, md5=None) + download_and_extract(url, root=self.raw_folder, filename=filename) # process and save as torch files print('Processing...') @@ -272,18 +262,13 @@ def download(self): # download files filename = self.url.rpartition('/')[2] file_path = os.path.join(self.raw_folder, filename) - #download_url(self.url, root=self.raw_folder, filename=filename, md5=None) print('Extracting zip archive') - #with zipfile.ZipFile(file_path) as zip_f: - # zip_f.extractall(self.raw_folder) - #os.unlink(file_path) - download_and_extract(self.url, root=self.raw_folder, filename="kmnist.zip", md5=None) + download_and_extract(self.url, root=self.raw_folder, filename="kmnist.zip", remove_finished=True) gzip_folder = os.path.join(self.raw_folder, 'gzip') for gzip_file in os.listdir(gzip_folder): if gzip_file.endswith('.gz'): extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder) - # self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file)) # process and save as torch files for split in self.splits: diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index c58ade6ffbf..4f759621ed2 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -194,19 +194,32 @@ def _save_response_content(response, destination, chunk_size=32768): import tarfile import zipfile import gzip + +def _is_tar(filename): + return filename.endswith(".tar") + +def _is_targz(filename): + return filename.endswith(".tar.gz") + +def _is_gzip(filename): + return filename.endswith(".gz") and not filename.endswith(".tar.gz") + +def _is_zip(filename): + return filename.endswith(".zip") + def extract_file(from_path, to_path, remove_finished=False): # TODO make it more robust wrt tar.gz - if from_path.endswith(".tar"): + if _is_tar(from_path): with tarfile.open(from_path, 'r:') as tar: tar.extractall(path=to_path) - elif from_path.endswith(".tar.gz"): + elif _is_targz(from_path): with tarfile.open(from_path, 'r:gz') as tar: tar.extractall(path=to_path) - elif from_path.endswith(".gz"): + elif _is_gzip(from_path): to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: out_f.write(zip_f.read()) - elif from_path.endswith("zip"): + elif _is_zip(from_path): with zipfile.ZipFile(from_path, 'r') as z: z.extractall(to_path) else: @@ -215,6 +228,6 @@ def extract_file(from_path, to_path, remove_finished=False): if remove_finished: os.unlink(from_path) -def download_and_extract(url, root, filename, md5=None): +def download_and_extract(url, root, filename, md5=None, remove_finished=False): download_url(url, root, filename, md5) - extract_file(os.path.join(root, filename), root) + extract_file(os.path.join(root, filename), root, remove_finished) From c3c708b337b4930426d05eeb4c9ea423135f9db8 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 28 May 2019 19:04:27 +0200 Subject: [PATCH 3/7] Add tests --- test/test_datasets.py | 46 ++++++++++++++++++++--------------- test/test_datasets_utils.py | 44 +++++++++++++++++++++++++++++++++ torchvision/datasets/mnist.py | 9 ++----- torchvision/datasets/utils.py | 16 +++++++----- 4 files changed, 82 insertions(+), 33 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index dfa84b129d1..8e41e2e25cb 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1,40 +1,46 @@ import PIL -import torchvision +import shutil import tempfile - import unittest -from unittest import mock -from unittest.mock import patch - - - -def download_url(url, root, filename=None, md5=None): - print("Downloaded {} to {} with filename {} and md5 {}".format(url, root, filename, md5)) +import torchvision class Tester(unittest.TestCase): - #@mock.patch('torchvision.datasets.utils.download_url') - #@mock.patch('torchvision.datasets.mnist.download_url') - #@mock.patch('torchvision.datasets.mnist.MNIST._check_exists') - #@mock.patch('torchvision.datasets.mnist.read_image_file') - #@mock.patch('torchvision.datasets.mnist.read_label_file') - #def test_mnist(self, check_fn, download_fn): - # dataset = torchvision.datasets.MNIST('.') - def test_mnist(self): - tmp_dir = tempfile.gettempdir() + tmp_dir = tempfile.mkdtemp() dataset = torchvision.datasets.MNIST(tmp_dir, download=True) self.assertEqual(len(dataset), 60000) img, target = dataset[0] self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(target, int)) + shutil.rmtree(temp_dir) + def test_emnist(self): + tmp_dir = tempfile.mkdtemp() + dataset = torchvision.datasets.EMNIST(tmp_dir, split='byclass', download=True) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + shutil.rmtree(temp_dir) + + def test_kmnist(self): + tmp_dir = tempfile.mkdtemp() + dataset = torchvision.datasets.KMNIST(tmp_dir, download=True) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + shutil.rmtree(temp_dir) + def test_fashionmnist(self): + tmp_dir = tempfile.mkdtemp() + dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + shutil.rmtree(temp_dir) if __name__ == '__main__': unittest.main() - # with patch('torchvision.datasets.utils.download_url', side_effect=download_url) as f: - # torchvision.datasets.utils.download_url('http://google.com', '.') diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 1f48b88ac74..79a1e3992c2 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -3,6 +3,9 @@ import tempfile import torchvision.datasets.utils as utils import unittest +import zipfile +import tarfile +import gzip TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') @@ -41,6 +44,47 @@ def test_download_url_retry_http(self): assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.' shutil.rmtree(temp_dir) + def test_extract_zip(self): + temp_dir = tempfile.mkdtemp() + with tempfile.NamedTemporaryFile(suffix='.zip') as f: + with zipfile.ZipFile(f, 'w') as zf: + zf.writestr('file.tst', 'this is the content') + utils.extract_file(f.name, temp_dir) + assert os.path.exists(os.path.join(temp_dir, 'file.tst')) + with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: + data = nf.read() + assert data == 'this is the content' + shutil.rmtree(temp_dir) + + def test_extract_tar(self): + for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']): + temp_dir = tempfile.mkdtemp() + with tempfile.NamedTemporaryFile() as bf: + bf.write("this is the content".encode()) + bf.seek(0) + with tempfile.NamedTemporaryFile(suffix=ext) as f: + with tarfile.open(f.name, mode=mode) as zf: + zf.add(bf.name, arcname='file.tst') + utils.extract_file(f.name, temp_dir) + assert os.path.exists(os.path.join(temp_dir, 'file.tst')) + with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: + data = nf.read() + assert data == 'this is the content', data + shutil.rmtree(temp_dir) + + def test_extract_gzip(self): + temp_dir = tempfile.mkdtemp() + with tempfile.NamedTemporaryFile(suffix='.gz') as f: + with gzip.GzipFile(f.name, 'wb') as zf: + zf.write('this is the content'.encode()) + utils.extract_file(f.name, temp_dir) + f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0]) + assert os.path.exists(f_name) + with open(os.path.join(f_name), 'r') as nf: + data = nf.read() + assert data == 'this is the content', data + shutil.rmtree(temp_dir) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 7afc07be4c9..51a3cb19df4 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -4,7 +4,6 @@ from PIL import Image import os import os.path -import gzip import numpy as np import torch import codecs @@ -251,7 +250,6 @@ def _test_file(split): def download(self): """Download the EMNIST data if it doesn't exist in processed_folder already.""" import shutil - import zipfile if self._check_exists(): return @@ -260,11 +258,8 @@ def download(self): makedir_exist_ok(self.processed_folder) # download files - filename = self.url.rpartition('/')[2] - file_path = os.path.join(self.raw_folder, filename) - - print('Extracting zip archive') - download_and_extract(self.url, root=self.raw_folder, filename="kmnist.zip", remove_finished=True) + print('Downloading and extracting zip archive') + download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True) gzip_folder = os.path.join(self.raw_folder, 'gzip') for gzip_file in os.listdir(gzip_folder): if gzip_file.endswith('.gz'): diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 4f759621ed2..5119a5e4193 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -1,7 +1,11 @@ import os import os.path import hashlib +import gzip import errno +import tarfile +import zipfile + from torch.utils.model_zoo import tqdm @@ -191,24 +195,23 @@ def _save_response_content(response, destination, chunk_size=32768): pbar.close() -import tarfile -import zipfile -import gzip - def _is_tar(filename): return filename.endswith(".tar") + def _is_targz(filename): return filename.endswith(".tar.gz") + def _is_gzip(filename): return filename.endswith(".gz") and not filename.endswith(".tar.gz") + def _is_zip(filename): return filename.endswith(".zip") + def extract_file(from_path, to_path, remove_finished=False): - # TODO make it more robust wrt tar.gz if _is_tar(from_path): with tarfile.open(from_path, 'r:') as tar: tar.extractall(path=to_path) @@ -223,11 +226,12 @@ def extract_file(from_path, to_path, remove_finished=False): with zipfile.ZipFile(from_path, 'r') as z: z.extractall(to_path) else: - raise ValueError("Not supported") + raise ValueError("Extraction of {} not supported".format(from_path)) if remove_finished: os.unlink(from_path) + def download_and_extract(url, root, filename, md5=None, remove_finished=False): download_url(url, root, filename, md5) extract_file(os.path.join(root, filename), root, remove_finished) From d8c0d9fe39b803595bf2652ff138b77dcb40d924 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 28 May 2019 21:21:58 +0200 Subject: [PATCH 4/7] Fix typo --- test/test_datasets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 8e41e2e25cb..ece50fb8e0c 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -15,7 +15,7 @@ def test_mnist(self): img, target = dataset[0] self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(target, int)) - shutil.rmtree(temp_dir) + shutil.rmtree(tmp_dir) def test_emnist(self): tmp_dir = tempfile.mkdtemp() @@ -23,7 +23,7 @@ def test_emnist(self): img, target = dataset[0] self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(target, int)) - shutil.rmtree(temp_dir) + shutil.rmtree(tmp_dir) def test_kmnist(self): tmp_dir = tempfile.mkdtemp() @@ -31,7 +31,7 @@ def test_kmnist(self): img, target = dataset[0] self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(target, int)) - shutil.rmtree(temp_dir) + shutil.rmtree(tmp_dir) def test_fashionmnist(self): tmp_dir = tempfile.mkdtemp() @@ -39,7 +39,7 @@ def test_fashionmnist(self): img, target = dataset[0] self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertTrue(isinstance(target, int)) - shutil.rmtree(temp_dir) + shutil.rmtree(tmp_dir) if __name__ == '__main__': From ce05448519bf74d33e7653166d425cfd5646ea79 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 29 May 2019 14:08:21 +0200 Subject: [PATCH 5/7] Use download_and_extract on caltech, cifar and omniglot --- torchvision/datasets/caltech.py | 44 ++++++++++++-------------------- torchvision/datasets/cifar.py | 11 ++------ torchvision/datasets/omniglot.py | 9 ++----- 3 files changed, 20 insertions(+), 44 deletions(-) diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index 8c477e64810..43e8d7caf8d 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -4,7 +4,7 @@ import os.path from .vision import VisionDataset -from .utils import download_url, makedir_exist_ok +from .utils import download_and_extract, makedir_exist_ok class Caltech101(VisionDataset): @@ -109,27 +109,20 @@ def __len__(self): return len(self.index) def download(self): - import tarfile - if self._check_integrity(): print('Files already downloaded and verified') return - download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", - self.root, - "101_ObjectCategories.tar.gz", - "b224c7392d521a49829488ab0f1120d9") - download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", - self.root, - "101_Annotations.tar", - "6f83eeb1f24d99cab4eb377263132c91") - - # extract file - with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar: - tar.extractall(path=self.root) - - with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar: - tar.extractall(path=self.root) + download_and_extract( + "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", + self.root, + "101_ObjectCategories.tar.gz", + "b224c7392d521a49829488ab0f1120d9") + download_and_extract( + "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", + self.root, + "101_Annotations.tar", + "6f83eeb1f24d99cab4eb377263132c91") def extra_repr(self): return "Target type: {target_type}".format(**self.__dict__) @@ -204,17 +197,12 @@ def __len__(self): return len(self.index) def download(self): - import tarfile - if self._check_integrity(): print('Files already downloaded and verified') return - download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", - self.root, - "256_ObjectCategories.tar", - "67b4f42ca05d46448c6bb8ecd2220f6d") - - # extract file - with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar: - tar.extractall(path=self.root) + download_and_extract( + "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", + self.root, + "256_ObjectCategories.tar", + "67b4f42ca05d46448c6bb8ecd2220f6d") diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index 307e8f60a06..59ecda5cf4d 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -11,7 +11,7 @@ import pickle from .vision import VisionDataset -from .utils import download_url, check_integrity +from .utils import check_integrity, download_and_extract class CIFAR10(VisionDataset): @@ -144,17 +144,10 @@ def _check_integrity(self): return True def download(self): - import tarfile - if self._check_integrity(): print('Files already downloaded and verified') return - - download_url(self.url, self.root, self.filename, self.tgz_md5) - - # extract file - with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: - tar.extractall(path=self.root) + download_and_extract(self.url, self.root, self.filename, self.tgz_md5) def extra_repr(self): return "Split: {}".format("Train" if self.train is True else "Test") diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 9e2af0157a0..b5f6d64f12e 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -3,7 +3,7 @@ from os.path import join import os from .vision import VisionDataset -from .utils import download_url, check_integrity, list_dir, list_files +from .utils import download_and_extract, check_integrity, list_dir, list_files class Omniglot(VisionDataset): @@ -81,8 +81,6 @@ def _check_integrity(self): return True def download(self): - import zipfile - if self._check_integrity(): print('Files already downloaded and verified') return @@ -90,10 +88,7 @@ def download(self): filename = self._get_target_folder() zip_filename = filename + '.zip' url = self.download_url_prefix + '/' + zip_filename - download_url(url, self.root, zip_filename, self.zips_md5[filename]) - print('Extracting downloaded file: ' + join(self.root, zip_filename)) - with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file: - zip_file.extractall(self.root) + download_and_extract(url, self.root, zip_filename, self.zips_md5[filename]) def _get_target_folder(self): return 'images_background' if self.background else 'images_evaluation' From 05e64b0f55ae64bab05402c8257d49c0bc9bb673 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 29 May 2019 14:15:53 +0200 Subject: [PATCH 6/7] Add a print message during extraction --- torchvision/datasets/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 5119a5e4193..86a2af48d52 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -234,4 +234,5 @@ def extract_file(from_path, to_path, remove_finished=False): def download_and_extract(url, root, filename, md5=None, remove_finished=False): download_url(url, root, filename, md5) + print("Extracting {} to {}".format(os.path.join(root, filename), root)) extract_file(os.path.join(root, filename), root, remove_finished) From b51e0b3e56a1ac8862a1c56928f22c652b2e7e92 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 29 May 2019 14:48:50 +0200 Subject: [PATCH 7/7] Remove EMNIST from test --- test/test_datasets.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index ece50fb8e0c..bc6474c96b8 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -17,14 +17,6 @@ def test_mnist(self): self.assertTrue(isinstance(target, int)) shutil.rmtree(tmp_dir) - def test_emnist(self): - tmp_dir = tempfile.mkdtemp() - dataset = torchvision.datasets.EMNIST(tmp_dir, split='byclass', download=True) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - shutil.rmtree(tmp_dir) - def test_kmnist(self): tmp_dir = tempfile.mkdtemp() dataset = torchvision.datasets.KMNIST(tmp_dir, download=True)