Skip to content

Commit 1dbc6f9

Browse files
datumboxfmassa
authored andcommitted
Add tests for the STL10 dataset (#3345)
Summary: * extract some functionality from places365 fakedata for common use * add a common DatasetTestcase * add fakedata generation and tests for STL10 * lint Reviewed By: fmassa Differential Revision: D26341418 fbshipit-source-id: 05f8a60c986c32f64339197ea377efc6c4d5b238 Co-authored-by: Francisco Massa <[email protected]>
1 parent 9df3022 commit 1dbc6f9

File tree

2 files changed

+209
-40
lines changed

2 files changed

+209
-40
lines changed

test/fakedata_generation.py

Lines changed: 129 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,48 @@
1313
import unittest.mock
1414
import hashlib
1515
from distutils import dir_util
16+
import re
17+
18+
19+
def mock_class_attribute(stack, target, new):
20+
mock = unittest.mock.patch(target, new_callable=unittest.mock.PropertyMock, return_value=new)
21+
stack.enter_context(mock)
22+
return mock
23+
24+
25+
def compute_md5(file):
26+
with open(file, "rb") as fh:
27+
return hashlib.md5(fh.read()).hexdigest()
28+
29+
30+
def make_tar(root, name, *files, compression=None):
31+
ext = ".tar"
32+
mode = "w"
33+
if compression is not None:
34+
ext = f"{ext}.{compression}"
35+
mode = f"{mode}:{compression}"
36+
37+
name = os.path.splitext(name)[0] + ext
38+
archive = os.path.join(root, name)
39+
40+
with tarfile.open(archive, mode) as fh:
41+
for file in files:
42+
fh.add(os.path.join(root, file), arcname=file)
43+
44+
return name, compute_md5(archive)
45+
46+
47+
def clean_dir(root, *keep):
48+
pattern = re.compile(f"({f')|('.join(keep)})")
49+
for file_or_dir in os.listdir(root):
50+
if pattern.search(file_or_dir):
51+
continue
52+
53+
file_or_dir = os.path.join(root, file_or_dir)
54+
if os.path.isfile(file_or_dir):
55+
os.remove(file_or_dir)
56+
else:
57+
dir_util.remove_tree(file_or_dir)
1658

1759

1860
@contextlib.contextmanager
@@ -385,7 +427,7 @@ def ucf101_root():
385427

386428

387429
@contextlib.contextmanager
388-
def places365_root(split="train-standard", small=False, extract_images=True):
430+
def places365_root(split="train-standard", small=False):
389431
VARIANTS = {
390432
"train-standard": "standard",
391433
"train-challenge": "challenge",
@@ -425,15 +467,6 @@ def places365_root(split="train-standard", small=False, extract_images=True):
425467
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
426468
return f"{partial}.{attr}"
427469

428-
def mock_class_attribute(stack, attr, new):
429-
mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new)
430-
stack.enter_context(mock)
431-
return mock
432-
433-
def compute_md5(file):
434-
with open(file, "rb") as fh:
435-
return hashlib.md5(fh.read()).hexdigest()
436-
437470
def make_txt(root, name, seq):
438471
file = os.path.join(root, name)
439472
with open(file, "w") as fh:
@@ -451,37 +484,20 @@ def make_image(file, size):
451484
os.makedirs(os.path.dirname(file), exist_ok=True)
452485
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)
453486

454-
def make_tar(root, name, *files, remove_files=True):
455-
name = f"{os.path.splitext(name)[0]}.tar"
456-
archive = os.path.join(root, name)
457-
458-
with tarfile.open(archive, "w") as fh:
459-
for file in files:
460-
fh.add(os.path.join(root, file), arcname=file)
461-
462-
if remove_files:
463-
for file in [os.path.join(root, file) for file in files]:
464-
if os.path.isdir(file):
465-
dir_util.remove_tree(file)
466-
else:
467-
os.remove(file)
468-
469-
return name, compute_md5(archive)
470-
471487
def make_devkit_archive(stack, root, split):
472488
archive = DEVKITS[split]
473489
files = []
474490

475491
meta = make_categories_txt(root, CATEGORIES)
476-
mock_class_attribute(stack, "_CATEGORIES_META", meta)
492+
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
477493
files.append(meta[0])
478494

479495
meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
480-
mock_class_attribute(stack, "_FILE_LIST_META", meta)
496+
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
481497
files.extend([item[0] for item in meta.values()])
482498

483499
meta = {VARIANTS[split]: make_tar(root, archive, *files)}
484-
mock_class_attribute(stack, "_DEVKIT_META", meta)
500+
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)
485501

