Skip to content

Commit 7024490

Browse files
fmassafacebook-github-bot
authored andcommitted
add tests for CelebA (#3413)
Summary: Co-authored-by: Francisco Massa <[email protected]> Reviewed By: NicolasHug Differential Revision: D26605312 fbshipit-source-id: 2731ae896f1c58f1376770e4f31b828b22c0a151
1 parent bef96d7 commit 7024490

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

test/test_datasets.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,5 +639,122 @@ class CIFAR100(CIFAR10TestCase):
639639
)
640640

641641

642+
class CelebATestCase(datasets_utils.ImageDatasetTestCase):
643+
DATASET_CLASS = datasets.CelebA
644+
FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None)))
645+
646+
CONFIGS = datasets_utils.combinations_grid(
647+
split=("train", "valid", "test", "all"),
648+
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
649+
)
650+
REQUIRED_PACKAGES = ("pandas",)
651+
652+
_SPLIT_TO_IDX = dict(train=0, valid=1, test=2)
653+
654+
def inject_fake_data(self, tmpdir, config):
655+
base_folder = pathlib.Path(tmpdir) / "celeba"
656+
os.makedirs(base_folder)
657+
658+
num_images, num_images_per_split = self._create_split_txt(base_folder)
659+
660+
datasets_utils.create_image_folder(
661+
base_folder, "img_align_celeba", lambda idx: f"{idx + 1:06d}.jpg", num_images
662+
)
663+
attr_names = self._create_attr_txt(base_folder, num_images)
664+
self._create_identity_txt(base_folder, num_images)
665+
self._create_bbox_txt(base_folder, num_images)
666+
self._create_landmarks_txt(base_folder, num_images)
667+
668+
return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names)
669+
670+
def _create_split_txt(self, root):
671+
num_images_per_split = dict(train=3, valid=2, test=1)
672+
673+
data = [
674+
[self._SPLIT_TO_IDX[split]] for split, num_images in num_images_per_split.items() for _ in range(num_images)
675+
]
676+
self._create_txt(root, "list_eval_partition.txt", data)
677+
678+
num_images_per_split["all"] = num_images = sum(num_images_per_split.values())
679+
return num_images, num_images_per_split
680+
681+
def _create_attr_txt(self, root, num_images):
682+
header = ("5_o_Clock_Shadow", "Young")
683+
data = torch.rand((num_images, len(header))).ge(0.5).int().mul(2).sub(1).tolist()
684+
self._create_txt(root, "list_attr_celeba.txt", data, header=header, add_num_examples=True)
685+
return header
686+
687+
def _create_identity_txt(self, root, num_images):
688+
data = torch.randint(1, 4, size=(num_images, 1)).tolist()
689+
self._create_txt(root, "identity_CelebA.txt", data)
690+
691+
def _create_bbox_txt(self, root, num_images):
692+
header = ("x_1", "y_1", "width", "height")
693+
data = torch.randint(10, size=(num_images, len(header))).tolist()
694+
self._create_txt(
695+
root, "list_bbox_celeba.txt", data, header=header, add_num_examples=True, add_image_id_to_header=True
696+
)
697+
698+
def _create_landmarks_txt(self, root, num_images):
699+
header = ("lefteye_x", "rightmouth_y")
700+
data = torch.randint(10, size=(num_images, len(header))).tolist()
701+
self._create_txt(root, "list_landmarks_align_celeba.txt", data, header=header, add_num_examples=True)
702+
703+
def _create_txt(self, root, name, data, header=None, add_num_examples=False, add_image_id_to_header=False):
704+
with open(pathlib.Path(root) / name, "w") as fh:
705+
if add_num_examples:
706+
fh.write(f"{len(data)}\n")
707+
708+
if header:
709+
if add_image_id_to_header:
710+
header = ("image_id", *header)
711+
fh.write(f"{' '.join(header)}\n")
712+
713+
for idx, line in enumerate(data, 1):
714+
fh.write(f"{' '.join((f'{idx:06d}.jpg', *[str(value) for value in line]))}\n")
715+
716+
def test_combined_targets(self):
717+
target_types = ["attr", "identity", "bbox", "landmarks"]
718+
719+
individual_targets = []
720+
for target_type in target_types:
721+
with self.create_dataset(target_type=target_type) as (dataset, _):
722+
_, target = dataset[0]
723+
individual_targets.append(target)
724+
725+
with self.create_dataset(target_type=target_types) as (dataset, _):
726+
_, combined_targets = dataset[0]
727+
728+
actual = len(individual_targets)
729+
expected = len(combined_targets)
730+
self.assertEqual(
731+
actual,
732+
expected,
733+
f"The number of the returned combined targets does not match the the number targets if requested "
734+
f"individually: {actual} != {expected}",
735+
)
736+
737+
for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
738+
with self.subTest(target_type=target_type):
739+
actual = type(combined_target)
740+
expected = type(individual_target)
741+
self.assertIs(
742+
actual,
743+
expected,
744+
f"Type of the combined target does not match the type of the corresponding individual target: "
745+
f"{actual} is not {expected}",
746+
)
747+
748+
def test_no_target(self):
749+
with self.create_dataset(target_type=[]) as (dataset, _):
750+
_, target = dataset[0]
751+
752+
self.assertIsNone(target)
753+
754+
def test_attr_names(self):
755+
with self.create_dataset() as (dataset, info):
756+
self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
757+
758+
642759
if __name__ == "__main__":
643760
unittest.main()

0 commit comments

Comments
 (0)