Skip to content

Commit 7ed02e4

Browse files
committed
Fix logprobs
This commit is mostly a cherry-pick of ggml-org/llama.cpp#10783. That PR and friends were partially cherry-picked by ikawrakow#723, but wasn't really in a working state yet. A couple of additional changes: * Include timing information in response, which was (unintentionally?) done in mainline since ggml-org/llama.cpp#10643. * Also return the actual logprobs for accepted draft tokens. This is still a TODO in mainline [1]. Note that there is a TG performance penalty to return the logprobs. Here are some numbers I got with Qwen2.5-Coder-32B-Instruct: * no draft, no logprobs: 12.81 tok/s * no draft, with logprobs: 12.02 tok/s (6.2% drop) * with draft, no logprobs: 36.59 tok/s * with draft, with logprobs: 29.08 tok/s (20.5% drop) [1] https://github.com/ggml-org/llama.cpp/blob/b6548/tools/server/server.cpp#L4019
1 parent 6d2e7ca commit 7ed02e4

File tree

2 files changed

+105
-130
lines changed

2 files changed

+105
-130
lines changed

examples/server/server.cpp

Lines changed: 73 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ struct slot_params {
556556
std::vector<std::string> antiprompt;
557557

558558
bool timings_per_token = false;
559+
bool post_sampling_probs = false;
559560
json input_prefix;
560561
json input_suffix;
561562

@@ -1545,6 +1546,8 @@ struct server_context {
15451546
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
15461547
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
15471548

1549+
slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", default_params.post_sampling_probs);
1550+
15481551
// speculative decoding parameters
15491552
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft);
15501553
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min);
@@ -1947,26 +1950,7 @@ struct server_context {
19471950
}
19481951

19491952
// check if there is incomplete UTF-8 character at the end
1950-
bool incomplete = false;
1951-
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
1952-
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
1953-
if ((c & 0xC0) == 0x80) {
1954-
// continuation byte: 10xxxxxx
1955-
continue;
1956-
}
1957-
if ((c & 0xE0) == 0xC0) {
1958-
// 2-byte character: 110xxxxx ...
1959-
incomplete = i < 2;
1960-
} else if ((c & 0xF0) == 0xE0) {
1961-
// 3-byte character: 1110xxxx ...
1962-
incomplete = i < 3;
1963-
} else if ((c & 0xF8) == 0xF0) {
1964-
// 4-byte character: 11110xxx ...
1965-
incomplete = i < 4;
1966-
}
1967-
// else 1-byte character or invalid byte
1968-
break;
1969-
}
1953+
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
19701954

