Skip to content

Commit a68d914

Browse files
authored
server: add exceed_context_size_error type (#15780)
* server: add exceed_context_size_error type * change error code to 400
1 parent badb80c commit a68d914

File tree

2 files changed

+45
-8
lines changed

2 files changed

+45
-8
lines changed

tools/server/server.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ enum error_type {
8686
ERROR_TYPE_PERMISSION,
8787
ERROR_TYPE_UNAVAILABLE, // custom error
8888
ERROR_TYPE_NOT_SUPPORTED, // custom error
89+
ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
8990
};
9091

9192
static bool server_task_type_need_embd(server_task_type task_type) {
@@ -1224,6 +1225,10 @@ static json format_error_response(const std::string & message, const enum error_
12241225
type_str = "unavailable_error";
12251226
code = 503;
12261227
break;
1228+
case ERROR_TYPE_EXCEED_CONTEXT_SIZE:
1229+
type_str = "exceed_context_size_error";
1230+
code = 400;
1231+
break;
12271232
}
12281233
return json {
12291234
{"code", code},
@@ -1237,12 +1242,21 @@ struct server_task_result_error : server_task_result {
12371242
error_type err_type = ERROR_TYPE_SERVER;
12381243
std::string err_msg;
12391244

1245+
// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
1246+
int32_t n_prompt_tokens = 0;
1247+
int32_t n_ctx = 0;
1248+
12401249
virtual bool is_error() override {
12411250
return true;
12421251
}
12431252

12441253
virtual json to_json() override {
1245-
return format_error_response(err_msg, err_type);
1254+
json res = format_error_response(err_msg, err_type);
1255+
if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
1256+
res["n_prompt_tokens"] = n_prompt_tokens;
1257+
res["n_ctx"] = n_ctx;
1258+
}
1259+
return res;
12461260
}
12471261
};
12481262

@@ -2605,16 +2619,22 @@ struct server_context {
26052619
}
26062620

26072621
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
2608-
send_error(slot.id_task, error, type);
2622+
send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx);
26092623
}
26102624

2611-
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
2625+
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
26122626
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
26132627

2628+
if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
2629+
GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
2630+
}
2631+
26142632
auto res = std::make_unique<server_task_result_error>();
2615-
res->id = id_task;
2616-
res->err_type = type;
2617-
res->err_msg = error;
2633+
res->id = id_task;
2634+
res->err_type = type;
2635+
res->err_msg = error;
2636+
res->n_prompt_tokens = n_prompt_tokens;
2637+
res->n_ctx = n_ctx;
26182638

26192639
queue_results.send(std::move(res));
26202640
}
@@ -3286,7 +3306,7 @@ struct server_context {
32863306

32873307
if (slot.n_prompt_tokens > slot.n_ctx) {
32883308
slot.release();
3289-
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
3309+
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
32903310
continue;
32913311
}
32923312
} else {
@@ -3296,7 +3316,7 @@ struct server_context {
32963316
// context shift should be applied only during the generation phase
32973317
if (slot.n_prompt_tokens >= slot.n_ctx) {
32983318
slot.release();
3299-
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
3319+
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
33003320
continue;
33013321
}
33023322
}

tools/server/tests/unit/test_chat_completion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,3 +385,20 @@ def test_logit_bias():
385385
output_text = res.choices[0].message.content
386386
assert output_text
387387
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
388+
389+
def test_context_size_exceeded():
390+
global server
391+
server.start()
392+
res = server.make_request("POST", "/chat/completions", data={
393+
"messages": [
394+
{"role": "system", "content": "Book"},
395+
{"role": "user", "content": "What is the best book"},
396+
] * 100, # make the prompt too long
397+
})
398+
assert res.status_code == 400
399+
assert "error" in res.body
400+
assert res.body["error"]["type"] == "exceed_context_size_error"
401+
assert res.body["error"]["n_prompt_tokens"] > 0
402+
assert server.n_ctx is not None
403+
assert server.n_slots is not None
404+
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots

0 commit comments

Comments
 (0)