Skip to content

Add tests for the STL10 dataset #3345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 129 additions & 36 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,48 @@
import unittest.mock
import hashlib
from distutils import dir_util
import re


def mock_class_attribute(stack, target, new):
mock = unittest.mock.patch(target, new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock


def compute_md5(file):
with open(file, "rb") as fh:
return hashlib.md5(fh.read()).hexdigest()


def make_tar(root, name, *files, compression=None):
ext = ".tar"
mode = "w"
if compression is not None:
ext = f"{ext}.{compression}"
mode = f"{mode}:{compression}"

name = os.path.splitext(name)[0] + ext
archive = os.path.join(root, name)

with tarfile.open(archive, mode) as fh:
for file in files:
fh.add(os.path.join(root, file), arcname=file)

return name, compute_md5(archive)


def clean_dir(root, *keep):
pattern = re.compile(f"({f')|('.join(keep)})")
for file_or_dir in os.listdir(root):
if pattern.search(file_or_dir):
continue

file_or_dir = os.path.join(root, file_or_dir)
if os.path.isfile(file_or_dir):
os.remove(file_or_dir)
else:
dir_util.remove_tree(file_or_dir)


@contextlib.contextmanager
Expand Down Expand Up @@ -385,7 +427,7 @@ def ucf101_root():


@contextlib.contextmanager
def places365_root(split="train-standard", small=False, extract_images=True):
def places365_root(split="train-standard", small=False):
VARIANTS = {
"train-standard": "standard",
"train-challenge": "challenge",
Expand Down Expand Up @@ -425,15 +467,6 @@ def places365_root(split="train-standard", small=False, extract_images=True):
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"

def mock_class_attribute(stack, attr, new):
mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock

def compute_md5(file):
with open(file, "rb") as fh:
return hashlib.md5(fh.read()).hexdigest()

def make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
Expand All @@ -451,37 +484,20 @@ def make_image(file, size):
os.makedirs(os.path.dirname(file), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)

def make_tar(root, name, *files, remove_files=True):
name = f"{os.path.splitext(name)[0]}.tar"
archive = os.path.join(root, name)

with tarfile.open(archive, "w") as fh:
for file in files:
fh.add(os.path.join(root, file), arcname=file)

if remove_files:
for file in [os.path.join(root, file) for file in files]:
if os.path.isdir(file):
dir_util.remove_tree(file)
else:
os.remove(file)

return name, compute_md5(archive)

def make_devkit_archive(stack, root, split):
archive = DEVKITS[split]
files = []

meta = make_categories_txt(root, CATEGORIES)
mock_class_attribute(stack, "_CATEGORIES_META", meta)
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
files.append(meta[0])

meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
mock_class_attribute(stack, "_FILE_LIST_META", meta)
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
files.extend([item[0] for item in meta.values()])

meta = {VARIANTS[split]: make_tar(root, archive, *files)}
mock_class_attribute(stack, "_DEVKIT_META", meta)
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)

def make_images_archive(stack, root, split, small):
archive, folder_default, folder_renamed = IMAGES[(split, small)]
Expand All @@ -493,20 +509,97 @@ def make_images_archive(stack, root, split, small):
make_image(os.path.join(root, folder_default, image), image_size)

meta = {(split, small): make_tar(root, archive, folder_default)}
mock_class_attribute(stack, "_IMAGES_META", meta)
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)

return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]

with contextlib.ExitStack() as stack, get_tmp_dir() as root:
make_devkit_archive(stack, root, split)
class_to_idx = dict(CATEGORIES_CONTENT)
classes = list(class_to_idx.keys())

data = {"class_to_idx": class_to_idx, "classes": classes}
data["imgs"] = make_images_archive(stack, root, split, small)

if extract_images:
data["imgs"] = make_images_archive(stack, root, split, small)
else:
stack.enter_context(unittest.mock.patch(mock_target("download_images")))
data["imgs"] = None
clean_dir(root, ".tar$")

yield root, data


@contextlib.contextmanager
def stl10_root(_extracted=False):
CLASS_NAMES = ("airplane", "bird")
ARCHIVE_NAME = "stl10_binary"
NUM_FOLDS = 10

def mock_target(attr, partial="torchvision.datasets.stl10.STL10"):
return f"{partial}.{attr}"

def make_binary_file(num_elements, root, name):
file = os.path.join(root, name)
np.zeros(num_elements, dtype=np.uint8).tofile(file)
return name, compute_md5(file)

def make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
return make_binary_file(num_images * num_channels * height * width, root, name)

def make_label_file(num_images, root, name):
return make_binary_file(num_images, root, name)

def make_class_names_file(root, name="class_names.txt"):
with open(os.path.join(root, name), "w") as fh:
for name in CLASS_NAMES:
fh.write(f"{name}\n")

