Skip to content

Commit 9f12ef4

Browse files
authored
fix prototype datasets data loading tests (#5711)
* reenable serialization test * cleanup * fix dill test * trigger CI * patch DILL_AVAILABLE for pickle serialization * revert CI changes * remove dill test and traversable test * add data loader test * parametrize over only_datapipe * draw one sample rather than exhaust data loader * cleanup * trigger CI
1 parent 49655b2 commit 9f12ef4

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

test/test_prototype_builtin_datasets.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
99
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
10+
from torch.utils.data import DataLoader
1011
from torch.utils.data.graph import traverse
1112
from torch.utils.data.graph_settings import get_all_graph_pipes
1213
from torchdata.datapipes.iter import Shuffler, ShardingFilter
@@ -109,19 +110,39 @@ def test_transformable(self, test_home, dataset_mock, config):
109110

110111
next(iter(dataset.map(transforms.Identity())))
111112

112-
@pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237")
113+
@pytest.mark.parametrize("only_datapipe", [False, True])
113114
@parametrize_dataset_mocks(DATASET_MOCKS)
114-
def test_serializable(self, test_home, dataset_mock, config):
115+
def test_traversable(self, test_home, dataset_mock, config, only_datapipe):
115116
dataset_mock.prepare(test_home, config)
117+
dataset = datasets.load(dataset_mock.name, **config)
116118

119+
traverse(dataset, only_datapipe=only_datapipe)
120+
121+
@parametrize_dataset_mocks(DATASET_MOCKS)
122+
def test_serializable(self, test_home, dataset_mock, config):
123+
dataset_mock.prepare(test_home, config)
117124
dataset = datasets.load(dataset_mock.name, **config)
118125

119126
pickle.dumps(dataset)
120127

128+
@pytest.mark.parametrize("num_workers", [0, 1])
129+
@parametrize_dataset_mocks(DATASET_MOCKS)
130+
def test_data_loader(self, test_home, dataset_mock, config, num_workers):
131+
dataset_mock.prepare(test_home, config)
132+
dataset = datasets.load(dataset_mock.name, **config)
133+
134+
dl = DataLoader(
135+
dataset,
136+
batch_size=2,
137+
num_workers=num_workers,
138+
collate_fn=lambda batch: batch,
139+
)
140+
141+
next(iter(dl))
142+
121143
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
122144
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
123145
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
124-
@pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237")
125146
@parametrize_dataset_mocks(DATASET_MOCKS)
126147
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
127148
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):

0 commit comments

Comments
 (0)