Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions torchtrain/datasets/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchtrain.datasets.tokenizer import TokenizerIf

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node


class AlpacaDataset(IterableDataset):
Expand Down Expand Up @@ -44,32 +45,24 @@ def __init__(
rank: int = 0,
**kwargs
) -> None:
self._data = load_dataset("tatsu-lab/alpaca", split="train")
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
# Setting `streaming=True` works for large dataset, but the speed is slow.
ds = load_dataset("tatsu-lab/alpaca", split="train")
self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size))
self._tokenizer = tokenizer
self.data_iterator = iter(self._data)
self.seq_len = seq_len
self.world_size = world_size
self.rank = rank
self.response_tag = "\n\n### Response:\n"

def __len__(self):
return len(self._data)

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

for idx, sample in enumerate(self.data_iterator):
# select samples to pack in a round-robin fashion
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
if idx % self.world_size != self.rank:
continue
for sample in self.data_iterator:
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)

if len(all_tokens) >= max_buffer_token_len:
while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
# batched_x = x.reshape(self.batch_size, -1)
# update tokens to the remaining tokens
Expand Down