19711955
if (!incomplete) {
19721956
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
@@ -2062,6 +2046,56 @@ struct server_context {
20622046
return slot.has_next_token; // continue
20632047
}
20642048

2049+
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
2050+
size_t n_probs = slot.sparams.n_probs;
2051+
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
2052+
2053+
if (post_sampling) {
2054+
const auto * cur_p = llama_sampling_get_candidates(slot.ctx_sampling);
2055+
const size_t max_probs = cur_p->size;
2056+
2057+
// set probability for sampled token
2058+
for (size_t i = 0; i < max_probs; i++) {
2059+
if (cur_p->data[i].id == result.tok) {
2060+
result.prob = cur_p->data[i].p;
2061+
break;
2062+
}
2063+
}
2064+
2065+
// set probability for top n_probs tokens
2066+
result.probs.reserve(max_probs);
2067+
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
2068+
result.probs.push_back({
2069+
cur_p->data[i].id,
2070+
llama_detokenize(ctx, {cur_p->data[i].id}, special),
2071+
cur_p->data[i].p
2072+
});
2073+
}
2074+
} else {
2075+
// TODO: optimize this with min-p optimization
2076+
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
2077+
2078+
// set probability for sampled token
2079+
for (size_t i = 0; i < n_vocab; i++) {
2080+
// set probability for sampled token
2081+
if (cur[i].id == result.tok) {
2082+
result.prob = cur[i].p;
2083+
break;
2084+
}
2085+
}
2086+
2087+
// set probability for top n_probs tokens
2088+
result.probs.reserve(n_probs);
2089+
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
2090+
result.probs.push_back({
2091+
cur[i].id,
2092+
llama_detokenize(ctx, {cur[i].id}, special),
2093+
cur[i].p
2094+
});
2095+
}
2096+
}
2097+
}
2098+
20652099
json get_formated_generation(const server_slot & slot) const {
20662100
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
20672101
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@@ -2159,38 +2193,19 @@ struct server_context {
21592193
res.stop = false;
21602194
res.stream = slot.params.stream;
21612195
res.content = tkn.text_to_send;
2196+
res.post_sampling_probs = slot.params.post_sampling_probs;
21622197
res.oaicompat = slot.params.oaicompat;
21632198
res.oaicompat_model = slot.params.oaicompat_model;
21642199
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
21652200
res.n_decoded = slot.n_decoded;
21662201
res.n_prompt_tokens = slot.n_prompt_tokens;
2167-
res.data = json {
2168-
{"content", tkn.text_to_send},
2169-
{"stop", false},
2170-
{"id_slot", slot.id},
2171-
{"multimodal", false}
2172-
};
21732202
slot.update_chat_msg(res.oaicompat_msg_diffs);
2174-
if (slot.sparams.n_probs > 0) {
2175-
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
2176-
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
2177-
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
2178-
2179-
std::vector<completion_token_output> probs_output;
2180-
if (probs_pos < probs_stop_pos) {
2181-
probs_output = std::vector<completion_token_output>(
2182-
slot.generated_token_probs.begin() + probs_pos,
2183-
slot.generated_token_probs.begin() + probs_stop_pos);
2184-
}
2185-
slot.n_sent_token_probs = probs_stop_pos;
21862203

2187-
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
2204+
// populate res.probs_output
2205+
if (slot.sparams.n_probs > 0) {
2206+
res.probs_output = {tkn}; // copy the token probs
21882207
}
21892208

2190-
if (slot.oaicompat) {
2191-
res.data["oaicompat_token_ctr"] = slot.n_decoded;
2192-
res.data["model"] = slot.oaicompat_model;
2193-
}
21942209
// populate timings if this is final response or timings_per_token is enabled
21952210
if (slot.params.timings_per_token) {
21962211
res.timings = slot.get_timings();
@@ -2207,56 +2222,30 @@ struct server_context {
22072222
res.stop = true; // to do: set value
22082223
res.stream = slot.params.stream;
22092224
res.content = slot.generated_text;
2225+
res.timings = slot.get_timings();
2226+
res.post_sampling_probs = slot.params.post_sampling_probs;
22102227
res.oaicompat = slot.params.oaicompat;
22112228
res.oaicompat_model = slot.params.oaicompat_model;
22122229
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
22132230
res.oaicompat_msg = slot.update_chat_msg(res.oaicompat_msg_diffs);
22142231
res.n_decoded = slot.n_decoded;
22152232
res.n_prompt_tokens = slot.n_prompt_tokens;
22162233
res.oaicompat_model = slot.oaicompat_model;
2217-
res.data = json {
2218-
{"content", !slot.params.stream ? slot.generated_text : ""},
2219-
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
2220-
{"id_slot", slot.id},
2221-
{"stop", true},
2222-
{"model", params.model_alias},
2223-
{"tokens_predicted", slot.n_decoded},
2224-
{"tokens_evaluated", slot.n_prompt_tokens},
2225-
{"generation_settings", get_formated_generation(slot)},
2226-
{"prompt", slot.prompt},
2227-
{"truncated", slot.truncated},
2228-
{"stopped_eos", slot.stopped_eos},
2229-
{"stopped_word", slot.stopped_word},
2230-
{"stopped_limit", slot.stopped_limit},
2231-
{"stopping_word", slot.stopping_word},
2232-
{"tokens_cached", slot.n_past},
2233-
{"timings", slot.get_formated_timings()},
2234-
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
2235-
};
22362234

2235+
// populate res.probs_output
22372236
if (slot.sparams.n_probs > 0) {
2238-
std::vector<completion_token_output> probs;
22392237
if (!slot.params.stream && slot.stopped_word) {
22402238
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
22412239

22422240
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
2243-
probs = std::vector<completion_token_output>(
2241+
res.probs_output = std::vector<completion_token_output>(
22442242
slot.generated_token_probs.begin(),
22452243
slot.generated_token_probs.end() - safe_offset);
22462244
} else {
2247-
probs = std::vector<completion_token_output>(
2245+
res.probs_output = std::vector<completion_token_output>(
22482246
slot.generated_token_probs.begin(),
22492247
slot.generated_token_probs.end());
22502248
}
2251-
//res.generation_params = slot.params;
2252-
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
2253-
}
2254-
2255-
res.timings = slot.get_timings();
2256-
2257-
if (slot.oaicompat) {
2258-
res.data["oaicompat_token_ctr"] = slot.n_decoded;
2259-
res.data["model"] = slot.oaicompat_model;
22602249
}
22612250

22622251
queue_results.send(std::move(res));
@@ -3194,7 +3183,8 @@ struct server_context {
31943183
}
31953184

31963185
completion_token_output result;
3197-
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
3186+
const int tok_idx = slot.i_batch - i;
3187+
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx);
31983188

31993189
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
32003190

@@ -3210,35 +3200,12 @@ struct server_context {
32103200

32113201
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
32123202

3213-
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
32143203
result.tok = id;
3204+
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
32153205
result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
32163206

3217-
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
3218-
if (n_probs > 0) {
3219-
const size_t n_valid = slot.ctx_sampling->n_valid;
3220-
3221-
// Make sure at least n_probs top tokens are at the front of the vector:
3222-
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
3223-
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
3224-
}
3225-
3226-
if (slot.sparams.temp == 0.0f) {
3227-
// With greedy sampling the probabilities have possibly not been calculated.
3228-
for (size_t i = 0; i < n_probs; ++i) {
3229-
result.probs.push_back({
3230-
cur_p.data[i].id,llama_detokenize(ctx, {cur_p.data[i].id}, params.special),
3231-
i == 0 ? 1.0f : 0.0f
3232-
});
3233-
}
3234-
} else {
3235-
for (size_t i = 0; i < n_probs; ++i) {
3236-
result.probs.push_back({
3237-
cur_p.data[i].id, llama_detokenize(ctx, {cur_p.data[i].id}, params.special),
3238-
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
3239-
});
3240-
}
3241-
}
3207+
if (slot.sparams.n_probs > 0) {
3208+
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, tok_idx);
32423209
}
32433210

32443211
if (!process_token(result, slot)) {
@@ -3343,7 +3310,11 @@ struct server_context {
33433310

33443311
result.tok = ids[i];
33453312
result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special);
3346-
// result.prob = 1.0f; // set later
3313+
result.prob = 1.0f; // set later
3314+
3315+
if (slot.sparams.n_probs > 0) {
3316+
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, i);
3317+
}
33473318

33483319
if (!process_token(result, slot)) {
33493320
// release slot because of stop condition

examples/server/utils.hpp

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -381,31 +381,6 @@ struct completion_token_output {
381381
}
382382
};
383383

384-
// convert a vector of completion_token_output to json
385-
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
386-
json out = json::array();
387-
388-
for (const auto & prob : probs) {
389-
json probs_for_token = json::array();
390-
391-
for (const auto & p : prob.probs) {
392-
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
393-
probs_for_token.push_back(json {
394-
{"tok_str", tok_str},
395-
{"prob", p.prob},
396-
});
397-
}
398-
399-
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
400-
out.push_back(json {
401-
{"content", tok_str},
402-
{"probs", probs_for_token},
403-
});
404-
}
405-
406-
return out;
407-
}
408-
409384

410385
//
411386
// OAI utils
@@ -616,13 +591,12 @@ static json oaicompat_chat_params_parse(
616591

617592
// Handle "logprobs" field
618593
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
619-
if (body.contains("logprobs")) {
594+
if (json_value(body, "logprobs", false)) {
620595
if (has_tools && stream) {
621596
throw std::runtime_error("logprobs is not supported with tools + stream");
622597
}
623598
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
624-
}
625-
else if (body.contains("top_logprobs")) {
599+
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
626600
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
627601
}
628602

@@ -715,3 +689,33 @@ static json format_error_response(const std::string & message, const enum error_
715689
{"type", type_str},
716690
};
717691
}
692+
693+
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
694+
std::vector<llama_token_data> cur;
695+
const auto * logits = llama_get_logits_ith(ctx, idx);
696+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
697+
698+
cur.resize(n_vocab);
699+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
700+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
701+
}
702+
703+
// sort tokens by logits
704+
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
705+
return a.logit > b.logit;
706+
});
707+
708+
// apply softmax
709+
float max_l = cur[0].logit;
710+
float cum_sum = 0.0f;
711+
for (size_t i = 0; i < cur.size(); ++i) {
712+
float p = expf(cur[i].logit - max_l);
713+
cur[i].p = p;
714+
cum_sum += p;
715+
}
716+
for (size_t i = 0; i < cur.size(); ++i) {
717+
cur[i].p /= cum_sum;
718+
}
719+
720+
return cur;
721+
}

0 commit comments

Comments
 (0)