Skip to content

Server: openai-style lookup decoding #12127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 166 additions & 5 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ struct server_task {
// used by SERVER_TASK_TYPE_INFERENCE
slot_params params;
llama_tokens prompt_tokens;
llama_tokens prediction_tokens;
int id_selected_slot = -1;

// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
Expand Down Expand Up @@ -604,6 +605,7 @@ struct server_task_result_cmpl_final : server_task_result {
int32_t n_decoded;
int32_t n_prompt_tokens;
int32_t n_tokens_cached;
int32_t n_lookup_used;
bool has_new_line;
std::string stopping_word;
stop_type stop = STOP_TYPE_NONE;
Expand Down Expand Up @@ -660,6 +662,7 @@ struct server_task_result_cmpl_final : server_task_result {
{"stopping_word", stopping_word},
{"tokens_cached", n_tokens_cached},
{"timings", timings.to_json()},
{"prediction_tokens_accepted", n_lookup_used},
};
if (!stream && !probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
Expand Down Expand Up @@ -695,7 +698,10 @@ struct server_task_result_cmpl_final : server_task_result {
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens}
{"total_tokens", n_decoded + n_prompt_tokens},
{"completion_tokens_details", json {
{"accepted_prediction_tokens", n_lookup_used },
}}
}},
{"id", oaicompat_cmpl_id}
};
Expand Down Expand Up @@ -771,11 +777,14 @@ struct server_task_result_cmpl_final : server_task_result {
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens}
{"total_tokens", n_decoded + n_prompt_tokens},
{"completion_tokens_details", json {
{"accepted_prediction_tokens", n_lookup_used },
}}
}},
{"id", oaicompat_cmpl_id}
};

// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = to_json_non_oaicompat();
Expand Down Expand Up @@ -811,6 +820,9 @@ struct server_task_result_cmpl_final : server_task_result {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
{"completion_tokens_details", json {
{"accepted_prediction_tokens", n_lookup_used },
}}
}},
};

Expand Down Expand Up @@ -1235,16 +1247,22 @@ struct server_slot {
int32_t n_ctx = 0; // context size per slot
int32_t n_past = 0;
int32_t n_decoded = 0;
int32_t n_lookup_used = 0;
int32_t n_remaining = -1;
int32_t i_batch = -1;
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict

// for "predicted outputs"
int32_t lookup_n_adaptive = 1;
int32_t lookup_index = 0;

// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0;

// input prompt tokens
llama_tokens prompt_tokens;
llama_tokens prediction_tokens;

size_t last_nl_pos = 0;

Expand Down Expand Up @@ -1912,9 +1930,8 @@ struct server_context {
slot.n_ctx = n_ctx_slot;
slot.n_predict = params_base.n_predict;

slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);

slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
Expand Down Expand Up @@ -2034,6 +2051,7 @@ struct server_context {
slot.task_type = task.type;
slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);
slot.prediction_tokens = std::move(task.prediction_tokens);

if (!are_lora_equal(task.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens
Expand Down Expand Up @@ -2345,6 +2363,7 @@ struct server_context {
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->n_tokens_cached = slot.n_past;
res->n_lookup_used = slot.n_lookup_used;
res->has_new_line = slot.has_new_line;
res->stopping_word = slot.stopping_word;
res->stop = slot.stop;
Expand Down Expand Up @@ -3217,6 +3236,137 @@ struct server_context {
}
}

// apply "predicted outputs" i.e. user-specified speculation
// using a simple lookup decoding method
for (auto & slot : slots) {
// don't use lookup if we are also using a draft model
if (slot.can_speculate() || !slot.is_processing() || slot.prediction_tokens.size() < 2) {
continue;
}
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}

// adaptive speculation window:
// increase window size every time all drafted tokens were accepted,
// otherwise reset to zero
auto draft_start_pos = 1;
bool found = false;
// first look for a match from the expected position
SLT_DBG(slot, "Looking up prediction tokens at index %d/%d\n", (int) slot.lookup_index, (int) slot.prediction_tokens.size());
if (slot.lookup_index > 0 &&
slot.lookup_index < static_cast<int32_t>(slot.prediction_tokens.size()) &&
slot.prediction_tokens[slot.lookup_index-1] == slot.sampled) {
found = true;
draft_start_pos = slot.lookup_index;
// TODO what is a good scaling law here?
// going for too large windows too fast will likely fail,
// but also too small windows in the beginning hurt perf
slot.lookup_n_adaptive = std::max(16, slot.lookup_n_adaptive*2);
} else {
// find first match in prediction_tokens
slot.lookup_n_adaptive = 1; // default
for (; draft_start_pos < static_cast<int32_t>(slot.prediction_tokens.size()); draft_start_pos++) {
if (slot.prediction_tokens[draft_start_pos-1] == slot.sampled) {
found = true;
break;
}
}
}
if (!found) continue;

// we erase the accepted tokens later, so we're looking for the same position next time
// increment by one because the next token will be generated
slot.lookup_index = draft_start_pos + 1;

llama_tokens draft = std::vector(
slot.prediction_tokens.begin() + draft_start_pos,
slot.prediction_tokens.end()
);

// determine the max draft that fits the current slot state
int n_draft_max = slot.lookup_n_adaptive;
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);

if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
}

n_draft_max = std::min(n_draft_max, static_cast<int>(draft.size()));
// NOTE: we use speculative.n_max here as the upper limit, but
// in general we want to allow large drafts, as opposed to when
// using a draft model. But this is linked to `slot.batch_spec`
// size also.
n_draft_max = std::min(n_draft_max, slot.params.speculative.n_max);

SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);

