Skip to content

[WIP] Add tests for datasets #966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import PIL
import shutil
import tempfile
import unittest

import torchvision


class Tester(unittest.TestCase):

def test_mnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.MNIST(tmp_dir, download=True)
self.assertEqual(len(dataset), 60000)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)

def test_kmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.KMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)

def test_fashionmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)


if __name__ == '__main__':
unittest.main()
44 changes: 44 additions & 0 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import tempfile
import torchvision.datasets.utils as utils
import unittest
import zipfile
import tarfile
import gzip

TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'grace_hopper_517x606.jpg')
Expand Down Expand Up @@ -41,6 +44,47 @@ def test_download_url_retry_http(self):
assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
shutil.rmtree(temp_dir)

def test_extract_zip(self):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content')
utils.extract_file(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content'
shutil.rmtree(temp_dir)

def test_extract_tar(self):
for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_file(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)

def test_extract_gzip(self):
temp_dir = tempfile.mkdtemp()
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode())
utils.extract_file(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
assert os.path.exists(f_name)
with open(os.path.join(f_name), 'r') as nf:
data = nf.read()
assert data == 'this is the content', data
shutil.rmtree(temp_dir)


if __name__ == '__main__':
unittest.main()
44 changes: 16 additions & 28 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path

from .vision import VisionDataset
from .utils import download_url, makedir_exist_ok
from .utils import download_and_extract, makedir_exist_ok


class Caltech101(VisionDataset):
Expand Down Expand Up @@ -109,27 +109,20 @@ def __len__(self):
return len(self.index)

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91")

# extract file
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
tar.extractall(path=self.root)

with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)
download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91")

def extra_repr(self):
return "Target type: {target_type}".format(**self.__dict__)
Expand Down Expand Up @@ -204,17 +197,12 @@ def __len__(self):
return len(self.index)

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d")

# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)
download_and_extract(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d")
11 changes: 2 additions & 9 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pickle

from .vision import VisionDataset
from .utils import download_url, check_integrity
from .utils import check_integrity, download_and_extract


class CIFAR10(VisionDataset):
Expand Down Expand Up @@ -144,17 +144,10 @@ def _check_integrity(self):
return True

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url(self.url, self.root, self.filename, self.tgz_md5)

# extract file
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
download_and_extract(self.url, self.root, self.filename, self.tgz_md5)

def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
Expand Down
29 changes: 5 additions & 24 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from PIL import Image
import os
import os.path
import gzip
import numpy as np
import torch
import codecs
from .utils import download_url, makedir_exist_ok
from .utils import download_and_extract, extract_file, makedir_exist_ok


class MNIST(VisionDataset):
Expand Down Expand Up @@ -120,15 +119,6 @@ def _check_exists(self):
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))

@staticmethod
def extract_gzip(gzip_path, remove_finished=False):
print('Extracting {}'.format(gzip_path))
with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(gzip_path) as zip_f:
out_f.write(zip_f.read())
if remove_finished:
os.unlink(gzip_path)

def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""

Expand All @@ -141,9 +131,7 @@ def download(self):
# download files
for url in self.urls:
filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
download_url(url, root=self.raw_folder, filename=filename, md5=None)
self.extract_gzip(gzip_path=file_path, remove_finished=True)
download_and_extract(url, root=self.raw_folder, filename=filename)

# process and save as torch files
print('Processing...')
Expand Down Expand Up @@ -262,7 +250,6 @@ def _test_file(split):
def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
import shutil
import zipfile

if self._check_exists():
return
Expand All @@ -271,18 +258,12 @@ def download(self):
makedir_exist_ok(self.processed_folder)

# download files
filename = self.url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
download_url(self.url, root=self.raw_folder, filename=filename, md5=None)

print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(self.raw_folder)
os.unlink(file_path)
print('Downloading and extracting zip archive')
download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True)
gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))
extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder)

# process and save as torch files
for split in self.splits:
Expand Down
9 changes: 2 additions & 7 deletions torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from os.path import join
import os
from .vision import VisionDataset
from .utils import download_url, check_integrity, list_dir, list_files
from .utils import download_and_extract, check_integrity, list_dir, list_files


class Omniglot(VisionDataset):
Expand Down Expand Up @@ -81,19 +81,14 @@ def _check_integrity(self):
return True

def download(self):
import zipfile

if self._check_integrity():
print('Files already downloaded and verified')
return

filename = self._get_target_folder()
zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename
download_url(url, self.root, zip_filename, self.zips_md5[filename])
print('Extracting downloaded file: ' + join(self.root, zip_filename))
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file:
zip_file.extractall(self.root)
download_and_extract(url, self.root, zip_filename, self.zips_md5[filename])

def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation'
47 changes: 47 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import os.path
import hashlib
import gzip
import errno
import tarfile
import zipfile
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we do a lazy import in extract_file()? AFAIK we do this now to prevent becoming dependent on these packages at import.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe tarfile, zipfile and gzip are part of the python standard library, so I think this should be ok


from torch.utils.model_zoo import tqdm


Expand Down Expand Up @@ -189,3 +193,46 @@ def _save_response_content(response, destination, chunk_size=32768):
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()


def _is_tar(filename):
return filename.endswith(".tar")


def _is_targz(filename):
return filename.endswith(".tar.gz")


def _is_gzip(filename):
return filename.endswith(".gz") and not filename.endswith(".tar.gz")


def _is_zip(filename):
return filename.endswith(".zip")


def extract_file(from_path, to_path, remove_finished=False):
if _is_tar(from_path):
with tarfile.open(from_path, 'r:') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))

if remove_finished:
os.unlink(from_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pure curiosity: Why did you use os.unlink instead of os.remove? I'm only now aware that they provide the same functionality. I think os.remove would be clearer since the flag is also called remove_finished and not unlink_finished.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No other reason than just because it was what was used in MNIST before, so I decided to do the same here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.



def download_and_extract(url, root, filename, md5=None, remove_finished=False):
download_url(url, root, filename, md5)
print("Extracting {} to {}".format(os.path.join(root, filename), root))
extract_file(os.path.join(root, filename), root, remove_finished)