|
3 | 3 | import importlib
|
4 | 4 | import inspect
|
5 | 5 | import itertools
|
6 |
| -import multiprocessing |
7 | 6 | import os
|
8 | 7 | import pathlib
|
9 | 8 | import random
|
@@ -180,27 +179,30 @@ def check_transforms_v2_wrapper(dataset_test_case, *, config=None, supports_targ
|
180 | 179 | from torchvision import datapoints
|
181 | 180 | from torchvision.datasets import wrap_dataset_for_transforms_v2
|
182 | 181 |
|
| 182 | + def check_wrapped_samples(dataset): |
| 183 | + for wrapped_sample in dataset: |
| 184 | + assert tree_any( |
| 185 | + lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample |
| 186 | + ) |
| 187 | + |
183 | 188 | target_keyss = [None]
|
184 | 189 | if supports_target_keys:
|
185 | 190 | target_keyss.append("all")
|
186 | 191 |
|
187 |
| - for target_keys, multiprocessing_context in itertools.product( |
188 |
| - target_keyss, multiprocessing.get_all_start_methods() |
189 |
| - ): |
| 192 | + for target_keys in target_keyss: |
190 | 193 | with dataset_test_case.create_dataset(config) as (dataset, info):
|
191 | 194 | wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
|
192 | 195 |
|
193 | 196 | assert isinstance(wrapped_dataset, type(dataset))
|
194 | 197 | assert len(wrapped_dataset) == info["num_examples"]
|
195 | 198 |
|
196 |
| - dataloader = DataLoader( |
197 |
| - wrapped_dataset, num_workers=2, multiprocessing_context=multiprocessing_context, collate_fn=_no_collate |
198 |
| - ) |
| 199 | + check_wrapped_samples(wrapped_dataset) |
199 | 200 |
|
200 |
| - for wrapped_sample in dataloader: |
201 |
| - assert tree_any( |
202 |
| - lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample |
203 |
| - ) |
| 201 | + with dataset_test_case.create_dataset(config) as (dataset, _): |
| 202 | + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) |
| 203 | + dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate) |
| 204 | + |
| 205 | + check_wrapped_samples(dataloader) |
204 | 206 |
|
205 | 207 |
|
206 | 208 | class DatasetTestCase(unittest.TestCase):
|
|
0 commit comments