Skip to content

Commit 1b736be

Browse files
committed
support infinite loop over alpaca dataset
ghstack-source-id: e9fa7fd Pull Request resolved: #66
1 parent 70be86e commit 1b736be

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

torchtrain/datasets/alpaca.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class AlpacaDataset(IterableDataset):
2020
seq_len (int): max sequence length
2121
world_size (int): number of data parallel processes participating in training
2222
rank (int): rank of the current data parallel process
23+
infinite: whether to loop infinitely over the dataset
2324
2425
Data input format:
2526
{
@@ -43,38 +44,48 @@ def __init__(
4344
seq_len: int = 2048,
4445
world_size: int = 1,
4546
rank: int = 0,
47+
infinite: bool = False,
4648
**kwargs
4749
) -> None:
4850
# TODO: This is a temporary solution for small datasets like Alpaca.
4951
# For larger datasets we need to use a more scalable approach.
5052
# Setting `streaming=True` works for large dataset, but the speed is slow.
5153
ds = load_dataset("tatsu-lab/alpaca", split="train")
52-
self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size))
54+
self._data = split_dataset_by_node(ds, rank, world_size)
5355
self._tokenizer = tokenizer
5456
self.seq_len = seq_len
57+
self.infinite = infinite
5558

5659
def __iter__(self):
5760
max_buffer_token_len = 1 + self.seq_len
5861
all_tokens: List[int] = []
5962

60-
for sample in self.data_iterator:
61-
sample_text = sample["text"]
62-
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
63-
all_tokens.extend(sample_tokens)
63+
while True:
64+
for sample in iter(self._data):
65+
sample_text = sample["text"]
66+
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
67+
all_tokens.extend(sample_tokens)
6468

65-
while len(all_tokens) >= max_buffer_token_len:
66-
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
67-
# batched_x = x.reshape(self.batch_size, -1)
68-
# update tokens to the remaining tokens
69-
all_tokens = all_tokens[max_buffer_token_len:]
70-
input = x[:-1]
71-
label = x[1:]
72-
yield input, label
69+
while len(all_tokens) >= max_buffer_token_len:
70+
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
71+
# batched_x = x.reshape(self.batch_size, -1)
72+
# update tokens to the remaining tokens
73+
all_tokens = all_tokens[max_buffer_token_len:]
74+
input = x[:-1]
75+
label = x[1:]
76+
yield input, label
77+
if not self.infinite:
78+
break
7379

7480

7581
def build_alpaca_data_loader(
76-
tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank
82+
tokenizer: TokenizerIf,
83+
batch_size: int,
84+
seq_len: int,
85+
world_size: int,
86+
rank: int,
87+
infinite: bool = True,
7788
):
78-
alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank)
89+
alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank, infinite)
7990

8091
return DataLoader(alpaca_ds, batch_size=batch_size)

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def main(args):
158158
)
159159
checkpoint.load()
160160

161+
data_iterator = iter(data_loader)
162+
161163
with maybe_run_profiler() as torch_profiler:
162164
checkpoint.reset()
163165
# variables used to keep info for metrics logging
@@ -167,7 +169,7 @@ def main(args):
167169
while train_state.step < args.steps or args.steps == -1:
168170
train_state.step += 1
169171
# get batch
170-
batch = next(iter(data_loader))
172+
batch = next(data_iterator)
171173
input_ids, labels = batch
172174
input_ids = input_ids.cuda()
173175
labels = labels.cuda()

0 commit comments

Comments
 (0)