Skip to content

Commit 9ca7808

Browse files
65a65a
authored andcommitted
server: Support multimodal completion prompts in JSON format
- Use server_tokens in more places in server and util.cpp - Convert most functions that used llama_tokens to server_tokens - Modify input tokenizer to handle JSON objects as subprompts - Break out MTMD prompt parsing into utility function - Support JSON objects with multimodal_data arrays for MTMD prompts along with other existing types - Add tests
1 parent 9515c61 commit 9ca7808

File tree

4 files changed

+215
-127
lines changed

4 files changed

+215
-127
lines changed

tools/server/README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ services:
226226
### Multimodal support
227227
228228
Multimodal support was added in [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) and is currently an experimental feature.
229+
It is currently available in the non-OAI-compatible completion endpoint, and the OAI-compatible chat endpoint.
229230
230231
For more details, please refer to [multimodal documentation](../../docs/multimodal.md)
231232
@@ -400,12 +401,15 @@ These input shapes and data type are allowed for `prompt`:
400401
- Single string: `"string"`
401402
- Single sequence of tokens: `[12, 34, 56]`
402403
- Mixed tokens and strings: `[12, 34, "string", 56, 78]`
404+
- A JSON object which optionally contains multimodal data: `{ "prompt": "string", "multimodal_data": ["base64"] }`
403405

404406
Multiple prompts are also supported. In this case, the completion result will be an array.
405407

406408
- Only strings: `["string1", "string2"]`
407-
- Strings and sequences of tokens: `["string1", [12, 34, 56]]`
408-
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
409+
- Strings, JSON objects, and sequences of tokens: `["string1", [12, 34, 56], { "prompt": "string", "multimodal_data": ["base64"]}]`
410+
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string", { "prompt": "string" }]`
411+
412+
Note for `multimodal_data` in JSON object prompts. This should be an array of strings, containing base64 encoded multimodal data such as images and audio. There must be an identical number of MTMD media markers in the string prompt element which act as placeholders for the data provided to this parameter. The multimodal data files will be substituted in order. The marker string (e.g. `<__media__>`) can be found by calling `mtmd_default_marker()` defined in [the MTMD C API](https://github.com/ggml-org/llama.cpp/blob/5fd160bbd9d70b94b5b11b0001fd7f477005e4a0/tools/mtmd/mtmd.h#L87).
409413

410414
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
411415

tools/server/server.cpp

Lines changed: 18 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4228,56 +4228,15 @@ int main(int argc, char ** argv) {
42284228
// TODO: this log can become very long, put it behind a flag or think about a more compact format
42294229
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
42304230

4231-
// process files
4232-
mtmd::bitmaps bitmaps;
4233-
const bool has_mtmd = ctx_server.mctx != nullptr;
4234-
{
4235-
if (!has_mtmd && !files.empty()) {
4236-
throw std::runtime_error("This server does not support multimodal");
4237-
}
4238-
for (auto & file : files) {
4239-
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, file.data(), file.size()));
4240-
if (!bmp.ptr) {
4241-
throw std::runtime_error("Failed to load image or audio file");
4242-
}
4243-
// calculate bitmap hash (for KV caching)
4244-
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
4245-
bmp.set_id(hash.c_str());
4246-
bitmaps.entries.push_back(std::move(bmp));
4247-
}
4248-
}
4249-
42504231
// process prompt
42514232
std::vector<server_tokens> inputs;
42524233

4253-
if (oaicompat && has_mtmd) {
4254-
// multimodal
4255-
std::string prompt_str = prompt.get<std::string>();
4256-
mtmd_input_text inp_txt = {
4257-
prompt_str.c_str(),
4258-
/* add_special */ true,
4259-
/* parse_special */ true,
4260-
};
4261-
mtmd::input_chunks chunks(mtmd_input_chunks_init());
4262-
auto bitmaps_c_ptr = bitmaps.c_ptr();
4263-
int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
4264-
chunks.ptr.get(),
4265-
&inp_txt,
4266-
bitmaps_c_ptr.data(),
4267-
bitmaps_c_ptr.size());
4268-
if (tokenized != 0) {
4269-
throw std::runtime_error("Failed to tokenize prompt");
4270-
}
4271-
4272-
server_tokens tmp(chunks, true);
4273-
inputs.push_back(std::move(tmp));
4234+
if (oaicompat && ctx_server.mctx != nullptr) {
4235+
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
4236+
inputs.push_back(std::move(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files)));
42744237
} else {
4275-
// non-multimodal version
4276-
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
4277-
for (auto & p : tokenized_prompts) {
4278-
auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
4279-
inputs.push_back(std::move(tmp));
4280-
}
4238+
// Everything else, including multimodal completions.
4239+
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
42814240
}
42824241

