Skip to content

Commit 9aac9ee

Browse files
committed
Simpler collate_fn
1 parent e1919dc commit 9aac9ee

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

test/test_prototype_builtin_datasets.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
1010
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
11-
from torch.utils.data import DataLoader, default_collate
11+
from torch.utils.data import DataLoader
1212
from torch.utils.data.graph import traverse
1313
from torch.utils.data.graph_settings import get_all_graph_pipes
1414
from torchdata.datapipes.iter import Shuffler, ShardingFilter
@@ -143,11 +143,7 @@ def test_ddp(self, test_home, dataset_mock, config, ddp_fixture):
143143

144144
dataset = datasets.load(dataset_mock.name, **config)
145145

146-
# Ugly hack: custom collate_fn because the default one doesn't handle None values
147-
def collate_fn(batch):
148-
return default_collate([x["image"] for x in batch])
149-
150-
dl = DataLoader(dataset, collate_fn=collate_fn)
146+
dl = DataLoader(dataset, collate_fn=lambda batch: batch)
151147

152148
next(iter(dl))
153149

0 commit comments

Comments
 (0)