@@ -639,5 +639,122 @@ class CIFAR100(CIFAR10TestCase):
639
639
)
640
640
641
641
642
+ class CelebATestCase (datasets_utils .ImageDatasetTestCase ):
643
+ DATASET_CLASS = datasets .CelebA
644
+ FEATURE_TYPES = (PIL .Image .Image , (torch .Tensor , int , tuple , type (None )))
645
+
646
+ CONFIGS = datasets_utils .combinations_grid (
647
+ split = ("train" , "valid" , "test" , "all" ),
648
+ target_type = ("attr" , "identity" , "bbox" , "landmarks" , ["attr" , "identity" ]),
649
+ )
650
+ REQUIRED_PACKAGES = ("pandas" ,)
651
+
652
+ _SPLIT_TO_IDX = dict (train = 0 , valid = 1 , test = 2 )
653
+
654
+ def inject_fake_data (self , tmpdir , config ):
655
+ base_folder = pathlib .Path (tmpdir ) / "celeba"
656
+ os .makedirs (base_folder )
657
+
658
+ num_images , num_images_per_split = self ._create_split_txt (base_folder )
659
+
660
+ datasets_utils .create_image_folder (
661
+ base_folder , "img_align_celeba" , lambda idx : f"{ idx + 1 :06d} .jpg" , num_images
662
+ )
663
+ attr_names = self ._create_attr_txt (base_folder , num_images )
664
+ self ._create_identity_txt (base_folder , num_images )
665
+ self ._create_bbox_txt (base_folder , num_images )
666
+ self ._create_landmarks_txt (base_folder , num_images )
667
+
668
+ return dict (num_examples = num_images_per_split [config ["split" ]], attr_names = attr_names )
669
+
670
+ def _create_split_txt (self , root ):
671
+ num_images_per_split = dict (train = 3 , valid = 2 , test = 1 )
672
+
673
+ data = [
674
+ [self ._SPLIT_TO_IDX [split ]] for split , num_images in num_images_per_split .items () for _ in range (num_images )
675
+ ]
676
+ self ._create_txt (root , "list_eval_partition.txt" , data )
677
+
678
+ num_images_per_split ["all" ] = num_images = sum (num_images_per_split .values ())
679
+ return num_images , num_images_per_split
680
+
681
+ def _create_attr_txt (self , root , num_images ):
682
+ header = ("5_o_Clock_Shadow" , "Young" )
683
+ data = torch .rand ((num_images , len (header ))).ge (0.5 ).int ().mul (2 ).sub (1 ).tolist ()
684
+ self ._create_txt (root , "list_attr_celeba.txt" , data , header = header , add_num_examples = True )
685
+ return header
686
+
687
+ def _create_identity_txt (self , root , num_images ):
688
+ data = torch .randint (1 , 4 , size = (num_images , 1 )).tolist ()
689
+ self ._create_txt (root , "identity_CelebA.txt" , data )
690
+
691
+ def _create_bbox_txt (self , root , num_images ):
692
+ header = ("x_1" , "y_1" , "width" , "height" )
693
+ data = torch .randint (10 , size = (num_images , len (header ))).tolist ()
694
+ self ._create_txt (
695
+ root , "list_bbox_celeba.txt" , data , header = header , add_num_examples = True , add_image_id_to_header = True
696
+ )
697
+
698
+ def _create_landmarks_txt (self , root , num_images ):
699
+ header = ("lefteye_x" , "rightmouth_y" )
700
+ data = torch .randint (10 , size = (num_images , len (header ))).tolist ()
701
+ self ._create_txt (root , "list_landmarks_align_celeba.txt" , data , header = header , add_num_examples = True )
702
+
703
+ def _create_txt (self , root , name , data , header = None , add_num_examples = False , add_image_id_to_header = False ):
704
+ with open (pathlib .Path (root ) / name , "w" ) as fh :
705
+ if add_num_examples :
706
+ fh .write (f"{ len (data )} \n " )
707
+
708
+ if header :
709
+ if add_image_id_to_header :
710
+ header = ("image_id" , * header )
711
+ fh .write (f"{ ' ' .join (header )} \n " )
712
+
713
+ for idx , line in enumerate (data , 1 ):
714
+ fh .write (f"{ ' ' .join ((f'{ idx :06d} .jpg' , * [str (value ) for value in line ]))} \n " )
715
+
716
+ def test_combined_targets (self ):
717
+ target_types = ["attr" , "identity" , "bbox" , "landmarks" ]
718
+
719
+ individual_targets = []
720
+ for target_type in target_types :
721
+ with self .create_dataset (target_type = target_type ) as (dataset , _ ):
722
+ _ , target = dataset [0 ]
723
+ individual_targets .append (target )
724
+
725
+ with self .create_dataset (target_type = target_types ) as (dataset , _ ):
726
+ _ , combined_targets = dataset [0 ]
727
+
728
+ actual = len (individual_targets )
729
+ expected = len (combined_targets )
730
+ self .assertEqual (
731
+ actual ,
732
+ expected ,
733
+ f"The number of the returned combined targets does not match the the number targets if requested "
734
+ f"individually: { actual } != { expected } " ,
735
+ )
736
+
737
+ for target_type , combined_target , individual_target in zip (target_types , combined_targets , individual_targets ):
738
+ with self .subTest (target_type = target_type ):
739
+ actual = type (combined_target )
740
+ expected = type (individual_target )
741
+ self .assertIs (
742
+ actual ,
743
+ expected ,
744
+ f"Type of the combined target does not match the type of the corresponding individual target: "
745
+ f"{ actual } is not { expected } " ,
746
+ )
747
+
748
+ def test_no_target (self ):
749
+ with self .create_dataset (target_type = []) as (dataset , _ ):
750
+ _ , target = dataset [0 ]
751
+
752
+ self .assertIsNone (target )
753
+
754
+ def test_attr_names (self ):
755
+ with self .create_dataset () as (dataset , info ):
756
+ self .assertEqual (tuple (dataset .attr_names ), info ["attr_names" ])
757
+
758
+
642
759
if __name__ == "__main__" :
643
760
unittest .main ()
0 commit comments