Skip to content

Commit f9b682c

Browse files
committed
add data loader test
1 parent d8aeb6d commit f9b682c

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

test/test_prototype_builtin_datasets.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
99
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
10+
from torch.utils.data import DataLoader
1011
from torch.utils.data.graph import traverse
1112
from torch.utils.data.graph_settings import get_all_graph_pipes
1213
from torchdata.datapipes.iter import Shuffler, ShardingFilter
@@ -40,7 +41,7 @@ def test_coverage():
4041
)
4142

4243

43-
@pytest.mark.filterwarnings("error")
44+
# @pytest.mark.filterwarnings("error")
4445
class TestCommon:
4546
@pytest.mark.parametrize("name", datasets.list_datasets())
4647
def test_info(self, name):
@@ -123,6 +124,22 @@ def test_serializable(self, test_home, dataset_mock, config):
123124

124125
pickle.dumps(dataset)
125126

127+
@pytest.mark.parametrize("num_workers", [0, 1])
128+
@parametrize_dataset_mocks(DATASET_MOCKS)
129+
def test_data_loader(self, test_home, dataset_mock, config, num_workers):
130+
dataset_mock.prepare(test_home, config)
131+
dataset = datasets.load(dataset_mock.name, **config)
132+
133+
dl = DataLoader(
134+
dataset,
135+
batch_size=2,
136+
num_workers=num_workers,
137+
collate_fn=lambda batch: batch,
138+
)
139+
140+
for _ in dl:
141+
pass
142+
126143
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
127144
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
128145
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.

0 commit comments

Comments
 (0)