Skip to content

Commit 2649e27

Browse files
committed
speculative : simplify the implementation
ggml-ci
1 parent 9fd8c26 commit 2649e27

File tree

1 file changed

+23
-32
lines changed

1 file changed

+23
-32
lines changed

examples/speculative-simple/speculative-simple.cpp

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -181,54 +181,45 @@ int main(int argc, char ** argv) {
181181
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
182182

183183
n_past += ids.size() - 1;
184-
n_drafted += batch_tgt.n_tokens - 1;
184+
n_drafted += draft.size(); // note: we ignore the discarded small drafts
185185
n_accept += ids.size() - 1;
186+
n_predict += ids.size();
186187

187188
// process the accepted tokens and update contexts
188189
//
189190
// this is the standard token post-processing that we normally do
190191
// in this case, we do it for a group of accepted tokens at once
191192
//
192-
{
193-
llama_token id;
194-
std::string token_str;
195-
196-
for (size_t i = 0; i < ids.size(); ++i) {
197-
id = ids[i];
198-
199-
++n_predict;
200-
201-
if (llama_token_is_eog(model_tgt, id)) {
202-
has_eos = true;
203-
break;
204-
}
193+
for (size_t i = 0; i < ids.size(); ++i) {
194+
const llama_token id = ids[i];
205195

206-
token_str = common_token_to_piece(ctx_tgt, id);
207-
208-
if (params.use_color && i + 1 < ids.size()) {
209-
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
210-
} else {
211-
LOG("%s", token_str.c_str());
212-
}
213-
}
196+
prompt_tgt.push_back(id_last);
197+
id_last = id;
214198

215-
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
199+
if (llama_token_is_eog(model_tgt, id)) {
200+
has_eos = true;
216201
break;
217202
}
218203

219-
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
204+
const std::string token_str = common_token_to_piece(ctx_tgt, id);
220205

221-
{
222-
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
223-
224-
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
206+
if (params.use_color && i + 1 < ids.size()) {
207+
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
208+
} else {
209+
LOG("%s", token_str.c_str());
225210
}
211+
}
226212

227-
prompt_tgt.push_back(id_last);
228-
prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
213+
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
229214

230-
// remember the last accepted token for the next iteration
231-
id_last = id;
215+
{
216+
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
217+
218+
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
219+
}
220+
221+
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
222+
break;
232223
}
233224
}
234225

0 commit comments

Comments
 (0)