Skip to content

Commit 52aed2b

Browse files
authored
4569 consistent rng for data loader (#5067)
Signed-off-by: Wenqi Li <[email protected]> Fixes #4569 ### Description with different `num_workers` the batch sampler's rng should be consistent, this PR keeps the initial random seed to achieve it. ### Status **Ready** ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). # the previous results for num_wokers==0 might change - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
1 parent f9e8e10 commit 52aed2b

File tree

7 files changed

+50
-28
lines changed

7 files changed

+50
-28
lines changed

monai/data/dataloader.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def __len__(self):
6666
num_workers: how many subprocesses to use for data
6767
loading. ``0`` means that the data will be loaded in the main process.
6868
(default: ``0``)
69+
collate_fn: default to :py:func:`monai.data.utils.list_data_collate`.
70+
worker_init_fn: default to :py:func:`monai.data.utils.worker_init_fn`.
6971
kwargs: other parameters for PyTorch DataLoader.
7072
"""
7173

@@ -74,11 +76,14 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
7476
# when num_workers > 0, random states are determined by worker_init_fn
7577
# this is to make the behavior consistent when num_workers == 0
7678
# torch.int64 doesn't work well on some versions of windows
77-
_seed = torch.empty((), dtype=torch.int32).random_(generator=None).item()
79+
_g = torch.random.default_generator if kwargs.get("generator", None) is None else kwargs["generator"]
80+
init_seed = _g.initial_seed()
81+
_seed = torch.empty((), dtype=torch.int64).random_(generator=_g).item()
7882
set_rnd(dataset, int(_seed))
83+
_g.manual_seed(init_seed)
7984
if "collate_fn" not in kwargs:
80-
kwargs.update({"collate_fn": list_data_collate})
85+
kwargs["collate_fn"] = list_data_collate
8186
if "worker_init_fn" not in kwargs:
82-
kwargs.update({"worker_init_fn": worker_init_fn})
87+
kwargs["worker_init_fn"] = worker_init_fn
8388

8489
super().__init__(dataset=dataset, num_workers=num_workers, **kwargs)

tests/test_compose.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,16 @@ def test_data_loader(self):
169169
set_determinism(seed=123)
170170
train_loader = DataLoader(train_ds, num_workers=0)
171171
out_1 = next(iter(train_loader))
172-
self.assertAlmostEqual(out_1.cpu().item(), 0.84291356)
172+
self.assertAlmostEqual(out_1.cpu().item(), 0.0409280)
173173

174174
if sys.platform != "win32": # skip multi-worker tests on win32
175175
train_loader = DataLoader(train_ds, num_workers=1)
176176
out_1 = next(iter(train_loader))
177-
self.assertAlmostEqual(out_1.cpu().item(), 0.180814653)
177+
self.assertAlmostEqual(out_1.cpu().item(), 0.78663897075)
178178

179179
train_loader = DataLoader(train_ds, num_workers=2)
180180
out_1 = next(iter(train_loader))
181-
self.assertAlmostEqual(out_1.cpu().item(), 0.04293707)
181+
self.assertAlmostEqual(out_1.cpu().item(), 0.785907334)
182182
set_determinism(None)
183183

184184
def test_data_loader_2(self):
@@ -191,16 +191,16 @@ def test_data_loader_2(self):
191191

192192
train_loader = DataLoader(train_ds, num_workers=0)
193193
out_2 = next(iter(train_loader))
194-
self.assertAlmostEqual(out_2.cpu().item(), 0.7858843729)
194+
self.assertAlmostEqual(out_2.cpu().item(), 0.98921915918)
195195

196196
if sys.platform != "win32": # skip multi-worker tests on win32
197197
train_loader = DataLoader(train_ds, num_workers=1)
198198
out_2 = next(iter(train_loader))
199-
self.assertAlmostEqual(out_2.cpu().item(), 0.305763411)
199+
self.assertAlmostEqual(out_2.cpu().item(), 0.32985207)
200200

201201
train_loader = DataLoader(train_ds, num_workers=2)
202202
out_1 = next(iter(train_loader))
203-
self.assertAlmostEqual(out_1.cpu().item(), 0.131966779)
203+
self.assertAlmostEqual(out_1.cpu().item(), 0.28602141572)
204204
set_determinism(None)
205205

206206
def test_flatten_and_len(self):

tests/test_dataloader.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,17 @@ def setUp(self):
7575
def tearDown(self):
7676
set_determinism(None)
7777

78-
def test_randomize(self):
78+
@parameterized.expand([[1], [0]])
79+
def test_randomize(self, workers):
80+
set_determinism(0)
7981
dataset = _RandomDataset()
80-
dataloader = DataLoader(dataset, batch_size=2, num_workers=3)
82+
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=workers)
8183
output = []
82-
for _ in range(2):
84+
for _ in range(1): # need persistent workers for reproducibility of num_workers 0, 1
8385
for batch in dataloader:
8486
output.extend(batch.data.numpy().flatten().tolist())
85-
self.assertListEqual(output, [594, 170, 524, 778, 370, 906, 292, 589, 762, 763, 156, 886, 42, 405, 221, 166])
87+
set_determinism(None)
88+
self.assertListEqual(output, [594, 170, 292, 589, 153, 811, 21, 550])
8689

8790
def test_zipdataset(self):
8891
dataset = ZipDataset([_RandomDataset(), ZipDataset([_RandomDataset(), _RandomDataset()])])

tests/test_grid_dataset.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_loading_array(self):
6060
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
6161
np.testing.assert_allclose(
6262
item[0],
63-
np.array([[[[8.0577, 9.0577], [12.0577, 13.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]),
63+
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
6464
rtol=1e-4,
6565
)
6666
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -69,7 +69,9 @@ def test_loading_array(self):
6969
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
7070
np.testing.assert_allclose(
7171
item[0],
72-
np.array([[[[7.6533, 8.6533], [11.6533, 12.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]),
72+
np.array(
73+
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
74+
),
7375
rtol=1e-3,
7476
)
7577
np.testing.assert_allclose(
@@ -102,7 +104,7 @@ def test_loading_dict(self):
102104
self.assertListEqual(item[0]["metadata"], ["test string", "test string"])
103105
np.testing.assert_allclose(
104106
item[0]["image"],
105-
np.array([[[[8.0577, 9.0577], [12.0577, 13.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]),
107+
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
106108
rtol=1e-4,
107109
)
108110
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -111,7 +113,9 @@ def test_loading_dict(self):
111113
np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2))
112114
np.testing.assert_allclose(
113115
item[0]["image"],
114-
np.array([[[[7.6533, 8.6533], [11.6533, 12.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]),
116+
np.array(
117+
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
118+
),
115119
rtol=1e-3,
116120
)
117121
np.testing.assert_allclose(

tests/test_patch_dataset.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_loading_array(self):
5959
np.testing.assert_allclose(
6060
item[0],
6161
np.array(
62-
[[[1.338681, 2.338681, 3.338681], [5.338681, 6.338681, 7.338681], [9.338681, 10.338681, 11.338681]]]
62+
[[[-0.593095, 0.406905, 1.406905], [3.406905, 4.406905, 5.406905], [7.406905, 8.406905, 9.406905]]]
6363
),
6464
rtol=1e-5,
6565
)
@@ -69,13 +69,7 @@ def test_loading_array(self):
6969
np.testing.assert_allclose(
7070
item[0],
7171
np.array(
72-
[
73-
[
74-
[4.957847, 5.957847, 6.957847],
75-
[8.957847, 9.957847, 10.957847],
76-
[12.957847, 13.957847, 14.957847],
77-
]
78-
]
72+
[[[0.234308, 1.234308, 2.234308], [4.234308, 5.234308, 6.234308], [8.234308, 9.234308, 10.234308]]]
7973
),
8074
rtol=1e-5,
8175
)

tests/test_thread_buffer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
import time
1414
import unittest
1515

16+
import torch
17+
1618
from monai.data import DataLoader, Dataset, ThreadBuffer, ThreadDataLoader
1719
from monai.transforms import Compose, SimulateDelayd
18-
from monai.utils import PerfContext
20+
from monai.utils import PerfContext, set_determinism
21+
from tests.utils import assert_allclose
1922

2023

2124
class TestDataLoader(unittest.TestCase):
@@ -53,6 +56,19 @@ def test_dataloader(self):
5356
self.assertEqual(d["label"][0], "spleen_label_19.nii.gz")
5457
self.assertEqual(d["label"][1], "spleen_label_31.nii.gz")
5558

59+
def test_deterministic(self):
60+
set_determinism(0)
61+
res_1 = list(ThreadDataLoader(torch.arange(5), batch_size=2, buffer_size=2, shuffle=True, num_workers=0))
62+
63+
set_determinism(0)
64+
num_workers = 2 if sys.platform == "linux" else 1
65+
res_2 = list(
66+
ThreadDataLoader(torch.arange(5), batch_size=2, buffer_size=3, shuffle=True, num_workers=num_workers)
67+
)
68+
69+
set_determinism(None)
70+
assert_allclose(torch.cat(res_1), torch.cat(res_2), type_test=False)
71+
5672
def test_time(self):
5773
dataset = Dataset(data=self.datalist * 2, transform=self.transform) # contains data for 2 batches
5874
dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0)

tests/testing_data/transform_metatensor_cases.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ TEST_CASE_3:
180180

181181
TEST_CASE_1_answer:
182182
load_shape: [1, 1, 33, 45, 54]
183-
affine: "$np.array([[-2, 0, 0, 34], [0, 2, 0, -64], [0, 0, 2, -54], [0, 0, 0, 1]], dtype=np.float64)"
183+
affine: "$np.array([[-2, 0, 0, 30], [0, 2, 0, -62], [0, 0, 2, -48], [0, 0, 0, 1]], dtype=np.float64)"
184184
inv_affine: "@init_affine"
185185
inv_shape: "@init_shape"
186186

@@ -192,6 +192,6 @@ TEST_CASE_2_answer:
192192

193193
TEST_CASE_3_answer:
194194
load_shape: [1, 1, 72, 57, 82]
195-
affine: "$np.array([[-1.343816, -0.682904, -0.234832, 76.01494], [0.309004, 0.653211, -1.734872, 24.511358], [-0.104049, 1.617199, 0.584171, -56.521294], [0, 0, 0, 1]], dtype=np.float64)"
195+
affine: "$np.array([[1.300558, -0.700765, -0.511861, -3.739605], [0.479723, -1.171149, 1.193079, -50.087933], [0.395736, 1.183532, 0.984201, -80.496605], [0, 0, 0, 1]], dtype=np.float64)"
196196
inv_affine: "@init_affine"
197197
inv_shape: "@init_shape"

0 commit comments

Comments
 (0)