Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions torchtrain/datasets/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AlpacaDataset(IterableDataset):
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process
infinite: whether to loop infinitely over the dataset

Data input format:
{
Expand All @@ -43,38 +44,48 @@ def __init__(
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
infinite: bool = False,
**kwargs
) -> None:
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
# Setting `streaming=True` works for large dataset, but the speed is slow.
ds = load_dataset("tatsu-lab/alpaca", split="train")
self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size))
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

for sample in self.data_iterator:
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)
while True:
for sample in iter(self._data):
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)

while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
# batched_x = x.reshape(self.batch_size, -1)
# update tokens to the remaining tokens
all_tokens = all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield input, label
while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
# batched_x = x.reshape(self.batch_size, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can delete the staled comment?

# update tokens to the remaining tokens
all_tokens = all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield input, label
if not self.infinite:
break
Comment on lines +77 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add some mechanic to allow a stop? self.infinite is a constant after being initialized.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cmd + c should be sufficient?



def build_alpaca_data_loader(
tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank
tokenizer: TokenizerIf,
batch_size: int,
seq_len: int,
world_size: int,
rank: int,
infinite: bool = True,
):
alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank)
alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank, infinite)

return DataLoader(alpaca_ds, batch_size=batch_size)
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def main(args):
)
checkpoint.load()

data_iterator = iter(data_loader)

with maybe_run_profiler() as torch_profiler:
checkpoint.reset()
# variables used to keep info for metrics logging
Expand All @@ -167,7 +169,7 @@ def main(args):
while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
# get batch
batch = next(iter(data_loader))
batch = next(data_iterator)
input_ids, labels = batch
input_ids = input_ids.cuda()
labels = labels.cuda()
Expand Down