@@ -117,7 +117,8 @@ int main(int argc, char ** argv) {
117
117
llama_token id_last = inp.back ();
118
118
119
119
// all tokens currently in the target context
120
- auto prompt_tgt = std::vector<llama_token>(inp.begin (), inp.end () - 1 );
120
+ llama_tokens prompt_tgt (inp.begin (), inp.end () - 1 );
121
+ prompt_tgt.reserve (llama_n_ctx (ctx_tgt));
121
122
122
123
int n_past = inp.size () - 1 ;
123
124
@@ -181,54 +182,44 @@ int main(int argc, char ** argv) {
181
182
GGML_ASSERT (ids.size () > 0 ); // there will always be at least one accepted token
182
183
183
184
n_past += ids.size () - 1 ;
184
- n_drafted += batch_tgt. n_tokens - 1 ;
185
+ n_drafted += draft. size (); // note: we ignore the discarded small drafts
185
186
n_accept += ids.size () - 1 ;
187
+ n_predict += ids.size ();
186
188
187
189
// process the accepted tokens and update contexts
188
190
//
189
191
// this is the standard token post-processing that we normally do
190
192
// in this case, we do it for a group of accepted tokens at once
191
193
//
192
- {
193
- llama_token id;
194
- std::string token_str;
195
-
196
- for (size_t i = 0 ; i < ids.size (); ++i) {
197
- id = ids[i];
198
-
199
- ++n_predict;
200
-
201
- if (llama_token_is_eog (model_tgt, id)) {
202
- has_eos = true ;
203
- break ;
204
- }
205
-
206
- token_str = common_token_to_piece (ctx_tgt, id);
194
+ for (size_t i = 0 ; i < ids.size (); ++i) {
195
+ prompt_tgt.push_back (id_last);
207
196
208
- if (params.use_color && i + 1 < ids.size ()) {
209
- LOG (" \u001b [%dm%s\u001b [37m" , (36 - 0 % 6 ), token_str.c_str ());
210
- } else {
211
- LOG (" %s" , token_str.c_str ());
212
- }
213
- }
197
+ id_last = ids[i];
214
198
215
- if ((params.n_predict >= 0 && n_predict > params.n_predict ) || has_eos) {
199
+ if (llama_token_is_eog (model_tgt, id_last)) {
200
+ has_eos = true ;
216
201
break ;
217
202
}
218
203
219
- LOG_DBG ( " accepted %d/%d draft tokens, the last target token is: (%d, '%s') \n " , ( int ) ids. size () - 1 , ( int ) draft. size (), id, token_str. c_str () );
204
+ const std::string token_str = common_token_to_piece (ctx_tgt, id_last );
220
205
221
- {
222
- LOG_DBG ( " clear kv cache from any extra tokens, n_past = %d \n " , n_past );
223
-
224
- llama_kv_cache_seq_rm (ctx_tgt, 0 , n_past, - 1 );
206
+ if (params. use_color && i + 1 < ids. size ()) {
207
+ LOG ( " \u001b [%dm%s \u001b [37m " , ( 36 - 0 % 6 ), token_str. c_str () );
208
+ } else {
209
+ LOG ( " %s " , token_str. c_str () );
225
210
}
211
+ }
226
212
227
- prompt_tgt.push_back (id_last);
228
- prompt_tgt.insert (prompt_tgt.end (), ids.begin (), ids.end () - 1 );
213
+ LOG_DBG (" accepted %d/%d draft tokens, the last target token is: (%d)\n " , (int ) ids.size () - 1 , (int ) draft.size (), id_last);
214
+
215
+ {
216
+ LOG_DBG (" clear kv cache from any extra tokens, n_past = %d\n " , n_past);
217
+
218
+ llama_kv_cache_seq_rm (ctx_tgt, 0 , n_past, -1 );
219
+ }
229
220
230
- // remember the last accepted token for the next iteration
231
- id_last = id ;
221
+ if ((params. n_predict >= 0 && n_predict > params. n_predict ) || has_eos) {
222
+ break ;
232
223
}
233
224
}
234
225
0 commit comments