draft.resize(n_draft_max);

llama_token id = slot.sampled;

// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);

for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}

llama_decode(ctx, slot.batch_spec);

// the accepted tokens from the speculation
// TODO can we stream these? Would be nice to reduce jankiness in UIs
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);

const auto n_accepted = ids.size() - 1;
slot.n_lookup_used += n_accepted;

if (n_accepted > 0) {
// remove the prediction tokens that were used + the next token
// (because it will be generated)
slot.prediction_tokens.erase(
slot.prediction_tokens.begin() + draft_start_pos,
std::min(
slot.prediction_tokens.end(),
slot.prediction_tokens.begin() + draft_start_pos + n_accepted + 1
)
);
if (n_accepted < draft.size()) {
// reset speculation as we didn't use the full draft
slot.lookup_n_adaptive = 1;
}
}

for (size_t i = 0; i < ids.size(); ++i) {
// NOTE: we need to update these here to avoid stopping early
slot.n_past++;
slot.n_decoded++;
completion_token_output result;

result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
result.prob = 1.0f; // set later

// TODO: set result.probs
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}

slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);

llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);

SLT_DBG(slot, "accepted %d/%d prediction tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
}

// do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
Expand Down Expand Up @@ -3838,10 +3988,17 @@ int main(int argc, char ** argv) {

try {
const auto & prompt = data.at("prompt");
const auto & prediction_obj = json_value(data, "prediction", json());
const auto & prediction = json_value(prediction_obj, "content", std::string());
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());

std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
std::vector<llama_tokens> tokenized_prediction;
if (!prediction.empty()) {
tokenized_prediction = tokenize_input_prompts(ctx_server.vocab, prediction, true, true);
}

tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type);
Expand All @@ -3850,6 +4007,10 @@ int main(int argc, char ** argv) {
task.index = i;

task.prompt_tokens = std::move(tokenized_prompts[i]);

if (!tokenized_prediction.empty()) {
task.prediction_tokens = std::vector(tokenized_prediction[0].begin(), tokenized_prediction[0].end());
}
task.params = server_task::params_from_json_cmpl(
ctx_server.ctx,
ctx_server.params_base,
Expand Down
62 changes: 62 additions & 0 deletions examples/server/tests/unit/test_predicted_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from utils import *


@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.draft_max = 1024
server.debug = True


def test_with_and_without_prediced_outputs():
global server
server.start()
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 0
content_no_pred = res.body["choices"][0]["message"]["content"]
server.stop()

server.start()
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
"temperature": 0.0,
"top_k": 1,
"prediction": {"content": '''"Here?" Annabyed.
"Okay, Annabyes!" Annabyed.
As Annagged, Annap came and said,'''}
})
assert res.status_code == 200
assert res.body["usage"]["completion_tokens_details"]["accepted_prediction_tokens"] == 54
content_pred = res.body["choices"][0]["message"]["content"]
server.stop()

assert content_no_pred == content_pred


@pytest.mark.parametrize("n_slots,n_requests", [
(1, 2),
(2, 2),
])
def test_multi_requests_parallel(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.start()
tasks = []
for _ in range(n_requests):
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [{"role": "user", "content": "I believe the meaning of life is"}],
"temperature": 0.0,
"top_k": 1,
"prediction": {"content": " believe the meaning of life is"}
})
results = parallel_function_calls(tasks)
for res in results:
assert res.status_code == 200
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])