Skip to content

Commit 62f3bae

Browse files
65a65a
authored andcommitted
server: Support multimodal completion prompts in JSON format
- Use server_tokens in more places in server and util.cpp - Convert most functions that used llama_tokens to server_tokens - Modify input tokenizer to handle JSON objects as subprompts - Break out MTMD prompt parsing into utility function - Support JSON objects with multimodal_data arrays for MTMD prompts along with other existing types - Add tests
1 parent 9515c61 commit 62f3bae

File tree

5 files changed

+258
-127
lines changed

5 files changed

+258
-127
lines changed

tools/server/README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ services:
226226
### Multimodal support
227227
228228
Multimodal support was added in [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) and is currently an experimental feature.
229+
It is currently available in the non-OAI-compatible completion endpoint, and the OAI-compatible chat endpoint.
229230
230231
For more details, please refer to [multimodal documentation](../../docs/multimodal.md)
231232
@@ -400,12 +401,15 @@ These input shapes and data type are allowed for `prompt`:
400401
- Single string: `"string"`
401402
- Single sequence of tokens: `[12, 34, 56]`
402403
- Mixed tokens and strings: `[12, 34, "string", 56, 78]`
404+
- A JSON object which optionally contains multimodal data: `{ "prompt": "string", "multimodal_data": ["base64"] }`
403405

404406
Multiple prompts are also supported. In this case, the completion result will be an array.
405407

406408
- Only strings: `["string1", "string2"]`
407-
- Strings and sequences of tokens: `["string1", [12, 34, 56]]`
408-
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]`
409+
- Strings, JSON objects, and sequences of tokens: `["string1", [12, 34, 56], { "prompt": "string", "multimodal_data": ["base64"]}]`
410+
- Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string", { "prompt": "string" }]`
411+
412+
Note for `multimodal_data` in JSON object prompts. This should be an array of strings, containing base64 encoded multimodal data such as images and audio. There must be an identical number of MTMD media markers in the string prompt element which act as placeholders for the data provided to this parameter. The multimodal data files will be substituted in order. The marker string (e.g. `<__media__>`) can be found by calling `mtmd_default_marker()` defined in [the MTMD C API](https://github.com/ggml-org/llama.cpp/blob/5fd160bbd9d70b94b5b11b0001fd7f477005e4a0/tools/mtmd/mtmd.h#L87).
409413

410414
`temperature`: Adjust the randomness of the generated text. Default: `0.8`
411415

tools/server/server.cpp

Lines changed: 13 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4228,56 +4228,15 @@ int main(int argc, char ** argv) {
42284228
// TODO: this log can become very long, put it behind a flag or think about a more compact format
42294229
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
42304230

4231-
// process files
4232-
mtmd::bitmaps bitmaps;
4233-
const bool has_mtmd = ctx_server.mctx != nullptr;
4234-
{
4235-
if (!has_mtmd && !files.empty()) {
4236-
throw std::runtime_error("This server does not support multimodal");
4237-
}
4238-
for (auto & file : files) {
4239-
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, file.data(), file.size()));
4240-
if (!bmp.ptr) {
4241-
throw std::runtime_error("Failed to load image or audio file");
4242-
}
4243-
// calculate bitmap hash (for KV caching)
4244-
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
4245-
bmp.set_id(hash.c_str());
4246-
bitmaps.entries.push_back(std::move(bmp));
4247-
}
4248-
}
4249-
42504231
// process prompt
42514232
std::vector<server_tokens> inputs;
42524233

4253-
if (oaicompat && has_mtmd) {
4254-
// multimodal
4255-
std::string prompt_str = prompt.get<std::string>();
4256-
mtmd_input_text inp_txt = {
4257-
prompt_str.c_str(),
4258-
/* add_special */ true,
4259-
/* parse_special */ true,
4260-
};
4261-
mtmd::input_chunks chunks(mtmd_input_chunks_init());
4262-
auto bitmaps_c_ptr = bitmaps.c_ptr();
4263-
int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
4264-
chunks.ptr.get(),
4265-
&inp_txt,
4266-
bitmaps_c_ptr.data(),
4267-
bitmaps_c_ptr.size());
4268-
if (tokenized != 0) {
4269-
throw std::runtime_error("Failed to tokenize prompt");
4270-
}
4271-
4272-
server_tokens tmp(chunks, true);
4273-
inputs.push_back(std::move(tmp));
4234+
if (oaicompat && ctx_server.mctx != nullptr) {
4235+
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
4236+
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
42744237
} else {
4275-
// non-multimodal version
4276-
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
4277-
for (auto & p : tokenized_prompts) {
4278-
auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
4279-
inputs.push_back(std::move(tmp));
4280-
}
4238+
// Everything else, including multimodal completions.
4239+
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
42814240
}
42824241

