@@ -183,6 +183,9 @@ def test_combined_targets(self):
183
183
), "Type of the combined target does not match the type of the corresponding individual target: "
184
184
f"{ actual } is not { expected } " ,
185
185
186
+ def test_transforms_v2_wrapper (self ):
187
+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (target_type = "category" ))
188
+
186
189
187
190
class Caltech256TestCase (datasets_utils .ImageDatasetTestCase ):
188
191
DATASET_CLASS = datasets .Caltech256
@@ -203,6 +206,9 @@ def inject_fake_data(self, tmpdir, config):
203
206
204
207
return num_images_per_category * len (categories )
205
208
209
+ def test_transforms_v2_wrapper (self ):
210
+ datasets_utils .check_transforms_v2_wrapper (self )
211
+
206
212
207
213
class WIDERFaceTestCase (datasets_utils .ImageDatasetTestCase ):
208
214
DATASET_CLASS = datasets .WIDERFace
@@ -258,6 +264,9 @@ def inject_fake_data(self, tmpdir, config):
258
264
259
265
return split_to_num_examples [config ["split" ]]
260
266
267
+ def test_transforms_v2_wrapper (self ):
268
+ datasets_utils .check_transforms_v2_wrapper (self , supports_target_keys = True )
269
+
261
270
262
271
class CityScapesTestCase (datasets_utils .ImageDatasetTestCase ):
263
272
DATASET_CLASS = datasets .Cityscapes
@@ -382,6 +391,10 @@ def test_feature_types_target_polygon(self):
382
391
assert isinstance (polygon_img , PIL .Image .Image )
383
392
(polygon_target , info ["expected_polygon_target" ])
384
393
394
+ def test_transforms_v2_wrapper (self ):
395
+ for target_type in ["instance" , "semantic" , ["instance" , "semantic" ]]:
396
+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (target_type = target_type ))
397
+
385
398
386
399
class ImageNetTestCase (datasets_utils .ImageDatasetTestCase ):
387
400
DATASET_CLASS = datasets .ImageNet
@@ -413,6 +426,9 @@ def inject_fake_data(self, tmpdir, config):
413
426
torch .save ((wnid_to_classes , None ), tmpdir / "meta.bin" )
414
427
return num_examples
415
428
429
+ def test_transforms_v2_wrapper (self ):
430
+ datasets_utils .check_transforms_v2_wrapper (self )
431
+
416
432
417
433
class CIFAR10TestCase (datasets_utils .ImageDatasetTestCase ):
418
434
DATASET_CLASS = datasets .CIFAR10
@@ -470,6 +486,9 @@ def test_class_to_idx(self):
470
486
actual = dataset .class_to_idx
471
487
assert actual == expected
472
488
489
+ def test_transforms_v2_wrapper (self ):
490
+ datasets_utils .check_transforms_v2_wrapper (self )
491
+
473
492
474
493
class CIFAR100 (CIFAR10TestCase ):
475
494
DATASET_CLASS = datasets .CIFAR100
@@ -484,6 +503,9 @@ class CIFAR100(CIFAR10TestCase):
484
503
categories_key = "fine_label_names" ,
485
504
)
486
505
506
+ def test_transforms_v2_wrapper (self ):
507
+ datasets_utils .check_transforms_v2_wrapper (self )
508
+
487
509
488
510
class CelebATestCase (datasets_utils .ImageDatasetTestCase ):
489
511
DATASET_CLASS = datasets .CelebA
@@ -607,6 +629,10 @@ def test_images_names_split(self):
607
629
608
630
assert merged_imgs_names == all_imgs_names
609
631
632
+ def test_transforms_v2_wrapper (self ):
633
+ for target_type in ["identity" , "bbox" , ["identity" , "bbox" ]]:
634
+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (target_type = target_type ))
635
+
610
636
611
637
class VOCSegmentationTestCase (datasets_utils .ImageDatasetTestCase ):
612
638
DATASET_CLASS = datasets .VOCSegmentation
@@ -694,6 +720,9 @@ def add_bndbox(obj, bndbox=None):
694
720
695
721
return data
696
722
723
+ def test_transforms_v2_wrapper (self ):
724
+ datasets_utils .check_transforms_v2_wrapper (self )
725
+
697
726
698
727
class VOCDetectionTestCase (VOCSegmentationTestCase ):
699
728
DATASET_CLASS = datasets .VOCDetection
@@ -714,6 +743,10 @@ def test_annotations(self):
714
743
715
744
assert object == info ["annotation" ]
716
745
746
+ def test_transforms_v2_wrapper (self ):
747
+ for target_type in ["identity" , "bbox" , ["identity" , "bbox" ]]:
748
+ datasets_utils .check_transforms_v2_wrapper (self , supports_target_keys = True )
749
+
717
750
718
751
class CocoDetectionTestCase (datasets_utils .ImageDatasetTestCase ):
719
752
DATASET_CLASS = datasets .CocoDetection
@@ -784,6 +817,9 @@ def _create_json(self, root, name, content):
784
817
json .dump (content , fh )
785
818
return file
786
819
820
+ def test_transforms_v2_wrapper (self ):
821
+ datasets_utils .check_transforms_v2_wrapper (self , supports_target_keys = True )
822
+
787
823
788
824
class CocoCaptionsTestCase (CocoDetectionTestCase ):
789
825
DATASET_CLASS = datasets .CocoCaptions
@@ -800,6 +836,11 @@ def test_captions(self):
800
836
_ , captions = dataset [0 ]
801
837
assert tuple (captions ) == tuple (info ["captions" ])
802
838
839
+ def test_transforms_v2_wrapper (self ):
840
+ # We need to define this method, because otherwise the test from the super class will
841
+ # be run
842
+ pytest .skip ("CocoCaptions is currently not supported by the v2 wrapper." )
843
+
803
844
804
845
class UCF101TestCase (datasets_utils .VideoDatasetTestCase ):
805
846
DATASET_CLASS = datasets .UCF101
@@ -860,6 +901,9 @@ def _create_annotation_file(self, root, name, video_files):
860
901
with open (pathlib .Path (root ) / name , "w" ) as fh :
861
902
fh .writelines (f"{ str (file ).replace (os .sep , '/' )} \n " for file in sorted (video_files ))
862
903
904
+ def test_transforms_v2_wrapper (self ):
905
+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (output_format = "TCHW" ))
906
+
863
907
864
908
class LSUNTestCase (datasets_utils .ImageDatasetTestCase ):
865
909
DATASET_CLASS = datasets .LSUN
@@ -966,6 +1010,9 @@ def inject_fake_data(self, tmpdir, config):
966
1010
)
967
1011
return num_videos_per_class * len (classes )
968
1012
1013
+ def test_transforms_v2_wrapper (self ):
1014
+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (output_format = "TCHW" ))
1015
+
969
1016
970
1017
class HMDB51TestCase (datasets_utils .VideoDatasetTestCase ):
971
1018
DATASET_CLASS = datasets .HMDB51
@@ -1026,6 +1073,9 @@ def _create_split_files(self, root, video_files, fold, train):
1026
1073
1027
1074
return num_train_videos if train else (num_videos - num_train_videos )
1028
1075
1076
+ def test_transforms_v2_wrapper (self ):
1077
+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (output_format = "TCHW" ))
1078
+
1029
1079
1030
1080
class OmniglotTestCase (datasets_utils .ImageDatasetTestCase ):
1031
1081
DATASET_CLASS = datasets .Omniglot
@@ -1193,6 +1243,9 @@ def _create_segmentation(self, size):
1193
1243
def _file_stem (self , idx ):
1194
1244
return f"2008_{ idx :06d} "
1195
1245
1246
+ def test_transforms_v2_wrapper (self ):
1247
+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (mode = "segmentation" ))
1248
+
1196
1249
1197
1250
class FakeDataTestCase (datasets_utils .ImageDatasetTestCase ):
1198
1251
DATASET_CLASS = datasets .FakeData
@@ -1434,6 +1487,9 @@ def _magic(self, dtype, dims):
1434
1487
def _encode (self , v ):
1435
1488
return torch .tensor (v , dtype = torch .int32 ).numpy ().tobytes ()[::- 1 ]
1436
1489
1490
+ def test_transforms_v2_wrapper (self ):
1491
+ datasets_utils .check_transforms_v2_wrapper (self )
1492
+
1437
1493
1438
1494
class FashionMNISTTestCase (MNISTTestCase ):
1439
1495
DATASET_CLASS = datasets .FashionMNIST
@@ -1585,6 +1641,9 @@ def test_classes(self, config):
1585
1641
assert len (dataset .classes ) == len (info ["classes" ])
1586
1642
assert all ([a == b for a , b in zip (dataset .classes , info ["classes" ])])
1587
1643
1644
+ def test_transforms_v2_wrapper (self ):
1645
+ datasets_utils .check_transforms_v2_wrapper (self )
1646
+
1588
1647
1589
1648
class ImageFolderTestCase (datasets_utils .ImageDatasetTestCase ):
1590
1649
DATASET_CLASS = datasets .ImageFolder
@@ -1606,6 +1665,9 @@ def test_classes(self, config):
1606
1665
assert len (dataset .classes ) == len (info ["classes" ])
1607
1666
assert all ([a == b for a , b in zip (dataset .classes , info ["classes" ])])
1608
1667
1668
+ def test_transforms_v2_wrapper (self ):
1669
+ datasets_utils .check_transforms_v2_wrapper (self )
1670
+
1609
1671
1610
1672
class KittiTestCase (datasets_utils .ImageDatasetTestCase ):
1611
1673
DATASET_CLASS = datasets .Kitti
@@ -1642,6 +1704,9 @@ def inject_fake_data(self, tmpdir, config):
1642
1704
1643
1705
return split_to_num_examples [config ["train" ]]
1644
1706
1707
+ def test_transforms_v2_wrapper (self ):
1708
+ datasets_utils .check_transforms_v2_wrapper (self , supports_target_keys = True )
1709
+
1645
1710
1646
1711
class SvhnTestCase (datasets_utils .ImageDatasetTestCase ):
1647
1712
DATASET_CLASS = datasets .SVHN
@@ -2516,6 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
2516
2581
breed_id = "-1"
2517
2582
return (image_id , class_id , species , breed_id )
2518
2583
2584
+ def test_transforms_v2_wrapper (self ):
2585
+ datasets_utils .check_transforms_v2_wrapper (self )
2586
+
2519
2587
2520
2588
class StanfordCarsTestCase (datasets_utils .ImageDatasetTestCase ):
2521
2589
DATASET_CLASS = datasets .StanfordCars
0 commit comments