Skip to content

Commit d506308

Browse files
committed
Update on "Convert generator in Sampler back to lazy construction"
Fixes #63609 - Revert #63026 - Sampler is expected to be re-seeded if user specify seed before each epoch - Can not attach generator to self with `__iter__` because multiple iterators will ruin the use case - Add tests to prevent the same case for different Samplers Differential Revision: [D30451774](https://our.internmc.facebook.com/intern/diff/D30451774) [ghstack-poisoned]
2 parents 725eb13 + 0f695b8 commit d506308

File tree

2 files changed

+4
-33
lines changed

2 files changed

+4
-33
lines changed

test/test_dataloader.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,26 +1547,6 @@ def test_sampler_reproducibility(self):
15471547
ls[i].append(next(its[i]))
15481548
self.assertEqual(ls[0], ls[1])
15491549

1550-
def test_serialize_sampler(self):
1551-
from torch.utils.data import RandomSampler
1552-
sampler = RandomSampler(self.dataset, num_samples=5, replacement=True)
1553-
it1 = iter(sampler)
1554-
torch.manual_seed(0)
1555-
_ = next(it1)
1556-
1557-
torch.manual_seed(0)
1558-
seed = torch.empty((), dtype=torch.int64).random_()
1559-
self.assertEqual(list(sampler._iterators.values())[0].initial_seed(), seed)
1560-
1561-
it2 = iter(sampler)
1562-
torch.manual_seed(1)
1563-
_ = next(it2)
1564-
1565-
torch.manual_seed(1)
1566-
seed = torch.empty((), dtype=torch.int64).random_()
1567-
_ = list(it1)
1568-
self.assertEqual(list(sampler._iterators.values())[0].initial_seed(), seed)
1569-
15701550
def _test_sampler(self, **kwargs):
15711551
indices = range(2, 12) # using a regular iterable
15721552
dl = self._get_data_loader(self.dataset, sampler=indices, batch_size=2, **kwargs)

torch/utils/data/sampler.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ def __init__(self, data_source: Sized, replacement: bool = False,
8989
self.replacement = replacement
9090
self._num_samples = num_samples
9191
self.generator = generator
92-
# Used to save state of RNG per iterator
93-
self._iterators = {}
9492

9593
if not isinstance(self.replacement, bool):
9694
raise TypeError("replacement should be a boolean value, but got "
@@ -112,27 +110,20 @@ def num_samples(self) -> int:
112110
return self._num_samples
113111

114112
def __iter__(self) -> Iterator[int]:
113+
n = len(self.data_source)
115114
if self.generator is None:
116115
seed = int(torch.empty((), dtype=torch.int64).random_().item())
117116
generator = torch.Generator()
118117
generator.manual_seed(seed)
119118
else:
120119
generator = self.generator
121120

122-
it = self.iter_fn(generator)
123-
self._iterators[it] = generator
124-
yield from it
125-
del self._iterators[it]
126-
127-
def iter_fn(self, rng) -> Iterator[int]:
128-
n = len(self.data_source)
129-
130121
if self.replacement:
131122
for _ in range(self.num_samples // 32):
132-
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=rng).tolist()
133-
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=rng).tolist()
123+
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
124+
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
134125
else:
135-
yield from torch.randperm(n, generator=rng).tolist()
126+
yield from torch.randperm(n, generator=generator).tolist()
136127

137128
def __len__(self) -> int:
138129
return self.num_samples

0 commit comments

Comments
 (0)