Skip to content

Commit 0b46370

Browse files
wwoodsTMarthw
authored andcommitted
server : samplers accept the prompt correctly (ggml-org#10019)
1 parent ac15559 commit 0b46370

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

examples/server/server.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,17 +2163,10 @@ struct server_context {
21632163
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
21642164
}
21652165

2166-
common_sampler_reset(slot.smpl);
2167-
21682166
if (slot.params.cache_prompt) {
21692167
// reuse any previously computed tokens that are common with the new prompt
21702168
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
21712169

2172-
// push the prompt into the sampling context (do not apply grammar)
2173-
for (int i = 0; i < slot.n_past; ++i) {
2174-
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
2175-
}
2176-
21772170
// reuse chunks from the cached prompt by shifting their KV cache in the new position
21782171
if (params.n_cache_reuse > 0) {
21792172
size_t head_c = slot.n_past; // cache
@@ -2206,8 +2199,6 @@ struct server_context {
22062199
for (size_t i = 0; i < n_match; i++) {
22072200
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
22082201

2209-
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
2210-
22112202
slot.n_past++;
22122203
}
22132204

@@ -2259,8 +2250,6 @@ struct server_context {
22592250

22602251
// there is no common part left
22612252
slot.n_past = 0;
2262-
2263-
common_sampler_reset(slot.smpl);
22642253
}
22652254

22662255
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
@@ -2288,6 +2277,13 @@ struct server_context {
22882277

22892278
GGML_ASSERT(batch.n_tokens > 0);
22902279

2280+
common_sampler_reset(slot.smpl);
2281+
2282+
// Process all prompt tokens through sampler system
2283+
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
2284+
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
2285+
}
2286+
22912287
// extract the logits only for the last token
22922288
batch.logits[batch.n_tokens - 1] = true;
22932289

0 commit comments

Comments
 (0)