Skip to content

Commit c68a6de

Browse files
committed
use DataLoader for testing on select configs
1 parent 60110cb commit c68a6de

File tree

2 files changed

+101
-53
lines changed

2 files changed

+101
-53
lines changed

test/datasets_utils.py

+33-53
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import importlib
44
import inspect
55
import itertools
6+
import multiprocessing
67
import os
78
import pathlib
8-
import pickle
99
import random
1010
import shutil
1111
import string
@@ -171,6 +171,38 @@ def wrapper(self):
171171
return wrapper
172172

173173

174+
def _no_collate(batch):
175+
return batch
176+
177+
178+
def check_transforms_v2_wrapper(dataset_test_case, *, config=None, supports_target_keys=False):
179+
from torch.utils.data import DataLoader
180+
from torchvision import datapoints
181+
from torchvision.datasets import wrap_dataset_for_transforms_v2
182+
183+
target_keyss = [None]
184+
if supports_target_keys:
185+
target_keyss.append("all")
186+
187+
for target_keys, multiprocessing_context in itertools.product(
188+
target_keyss, multiprocessing.get_all_start_methods()
189+
):
190+
with dataset_test_case.create_dataset(config) as (dataset, info):
191+
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
192+
193+
assert isinstance(wrapped_dataset, type(dataset))
194+
assert len(wrapped_dataset) == info["num_examples"]
195+
196+
dataloader = DataLoader(
197+
wrapped_dataset, num_workers=2, multiprocessing_context=multiprocessing_context, collate_fn=_no_collate
198+
)
199+
200+
for wrapped_sample in dataloader:
201+
assert tree_any(
202+
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
203+
)
204+
205+
174206
class DatasetTestCase(unittest.TestCase):
175207
"""Abstract base class for all dataset testcases.
176208
@@ -566,49 +598,6 @@ def test_transforms(self, config):
566598

567599
mock.assert_called()
568600

569-
@test_all_configs
570-
def test_transforms_v2_wrapper(self, config):
571-
from torchvision import datapoints
572-
from torchvision.datasets import wrap_dataset_for_transforms_v2
573-
574-
try:
575-
with self.create_dataset(config) as (dataset, info):
576-
wrap_dataset_for_transforms_v2(dataset)
577-
except TypeError as error:
578-
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
579-
if str(error).startswith(msg):
580-
return
581-
raise error
582-
except RuntimeError as error:
583-
if "currently not supported by this wrapper" in str(error):
584-
return
585-
raise error
586-
587-
for target_keys, de_serialize in itertools.product(
588-
[None, "all"], [lambda d: d, lambda d: pickle.loads(pickle.dumps(d))]
589-
):
590-
591-
with self.create_dataset(config) as (dataset, info):
592-
if target_keys is not None and self.DATASET_CLASS not in {
593-
torchvision.datasets.CocoDetection,
594-
torchvision.datasets.VOCDetection,
595-
torchvision.datasets.Kitti,
596-
torchvision.datasets.WIDERFace,
597-
}:
598-
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
599-
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
600-
continue
601-
602-
wrapped_dataset = de_serialize(wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys))
603-
604-
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
605-
assert len(wrapped_dataset) == info["num_examples"]
606-
607-
wrapped_sample = wrapped_dataset[0]
608-
assert tree_any(
609-
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
610-
)
611-
612601

613602
class ImageDatasetTestCase(DatasetTestCase):
614603
"""Abstract base class for image dataset testcases.
@@ -690,15 +679,6 @@ def wrapper(tmpdir, config):
690679

691680
return wrapper
692681

693-
@test_all_configs
694-
def test_transforms_v2_wrapper(self, config):
695-
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
696-
# or use the supported `"TCHW"`
697-
if config.setdefault("output_format", "TCHW") == "THWC":
698-
return
699-
700-
super().test_transforms_v2_wrapper.__wrapped__(self, config)
701-
702682

