|
19 | 19 | import pathlib
|
20 | 20 | import pickle
|
21 | 21 | from torchvision import datasets
|
| 22 | +import torch |
22 | 23 |
|
23 | 24 |
|
24 | 25 | try:
|
@@ -560,5 +561,122 @@ class CIFAR100(CIFAR10TestCase):
|
560 | 561 | )
|
561 | 562 |
|
562 | 563 |
|
| 564 | +class CelebATestCase(datasets_utils.ImageDatasetTestCase): |
| 565 | + DATASET_CLASS = datasets.CelebA |
| 566 | + FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None))) |
| 567 | + |
| 568 | + CONFIGS = datasets_utils.combinations_grid( |
| 569 | + split=("train", "valid", "test", "all"), |
| 570 | + target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]), |
| 571 | + ) |
| 572 | + REQUIRED_PACKAGES = ("pandas",) |
| 573 | + |
| 574 | + _SPLIT_TO_IDX = dict(train=0, valid=1, test=2) |
| 575 | + |
| 576 | + def inject_fake_data(self, tmpdir, config): |
| 577 | + base_folder = pathlib.Path(tmpdir) / "celeba" |
| 578 | + os.makedirs(base_folder) |
| 579 | + |
| 580 | + num_images, num_images_per_split = self._create_split_txt(base_folder) |
| 581 | + |
| 582 | + datasets_utils.create_image_folder( |
| 583 | + base_folder, "img_align_celeba", lambda idx: f"{idx + 1:06d}.jpg", num_images |
| 584 | + ) |
| 585 | + attr_names = self._create_attr_txt(base_folder, num_images) |
| 586 | + self._create_identity_txt(base_folder, num_images) |
| 587 | + self._create_bbox_txt(base_folder, num_images) |
| 588 | + self._create_landmarks_txt(base_folder, num_images) |
| 589 | + |
| 590 | + return dict(num_examples=num_images_per_split[config["split"]], attr_names=attr_names) |
| 591 | + |
| 592 | + def _create_split_txt(self, root): |
| 593 | + num_images_per_split = dict(train=3, valid=2, test=1) |
| 594 | + |
| 595 | + data = [ |
| 596 | + [self._SPLIT_TO_IDX[split]] for split, num_images in num_images_per_split.items() for _ in range(num_images) |
| 597 | + ] |
| 598 | + self._create_txt(root, "list_eval_partition.txt", data) |
| 599 | + |
| 600 | + num_images_per_split["all"] = num_images = sum(num_images_per_split.values()) |
| 601 | + return num_images, num_images_per_split |
| 602 | + |
| 603 | + def _create_attr_txt(self, root, num_images): |
| 604 | + header = ("5_o_Clock_Shadow", "Young") |
| 605 | + data = torch.rand((num_images, len(header))).ge(0.5).int().mul(2).sub(1).tolist() |
| 606 | + self._create_txt(root, "list_attr_celeba.txt", data, header=header, add_num_examples=True) |
| 607 | + return header |
| 608 | + |
| 609 | + def _create_identity_txt(self, root, num_images): |
| 610 | + data = torch.randint(1, 4, size=(num_images, 1)).tolist() |
| 611 | + self._create_txt(root, "identity_CelebA.txt", data) |
| 612 | + |
| 613 | + def _create_bbox_txt(self, root, num_images): |
| 614 | + header = ("x_1", "y_1", "width", "height") |
| 615 | + data = torch.randint(10, size=(num_images, len(header))).tolist() |
| 616 | + self._create_txt( |
| 617 | + root, "list_bbox_celeba.txt", data, header=header, add_num_examples=True, add_image_id_to_header=True |
| 618 | + ) |
| 619 | + |
| 620 | + def _create_landmarks_txt(self, root, num_images): |
| 621 | + header = ("lefteye_x", "rightmouth_y") |
| 622 | + data = torch.randint(10, size=(num_images, len(header))).tolist() |
| 623 | + self._create_txt(root, "list_landmarks_align_celeba.txt", data, header=header, add_num_examples=True) |
| 624 | + |
| 625 | + def _create_txt(self, root, name, data, header=None, add_num_examples=False, add_image_id_to_header=False): |
| 626 | + with open(pathlib.Path(root) / name, "w") as fh: |
| 627 | + if add_num_examples: |
| 628 | + fh.write(f"{len(data)}\n") |
| 629 | + |
| 630 | + if header: |
| 631 | + if add_image_id_to_header: |
| 632 | + header = ("image_id", *header) |
| 633 | + fh.write(f"{' '.join(header)}\n") |
| 634 | + |
| 635 | + for idx, line in enumerate(data, 1): |
| 636 | + fh.write(f"{' '.join((f'{idx:06d}.jpg', *[str(value) for value in line]))}\n") |
| 637 | + |
| 638 | + def test_combined_targets(self): |
| 639 | + target_types = ["attr", "identity", "bbox", "landmarks"] |
| 640 | + |
| 641 | + individual_targets = [] |
| 642 | + for target_type in target_types: |
| 643 | + with self.create_dataset(target_type=target_type) as (dataset, _): |
| 644 | + _, target = dataset[0] |
| 645 | + individual_targets.append(target) |
| 646 | + |
| 647 | + with self.create_dataset(target_type=target_types) as (dataset, _): |
| 648 | + _, combined_targets = dataset[0] |
| 649 | + |
| 650 | + actual = len(individual_targets) |
| 651 | + expected = len(combined_targets) |
| 652 | + self.assertEqual( |
| 653 | + actual, |
| 654 | + expected, |
| 655 | + f"The number of the returned combined targets does not match the the number targets if requested " |
| 656 | + f"individually: {actual} != {expected}", |
| 657 | + ) |
| 658 | + |
| 659 | + for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets): |
| 660 | + with self.subTest(target_type=target_type): |
| 661 | + actual = type(combined_target) |
| 662 | + expected = type(individual_target) |
| 663 | + self.assertIs( |
| 664 | + actual, |
| 665 | + expected, |
| 666 | + f"Type of the combined target does not match the type of the corresponding individual target: " |
| 667 | + f"{actual} is not {expected}", |
| 668 | + ) |
| 669 | + |
| 670 | + def test_no_target(self): |
| 671 | + with self.create_dataset(target_type=[]) as (dataset, _): |
| 672 | + _, target = dataset[0] |
| 673 | + |
| 674 | + self.assertIsNone(target) |
| 675 | + |
| 676 | + def test_attr_names(self): |
| 677 | + with self.create_dataset() as (dataset, info): |
| 678 | + self.assertEqual(tuple(dataset.attr_names), info["attr_names"]) |
| 679 | + |
| 680 | + |
563 | 681 | if __name__ == "__main__":
|
564 | 682 | unittest.main()
|
0 commit comments