Skip to content

Commit ac05cc6

Browse files
committed
add tests for CelebA
1 parent 22c548b commit ac05cc6

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

test/test_datasets.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pathlib
2020
import pickle
2121
from torchvision import datasets
22+
import torch
2223

2324

2425
try:
@@ -560,5 +561,122 @@ class CIFAR100(CIFAR10TestCase):
560561
)
561562

562563

564+
class CelebATestCase(datasets_utils.ImageDatasetTestCase):
565+
DATASET_CLASS = datasets.CelebA
566+
FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None)))
567+
568+
CONFIGS = datasets_utils.combinations_grid(
569+
split=("train", "valid", "test", "all"),
570+
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
571+
)
572+
REQUIRED_PACKAGES = ("pandas",)
573+
574+
_SPLIT_TO_IDX = dict(train=0, valid=1, test=2)
575+
576+
def inject_fake_data(self, tmpdir, config):
577+
base_folder = pathlib.Path(tmpdir) / "celeba"
578+
os.makedirs(base_folder)
579+
580+
num_images, num_images_per_split = self._create_split_txt(base_folder)
581+
582+
datasets_utils.create_image_folder(
583+
base_folder, "img_align_celeba", lambda idx: f"{idx + 1:06d}.jpg", num_images
584+
)
585+
attr_names = self._create_attr_txt(base_folder, num_images)
586+
self._create_identity_txt(base_folder, num_images)
587+
self._create_bbox_txt(base_folder, num_images)
588+
self._create_landmarks_txt(base_folder, num_images)
589+
590+
return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names)
591+
592+
def _create_split_txt(self, root):
593+
num_images_per_split = dict(train=3, valid=2, test=1)
594+
595+
data = [
596+
[self._SPLIT_TO_IDX[split]] for split, num_images in num_images_per_split.items() for _ in range(num_images)
597+
]
598+
self._create_txt(root, "list_eval_partition.txt", data)
599+
600+
num_images_per_split["all"] = num_images = sum(num_images_per_split.values())
601+
return num_images, num_images_per_split
602+
603+
def _create_attr_txt(self, root, num_images):
604+
header = ("5_o_Clock_Shadow", "Young")
605+
data = torch.rand((num_images, len(header))).ge(0.5).int().mul(2).sub(1).tolist()
606+
self._create_txt(root, "list_attr_celeba.txt", data, header=header, add_num_examples=True)
607+
return header
608+
609+
def _create_identity_txt(self, root, num_images):
610+
data = torch.randint(1, 4, size=(num_images, 1)).tolist()
611+
self._create_txt(root, "identity_CelebA.txt", data)
612+
613+
def _create_bbox_txt(self, root, num_images):
614+
header = ("x_1", "y_1", "width", "height")
615+
data = torch.randint(10, size=(num_images, len(header))).tolist()
616+
self._create_txt(
617+
root, "list_bbox_celeba.txt", data, header=header, add_num_examples=True, add_image_id_to_header=True
618+
)
619+
620+
def _create_landmarks_txt(self, root, num_images):
621+
header = ("lefteye_x", "rightmouth_y")
622+
data = torch.randint(10, size=(num_images, len(header))).tolist()
623+
self._create_txt(root, "list_landmarks_align_celeba.txt", data, header=header, add_num_examples=True)
624+
625+
def _create_txt(self, root, name, data, header=None, add_num_examples=False, add_image_id_to_header=False):
626+
with open(pathlib.Path(root) / name, "w") as fh:
627+
if add_num_examples:
628+
fh.write(f"{len(data)}\n")
629+
630+
if header:
631+
if add_image_id_to_header:
632+
header = ("image_id", *header)
633+
fh.write(f"{' '.join(header)}\n")
634+
635+
for idx, line in enumerate(data, 1):
636+
fh.write(f"{' '.join((f'{idx:06d}.jpg', *[str(value) for value in line]))}\n")
637+
638+
def test_combined_targets(self):
639+
target_types = ["attr", "identity", "bbox", "landmarks"]
640+
641+
individual_targets = []
642+
for target_type in target_types:
643+
with self.create_dataset(target_type=target_type) as (dataset, _):
644+
_, target = dataset[0]
645+
individual_targets.append(target)
646+
647+
with self.create_dataset(target_type=target_types) as (dataset, _):
648+
_, combined_targets = dataset[0]
649+
650+
actual = len(individual_targets)
651+
expected = len(combined_targets)
652+
self.assertEqual(
653+
actual,
654+
expected,
655+
f"The number of the returned combined targets does not match the the number targets if requested "
656+
f"individually: {actual} != {expected}",
657+
)
658+
659+
for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
660+
with self.subTest(target_type=target_type):
661+
actual = type(combined_target)
662+
expected = type(individual_target)
663+
self.assertIs(
664+
actual,
665+
expected,
666+
f"Type of the combined target does not match the type of the corresponding individual target: "
667+
f"{actual} is not {expected}",
668+
)
669+
670+
def test_no_target(self):
671+
with self.create_dataset(target_type=[]) as (dataset, _):
672+
_, target = dataset[0]
673+
674+
self.assertIsNone(target)
675+
676+
def test_attr_names(self):
677+
with self.create_dataset() as (dataset, info):
678+
self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
679+
680+
563681
if __name__ == "__main__":
564682
unittest.main()

0 commit comments

Comments
 (0)