Skip to content

Commit 811872a

Browse files
authored
speculative : simplify the implementation (#10504)
ggml-ci
1 parent 9a4b79b commit 811872a

File tree

1 file changed

+24
-33
lines changed

1 file changed

+24
-33
lines changed

examples/speculative-simple/speculative-simple.cpp

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ int main(int argc, char ** argv) {
117117
llama_token id_last = inp.back();
118118

119119
// all tokens currently in the target context
120-
auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1);
120+
llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
121+
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
121122

122123
int n_past = inp.size() - 1;
123124

@@ -181,54 +182,44 @@ int main(int argc, char ** argv) {
181182
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
182183

183184
n_past += ids.size() - 1;
184-
n_drafted += batch_tgt.n_tokens - 1;
185+
n_drafted += draft.size(); // note: we ignore the discarded small drafts
185186
n_accept += ids.size() - 1;
187+
n_predict += ids.size();
186188

187189
// process the accepted tokens and update contexts
188190
//
189191
// this is the standard token post-processing that we normally do
190192
// in this case, we do it for a group of accepted tokens at once
191193
//
192-
{
193-
llama_token id;
194-
std::string token_str;
195-
196-
for (size_t i = 0; i < ids.size(); ++i) {
197-
id = ids[i];
198-
199-
++n_predict;
200-
201-
if (llama_token_is_eog(model_tgt, id)) {
202-
has_eos = true;
203-
break;
204-
}
205-
206-
token_str = common_token_to_piece(ctx_tgt, id);
194+
for (size_t i = 0; i < ids.size(); ++i) {
195+
prompt_tgt.push_back(id_last);
207196

208-
if (params.use_color && i + 1 < ids.size()) {
209-
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
210-
} else {
211-
LOG("%s", token_str.c_str());
212-
}
213-
}
197+
id_last = ids[i];
214198

215-
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
199+
if (llama_token_is_eog(model_tgt, id_last)) {
200+
has_eos = true;
216201
break;
217202
}
218203

219-
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
204+
const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
220205

221-
{
222-
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
223-
224-
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
206+
if (params.use_color && i + 1 < ids.size()) {
207+
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
208+
} else {
209+
LOG("%s", token_str.c_str());
225210
}
211+
}
226212

227-
prompt_tgt.push_back(id_last);
228-
prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
213+
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
214+
215+
{
216+
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
217+
218+
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
219+
}
229220

230-
// remember the last accepted token for the next iteration
231-
id_last = id;
221+
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
222+
break;
232223
}
233224
}
234225

0 commit comments

Comments
 (0)