@@ -181,54 +181,45 @@ int main(int argc, char ** argv) {
181
181
GGML_ASSERT (ids.size () > 0 ); // there will always be at least one accepted token
182
182
183
183
n_past += ids.size () - 1 ;
184
- n_drafted += batch_tgt. n_tokens - 1 ;
184
+ n_drafted += draft. size (); // note: we ignore the discarded small drafts
185
185
n_accept += ids.size () - 1 ;
186
+ n_predict += ids.size ();
186
187
187
188
// process the accepted tokens and update contexts
188
189
//
189
190
// this is the standard token post-processing that we normally do
190
191
// in this case, we do it for a group of accepted tokens at once
191
192
//
192
- {
193
- llama_token id;
194
- std::string token_str;
195
-
196
- for (size_t i = 0 ; i < ids.size (); ++i) {
197
- id = ids[i];
198
-
199
- ++n_predict;
200
-
201
- if (llama_token_is_eog (model_tgt, id)) {
202
- has_eos = true ;
203
- break ;
204
- }
193
+ for (size_t i = 0 ; i < ids.size (); ++i) {
194
+ const llama_token id = ids[i];
205
195
206
- token_str = common_token_to_piece (ctx_tgt, id);
207
-
208
- if (params.use_color && i + 1 < ids.size ()) {
209
- LOG (" \u001b [%dm%s\u001b [37m" , (36 - 0 % 6 ), token_str.c_str ());
210
- } else {
211
- LOG (" %s" , token_str.c_str ());
212
- }
213
- }
196
+ prompt_tgt.push_back (id_last);
197
+ id_last = id;
214
198
215
- if ((params.n_predict >= 0 && n_predict > params.n_predict ) || has_eos) {
199
+ if (llama_token_is_eog (model_tgt, id)) {
200
+ has_eos = true ;
216
201
break ;
217
202
}
218
203
219
- LOG_DBG ( " accepted %d/%d draft tokens, the last target token is: (%d, '%s') \n " , ( int ) ids. size () - 1 , ( int ) draft. size (), id, token_str. c_str () );
204
+ const std::string token_str = common_token_to_piece (ctx_tgt, id );
220
205
221
- {
222
- LOG_DBG ( " clear kv cache from any extra tokens, n_past = %d \n " , n_past );
223
-
224
- llama_kv_cache_seq_rm (ctx_tgt, 0 , n_past, - 1 );
206
+ if (params. use_color && i + 1 < ids. size ()) {
207
+ LOG ( " \u001b [%dm%s \u001b [37m " , ( 36 - 0 % 6 ), token_str. c_str () );
208
+ } else {
209
+ LOG ( " %s " , token_str. c_str () );
225
210
}
211
+ }
226
212
227
- prompt_tgt.push_back (id_last);
228
- prompt_tgt.insert (prompt_tgt.end (), ids.begin (), ids.end () - 1 );
213
+ LOG_DBG (" accepted %d/%d draft tokens, the last target token is: (%d)\n " , (int ) ids.size () - 1 , (int ) draft.size (), id_last);
229
214
230
- // remember the last accepted token for the next iteration
231
- id_last = id;
215
+ {
216
+ LOG_DBG (" clear kv cache from any extra tokens, n_past = %d\n " , n_past);
217
+
218
+ llama_kv_cache_seq_rm (ctx_tgt, 0 , n_past, -1 );
219
+ }
220
+
221
+ if ((params.n_predict >= 0 && n_predict > params.n_predict ) || has_eos) {
222
+ break ;
232
223
}
233
224
}
234
225
0 commit comments