42834242
tasks.reserve(inputs.size());
@@ -4369,7 +4328,12 @@ int main(int argc, char ** argv) {
43694328

43704329
const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
43714330
json data = json::parse(req.body);
4372-
std::vector<raw_buffer> files; // dummy
4331+
std::vector<raw_buffer> files;
4332+
if (data.find("multimodal_data") != data.end()) {
4333+
for (const auto& entry : data.at("multimodal_data")) {
4334+
files.push_back(base64_decode(entry));
4335+
}
4336+
}
43734337
handle_completions_impl(
43744338
SERVER_TASK_TYPE_COMPLETION,
43754339
data,
@@ -4446,7 +4410,7 @@ int main(int argc, char ** argv) {
44464410
data["input_extra"] = input_extra; // default to empty array if it's not exist
44474411

44484412
std::string prompt = json_value(data, "prompt", std::string());
4449-
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
4413+
std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
44504414
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
44514415
data["prompt"] = format_infill(
44524416
ctx_server.vocab,
@@ -4457,7 +4421,7 @@ int main(int argc, char ** argv) {
44574421
ctx_server.params_base.n_predict,
44584422
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
44594423
ctx_server.params_base.spm_infill,
4460-
tokenized_prompts[0]
4424+
tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
44614425
);
44624426

44634427
std::vector<raw_buffer> files; // dummy
@@ -4635,7 +4599,7 @@ int main(int argc, char ** argv) {
46354599
}
46364600
}
46374601

4638-
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
4602+
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
46394603
for (const auto & tokens : tokenized_prompts) {
46404604
// this check is necessary for models that do not add BOS token to the input
46414605
if (tokens.empty()) {
@@ -4663,7 +4627,7 @@ int main(int argc, char ** argv) {
46634627

46644628
task.id = ctx_server.queue_tasks.get_new_id();
46654629
task.index = i;
4666-
task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
4630+
task.prompt_tokens = std::move(tokenized_prompts[i]);
46674631

46684632
// OAI-compat
46694633
task.params.oaicompat = oaicompat;
@@ -4750,22 +4714,22 @@ int main(int argc, char ** argv) {
47504714
return;
47514715
}
47524716

4753-
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
4717+
server_tokens tokenized_query = std::move(tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true)[0]);
47544718

47554719
// create and queue the task
47564720
json responses = json::array();
47574721
bool error = false;
47584722
std::unordered_set<int> task_ids;
47594723
{
47604724
std::vector<server_task> tasks;
4761-
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
4725+
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
47624726
tasks.reserve(tokenized_docs.size());
47634727
for (size_t i = 0; i < tokenized_docs.size(); i++) {
47644728
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
47654729
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
47664730
task.id = ctx_server.queue_tasks.get_new_id();
47674731
task.index = i;
4768-
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
4732+
task.prompt_tokens = std::move(tmp);
47694733
tasks.push_back(std::move(task));
47704734
}
47714735

tools/server/tests/unit/test_completion.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,27 @@ def test_nocache_long_input_prompt():
231231
})
232232
assert res.status_code == 200
233233

234+
def test_nocache_json_prompt():
235+
global server
236+
server.start()
237+
res = server.make_request("POST", "/completion", data={
238+
"prompt": { "prompt": "I believe the meaning of life is" },
239+
"seed": 42,
240+
"temperature": 1.0,
241+
"cache_prompt": False,
242+
})
243+
assert res.status_code == 200
244+
245+
def test_nocache_multimodal_prompt():
246+
global server
247+
server.start()
248+
res = server.make_request("POST", "/completion", data={
249+
"prompt": { "prompt": "I believe the meaning of life is <__media__>", "multimodal_data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
250+
"seed": 42,
251+
"temperature": 1.0,
252+
"cache_prompt": False,
253+
})
254+
assert res.status_code == 200
234255

235256
def test_completion_with_tokens_input():
236257
global server
@@ -269,6 +290,20 @@ def test_completion_with_tokens_input():
269290
assert len(res.body) == 2
270291
assert res.body[0]["content"] == res.body[1]["content"]
271292

293+
# mixed multimodal and tokens works. Does not assert equality.
294+
res = server.make_request("POST", "/completion", data={
295+
"prompt": [
296+
tokens,
297+
{
298+
"prompt": "Here is my photo: <__media__>",
299+
"multimodal_data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
300+
},
301+
],
302+
})
303+
assert res.status_code == 200
304+
assert type(res.body) == list
305+
assert len(res.body) == 2
306+
272307
# mixed string and tokens in one sequence
273308
res = server.make_request("POST", "/completion", data={
274309
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],

0 commit comments

Comments
 (0)