486502
def make_images_archive(stack, root, split, small):
487503
archive, folder_default, folder_renamed = IMAGES[(split, small)]
@@ -493,20 +509,97 @@ def make_images_archive(stack, root, split, small):
493509
make_image(os.path.join(root, folder_default, image), image_size)
494510

495511
meta = {(split, small): make_tar(root, archive, folder_default)}
496-
mock_class_attribute(stack, "_IMAGES_META", meta)
512+
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)
497513

498514
return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]
499515

500516
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
501517
make_devkit_archive(stack, root, split)
502518
class_to_idx = dict(CATEGORIES_CONTENT)
503519
classes = list(class_to_idx.keys())
520+
504521
data = {"class_to_idx": class_to_idx, "classes": classes}
522+
data["imgs"] = make_images_archive(stack, root, split, small)
505523

506-
if extract_images:
507-
data["imgs"] = make_images_archive(stack, root, split, small)
508-
else:
509-
stack.enter_context(unittest.mock.patch(mock_target("download_images")))
510-
data["imgs"] = None
524+
clean_dir(root, ".tar$")
525+
526+
yield root, data
527+
528+
529+
@contextlib.contextmanager
530+
def stl10_root(_extracted=False):
531+
CLASS_NAMES = ("airplane", "bird")
532+
ARCHIVE_NAME = "stl10_binary"
533+
NUM_FOLDS = 10
534+
535+
def mock_target(attr, partial="torchvision.datasets.stl10.STL10"):
536+
return f"{partial}.{attr}"
537+
538+
def make_binary_file(num_elements, root, name):
539+
file = os.path.join(root, name)
540+
np.zeros(num_elements, dtype=np.uint8).tofile(file)
541+
return name, compute_md5(file)
542+
543+
def make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
544+
return make_binary_file(num_images * num_channels * height * width, root, name)
545+
546+
def make_label_file(num_images, root, name):
547+
return make_binary_file(num_images, root, name)
548+
549+
def make_class_names_file(root, name="class_names.txt"):
550+
with open(os.path.join(root, name), "w") as fh:
551+
for name in CLASS_NAMES:
552+
fh.write(f"{name}\n")
553+
554+
def make_fold_indices_file(root):
555+
offset = 0
556+
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
557+
for fold in range(NUM_FOLDS):
558+
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
559+
fh.write(f"{line}\n")
560+
offset += fold + 1
561+
562+
return tuple(range(1, NUM_FOLDS + 1))
563+
564+
def make_train_files(stack, root, num_unlabeled_images=1):
565+
num_images_in_fold = make_fold_indices_file(root)
566+
num_train_images = sum(num_images_in_fold)
567+
568+
train_list = [
569+
list(make_image_file(num_train_images, root, "train_X.bin")),
570+
list(make_label_file(num_train_images, root, "train_y.bin")),
571+
list(make_image_file(1, root, "unlabeled_X.bin"))
572+
]
573+
mock_class_attribute(stack, target=mock_target("train_list"), new=train_list)
574+
575+
return num_images_in_fold, dict(train=num_train_images, unlabeled=num_unlabeled_images)
576+
577+
def make_test_files(stack, root, num_images=2):
578+
test_list = [
579+
list(make_image_file(num_images, root, "test_X.bin")),
580+
list(make_label_file(num_images, root, "test_y.bin")),
581+
]
582+
mock_class_attribute(stack, target=mock_target("test_list"), new=test_list)
583+
584+
return dict(test=num_images)
585+
586+
def make_archive(stack, root, name):
587+
archive, md5 = make_tar(root, name, name, compression="gz")
588+
mock_class_attribute(stack, target=mock_target("tgz_md5"), new=md5)
589+
return archive
590+
591+
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
592+
archive_folder = os.path.join(root, ARCHIVE_NAME)
593+
os.mkdir(archive_folder)
594+
595+
num_images_in_folds, num_images_in_split = make_train_files(stack, archive_folder)
596+
num_images_in_split.update(make_test_files(stack, archive_folder))
597+
598+
make_class_names_file(archive_folder)
599+
600+
archive = make_archive(stack, root, ARCHIVE_NAME)
601+
602+
dir_util.remove_tree(archive_folder)
603+
data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive)
511604

512605
yield root, data

test/test_datasets.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import sys
23
import os
34
import unittest
@@ -7,9 +8,10 @@
78
from PIL import Image
89
from torch._utils_internal import get_file_path_2
910
import torchvision
11+
from torchvision.datasets import utils
1012
from common_utils import get_tmp_dir
1113
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
12-
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root
14+
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root
1315
import xml.etree.ElementTree as ET
1416
from urllib.request import Request, urlopen
1517
import itertools
@@ -28,7 +30,7 @@
2830
HAS_PYAV = False
2931

