Skip to content

Commit 5358620

Browse files
committed
streamline v2 check
1 parent 1efe583 commit 5358620

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

test/datasets_utils.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import importlib
44
import inspect
55
import itertools
6-
import multiprocessing
76
import os
87
import pathlib
98
import random
@@ -180,27 +179,30 @@ def check_transforms_v2_wrapper(dataset_test_case, *, config=None, supports_targ
180179
from torchvision import datapoints
181180
from torchvision.datasets import wrap_dataset_for_transforms_v2
182181

182+
def check_wrapped_samples(dataset):
183+
for wrapped_sample in dataset:
184+
assert tree_any(
185+
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
186+
)
187+
183188
target_keyss = [None]
184189
if supports_target_keys:
185190
target_keyss.append("all")
186191

187-
for target_keys, multiprocessing_context in itertools.product(
188-
target_keyss, multiprocessing.get_all_start_methods()
189-
):
192+
for target_keys in target_keyss:
190193
with dataset_test_case.create_dataset(config) as (dataset, info):
191194
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
192195

193196
assert isinstance(wrapped_dataset, type(dataset))
194197
assert len(wrapped_dataset) == info["num_examples"]
195198

196-
dataloader = DataLoader(
197-
wrapped_dataset, num_workers=2, multiprocessing_context=multiprocessing_context, collate_fn=_no_collate
198-
)
199+
check_wrapped_samples(wrapped_dataset)
199200

200-
for wrapped_sample in dataloader:
201-
assert tree_any(
202-
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
203-
)
201+
with dataset_test_case.create_dataset(config) as (dataset, _):
202+
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
203+
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
204+
205+
check_wrapped_samples(dataloader)
204206

205207

206208
class DatasetTestCase(unittest.TestCase):

0 commit comments

Comments
 (0)