@@ -556,6 +556,7 @@ struct slot_params {
556
556
std::vector<std::string> antiprompt;
557
557
558
558
bool timings_per_token = false ;
559
+ bool post_sampling_probs = false ;
559
560
json input_prefix;
560
561
json input_suffix;
561
562
@@ -1545,6 +1546,8 @@ struct server_context {
1545
1546
slot.sparams .n_probs = json_value (data, " n_probs" , default_sparams.n_probs );
1546
1547
slot.sparams .min_keep = json_value (data, " min_keep" , default_sparams.min_keep );
1547
1548
1549
+ slot.params .post_sampling_probs = json_value (data, " post_sampling_probs" , default_params.post_sampling_probs );
1550
+
1548
1551
// speculative decoding parameters
1549
1552
slot.params .speculative .n_max = json_value (data, " speculative.n_max" , params.n_draft );
1550
1553
slot.params .speculative .n_min = json_value (data, " speculative.n_min" , params.n_draft_min );
@@ -1947,26 +1950,7 @@ struct server_context {
1947
1950
}
1948
1951
1949
1952
// check if there is incomplete UTF-8 character at the end
1950
- bool incomplete = false ;
1951
- for (unsigned i = 1 ; i < 5 && i <= slot.generated_text .size (); ++i) {
1952
- unsigned char c = slot.generated_text [slot.generated_text .size () - i];
1953
- if ((c & 0xC0 ) == 0x80 ) {
1954
- // continuation byte: 10xxxxxx
1955
- continue ;
1956
- }
1957
- if ((c & 0xE0 ) == 0xC0 ) {
1958
- // 2-byte character: 110xxxxx ...
1959
- incomplete = i < 2 ;
1960
- } else if ((c & 0xF0 ) == 0xE0 ) {
1961
- // 3-byte character: 1110xxxx ...
1962
- incomplete = i < 3 ;
1963
- } else if ((c & 0xF8 ) == 0xF0 ) {
1964
- // 4-byte character: 11110xxx ...
1965
- incomplete = i < 4 ;
1966
- }
1967
- // else 1-byte character or invalid byte
1968
- break ;
1969
- }
1953
+ bool incomplete = validate_utf8 (slot.generated_text ) < slot.generated_text .size ();
1970
1954
1971
1955
if (!incomplete) {
1972
1956
size_t pos = std::min (slot.n_sent_text , slot.generated_text .size ());
@@ -2062,6 +2046,49 @@ struct server_context {
2062
2046
return slot.has_next_token ; // continue
2063
2047
}
2064
2048
2049
+ void populate_token_probs (const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
2050
+ size_t n_probs = slot.sparams .n_probs ;
2051
+ size_t n_vocab = llama_n_vocab (llama_get_model (ctx));
2052
+
2053
+ if (post_sampling) {
2054
+ const auto * cur_p = llama_sampling_get_candidates (slot.ctx_sampling );
2055
+ const size_t max_probs = cur_p->size ;
2056
+
2057
+ // set probability for sampled token
2058
+ for (size_t i = 0 ; i < max_probs; i++) {
2059
+ if (cur_p->data [i].id == result.tok ) {
2060
+ result.prob = cur_p->data [i].p ;
2061
+ break ;
2062
+ }
2063
+ }
2064
+
2065
+ // set probability for top n_probs tokens
2066
+ result.probs .reserve (max_probs);
2067
+ for (size_t i = 0 ; i < std::min (max_probs, n_probs); i++) {
2068
+ result.probs .push_back ({
2069
+ cur_p->data [i].id ,
2070
+ llama_detokenize (ctx, {cur_p->data [i].id }, special),
2071
+ cur_p->data [i].p
2072
+ });
2073
+ }
2074
+ } else {
2075
+ auto &&[sampled_token_p, cur] = get_token_probabilities (ctx, idx, result.tok , n_probs);
2076
+
2077
+ // set probability for sampled token
2078
+ result.prob = sampled_token_p;
2079
+
2080
+ // set probability for top n_probs tokens
2081
+ result.probs .reserve (n_probs);
2082
+ for (size_t i = 0 ; i < std::min (n_vocab, n_probs); i++) {
2083
+ result.probs .push_back ({
2084
+ cur[i].id ,
2085
+ llama_detokenize (ctx, {cur[i].id }, special),
2086
+ cur[i].p
2087
+ });
2088
+ }
2089
+ }
2090
+ }
2091
+
2065
2092
json get_formated_generation (const server_slot & slot) const {
2066
2093
const auto eos_bias = slot.sparams .logit_bias .find (llama_token_eos (model));
2067
2094
const bool ignore_eos = eos_bias != slot.sparams .logit_bias .end () && eos_bias->second < 0 .0f && std::isinf (eos_bias->second );
@@ -2159,6 +2186,7 @@ struct server_context {
2159
2186
res.stop = false ;
2160
2187
res.stream = slot.params .stream ;
2161
2188
res.content = tkn.text_to_send ;
2189
+ res.post_sampling_probs = slot.params .post_sampling_probs ;
2162
2190
res.oaicompat = slot.params .oaicompat ;
2163
2191
res.oaicompat_model = slot.params .oaicompat_model ;
2164
2192
res.oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
@@ -2171,26 +2199,18 @@ struct server_context {
2171
2199
{" multimodal" , false }
2172
2200
};
2173
2201
slot.update_chat_msg (res.oaicompat_msg_diffs );
2174
- if (slot.sparams .n_probs > 0 ) {
2175
- const std::vector<llama_token> to_send_toks = llama_tokenize (ctx, tkn.text_to_send , false );
2176
- const size_t probs_pos = std::min (slot.n_sent_token_probs , slot.generated_token_probs .size ());
2177
- const size_t probs_stop_pos = std::min (slot.n_sent_token_probs + to_send_toks.size (), slot.generated_token_probs .size ());
2178
-
2179
- std::vector<completion_token_output> probs_output;
2180
- if (probs_pos < probs_stop_pos) {
2181
- probs_output = std::vector<completion_token_output>(
2182
- slot.generated_token_probs .begin () + probs_pos,
2183
- slot.generated_token_probs .begin () + probs_stop_pos);
2184
- }
2185
- slot.n_sent_token_probs = probs_stop_pos;
2186
2202
2187
- res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, probs_output);
2203
+ // populate res.probs_output
2204
+ if (slot.sparams .n_probs > 0 ) {
2205
+ res.probs_output = {tkn}; // copy the token probs
2206
+ res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, res.probs_output );
2188
2207
}
2189
2208
2190
2209
if (slot.oaicompat ) {
2191
2210
res.data [" oaicompat_token_ctr" ] = slot.n_decoded ;
2192
2211
res.data [" model" ] = slot.oaicompat_model ;
2193
2212
}
2213
+
2194
2214
// populate timings if this is final response or timings_per_token is enabled
2195
2215
if (slot.params .timings_per_token ) {
2196
2216
res.timings = slot.get_timings ();
@@ -2207,6 +2227,8 @@ struct server_context {
2207
2227
res.stop = true ; // to do: set value
2208
2228
res.stream = slot.params .stream ;
2209
2229
res.content = slot.generated_text ;
2230
+ res.timings = slot.get_timings ();
2231
+ res.post_sampling_probs = slot.params .post_sampling_probs ;
2210
2232
res.oaicompat = slot.params .oaicompat ;
2211
2233
res.oaicompat_model = slot.params .oaicompat_model ;
2212
2234
res.oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
@@ -2234,26 +2256,23 @@ struct server_context {
2234
2256
// {"oaicompat_chat_format", slot.params.oaicompat_chat_format},
2235
2257
};
2236
2258
2259
+ // populate res.probs_output
2237
2260
if (slot.sparams .n_probs > 0 ) {
2238
- std::vector<completion_token_output> probs;
2239
2261
if (!slot.params .stream && slot.stopped_word ) {
2240
2262
const std::vector<llama_token> stop_word_toks = llama_tokenize (ctx, slot.stopping_word , false );
2241
2263
2242
2264
size_t safe_offset = std::min (slot.generated_token_probs .size (), stop_word_toks.size ());
2243
- probs = std::vector<completion_token_output>(
2265
+ res. probs_output = std::vector<completion_token_output>(
2244
2266
slot.generated_token_probs .begin (),
2245
2267
slot.generated_token_probs .end () - safe_offset);
2246
2268
} else {
2247
- probs = std::vector<completion_token_output>(
2269
+ res. probs_output = std::vector<completion_token_output>(
2248
2270
slot.generated_token_probs .begin (),
2249
2271
slot.generated_token_probs .end ());
2250
2272
}
2251
- // res.generation_params = slot.params;
2252
- res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, probs);
2273
+ res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, res.probs_output );
2253
2274
}
2254
2275
2255
- res.timings = slot.get_timings ();
2256
-
2257
2276
if (slot.oaicompat ) {
2258
2277
res.data [" oaicompat_token_ctr" ] = slot.n_decoded ;
2259
2278
res.data [" model" ] = slot.oaicompat_model ;
@@ -3194,7 +3213,8 @@ struct server_context {
3194
3213
}
3195
3214
3196
3215
completion_token_output result;
3197
- const llama_token id = llama_sampling_sample (slot.ctx_sampling , ctx, NULL , slot.i_batch - i);
3216
+ const int tok_idx = slot.i_batch - i;
3217
+ const llama_token id = llama_sampling_sample (slot.ctx_sampling , ctx, NULL , tok_idx);
3198
3218
3199
3219
llama_sampling_accept (slot.ctx_sampling , ctx, id, true );
3200
3220
@@ -3210,35 +3230,12 @@ struct server_context {
3210
3230
3211
3231
slot.t_token_generation = (t_current - slot.t_start_generation ) / 1e3 ;
3212
3232
3213
- llama_token_data_array cur_p = { slot.ctx_sampling ->cur .data (), slot.ctx_sampling ->cur .size (), false };
3214
3233
result.tok = id;
3234
+ result.prob = 1 .0f ; // TODO: set it here instead of doing inside populate_token_probs
3215
3235
result.text_to_send = llama_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3216
3236
3217
- const size_t n_probs = std::min (cur_p.size , (size_t ) slot.sparams .n_probs );
3218
- if (n_probs > 0 ) {
3219
- const size_t n_valid = slot.ctx_sampling ->n_valid ;
3220
-
3221
- // Make sure at least n_probs top tokens are at the front of the vector:
3222
- if (slot.sparams .temp == 0 .0f && n_probs > n_valid) {
3223
- llama_sample_top_k (ctx, &cur_p, n_probs, 0 );
3224
- }
3225
-
3226
- if (slot.sparams .temp == 0 .0f ) {
3227
- // With greedy sampling the probabilities have possibly not been calculated.
3228
- for (size_t i = 0 ; i < n_probs; ++i) {
3229
- result.probs .push_back ({
3230
- cur_p.data [i].id ,llama_detokenize (ctx, {cur_p.data [i].id }, params.special ),
3231
- i == 0 ? 1 .0f : 0 .0f
3232
- });
3233
- }
3234
- } else {
3235
- for (size_t i = 0 ; i < n_probs; ++i) {
3236
- result.probs .push_back ({
3237
- cur_p.data [i].id , llama_detokenize (ctx, {cur_p.data [i].id }, params.special ),
3238
- i >= n_valid ? 0 .0f : cur_p.data [i].p // Tokens filtered out due to e.g. top_k have 0 probability.
3239
- });
3240
- }
3241
- }
3237
+ if (slot.sparams .n_probs > 0 ) {
3238
+ populate_token_probs (slot, result, slot.params .post_sampling_probs , params.special , tok_idx);
3242
3239
}
3243
3240
3244
3241
if (!process_token (result, slot)) {
@@ -3343,7 +3340,11 @@ struct server_context {
3343
3340
3344
3341
result.tok = ids[i];
3345
3342
result.text_to_send = llama_token_to_piece (ctx, result.tok , params.special );
3346
- // result.prob = 1.0f; // set later
3343
+ result.prob = 1 .0f ; // set later
3344
+
3345
+ if (slot.sparams .n_probs > 0 ) {
3346
+ populate_token_probs (slot, result, slot.params .post_sampling_probs , params.special , i);
3347
+ }
3347
3348
3348
3349
if (!process_token (result, slot)) {
3349
3350
// release slot because of stop condition
0 commit comments