3032

31-
class Tester(unittest.TestCase):
33+
class DatasetTestcase(unittest.TestCase):
3234
def generic_classification_dataset_test(self, dataset, num_images=1):
3335
self.assertEqual(len(dataset), num_images)
3436
img, target = dataset[0]
@@ -41,6 +43,8 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1):
4143
self.assertTrue(isinstance(img, PIL.Image.Image))
4244
self.assertTrue(isinstance(target, PIL.Image.Image))
4345

46+
47+
class Tester(DatasetTestcase):
4448
def test_imagefolder(self):
4549
# TODO: create the fake data on-the-fly
4650
FAKEDATA_DIR = get_file_path_2(
@@ -354,7 +358,7 @@ def test_places365_devkit_download(self):
354358
def test_places365_devkit_no_download(self):
355359
for split in ("train-standard", "train-challenge", "val"):
356360
with self.subTest(split=split):
357-
with places365_root(split=split, extract_images=False) as places365:
361+
with places365_root(split=split) as places365:
358362
root, data = places365
359363

360364
with self.assertRaises(RuntimeError):
@@ -383,12 +387,84 @@ def test_places365_images_download_preexisting(self):
383387
torchvision.datasets.Places365(root, split=split, small=small, download=True)
384388

385389
def test_places365_repr_smoke(self):
386-
with places365_root(extract_images=False) as places365:
390+
with places365_root() as places365:
387391
root, data = places365
388392

389393
dataset = torchvision.datasets.Places365(root, download=True)
390394
self.assertIsInstance(repr(dataset), str)
391395

392396

397+
class STL10Tester(DatasetTestcase):
398+
@contextlib.contextmanager
399+
def mocked_root(self):
400+
with stl10_root() as (root, data):
401+
yield root, data
402+
403+
@contextlib.contextmanager
404+
def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
405+
with self.mocked_root() as (root, data):
406+
if pre_extract:
407+
utils.extract_archive(os.path.join(root, data["archive"]))
408+
dataset = torchvision.datasets.STL10(root, download=download, **kwargs)
409+
yield dataset, data
410+
411+
def test_not_found(self):
412+
with self.assertRaises(RuntimeError):
413+
with self.mocked_dataset(download=False):
414+
pass
415+
416+
def test_splits(self):
417+
for split in ('train', 'train+unlabeled', 'unlabeled', 'test'):
418+
with self.mocked_dataset(split=split) as (dataset, data):
419+
num_images = sum([data["num_images_in_split"][part] for part in split.split("+")])
420+
self.generic_classification_dataset_test(dataset, num_images=num_images)
421+
422+
def test_folds(self):
423+
for fold in range(10):
424+
with self.mocked_dataset(split="train", folds=fold) as (dataset, data):
425+
num_images = data["num_images_in_folds"][fold]
426+
self.assertEqual(len(dataset), num_images)
427+
428+
def test_invalid_folds1(self):
429+
with self.assertRaises(ValueError):
430+
with self.mocked_dataset(folds=10):
431+
pass
432+
433+
def test_invalid_folds2(self):
434+
with self.assertRaises(ValueError):
435+
with self.mocked_dataset(folds="0"):
436+
pass
437+
438+
def test_transforms(self):
439+
expected_image = "image"
440+
expected_target = "target"
441+
442+
def transform(image):
443+
return expected_image
444+
445+
def target_transform(target):
446+
return expected_target
447+
448+
with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _):
449+
actual_image, actual_target = dataset[0]
450+
451+
self.assertEqual(actual_image, expected_image)
452+
self.assertEqual(actual_target, expected_target)
453+
454+
def test_unlabeled(self):
455+
with self.mocked_dataset(split="unlabeled") as (dataset, _):
456+
labels = [dataset[idx][1] for idx in range(len(dataset))]
457+
self.assertTrue(all([label == -1 for label in labels]))
458+
459+
@unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive")
460+
def test_download_preexisting(self, mock):
461+
with self.mocked_dataset(pre_extract=True) as (dataset, data):
462+
mock.assert_not_called()
463+
464+
def test_repr_smoke(self):
465+
with self.mocked_dataset() as (dataset, _):
466+
self.assertIsInstance(repr(dataset), str)
467+
468+
393469
if __name__ == '__main__':
394470
unittest.main()

0 commit comments

Comments
 (0)