|
7 | 7 | import torch
|
8 | 8 | from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
|
9 | 9 | from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
|
| 10 | +from torch.utils.data import DataLoader |
10 | 11 | from torch.utils.data.graph import traverse
|
11 | 12 | from torch.utils.data.graph_settings import get_all_graph_pipes
|
12 | 13 | from torchdata.datapipes.iter import Shuffler, ShardingFilter
|
@@ -109,19 +110,39 @@ def test_transformable(self, test_home, dataset_mock, config):
|
109 | 110 |
|
110 | 111 | next(iter(dataset.map(transforms.Identity())))
|
111 | 112 |
|
112 |
| - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") |
| 113 | + @pytest.mark.parametrize("only_datapipe", [False, True]) |
113 | 114 | @parametrize_dataset_mocks(DATASET_MOCKS)
|
114 |
| - def test_serializable(self, test_home, dataset_mock, config): |
| 115 | + def test_traversable(self, test_home, dataset_mock, config, only_datapipe): |
115 | 116 | dataset_mock.prepare(test_home, config)
|
| 117 | + dataset = datasets.load(dataset_mock.name, **config) |
116 | 118 |
|
| 119 | + traverse(dataset, only_datapipe=only_datapipe) |
| 120 | + |
| 121 | + @parametrize_dataset_mocks(DATASET_MOCKS) |
| 122 | + def test_serializable(self, test_home, dataset_mock, config): |
| 123 | + dataset_mock.prepare(test_home, config) |
117 | 124 | dataset = datasets.load(dataset_mock.name, **config)
|
118 | 125 |
|
119 | 126 | pickle.dumps(dataset)
|
120 | 127 |
|
| 128 | + @pytest.mark.parametrize("num_workers", [0, 1]) |
| 129 | + @parametrize_dataset_mocks(DATASET_MOCKS) |
| 130 | + def test_data_loader(self, test_home, dataset_mock, config, num_workers): |
| 131 | + dataset_mock.prepare(test_home, config) |
| 132 | + dataset = datasets.load(dataset_mock.name, **config) |
| 133 | + |
| 134 | + dl = DataLoader( |
| 135 | + dataset, |
| 136 | + batch_size=2, |
| 137 | + num_workers=num_workers, |
| 138 | + collate_fn=lambda batch: batch, |
| 139 | + ) |
| 140 | + |
| 141 | + next(iter(dl)) |
| 142 | + |
121 | 143 | # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
|
122 | 144 | # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
|
123 | 145 | # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
|
124 |
| - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") |
125 | 146 | @parametrize_dataset_mocks(DATASET_MOCKS)
|
126 | 147 | @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
|
127 | 148 | def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
|
|
0 commit comments