|
7 | 7 | import torch
|
8 | 8 | from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
|
9 | 9 | from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
|
| 10 | +from torch.utils.data import DataLoader |
10 | 11 | from torch.utils.data.graph import traverse
|
11 | 12 | from torch.utils.data.graph_settings import get_all_graph_pipes
|
12 | 13 | from torchdata.datapipes.iter import Shuffler, ShardingFilter
|
@@ -40,7 +41,7 @@ def test_coverage():
|
40 | 41 | )
|
41 | 42 |
|
42 | 43 |
|
43 |
| -@pytest.mark.filterwarnings("error") |
| 44 | +# @pytest.mark.filterwarnings("error") |
44 | 45 | class TestCommon:
|
45 | 46 | @pytest.mark.parametrize("name", datasets.list_datasets())
|
46 | 47 | def test_info(self, name):
|
@@ -123,6 +124,22 @@ def test_serializable(self, test_home, dataset_mock, config):
|
123 | 124 |
|
124 | 125 | pickle.dumps(dataset)
|
125 | 126 |
|
| 127 | + @pytest.mark.parametrize("num_workers", [0, 1]) |
| 128 | + @parametrize_dataset_mocks(DATASET_MOCKS) |
| 129 | + def test_data_loader(self, test_home, dataset_mock, config, num_workers): |
| 130 | + dataset_mock.prepare(test_home, config) |
| 131 | + dataset = datasets.load(dataset_mock.name, **config) |
| 132 | + |
| 133 | + dl = DataLoader( |
| 134 | + dataset, |
| 135 | + batch_size=2, |
| 136 | + num_workers=num_workers, |
| 137 | + collate_fn=lambda batch: batch, |
| 138 | + ) |
| 139 | + |
| 140 | + for _ in dl: |
| 141 | + pass |
| 142 | + |
126 | 143 | # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
|
127 | 144 | # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
|
128 | 145 | # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
|
|
0 commit comments