def make_fold_indices_file(root):
offset = 0
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
for fold in range(NUM_FOLDS):
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
fh.write(f"{line}\n")
offset += fold + 1

return tuple(range(1, NUM_FOLDS + 1))

def make_train_files(stack, root, num_unlabeled_images=1):
num_images_in_fold = make_fold_indices_file(root)
num_train_images = sum(num_images_in_fold)

train_list = [
list(make_image_file(num_train_images, root, "train_X.bin")),
list(make_label_file(num_train_images, root, "train_y.bin")),
list(make_image_file(1, root, "unlabeled_X.bin"))
]
mock_class_attribute(stack, target=mock_target("train_list"), new=train_list)

return num_images_in_fold, dict(train=num_train_images, unlabeled=num_unlabeled_images)

def make_test_files(stack, root, num_images=2):
test_list = [
list(make_image_file(num_images, root, "test_X.bin")),
list(make_label_file(num_images, root, "test_y.bin")),
]
mock_class_attribute(stack, target=mock_target("test_list"), new=test_list)

return dict(test=num_images)

def make_archive(stack, root, name):
archive, md5 = make_tar(root, name, name, compression="gz")
mock_class_attribute(stack, target=mock_target("tgz_md5"), new=md5)
return archive

with contextlib.ExitStack() as stack, get_tmp_dir() as root:
archive_folder = os.path.join(root, ARCHIVE_NAME)
os.mkdir(archive_folder)

num_images_in_folds, num_images_in_split = make_train_files(stack, archive_folder)
num_images_in_split.update(make_test_files(stack, archive_folder))

make_class_names_file(archive_folder)

archive = make_archive(stack, root, ARCHIVE_NAME)

dir_util.remove_tree(archive_folder)
data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive)

yield root, data
84 changes: 80 additions & 4 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import sys
import os
import unittest
Expand All @@ -7,9 +8,10 @@
from PIL import Image
from torch._utils_internal import get_file_path_2
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
Expand All @@ -28,7 +30,7 @@
HAS_PYAV = False


class Tester(unittest.TestCase):
class DatasetTestcase(unittest.TestCase):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was necessary since STL10Tester needs access to generic_classification_dataset_test without subclassing Tester directly. Otherwise running STL10Tester would also run all inherited test.

def generic_classification_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
Expand All @@ -41,6 +43,8 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1):
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, PIL.Image.Image))


class Tester(DatasetTestcase):
def test_imagefolder(self):
# TODO: create the fake data on-the-fly
FAKEDATA_DIR = get_file_path_2(
Expand Down Expand Up @@ -354,7 +358,7 @@ def test_places365_devkit_download(self):
def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split, extract_images=False) as places365:
with places365_root(split=split) as places365:
root, data = places365

with self.assertRaises(RuntimeError):
Expand Down Expand Up @@ -383,12 +387,84 @@ def test_places365_images_download_preexisting(self):
torchvision.datasets.Places365(root, split=split, small=small, download=True)

def test_places365_repr_smoke(self):
with places365_root(extract_images=False) as places365:
with places365_root() as places365:
root, data = places365

dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)


class STL10Tester(DatasetTestcase):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After adding the tests for Places365 we agreed on creating a separate test case for a dataset if more than the generics are checked

@contextlib.contextmanager
def mocked_root(self):
with stl10_root() as (root, data):
yield root, data

@contextlib.contextmanager
def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
with self.mocked_root() as (root, data):
if pre_extract:
utils.extract_archive(os.path.join(root, data["archive"]))
dataset = torchvision.datasets.STL10(root, download=download, **kwargs)
yield dataset, data

def test_not_found(self):
with self.assertRaises(RuntimeError):
with self.mocked_dataset(download=False):
pass

def test_splits(self):
for split in ('train', 'train+unlabeled', 'unlabeled', 'test'):
with self.mocked_dataset(split=split) as (dataset, data):
num_images = sum([data["num_images_in_split"][part] for part in split.split("+")])
self.generic_classification_dataset_test(dataset, num_images=num_images)

def test_folds(self):
for fold in range(10):
with self.mocked_dataset(split="train", folds=fold) as (dataset, data):
num_images = data["num_images_in_folds"][fold]
self.assertEqual(len(dataset), num_images)

def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds=10):
pass

def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds="0"):
pass

def test_transforms(self):
expected_image = "image"
expected_target = "target"

def transform(image):
return expected_image

def target_transform(target):
return expected_target

with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _):
actual_image, actual_target = dataset[0]

self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)

def test_unlabeled(self):
with self.mocked_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all([label == -1 for label in labels]))

@unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive")
def test_download_preexisting(self, mock):
with self.mocked_dataset(pre_extract=True) as (dataset, data):
mock.assert_not_called()

def test_repr_smoke(self):
with self.mocked_dataset() as (dataset, _):
self.assertIsInstance(repr(dataset), str)


if __name__ == '__main__':
unittest.main()