Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ enum error_type {
ERROR_TYPE_PERMISSION,
ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_TYPE_NOT_SUPPORTED, // custom error
ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
};

static bool server_task_type_need_embd(server_task_type task_type) {
Expand Down Expand Up @@ -1224,6 +1225,10 @@ static json format_error_response(const std::string & message, const enum error_
type_str = "unavailable_error";
code = 503;
break;
case ERROR_TYPE_EXCEED_CONTEXT_SIZE:
type_str = "exceed_context_size_error";
code = 400;
break;
}
return json {
{"code", code},
Expand All @@ -1237,12 +1242,21 @@ struct server_task_result_error : server_task_result {
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;

// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
int32_t n_prompt_tokens = 0;
int32_t n_ctx = 0;

virtual bool is_error() override {
return true;
}

virtual json to_json() override {
return format_error_response(err_msg, err_type);
json res = format_error_response(err_msg, err_type);
if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
res["n_prompt_tokens"] = n_prompt_tokens;
res["n_ctx"] = n_ctx;
}
return res;
}
};

Expand Down Expand Up @@ -2605,16 +2619,22 @@ struct server_context {
}

void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(slot.id_task, error, type);
send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx);
}

void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());

if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
}

auto res = std::make_unique<server_task_result_error>();
res->id = id_task;
res->err_type = type;
res->err_msg = error;
res->id = id_task;
res->err_type = type;
res->err_msg = error;
res->n_prompt_tokens = n_prompt_tokens;
res->n_ctx = n_ctx;

queue_results.send(std::move(res));
}
Expand Down Expand Up @@ -3286,7 +3306,7 @@ struct server_context {

if (slot.n_prompt_tokens > slot.n_ctx) {
slot.release();
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
continue;
}
} else {
Expand All @@ -3296,7 +3316,7 @@ struct server_context {
// context shift should be applied only during the generation phase
if (slot.n_prompt_tokens >= slot.n_ctx) {
slot.release();
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
continue;
}
}
Expand Down
17 changes: 17 additions & 0 deletions tools/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,20 @@ def test_logit_bias():
output_text = res.choices[0].message.content
assert output_text
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)

def test_context_size_exceeded():
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
] * 100, # make the prompt too long
})
assert res.status_code == 400
assert "error" in res.body
assert res.body["error"]["type"] == "exceed_context_size_error"
assert res.body["error"]["n_prompt_tokens"] > 0
assert server.n_ctx is not None
assert server.n_slots is not None
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
Loading