703683
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
704684
r"""Create a random uint8 tensor.

test/test_datasets.py

+68
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ def test_combined_targets(self):
183183
), "Type of the combined target does not match the type of the corresponding individual target: "
184184
f"{actual} is not {expected}",
185185

186+
def test_transforms_v2_wrapper(self):
187+
datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type="category"))
188+
186189

187190
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
188191
DATASET_CLASS = datasets.Caltech256
@@ -203,6 +206,9 @@ def inject_fake_data(self, tmpdir, config):
203206

204207
return num_images_per_category * len(categories)
205208

209+
def test_transforms_v2_wrapper(self):
210+
datasets_utils.check_transforms_v2_wrapper(self)
211+
206212

207213
class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
208214
DATASET_CLASS = datasets.WIDERFace
@@ -258,6 +264,9 @@ def inject_fake_data(self, tmpdir, config):
258264

259265
return split_to_num_examples[config["split"]]
260266

267+
def test_transforms_v2_wrapper(self):
268+
datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True)
269+
261270

262271
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
263272
DATASET_CLASS = datasets.Cityscapes
@@ -382,6 +391,10 @@ def test_feature_types_target_polygon(self):
382391
assert isinstance(polygon_img, PIL.Image.Image)
383392
(polygon_target, info["expected_polygon_target"])
384393

394+
def test_transforms_v2_wrapper(self):
395+
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
396+
datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type=target_type))
397+
385398

386399
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
387400
DATASET_CLASS = datasets.ImageNet
@@ -413,6 +426,9 @@ def inject_fake_data(self, tmpdir, config):
413426
torch.save((wnid_to_classes, None), tmpdir / "meta.bin")
414427
return num_examples
415428

429+
def test_transforms_v2_wrapper(self):
430+
datasets_utils.check_transforms_v2_wrapper(self)
431+
416432

417433
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
418434
DATASET_CLASS = datasets.CIFAR10
@@ -470,6 +486,9 @@ def test_class_to_idx(self):
470486
actual = dataset.class_to_idx
471487
assert actual == expected
472488

489+
def test_transforms_v2_wrapper(self):
490+
datasets_utils.check_transforms_v2_wrapper(self)
491+
473492

474493
class CIFAR100(CIFAR10TestCase):
475494
DATASET_CLASS = datasets.CIFAR100
@@ -484,6 +503,9 @@ class CIFAR100(CIFAR10TestCase):
484503
categories_key="fine_label_names",
485504
)
486505

506+
def test_transforms_v2_wrapper(self):
507+
datasets_utils.check_transforms_v2_wrapper(self)
508+
487509

488510
class CelebATestCase(datasets_utils.ImageDatasetTestCase):
489511
DATASET_CLASS = datasets.CelebA
@@ -607,6 +629,10 @@ def test_images_names_split(self):
607629

608630
assert merged_imgs_names == all_imgs_names
609631

632+
def test_transforms_v2_wrapper(self):
633+
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
634+
datasets_utils.check_transforms_v2_wrapper(self, config=dict(target_type=target_type))
635+
610636

611637
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
612638
DATASET_CLASS = datasets.VOCSegmentation
@@ -694,6 +720,9 @@ def add_bndbox(obj, bndbox=None):
694720

695721
return data
696722

723+
def test_transforms_v2_wrapper(self):
724+
datasets_utils.check_transforms_v2_wrapper(self)
725+
697726

698727
class VOCDetectionTestCase(VOCSegmentationTestCase):
699728
DATASET_CLASS = datasets.VOCDetection
@@ -714,6 +743,10 @@ def test_annotations(self):
714743

715744
assert object == info["annotation"]
716745

746+
def test_transforms_v2_wrapper(self):
747+
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
748+
datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True)
749+
717750

718751
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
719752
DATASET_CLASS = datasets.CocoDetection
@@ -784,6 +817,9 @@ def _create_json(self, root, name, content):
784817
json.dump(content, fh)
785818
return file
786819

820+
def test_transforms_v2_wrapper(self):
821+
datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True)
822+
787823

788824
class CocoCaptionsTestCase(CocoDetectionTestCase):
789825
DATASET_CLASS = datasets.CocoCaptions
@@ -800,6 +836,11 @@ def test_captions(self):
800836
_, captions = dataset[0]
801837
assert tuple(captions) == tuple(info["captions"])
802838

839+
def test_transforms_v2_wrapper(self):
840+
# We need to define this method, because otherwise the test from the super class will
841+
# be run
842+
pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.")
843+
803844

804845
class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
805846
DATASET_CLASS = datasets.UCF101
@@ -860,6 +901,9 @@ def _create_annotation_file(self, root, name, video_files):
860901
with open(pathlib.Path(root) / name, "w") as fh:
861902
fh.writelines(f"{str(file).replace(os.sep, '/')}\n" for file in sorted(video_files))
862903

904+
def test_transforms_v2_wrapper(self):
905+
datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW"))
906+
863907

864908
class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
865909
DATASET_CLASS = datasets.LSUN
@@ -966,6 +1010,9 @@ def inject_fake_data(self, tmpdir, config):
9661010
)
9671011
return num_videos_per_class * len(classes)
9681012

1013+
def test_transforms_v2_wrapper(self):
1014+
datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW"))
1015+
9691016

9701017
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
9711018
DATASET_CLASS = datasets.HMDB51
@@ -1026,6 +1073,9 @@ def _create_split_files(self, root, video_files, fold, train):
10261073

10271074
return num_train_videos if train else (num_videos - num_train_videos)
10281075

1076+
def test_transforms_v2_wrapper(self):
1077+
datasets_utils.check_transforms_v2_wrapper(self, config=dict(output_format="TCHW"))
1078+
10291079

10301080
class OmniglotTestCase(datasets_utils.ImageDatasetTestCase):
10311081
DATASET_CLASS = datasets.Omniglot
@@ -1193,6 +1243,9 @@ def _create_segmentation(self, size):
11931243
def _file_stem(self, idx):
11941244
return f"2008_{idx:06d}"
11951245

1246+
def test_transforms_v2_wrapper(self):
1247+
datasets_utils.check_transforms_v2_wrapper(self, config=dict(mode="segmentation"))
1248+
11961249

11971250
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
11981251
DATASET_CLASS = datasets.FakeData
@@ -1434,6 +1487,9 @@ def _magic(self, dtype, dims):
14341487
def _encode(self, v):
14351488
return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1]
14361489

1490+
def test_transforms_v2_wrapper(self):
1491+
datasets_utils.check_transforms_v2_wrapper(self)
1492+
14371493

14381494
class FashionMNISTTestCase(MNISTTestCase):
14391495
DATASET_CLASS = datasets.FashionMNIST
@@ -1585,6 +1641,9 @@ def test_classes(self, config):
15851641
assert len(dataset.classes) == len(info["classes"])
15861642
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
15871643

1644+
def test_transforms_v2_wrapper(self):
1645+
datasets_utils.check_transforms_v2_wrapper(self)
1646+
15881647

15891648
class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
15901649
DATASET_CLASS = datasets.ImageFolder
@@ -1606,6 +1665,9 @@ def test_classes(self, config):
16061665
assert len(dataset.classes) == len(info["classes"])
16071666
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])
16081667

1668+
def test_transforms_v2_wrapper(self):
1669+
datasets_utils.check_transforms_v2_wrapper(self)
1670+
16091671

16101672
class KittiTestCase(datasets_utils.ImageDatasetTestCase):
16111673
DATASET_CLASS = datasets.Kitti
@@ -1642,6 +1704,9 @@ def inject_fake_data(self, tmpdir, config):
16421704

16431705
return split_to_num_examples[config["train"]]
16441706

1707+
def test_transforms_v2_wrapper(self):
1708+
datasets_utils.check_transforms_v2_wrapper(self, supports_target_keys=True)
1709+
16451710

16461711
class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
16471712
DATASET_CLASS = datasets.SVHN
@@ -2516,6 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
25162581
breed_id = "-1"
25172582
return (image_id, class_id, species, breed_id)
25182583

2584+
def test_transforms_v2_wrapper(self):
2585+
datasets_utils.check_transforms_v2_wrapper(self)
2586+
25192587

25202588
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
25212589
DATASET_CLASS = datasets.StanfordCars

0 commit comments

Comments
 (0)