diff --git a/.travis.yml b/.travis.yml index a80f6fbdd27..dbe33d271f3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -34,6 +34,7 @@ before_install: fi - pip install future - pip install pytest pytest-cov codecov + - pip install mock install: diff --git a/test/assets/fakedata/dummy.jpg b/test/assets/fakedata/dummy.jpg new file mode 100644 index 00000000000..5517c39bbb4 Binary files /dev/null and b/test/assets/fakedata/dummy.jpg differ diff --git a/test/assets/fakedata/dummy.png b/test/assets/fakedata/dummy.png new file mode 100644 index 00000000000..a3eb8092d64 Binary files /dev/null and b/test/assets/fakedata/dummy.png differ diff --git a/test/assets/dataset/a/a1.png b/test/assets/fakedata/imagefolder/a/a1.png similarity index 100% rename from test/assets/dataset/a/a1.png rename to test/assets/fakedata/imagefolder/a/a1.png diff --git a/test/assets/dataset/a/a2.png b/test/assets/fakedata/imagefolder/a/a2.png similarity index 100% rename from test/assets/dataset/a/a2.png rename to test/assets/fakedata/imagefolder/a/a2.png diff --git a/test/assets/dataset/a/a3.png b/test/assets/fakedata/imagefolder/a/a3.png similarity index 100% rename from test/assets/dataset/a/a3.png rename to test/assets/fakedata/imagefolder/a/a3.png diff --git a/test/assets/dataset/b/b1.png b/test/assets/fakedata/imagefolder/b/b1.png similarity index 100% rename from test/assets/dataset/b/b1.png rename to test/assets/fakedata/imagefolder/b/b1.png diff --git a/test/assets/dataset/b/b2.png b/test/assets/fakedata/imagefolder/b/b2.png similarity index 100% rename from test/assets/dataset/b/b2.png rename to test/assets/fakedata/imagefolder/b/b2.png diff --git a/test/assets/dataset/b/b3.png b/test/assets/fakedata/imagefolder/b/b3.png similarity index 100% rename from test/assets/dataset/b/b3.png rename to test/assets/fakedata/imagefolder/b/b3.png diff --git a/test/assets/dataset/b/b4.png b/test/assets/fakedata/imagefolder/b/b4.png similarity index 100% rename from test/assets/dataset/b/b4.png rename to test/assets/fakedata/imagefolder/b/b4.png diff --git a/test/assets/fakedata/imagenet/ILSVRC2012_devkit_t12.tar.gz b/test/assets/fakedata/imagenet/ILSVRC2012_devkit_t12.tar.gz new file mode 100644 index 00000000000..1a7a587202d Binary files /dev/null and b/test/assets/fakedata/imagenet/ILSVRC2012_devkit_t12.tar.gz differ diff --git a/test/assets/fakedata/imagenet/ILSVRC2012_img_train.tar b/test/assets/fakedata/imagenet/ILSVRC2012_img_train.tar new file mode 100644 index 00000000000..3905c8e4bcb Binary files /dev/null and b/test/assets/fakedata/imagenet/ILSVRC2012_img_train.tar differ diff --git a/test/assets/fakedata/imagenet/ILSVRC2012_img_val.tar b/test/assets/fakedata/imagenet/ILSVRC2012_img_val.tar new file mode 100644 index 00000000000..dd65e058b4b Binary files /dev/null and b/test/assets/fakedata/imagenet/ILSVRC2012_img_val.tar differ diff --git a/test/test_datasets.py b/test/test_datasets.py index bc6474c96b8..ae8e43e702e 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1,37 +1,113 @@ -import PIL +import os import shutil +import contextlib import tempfile import unittest - +import mock +import PIL import torchvision +FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'assets', 'fakedata') + + +@contextlib.contextmanager +def tmp_dir(src=None, **kwargs): + tmp_dir = tempfile.mkdtemp(**kwargs) + if src is not None: + os.rmdir(tmp_dir) + shutil.copytree(src, tmp_dir) + try: + yield tmp_dir + finally: + shutil.rmtree(tmp_dir) + class Tester(unittest.TestCase): + def test_imagefolder(self): + with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: + classes = sorted(['a', 'b']) + class_a_image_files = [os.path.join(root, 'a', file) + for file in ('a1.png', 'a2.png', 'a3.png')] + class_b_image_files = [os.path.join(root, 'b', file) + for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')] + dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x) + + # test if all classes are present + self.assertEqual(classes, sorted(dataset.classes)) + + # test if combination of classes and class_to_index functions correctly + for cls in classes: + self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]]) + + # test if all images were detected correctly + class_a_idx = dataset.class_to_idx['a'] + class_b_idx = dataset.class_to_idx['b'] + imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files] + imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files] + imgs = sorted(imgs_a + imgs_b) + self.assertEqual(imgs, dataset.imgs) + + # test if the datasets outputs all images correctly + outputs = sorted([dataset[i] for i in range(len(dataset))]) + self.assertEqual(imgs, outputs) + + # redo all tests with specified valid image files + dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x, + is_valid_file=lambda x: '3' in x) + self.assertEqual(classes, sorted(dataset.classes)) + + class_a_idx = dataset.class_to_idx['a'] + class_b_idx = dataset.class_to_idx['b'] + imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files + if '3' in img_file] + imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files + if '3' in img_file] + imgs = sorted(imgs_a + imgs_b) + self.assertEqual(imgs, dataset.imgs) + + outputs = sorted([dataset[i] for i in range(len(dataset))]) + self.assertEqual(imgs, outputs) + def test_mnist(self): - tmp_dir = tempfile.mkdtemp() - dataset = torchvision.datasets.MNIST(tmp_dir, download=True) - self.assertEqual(len(dataset), 60000) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - shutil.rmtree(tmp_dir) + with tmp_dir() as root: + dataset = torchvision.datasets.MNIST(root, download=True) + self.assertEqual(len(dataset), 60000) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) def test_kmnist(self): - tmp_dir = tempfile.mkdtemp() - dataset = torchvision.datasets.KMNIST(tmp_dir, download=True) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - shutil.rmtree(tmp_dir) + with tmp_dir() as root: + dataset = torchvision.datasets.KMNIST(root, download=True) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) def test_fashionmnist(self): - tmp_dir = tempfile.mkdtemp() - dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True) - img, target = dataset[0] - self.assertTrue(isinstance(img, PIL.Image.Image)) - self.assertTrue(isinstance(target, int)) - shutil.rmtree(tmp_dir) + with tmp_dir() as root: + dataset = torchvision.datasets.FashionMNIST(root, download=True) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + + @mock.patch('torchvision.datasets.utils.download_url') + def test_imagenet(self, mock_download): + with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagenet')) as root: + dataset = torchvision.datasets.ImageNet(root, split='train', download=True) + self.assertEqual(len(dataset), 3) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + self.assertEqual(dataset.class_to_idx['Tinca tinca'], target) + + dataset = torchvision.datasets.ImageNet(root, split='val', download=True) + self.assertEqual(len(dataset), 3) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + self.assertEqual(dataset.class_to_idx['Tinca tinca'], target) if __name__ == '__main__': diff --git a/test/test_datasets_transforms.py b/test/test_datasets_transforms.py new file mode 100644 index 00000000000..6cffd4f76a9 --- /dev/null +++ b/test/test_datasets_transforms.py @@ -0,0 +1,72 @@ +import os +import shutil +import contextlib +import tempfile +import unittest +from torchvision.datasets import ImageFolder + +FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'assets', 'fakedata') + + +@contextlib.contextmanager +def tmp_dir(src=None, **kwargs): + tmp_dir = tempfile.mkdtemp(**kwargs) + if src is not None: + os.rmdir(tmp_dir) + shutil.copytree(src, tmp_dir) + try: + yield tmp_dir + finally: + shutil.rmtree(tmp_dir) + + +def mock_transform(return_value, arg_list): + def mock(arg): + arg_list.append(arg) + return return_value + return mock + + +class Tester(unittest.TestCase): + def test_transform(self): + with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: + class_a_image_files = [os.path.join(root, 'a', file) + for file in ('a1.png', 'a2.png', 'a3.png')] + class_b_image_files = [os.path.join(root, 'b', file) + for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')] + return_value = os.path.join(root, 'a', 'a1.png') + args = [] + transform = mock_transform(return_value, args) + dataset = ImageFolder(root, loader=lambda x: x, transform=transform) + + outputs = [dataset[i][0] for i in range(len(dataset))] + self.assertEqual([return_value] * len(outputs), outputs) + + imgs = sorted(class_a_image_files + class_b_image_files) + self.assertEqual(imgs, sorted(args)) + + def test_target_transform(self): + with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: + class_a_image_files = [os.path.join(root, 'a', file) + for file in ('a1.png', 'a2.png', 'a3.png')] + class_b_image_files = [os.path.join(root, 'b', file) + for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')] + return_value = os.path.join(root, 'a', 'a1.png') + args = [] + target_transform = mock_transform(return_value, args) + dataset = ImageFolder(root, loader=lambda x: x, + target_transform=target_transform) + + outputs = [dataset[i][1] for i in range(len(dataset))] + self.assertEqual([return_value] * len(outputs), outputs) + + class_a_idx = dataset.class_to_idx['a'] + class_b_idx = dataset.class_to_idx['b'] + targets = sorted([class_a_idx] * len(class_a_image_files) + + [class_b_idx] * len(class_b_image_files)) + self.assertEqual(targets, sorted(args)) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 79a1e3992c2..8bc377efd8e 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -49,7 +49,7 @@ def test_extract_zip(self): with tempfile.NamedTemporaryFile(suffix='.zip') as f: with zipfile.ZipFile(f, 'w') as zf: zf.writestr('file.tst', 'this is the content') - utils.extract_file(f.name, temp_dir) + utils.extract_archive(f.name, temp_dir) assert os.path.exists(os.path.join(temp_dir, 'file.tst')) with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: data = nf.read() @@ -65,7 +65,7 @@ def test_extract_tar(self): with tempfile.NamedTemporaryFile(suffix=ext) as f: with tarfile.open(f.name, mode=mode) as zf: zf.add(bf.name, arcname='file.tst') - utils.extract_file(f.name, temp_dir) + utils.extract_archive(f.name, temp_dir) assert os.path.exists(os.path.join(temp_dir, 'file.tst')) with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: data = nf.read() @@ -77,7 +77,7 @@ def test_extract_gzip(self): with tempfile.NamedTemporaryFile(suffix='.gz') as f: with gzip.GzipFile(f.name, 'wb') as zf: zf.write('this is the content'.encode()) - utils.extract_file(f.name, temp_dir) + utils.extract_archive(f.name, temp_dir) f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0]) assert os.path.exists(f_name) with open(os.path.join(f_name), 'r') as nf: diff --git a/test/test_folder.py b/test/test_folder.py deleted file mode 100644 index 8b49518b419..00000000000 --- a/test/test_folder.py +++ /dev/null @@ -1,84 +0,0 @@ -import unittest - -import os - -from torchvision.datasets import ImageFolder -from torch._utils_internal import get_file_path_2 - - -def mock_transform(return_value, arg_list): - def mock(arg): - arg_list.append(arg) - return return_value - return mock - - -class Tester(unittest.TestCase): - root = os.path.normpath(get_file_path_2('test/assets/dataset/')) - classes = ['a', 'b'] - class_a_images = [os.path.normpath(get_file_path_2(os.path.join('test/assets/dataset/a/', path))) - for path in ['a1.png', 'a2.png', 'a3.png']] - class_b_images = [os.path.normpath(get_file_path_2(os.path.join('test/assets/dataset/b/', path))) - for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']] - - def test_image_folder(self): - dataset = ImageFolder(Tester.root, loader=lambda x: x) - self.assertEqual(sorted(Tester.classes), sorted(dataset.classes)) - for cls in Tester.classes: - self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]]) - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images] - imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images] - imgs = sorted(imgs_a + imgs_b) - self.assertEqual(imgs, dataset.imgs) - - outputs = sorted([dataset[i] for i in range(len(dataset))]) - self.assertEqual(imgs, outputs) - - dataset = ImageFolder(Tester.root, loader=lambda x: x, is_valid_file=lambda x: '3' in x) - self.assertEqual(sorted(Tester.classes), sorted(dataset.classes)) - for cls in Tester.classes: - self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]]) - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - imgs_a = [(img_path, class_a_idx)for img_path in Tester.class_a_images if '3' in img_path] - imgs_b = [(img_path, class_b_idx)for img_path in Tester.class_b_images if '3' in img_path] - imgs = sorted(imgs_a + imgs_b) - self.assertEqual(imgs, dataset.imgs) - - outputs = sorted([dataset[i] for i in range(len(dataset))]) - self.assertEqual(imgs, outputs) - - def test_transform(self): - return_value = os.path.normpath(get_file_path_2('test/assets/dataset/a/a1.png')) - - args = [] - transform = mock_transform(return_value, args) - - dataset = ImageFolder(Tester.root, loader=lambda x: x, transform=transform) - outputs = [dataset[i][0] for i in range(len(dataset))] - self.assertEqual([return_value] * len(outputs), outputs) - - imgs = sorted(Tester.class_a_images + Tester.class_b_images) - self.assertEqual(imgs, sorted(args)) - - def test_target_transform(self): - return_value = 1 - - args = [] - target_transform = mock_transform(return_value, args) - - dataset = ImageFolder(Tester.root, loader=lambda x: x, target_transform=target_transform) - outputs = [dataset[i][1] for i in range(len(dataset))] - self.assertEqual([return_value] * len(outputs), outputs) - - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - targets = sorted([class_a_idx] * len(Tester.class_a_images) + - [class_b_idx] * len(Tester.class_b_images)) - self.assertEqual(targets, sorted(args)) - - -if __name__ == '__main__': - unittest.main() diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index 43e8d7caf8d..d8ee66c1930 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -4,7 +4,7 @@ import os.path from .vision import VisionDataset -from .utils import download_and_extract, makedir_exist_ok +from .utils import download_and_extract_archive, makedir_exist_ok class Caltech101(VisionDataset): @@ -113,12 +113,12 @@ def download(self): print('Files already downloaded and verified') return - download_and_extract( + download_and_extract_archive( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", self.root, "101_ObjectCategories.tar.gz", "b224c7392d521a49829488ab0f1120d9") - download_and_extract( + download_and_extract_archive( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", self.root, "101_Annotations.tar", @@ -201,7 +201,7 @@ def download(self): print('Files already downloaded and verified') return - download_and_extract( + download_and_extract_archive( "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", self.root, "256_ObjectCategories.tar", diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index 59ecda5cf4d..a25c41eef7a 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -11,7 +11,7 @@ import pickle from .vision import VisionDataset -from .utils import check_integrity, download_and_extract +from .utils import check_integrity, download_and_extract_archive class CIFAR10(VisionDataset): @@ -147,7 +147,7 @@ def download(self): if self._check_integrity(): print('Files already downloaded and verified') return - download_and_extract(self.url, self.root, self.filename, self.tgz_md5) + download_and_extract_archive(self.url, self.root, self.filename, self.tgz_md5) def extra_repr(self): return "Split: {}".format("Train" if self.train is True else "Test") diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index af14d178b8e..85bb7c759be 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -1,9 +1,10 @@ from __future__ import print_function import os import shutil +import tempfile import torch from .folder import ImageFolder -from .utils import check_integrity, download_url +from .utils import check_integrity, download_and_extract_archive, extract_archive ARCHIVE_DICT = { 'train': { @@ -66,23 +67,23 @@ def __init__(self, root, split='train', download=False, **kwargs): def download(self): if not check_integrity(self.meta_file): - tmpdir = os.path.join(self.root, 'tmp') + tmp_dir = tempfile.mkdtemp() archive_dict = ARCHIVE_DICT['devkit'] - download_and_extract_tar(archive_dict['url'], self.root, - extract_root=tmpdir, - md5=archive_dict['md5']) + download_and_extract_archive(archive_dict['url'], self.root, + extract_root=tmp_dir, + md5=archive_dict['md5']) devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0] - meta = parse_devkit(os.path.join(tmpdir, devkit_folder)) + meta = parse_devkit(os.path.join(tmp_dir, devkit_folder)) self._save_meta_file(*meta) - shutil.rmtree(tmpdir) + shutil.rmtree(tmp_dir) if not os.path.isdir(self.split_folder): archive_dict = ARCHIVE_DICT[self.split] - download_and_extract_tar(archive_dict['url'], self.root, - extract_root=self.split_folder, - md5=archive_dict['md5']) + download_and_extract_archive(archive_dict['url'], self.root, + extract_root=self.split_folder, + md5=archive_dict['md5']) if self.split == 'train': prepare_train_folder(self.split_folder) @@ -128,36 +129,6 @@ def extra_repr(self): return "Split: {split}".format(**self.__dict__) -def extract_tar(src, dest=None, gzip=None, delete=False): - import tarfile - - if dest is None: - dest = os.path.dirname(src) - if gzip is None: - gzip = src.lower().endswith('.gz') - - mode = 'r:gz' if gzip else 'r' - with tarfile.open(src, mode) as tarfh: - tarfh.extractall(path=dest) - - if delete: - os.remove(src) - - -def download_and_extract_tar(url, download_root, extract_root=None, filename=None, - md5=None, **kwargs): - download_root = os.path.expanduser(download_root) - if extract_root is None: - extract_root = download_root - if filename is None: - filename = os.path.basename(url) - - if not check_integrity(os.path.join(download_root, filename), md5): - download_url(url, download_root, filename=filename, md5=md5) - - extract_tar(os.path.join(download_root, filename), extract_root, **kwargs) - - def parse_devkit(root): idx_to_wnid, wnid_to_classes = parse_meta(root) val_idcs = parse_val_groundtruth(root) @@ -189,7 +160,7 @@ def parse_val_groundtruth(devkit_root, path='data', def prepare_train_folder(folder): for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: - extract_tar(archive, os.path.splitext(archive)[0], delete=True) + extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) def prepare_val_folder(folder, wnids): diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 51a3cb19df4..fe861ea5930 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -7,7 +7,7 @@ import numpy as np import torch import codecs -from .utils import download_and_extract, extract_file, makedir_exist_ok +from .utils import download_and_extract_archive, extract_archive, makedir_exist_ok class MNIST(VisionDataset): @@ -131,7 +131,7 @@ def download(self): # download files for url in self.urls: filename = url.rpartition('/')[2] - download_and_extract(url, root=self.raw_folder, filename=filename) + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename) # process and save as torch files print('Processing...') @@ -259,11 +259,12 @@ def download(self): # download files print('Downloading and extracting zip archive') - download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True) + download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip", + remove_finished=True) gzip_folder = os.path.join(self.raw_folder, 'gzip') for gzip_file in os.listdir(gzip_folder): if gzip_file.endswith('.gz'): - extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder) + extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder) # process and save as torch files for split in self.splits: diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index b5f6d64f12e..2a0ab06285f 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -3,7 +3,7 @@ from os.path import join import os from .vision import VisionDataset -from .utils import download_and_extract, check_integrity, list_dir, list_files +from .utils import download_and_extract_archive, check_integrity, list_dir, list_files class Omniglot(VisionDataset): @@ -88,7 +88,7 @@ def download(self): filename = self._get_target_folder() zip_filename = filename + '.zip' url = self.download_url_prefix + '/' + zip_filename - download_and_extract(url, self.root, zip_filename, self.zips_md5[filename]) + download_and_extract_archive(url, self.root, zip_filename, self.zips_md5[filename]) def _get_target_folder(self): return 'images_background' if self.background else 'images_evaluation' diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index a5ceedbca2b..f6846ee8c10 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -5,7 +5,7 @@ import numpy as np from .vision import VisionDataset -from .utils import check_integrity, download_and_extract +from .utils import check_integrity, download_and_extract_archive class STL10(VisionDataset): @@ -152,7 +152,7 @@ def download(self): if self._check_integrity(): print('Files already downloaded and verified') return - download_and_extract(self.url, self.root, self.filename, self.tgz_md5) + download_and_extract_archive(self.url, self.root, self.filename, self.tgz_md5) def extra_repr(self): return "Split: {split}".format(**self.__dict__) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 86a2af48d52..3e3c674ede9 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -38,8 +38,7 @@ def check_integrity(fpath, md5=None): return False if md5 is None: return True - else: - return check_md5(fpath, md5) + return check_md5(fpath, md5) def makedir_exist_ok(dirpath): @@ -74,7 +73,7 @@ def download_url(url, root, filename=None, md5=None): makedir_exist_ok(root) # downloads file - if os.path.isfile(fpath) and check_integrity(fpath, md5): + if check_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) else: try: @@ -211,9 +210,12 @@ def _is_zip(filename): return filename.endswith(".zip") -def extract_file(from_path, to_path, remove_finished=False): +def extract_archive(from_path, to_path=None, remove_finished=False): + if to_path is None: + to_path = os.path.dirname(from_path) + if _is_tar(from_path): - with tarfile.open(from_path, 'r:') as tar: + with tarfile.open(from_path, 'r') as tar: tar.extractall(path=to_path) elif _is_targz(from_path): with tarfile.open(from_path, 'r:gz') as tar: @@ -229,10 +231,19 @@ def extract_file(from_path, to_path, remove_finished=False): raise ValueError("Extraction of {} not supported".format(from_path)) if remove_finished: - os.unlink(from_path) + os.remove(from_path) + + +def download_and_extract_archive(url, download_root, extract_root=None, filename=None, + md5=None, remove_finished=False): + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + download_url(url, download_root, filename, md5) -def download_and_extract(url, root, filename, md5=None, remove_finished=False): - download_url(url, root, filename, md5) - print("Extracting {} to {}".format(os.path.join(root, filename), root)) - extract_file(os.path.join(root, filename), root, remove_finished) + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished)