Skip to content

Commit 60110cb

Browse files
committed
enforce pickleability for v2 transforms and wrapped datasets
1 parent 2c44eba commit 60110cb

File tree

4 files changed

+46
-23
lines changed

4 files changed

+46
-23
lines changed

test/datasets_utils.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import itertools
66
import os
77
import pathlib
8+
import pickle
89
import random
910
import shutil
1011
import string
@@ -572,35 +573,42 @@ def test_transforms_v2_wrapper(self, config):
572573

573574
try:
574575
with self.create_dataset(config) as (dataset, info):
575-
for target_keys in [None, "all"]:
576-
if target_keys is not None and self.DATASET_CLASS not in {
577-
torchvision.datasets.CocoDetection,
578-
torchvision.datasets.VOCDetection,
579-
torchvision.datasets.Kitti,
580-
torchvision.datasets.WIDERFace,
581-
}:
582-
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
583-
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
584-
continue
585-
586-
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
587-
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
588-
assert len(wrapped_dataset) == info["num_examples"]
589-
590-
wrapped_sample = wrapped_dataset[0]
591-
assert tree_any(
592-
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
593-
)
576+
wrap_dataset_for_transforms_v2(dataset)
594577
except TypeError as error:
595578
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
596579
if str(error).startswith(msg):
597-
pytest.skip(msg)
580+
return
598581
raise error
599582
except RuntimeError as error:
600583
if "currently not supported by this wrapper" in str(error):
601-
pytest.skip("Config is currently not supported by this wrapper")
584+
return
602585
raise error
603586

587+
for target_keys, de_serialize in itertools.product(
588+
[None, "all"], [lambda d: d, lambda d: pickle.loads(pickle.dumps(d))]
589+
):
590+
591+
with self.create_dataset(config) as (dataset, info):
592+
if target_keys is not None and self.DATASET_CLASS not in {
593+
torchvision.datasets.CocoDetection,
594+
torchvision.datasets.VOCDetection,
595+
torchvision.datasets.Kitti,
596+
torchvision.datasets.WIDERFace,
597+
}:
598+
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
599+
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
600+
continue
601+
602+
wrapped_dataset = de_serialize(wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys))
603+
604+
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
605+
assert len(wrapped_dataset) == info["num_examples"]
606+
607+
wrapped_sample = wrapped_dataset[0]
608+
assert tree_any(
609+
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
610+
)
611+
604612

605613
class ImageDatasetTestCase(DatasetTestCase):
606614
"""Abstract base class for image dataset testcases.

test/test_transforms_v2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import pathlib
3+
import pickle
34
import random
45
import warnings
56

@@ -169,8 +170,11 @@ class TestSmoke:
169170
next(make_vanilla_tensor_images()),
170171
],
171172
)
173+
@pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
172174
@pytest.mark.parametrize("device", cpu_and_cuda())
173-
def test_common(self, transform, adapter, container_type, image_or_video, device):
175+
def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
176+
transform = de_serialize(transform)
177+
174178
canvas_size = F.get_size(image_or_video)
175179
input = dict(
176180
image_or_video=image_or_video,
@@ -234,6 +238,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device
234238
tensor=torch.empty(5),
235239
array=np.empty(5),
236240
)
241+
237242
if adapter is not None:
238243
input = adapter(transform, input, device)
239244

test/test_transforms_v2_refactored.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import decimal
33
import inspect
44
import math
5+
import pickle
56
import re
67
from pathlib import Path
78
from unittest import mock
@@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input):
247248
def check_transform(transform_cls, input, *args, **kwargs):
248249
transform = transform_cls(*args, **kwargs)
249250

251+
pickle.loads(pickle.dumps(transform))
252+
250253
output = transform(input)
251254
assert isinstance(output, type(input))
252255

torchvision/datapoints/_dataset_wrapper.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import collections.abc
6-
76
import contextlib
87
from collections import defaultdict
98

@@ -97,6 +96,10 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
9796
f"but got {target_keys}"
9897
)
9998

99+
return _make_wrapped_dataset(dataset, target_keys)
100+
101+
102+
def _make_wrapped_dataset(dataset, target_keys):
100103
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
101104
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
102105
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
@@ -162,6 +165,7 @@ def __init__(self, dataset, target_keys):
162165
raise TypeError(msg)
163166

164167
self._dataset = dataset
168+
self._target_keys = target_keys
165169
self._wrapper = wrapper_factory(dataset, target_keys)
166170

167171
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
@@ -197,6 +201,9 @@ def __getitem__(self, idx):
197201
def __len__(self):
198202
return len(self._dataset)
199203

204+
def __reduce__(self):
205+
return _make_wrapped_dataset, (self._dataset, self._target_keys)
206+
200207

201208
def raise_not_supported(description):
202209
raise RuntimeError(

0 commit comments

Comments
 (0)