@@ -20,6 +20,7 @@ class AlpacaDataset(IterableDataset):
20
20
seq_len (int): max sequence length
21
21
world_size (int): number of data parallel processes participating in training
22
22
rank (int): rank of the current data parallel process
23
+ infinite: whether to loop infinitely over the dataset
23
24
24
25
Data input format:
25
26
{
@@ -43,38 +44,48 @@ def __init__(
43
44
seq_len : int = 2048 ,
44
45
world_size : int = 1 ,
45
46
rank : int = 0 ,
47
+ infinite : bool = False ,
46
48
** kwargs
47
49
) -> None :
48
50
# TODO: This is a temporary solution for small datasets like Alpaca.
49
51
# For larger datasets we need to use a more scalable approach.
50
52
# Setting `streaming=True` works for large dataset, but the speed is slow.
51
53
ds = load_dataset ("tatsu-lab/alpaca" , split = "train" )
52
- self .data_iterator = iter ( split_dataset_by_node (ds , rank , world_size ) )
54
+ self ._data = split_dataset_by_node (ds , rank , world_size )
53
55
self ._tokenizer = tokenizer
54
56
self .seq_len = seq_len
57
+ self .infinite = infinite
55
58
56
59
def __iter__ (self ):
57
60
max_buffer_token_len = 1 + self .seq_len
58
61
all_tokens : List [int ] = []
59
62
60
- for sample in self .data_iterator :
61
- sample_text = sample ["text" ]
62
- sample_tokens = self ._tokenizer .encode (sample_text , bos = True , eos = True )
63
- all_tokens .extend (sample_tokens )
63
+ while True :
64
+ for sample in iter (self ._data ):
65
+ sample_text = sample ["text" ]
66
+ sample_tokens = self ._tokenizer .encode (sample_text , bos = True , eos = True )
67
+ all_tokens .extend (sample_tokens )
64
68
65
- while len (all_tokens ) >= max_buffer_token_len :
66
- x = torch .LongTensor (all_tokens [:max_buffer_token_len ])
67
- # batched_x = x.reshape(self.batch_size, -1)
68
- # update tokens to the remaining tokens
69
- all_tokens = all_tokens [max_buffer_token_len :]
70
- input = x [:- 1 ]
71
- label = x [1 :]
72
- yield input , label
69
+ while len (all_tokens ) >= max_buffer_token_len :
70
+ x = torch .LongTensor (all_tokens [:max_buffer_token_len ])
71
+ # batched_x = x.reshape(self.batch_size, -1)
72
+ # update tokens to the remaining tokens
73
+ all_tokens = all_tokens [max_buffer_token_len :]
74
+ input = x [:- 1 ]
75
+ label = x [1 :]
76
+ yield input , label
77
+ if not self .infinite :
78
+ break
73
79
74
80
75
81
def build_alpaca_data_loader (
76
- tokenizer : TokenizerIf , batch_size : int , seq_len : int , world_size , rank
82
+ tokenizer : TokenizerIf ,
83
+ batch_size : int ,
84
+ seq_len : int ,
85
+ world_size : int ,
86
+ rank : int ,
87
+ infinite : bool = True ,
77
88
):
78
- alpaca_ds = AlpacaDataset (tokenizer , seq_len , world_size , rank )
89
+ alpaca_ds = AlpacaDataset (tokenizer , seq_len , world_size , rank , infinite )
79
90
80
91
return DataLoader (alpaca_ds , batch_size = batch_size )
0 commit comments