42834242
tasks.reserve(inputs.size());
@@ -4369,7 +4328,7 @@ int main(int argc, char ** argv) {
43694328

43704329
const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
43714330
json data = json::parse(req.body);
4372-
std::vector<raw_buffer> files; // dummy
4331+
std::vector<raw_buffer> files; //dummy
43734332
handle_completions_impl(
43744333
SERVER_TASK_TYPE_COMPLETION,
43754334
data,
@@ -4446,7 +4405,7 @@ int main(int argc, char ** argv) {
44464405
data["input_extra"] = input_extra; // default to empty array if it's not exist
44474406

44484407
std::string prompt = json_value(data, "prompt", std::string());
4449-
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
4408+
std::vector<server_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true);
44504409
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
44514410
data["prompt"] = format_infill(
44524411
ctx_server.vocab,
@@ -4457,7 +4416,7 @@ int main(int argc, char ** argv) {
44574416
ctx_server.params_base.n_predict,
44584417
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
44594418
ctx_server.params_base.spm_infill,
4460-
tokenized_prompts[0]
4419+
tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal.
44614420
);
44624421

44634422
std::vector<raw_buffer> files; // dummy
@@ -4635,7 +4594,7 @@ int main(int argc, char ** argv) {
46354594
}
46364595
}
46374596

4638-
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
4597+
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
46394598
for (const auto & tokens : tokenized_prompts) {
46404599
// this check is necessary for models that do not add BOS token to the input
46414600
if (tokens.empty()) {
@@ -4663,7 +4622,7 @@ int main(int argc, char ** argv) {
46634622

46644623
task.id = ctx_server.queue_tasks.get_new_id();
46654624
task.index = i;
4666-
task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
4625+
task.prompt_tokens = std::move(tokenized_prompts[i]);
46674626

46684627
// OAI-compat
46694628
task.params.oaicompat = oaicompat;
@@ -4750,22 +4709,22 @@ int main(int argc, char ** argv) {
47504709
return;
47514710
}
47524711

4753-
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
4712+
server_tokens tokenized_query = std::move(tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true)[0]);
47544713

47554714
// create and queue the task
47564715
json responses = json::array();
47574716
bool error = false;
47584717
std::unordered_set<int> task_ids;
47594718
{
47604719
std::vector<server_task> tasks;
4761-
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
4720+
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
47624721
tasks.reserve(tokenized_docs.size());
47634722
for (size_t i = 0; i < tokenized_docs.size(); i++) {
47644723
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
47654724
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
47664725
task.id = ctx_server.queue_tasks.get_new_id();
47674726
task.index = i;
4768-
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
4727+
task.prompt_tokens = std::move(tmp);
47694728
tasks.push_back(std::move(task));
47704729
}
47714730

tools/server/tests/unit/test_completion.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,27 @@ def test_nocache_long_input_prompt():
231231
})
232232
assert res.status_code == 200
233233

234+
def test_nocache_json_prompt():
235+
global server
236+
server.start()
237+
res = server.make_request("POST", "/completion", data={
238+
"prompt": { "prompt": "I believe the meaning of life is" },
239+
"seed": 42,
240+
"temperature": 1.0,
241+
"cache_prompt": False,
242+
})
243+
assert res.status_code == 200
244+
245+
def test_nocache_multimodal_prompt():
246+
global server
247+
server.start()
248+
res = server.make_request("POST", "/completion", data={
249+
"prompt": { "prompt": "I believe the meaning of life is <__media__>", "multimodal_data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
250+
"seed": 42,
251+
"temperature": 1.0,
252+
"cache_prompt": False,
253+
})
254+
assert res.status_code == 200
234255

235256
def test_completion_with_tokens_input():
236257
global server
@@ -269,6 +290,20 @@ def test_completion_with_tokens_input():
269290
assert len(res.body) == 2
270291
assert res.body[0]["content"] == res.body[1]["content"]
271292

293+
# mixed multimodal and tokens works. Does not assert equality.
294+
res = server.make_request("POST", "/completion", data={
295+
"prompt": [
296+
tokens,
297+
{
298+
"prompt": "Here is my photo: <__media__>",
299+
"multimodal_data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
300+
},
301+
],
302+
})
303+
assert res.status_code == 200
304+
assert type(res.body) == list
305+
assert len(res.body) == 2
306+
272307
# mixed string and tokens in one sequence
273308
res = server.make_request("POST", "/completion", data={
274309
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
from utils import *
3+
import base64
4+
import requests
5+
6+
server: ServerProcess
7+
8+
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
9+
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
10+
11+
response = requests.get(IMG_URL_0)
12+
response.raise_for_status() # Raise an exception for bad status codes
13+
IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")
14+
response = requests.get(IMG_URL_1)
15+
response.raise_for_status() # Raise an exception for bad status codes
16+
IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
17+
18+
@pytest.fixture(autouse=True)
19+
def create_server():
20+
global server
21+
server = ServerPreset.tinygemma3()
22+
23+
24+
@pytest.mark.parametrize(
25+
"prompt, image_data, success, re_content",
26+
[
27+
# test model is trained on CIFAR-10, but it's quite dumb due to small size
28+
("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"), # exceptional, so that we don't cog up the log
29+
("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"),
30+
("What is this: <__media__>\n", "malformed", False, None),
31+
("What is this:\n", "base64", False, None), # non-image data
32+
]
33+
)
34+
def test_vision_completion(prompt, image_data, success, re_content):
35+
global server
36+
server.start(timeout_seconds=60) # vision model may take longer to load due to download size
37+
res = server.make_request("POST", "/completions", data={
38+
"temperature": 0.0,
39+
"top_k": 1,
40+
"prompt": { "prompt": prompt, "multimodal_data": [ image_data ] },
41+
})
42+
if success:
43+
assert res.status_code == 200
44+
content = res.body["content"]
45+
assert match_regex(re_content, content)
46+
else:
47+
assert res.status_code != 200
48+

0 commit comments

Comments
 (0)