Skip to content

[WIP] Add test for ImageNet #976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ before_install:
fi
- pip install future
- pip install pytest pytest-cov codecov
- pip install mock


install:
Expand Down
Binary file added test/assets/fakedata/dummy.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/fakedata/dummy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
118 changes: 97 additions & 21 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,113 @@
import PIL
import os
import shutil
import contextlib
import tempfile
import unittest

import mock
import PIL
import torchvision

FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'fakedata')


@contextlib.contextmanager
def tmp_dir(src=None, **kwargs):
tmp_dir = tempfile.mkdtemp(**kwargs)
if src is not None:
os.rmdir(tmp_dir)
shutil.copytree(src, tmp_dir)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)


class Tester(unittest.TestCase):

def test_imagefolder(self):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
classes = sorted(['a', 'b'])
class_a_image_files = [os.path.join(root, 'a', file)
for file in ('a1.png', 'a2.png', 'a3.png')]
class_b_image_files = [os.path.join(root, 'b', file)
for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')]
dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)

# test if all classes are present
self.assertEqual(classes, sorted(dataset.classes))

# test if combination of classes and class_to_index functions correctly
for cls in classes:
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])

# test if all images were detected correctly
class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)

# test if the datasets outputs all images correctly
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)

# redo all tests with specified valid image files
dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x,
is_valid_file=lambda x: '3' in x)
self.assertEqual(classes, sorted(dataset.classes))

class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
if '3' in img_file]
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
if '3' in img_file]
imgs = sorted(imgs_a + imgs_b)
self.assertEqual(imgs, dataset.imgs)

outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)

def test_mnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.MNIST(tmp_dir, download=True)
self.assertEqual(len(dataset), 60000)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
with tmp_dir() as root:
dataset = torchvision.datasets.MNIST(root, download=True)
self.assertEqual(len(dataset), 60000)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))

def test_kmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.KMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
with tmp_dir() as root:
dataset = torchvision.datasets.KMNIST(root, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))

def test_fashionmnist(self):
tmp_dir = tempfile.mkdtemp()
dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
shutil.rmtree(tmp_dir)
with tmp_dir() as root:
dataset = torchvision.datasets.FashionMNIST(root, download=True)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))

@mock.patch('torchvision.datasets.utils.download_url')
def test_imagenet(self, mock_download):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagenet')) as root:
dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
self.assertEqual(len(dataset), 3)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['Tinca tinca'], target)

dataset = torchvision.datasets.ImageNet(root, split='val', download=True)
self.assertEqual(len(dataset), 3)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['Tinca tinca'], target)


if __name__ == '__main__':
Expand Down
72 changes: 72 additions & 0 deletions test/test_datasets_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import shutil
import contextlib
import tempfile
import unittest
from torchvision.datasets import ImageFolder

FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'assets', 'fakedata')


@contextlib.contextmanager
def tmp_dir(src=None, **kwargs):
tmp_dir = tempfile.mkdtemp(**kwargs)
if src is not None:
os.rmdir(tmp_dir)
shutil.copytree(src, tmp_dir)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)


def mock_transform(return_value, arg_list):
def mock(arg):
arg_list.append(arg)
return return_value
return mock


class Tester(unittest.TestCase):
def test_transform(self):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
class_a_image_files = [os.path.join(root, 'a', file)
for file in ('a1.png', 'a2.png', 'a3.png')]
class_b_image_files = [os.path.join(root, 'b', file)
for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')]
return_value = os.path.join(root, 'a', 'a1.png')
args = []
transform = mock_transform(return_value, args)
dataset = ImageFolder(root, loader=lambda x: x, transform=transform)

outputs = [dataset[i][0] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)

imgs = sorted(class_a_image_files + class_b_image_files)
self.assertEqual(imgs, sorted(args))

def test_target_transform(self):
with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
class_a_image_files = [os.path.join(root, 'a', file)
for file in ('a1.png', 'a2.png', 'a3.png')]
class_b_image_files = [os.path.join(root, 'b', file)
for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')]
return_value = os.path.join(root, 'a', 'a1.png')
args = []
target_transform = mock_transform(return_value, args)
dataset = ImageFolder(root, loader=lambda x: x,
target_transform=target_transform)

outputs = [dataset[i][1] for i in range(len(dataset))]
self.assertEqual([return_value] * len(outputs), outputs)

class_a_idx = dataset.class_to_idx['a']
class_b_idx = dataset.class_to_idx['b']
targets = sorted([class_a_idx] * len(class_a_image_files) +
[class_b_idx] * len(class_b_image_files))
self.assertEqual(targets, sorted(args))


if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_extract_zip(self):
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content')
utils.extract_file(f.name, temp_dir)
utils.extract_archive(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
Expand All @@ -65,7 +65,7 @@ def test_extract_tar(self):
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_file(f.name, temp_dir)
utils.extract_archive(f.name, temp_dir)
assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
Expand All @@ -77,7 +77,7 @@ def test_extract_gzip(self):
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode())
utils.extract_file(f.name, temp_dir)
utils.extract_archive(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
assert os.path.exists(f_name)
with open(os.path.join(f_name), 'r') as nf:
Expand Down
84 changes: 0 additions & 84 deletions test/test_folder.py

This file was deleted.

8 changes: 4 additions & 4 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path

from .vision import VisionDataset
from .utils import download_and_extract, makedir_exist_ok
from .utils import download_and_extract_archive, makedir_exist_ok


class Caltech101(VisionDataset):
Expand Down Expand Up @@ -113,12 +113,12 @@ def download(self):
print('Files already downloaded and verified')
return

download_and_extract(
download_and_extract_archive(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_and_extract(
download_and_extract_archive(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
Expand Down Expand Up @@ -201,7 +201,7 @@ def download(self):
print('Files already downloaded and verified')
return

download_and_extract(
download_and_extract_archive(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pickle

from .vision import VisionDataset
from .utils import check_integrity, download_and_extract
from .utils import check_integrity, download_and_extract_archive


class CIFAR10(VisionDataset):
Expand Down Expand Up @@ -147,7 +147,7 @@ def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract(self.url, self.root, self.filename, self.tgz_md5)
download_and_extract_archive(self.url, self.root, self.filename, self.tgz_md5)

def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
Expand Down
Loading