From 183029d88f4ab8c3d0747125e5fdf99488ff3254 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 31 Jan 2025 14:37:23 -0400 Subject: [PATCH 01/69] Add tools option to llama-cli --- common/arg.cpp | 9 +++++++++ common/common.cpp | 27 ++++++++++++++++++++++----- common/common.h | 8 +++++--- examples/main/main.cpp | 19 +++++++++++++------ 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index f5e9b294f3048..b916dcfccbf93 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1993,6 +1993,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::back_inserter(params.chat_template)); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( + {"--tools"}, "JINJA_TOOLS", + string_format( + "set to a JSON array of tool definitions used for assistant function-calling " + "(requires --jinja)"), + [](common_params ¶ms, const std::string & value) { + params.jinja_tools = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA_TOOLS")); add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.cpp b/common/common.cpp index 6c81d18f91c43..de4a529050a17 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1797,13 +1797,27 @@ std::string common_chat_apply_template( const common_chat_template & tmpl, const std::vector & msgs, bool add_ass, - bool use_jinja) { + bool use_jinja, + std::string tools_json_arr) +{ if (use_jinja) { + common_chat_inputs inputs; + auto messages = json::array(); for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } - common_chat_inputs inputs; + + if (! tools_json_arr.empty()) { + try { + inputs.tools = tools_json_arr; + + } catch (const json::exception & err) { + LOG_WRN("Failed to parse tools JSON array \"%s\": \"%s\". Ignoring tools...\n", + tools_json_arr.c_str(), err.what()); + } + } + inputs.messages = messages; inputs.add_generation_prompt = add_ass; return common_chat_params_init(tmpl, inputs).prompt; @@ -1843,9 +1857,13 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja) { + bool use_jinja, + std::string tools_json_arr) +{ std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); + auto fmt_past_msg = past_msg.empty() ? "" + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools_json_arr); + std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -2182,4 +2200,3 @@ common_control_vector_data common_control_vector_load(const std::vector api_keys; std::string ssl_file_key = ""; // NOLINT @@ -645,7 +645,8 @@ std::string common_chat_apply_template( const common_chat_template & tmpl, const std::vector & chat, bool add_ass, - bool use_jinja); + bool use_jinja, + std::string tools_json_arr = std::string()); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -653,7 +654,8 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja); + bool use_jinja, + std::string tools_json_arr = std::string()); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e654d3542c6c3..dd0f65bfe282d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,20 +263,27 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { + auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, + const std::string & content, + const std::string & tools = std::string()) { common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); + auto formatted = common_chat_format_single(*chat_templates.template_default, + chat_msgs, new_msg, + role == "user", + g_params->use_jinja, tools); + chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; }; { - auto prompt = (params.conversation_mode && params.enable_chat_template) - // format the system prompt in conversation mode (fallback to default if empty) - ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) - // otherwise use the prompt as is + std::string system_prompt (params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt); + bool use_conversation_prompt (params.conversation_mode && params.enable_chat_template); + auto prompt = use_conversation_prompt ? + chat_add_and_format("system", system_prompt, params.jinja_tools) : params.prompt; + if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); embd_inp = common_tokenize(ctx, prompt, true, true); From 4ad82587eeccd72206845589c2056993f8b8ad51 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 31 Jan 2025 17:57:00 -0400 Subject: [PATCH 02/69] tools_json_arr now properly passed to apply-template --- common/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index de4a529050a17..e546f78c21da8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1871,7 +1871,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools_json_arr); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); From becf9b4003809ba8e9e8d8b1f2daab665e218953 Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 4 Feb 2025 11:40:29 -0400 Subject: [PATCH 03/69] add tool-choice parameter --- common/arg.cpp | 20 +++++++++---- common/common.cpp | 65 ++++++++++++++++++++++++++++++++---------- common/common.h | 35 +++++++++++++++++++++-- examples/main/main.cpp | 14 ++++----- 4 files changed, 103 insertions(+), 31 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index b916dcfccbf93..dd0cc51d9afb8 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1993,15 +1993,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::back_inserter(params.chat_template)); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( {"--tools"}, "JINJA_TOOLS", - string_format( - "set to a JSON array of tool definitions used for assistant function-calling " - "(requires --jinja)"), + "set to JSON array of tool definitions used for assistant function-calling (requires --jinja)", [](common_params ¶ms, const std::string & value) { - params.jinja_tools = value; - } - ).set_examples({LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA_TOOLS")); + params.jinja_tools.tools(value); + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + + add_opt(common_arg( + {"--tool-choice"}, "JINJA_TOOL_CHOICE", + "set to \"auto\", \"required\", \"none\" or a JSON object specifying a tool function (default: \"auto\")", + [](common_params ¶ms, const std::string & value) { + params.jinja_tools.choice(value); + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.cpp b/common/common.cpp index e546f78c21da8..d9657bc859ca0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -7,9 +7,6 @@ #include "common.h" #include "log.h" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" #include "chat.hpp" @@ -1772,6 +1769,42 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto // Chat template utils // +common_params_tools::common_params_tools(std::string tools, std::string choice) { + this->tools(tools); + this->choice(choice); +} + +void common_params_tools::tools(std::string tools) { + try { + tools_ = std::make_shared(json::parse(tools)); + if (! tools_->is_array()) { + throw std::invalid_argument("tools must be a valid JSON array"); + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +void common_params_tools::choice(std::string choice) { + try { + if (choice == "auto" || choice == "required" || choice == "none") { + tool_choice_ = std::move(choice); + + } else { + auto choice_ptr = std::make_shared(json::parse(choice)); + tool_choice_ = choice_ptr; + if (! choice_ptr->is_object()) { + throw std::invalid_argument( + "tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\""); + } + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -1798,7 +1831,7 @@ std::string common_chat_apply_template( const std::vector & msgs, bool add_ass, bool use_jinja, - std::string tools_json_arr) + const common_params_tools & tools) { if (use_jinja) { common_chat_inputs inputs; @@ -1807,17 +1840,19 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } + if (tools.tools() != nullptr) { + inputs.tools = *tools.tools(); + } + auto choice = tools.choice(); + if (std::holds_alternative(choice)) { + inputs.tool_choice = std::get(choice); - if (! tools_json_arr.empty()) { - try { - inputs.tools = tools_json_arr; - - } catch (const json::exception & err) { - LOG_WRN("Failed to parse tools JSON array \"%s\": \"%s\". Ignoring tools...\n", - tools_json_arr.c_str(), err.what()); + } else { + auto choice_ptr = std::get(choice); + if (choice_ptr != nullptr) { + inputs.tool_choice = *choice_ptr; } } - inputs.messages = messages; inputs.add_generation_prompt = add_ass; return common_chat_params_init(tmpl, inputs).prompt; @@ -1858,11 +1893,11 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - std::string tools_json_arr) + const common_params_tools & tools) { std::ostringstream ss; auto fmt_past_msg = past_msg.empty() ? "" - : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools_json_arr); + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -1871,7 +1906,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools_json_arr); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index 5058d654c9d91..a1925d7747da9 100644 --- a/common/common.h +++ b/common/common.h @@ -8,6 +8,9 @@ #include #include #include +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "json.hpp" #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -202,6 +205,31 @@ struct common_params_vocoder { bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; +class common_params_tools { +public: + using json = nlohmann::ordered_json; + using json_ptr = std::shared_ptr; + using tool_choice_t = std::variant; + + common_params_tools(std::string tools = "", + std::string choice = "auto"); + + common_params_tools(const common_params_tools & other) = default; + common_params_tools(common_params_tools && other) noexcept = default; + common_params_tools & operator=(const common_params_tools & other) = default; + common_params_tools & operator=(common_params_tools && other) noexcept = default; + + void tools(std::string tools); + const json * tools() const { return tools_.get(); } + + void choice(std::string choice); + const tool_choice_t & choice() const { return tool_choice_; } + +private: + json_ptr tools_; + tool_choice_t tool_choice_; +}; + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -346,7 +374,8 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; - std::string jinja_tools = ""; + common_params_tools jinja_tools; + std::vector api_keys; std::string ssl_file_key = ""; // NOLINT @@ -649,7 +678,7 @@ std::string common_chat_apply_template( const std::vector & chat, bool add_ass, bool use_jinja, - std::string tools_json_arr = std::string()); + const common_params_tools & tools = common_params_tools()); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -658,7 +687,7 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - std::string tools_json_arr = std::string()); + const common_params_tools & tools = common_params_tools()); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index dd0f65bfe282d..d19a24331b6f8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,14 +263,14 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, - const std::string & content, - const std::string & tools = std::string()) { + auto chat_add_and_format = [&chat_msgs, &chat_templates]( + const std::string & role, const std::string & content, + const common_params_tools & tools = common_params_tools()) + { common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, - chat_msgs, new_msg, - role == "user", - g_params->use_jinja, tools); + auto formatted = common_chat_format_single( + *chat_templates.template_default, chat_msgs, new_msg, role == "user", + g_params->use_jinja, tools); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); From cd16957f7d938bd5a334d2c23e4d1a7e604436cf Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 4 Feb 2025 11:54:48 -0400 Subject: [PATCH 04/69] Add variant include --- common/common.h | 1 + 1 file changed, 1 insertion(+) diff --git a/common/common.h b/common/common.h index a1925d7747da9..417eb546cc739 100644 --- a/common/common.h +++ b/common/common.h @@ -8,6 +8,7 @@ #include #include #include +#include // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" From 4e8beb0c533949e7433375444344f3777deae1ed Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 4 Feb 2025 13:21:28 -0400 Subject: [PATCH 05/69] Reset tools when empty string provided --- common/common.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index d9657bc859ca0..de2adba1fd999 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1775,6 +1775,10 @@ common_params_tools::common_params_tools(std::string tools, std::string choice) } void common_params_tools::tools(std::string tools) { + if (tools.empty()) { + tools_.reset(); + return; + } try { tools_ = std::make_shared(json::parse(tools)); if (! tools_->is_array()) { From 34370803fbbb84a25ca1763831aba1e45d7f3dfd Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 4 Feb 2025 15:00:58 -0400 Subject: [PATCH 06/69] Pass template group to common_chat_apply_template --- common/common.cpp | 18 ++++++++++++------ common/common.h | 6 +++--- examples/main/main.cpp | 8 ++++---- examples/server/server.cpp | 2 +- examples/server/utils.hpp | 4 ++-- tests/test-chat-template.cpp | 6 ++++-- 6 files changed, 26 insertions(+), 18 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index de2adba1fd999..8d6e8b0cb18d3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1831,12 +1831,15 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { } std::string common_chat_apply_template( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & msgs, bool add_ass, bool use_jinja, const common_params_tools & tools) { + const auto & tmpl_selected = + tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; + if (use_jinja) { common_chat_inputs inputs; @@ -1844,9 +1847,11 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } + if (tools.tools() != nullptr) { inputs.tools = *tools.tools(); } + auto choice = tools.choice(); if (std::holds_alternative(choice)) { inputs.tool_choice = std::get(choice); @@ -1857,9 +1862,10 @@ std::string common_chat_apply_template( inputs.tool_choice = *choice_ptr; } } + inputs.messages = messages; inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl, inputs).prompt; + return common_chat_params_init(tmpl_selected, inputs).prompt; } int alloc_size = 0; @@ -1872,7 +1878,7 @@ std::string common_chat_apply_template( std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -1884,7 +1890,7 @@ std::string common_chat_apply_template( // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } std::string formatted_chat(buf.data(), res); @@ -1892,7 +1898,7 @@ std::string common_chat_apply_template( } std::string common_chat_format_single( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, @@ -1916,7 +1922,7 @@ std::string common_chat_format_single( return ss.str(); } -std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { +std::string common_chat_format_example(const common_chat_templates & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant", {}}, {"user", "Hello", {}}, diff --git a/common/common.h b/common/common.h index 417eb546cc739..c3f4f66ee1f63 100644 --- a/common/common.h +++ b/common/common.h @@ -675,7 +675,7 @@ struct common_chat_templates { // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error std::string common_chat_apply_template( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & chat, bool add_ass, bool use_jinja, @@ -683,7 +683,7 @@ std::string common_chat_apply_template( // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, @@ -692,7 +692,7 @@ std::string common_chat_format_single( // Returns an example of formatted chat std::string common_chat_format_example( - const common_chat_template & tmpl, bool use_jinja); + const common_chat_templates & tmpl, bool use_jinja); common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d19a24331b6f8..d562522a4363b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -268,9 +268,9 @@ int main(int argc, char ** argv) { const common_params_tools & tools = common_params_tools()) { common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single( - *chat_templates.template_default, chat_msgs, new_msg, role == "user", - g_params->use_jinja, tools); + + auto formatted = common_chat_format_single(chat_templates, chat_msgs, + new_msg, role == "user", g_params->use_jinja, tools); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e0acc47059656..3ceba0558548c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4468,7 +4468,7 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, ctx_server.chat_templates.template_default->source().c_str(), - common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); + common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.process_single_task(task); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5f97df5fde639..f1b4ee5b593f4 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -348,7 +348,7 @@ static llama_tokens format_infill( } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { +inline std::string format_chat(const common_chat_templates & tmpl, const std::vector & messages) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { @@ -663,7 +663,7 @@ static json oaicompat_completion_params_parse( llama_params["stop"].push_back(stop); } } else { - llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + llama_params["prompt"] = format_chat(chat_templates, body.at("messages")); } // Handle "n" field diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e0314ae1d6296..022205b7b2009 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -339,7 +339,8 @@ int main(void) { common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; auto fmt_sys = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); + common_chat_templates tmpl; + tmpl.template_default.reset(new common_chat_template(tmpl_str, "", "")); auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); @@ -366,7 +367,8 @@ int main(void) { common_chat_msg new_msg{"user", "How are you", {}}; auto fmt_single = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); + common_chat_templates tmpl; + tmpl.template_default.reset(new common_chat_template(tmpl_str, "", "")); auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); From a726adaef7be99eafbfc873420e337d6f9f79b1d Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 5 Feb 2025 12:07:22 -0400 Subject: [PATCH 07/69] Copy sampler parameters from chat template --- common/common.cpp | 54 ++++++++++++++++++++++++++++++++++++++---- common/common.h | 11 +++++++-- examples/main/main.cpp | 10 +++++--- 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 15d0258469963..13e03e88f5246 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1830,12 +1830,51 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { return res >= 0; } +static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams) +{ + GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab); + + auto & dst = *update_sparams->sparams; + auto vocab = update_sparams->vocab; + + dst.grammar = src.grammar; + dst.grammar_lazy = src.grammar_lazy; + + for (const auto & trigger : src.grammar_triggers) { + auto ids = common_tokenize(vocab, trigger.word, false, true); + + if (ids.size() == 1) { + LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); + dst.grammar_trigger_tokens.push_back(ids[0]); + dst.preserved_tokens.insert(ids[0]); + continue; + } + LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); + dst.grammar_trigger_words.push_back(trigger); + } + + for (const auto & preserved : src.preserved_tokens) { + auto ids = common_tokenize(vocab, preserved, false, true); + if (ids.size() == 1) { + LOG_DBG("Preserved token: %d\n", ids[0]); + dst.preserved_tokens.insert(ids[0]); + + } else { + // This may happen when using a tool call style meant for a model + // with special tokens to preserve on a model without said tokens. + LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", + preserved.c_str()); + } + } +} + std::string common_chat_apply_template( const common_chat_templates & tmpl, const std::vector & msgs, bool add_ass, bool use_jinja, - const common_params_tools & tools) + const common_params_tools & tools, + common_chat_sampling_updater * update_sparams) { const auto & tmpl_selected = tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; @@ -1865,7 +1904,11 @@ std::string common_chat_apply_template( inputs.messages = messages; inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl_selected, inputs).prompt; + auto chat_params = common_chat_params_init(tmpl_selected, inputs); + if (update_sparams) { + copy_chat_params(chat_params, update_sparams); + } + return chat_params.prompt; } int alloc_size = 0; @@ -1903,11 +1946,12 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - const common_params_tools & tools) + const common_params_tools & tools, + common_chat_sampling_updater * update_sparams) { std::ostringstream ss; auto fmt_past_msg = past_msg.empty() ? "" - : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools); + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -1916,7 +1960,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index c3f4f66ee1f63..f7bdb4e69cd5c 100644 --- a/common/common.h +++ b/common/common.h @@ -671,6 +671,11 @@ struct common_chat_templates { std::unique_ptr template_tool_use; }; +struct common_chat_sampling_updater { + common_params_sampling * sparams; + const llama_vocab * vocab; +}; + // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error @@ -679,7 +684,8 @@ std::string common_chat_apply_template( const std::vector & chat, bool add_ass, bool use_jinja, - const common_params_tools & tools = common_params_tools()); + const common_params_tools & tools = common_params_tools(), + common_chat_sampling_updater * update_sparams = nullptr); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -688,7 +694,8 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - const common_params_tools & tools = common_params_tools()); + const common_params_tools & tools = common_params_tools(), + common_chat_sampling_updater * update_sparams = nullptr); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d562522a4363b..0ea614bcd5489 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,14 +263,18 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates]( + auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab]( const std::string & role, const std::string & content, const common_params_tools & tools = common_params_tools()) { + bool add_ass = (role == "user"); + common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(chat_templates, chat_msgs, - new_msg, role == "user", g_params->use_jinja, tools); + common_chat_sampling_updater updater{&sparams, vocab}; + auto formatted = + common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja, + tools, &updater); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); From 1dd2e3be301b75739e9ed20d6f6db7331b7d3bf3 Mon Sep 17 00:00:00 2001 From: Mason M Date: Thu, 13 Feb 2025 15:54:21 -0400 Subject: [PATCH 08/69] Add handler and MCP message types --- common/CMakeLists.txt | 4 + common/common.cpp | 96 ++++------- common/common.h | 50 ++---- common/toolcall/handler.cpp | 110 +++++++++++++ common/toolcall/handler.hpp | 103 ++++++++++++ common/toolcall/mcp_messages.cpp | 268 +++++++++++++++++++++++++++++++ common/toolcall/mcp_messages.hpp | 160 ++++++++++++++++++ examples/main/main.cpp | 25 ++- 8 files changed, 709 insertions(+), 107 deletions(-) create mode 100644 common/toolcall/handler.cpp create mode 100644 common/toolcall/handler.hpp create mode 100644 common/toolcall/mcp_messages.cpp create mode 100644 common/toolcall/mcp_messages.hpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index e61015d2ad7f9..55628ce413977 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -75,6 +75,10 @@ add_library(${TARGET} STATIC sampling.h speculative.cpp speculative.h + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.hpp ) if (BUILD_SHARED_LIBS) diff --git a/common/common.cpp b/common/common.cpp index 13e03e88f5246..97418d78e52e4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -7,6 +7,9 @@ #include "common.h" #include "log.h" +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" #include "chat.hpp" @@ -1769,46 +1772,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto // Chat template utils // -common_params_tools::common_params_tools(std::string tools, std::string choice) { - this->tools(tools); - this->choice(choice); -} - -void common_params_tools::tools(std::string tools) { - if (tools.empty()) { - tools_.reset(); - return; - } - try { - tools_ = std::make_shared(json::parse(tools)); - if (! tools_->is_array()) { - throw std::invalid_argument("tools must be a valid JSON array"); - } - - } catch (const json::exception & err) { - throw std::invalid_argument(err.what()); - } -} - -void common_params_tools::choice(std::string choice) { - try { - if (choice == "auto" || choice == "required" || choice == "none") { - tool_choice_ = std::move(choice); - - } else { - auto choice_ptr = std::make_shared(json::parse(choice)); - tool_choice_ = choice_ptr; - if (! choice_ptr->is_object()) { - throw std::invalid_argument( - "tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\""); - } - } - - } catch (const json::exception & err) { - throw std::invalid_argument(err.what()); - } -} - bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -1830,7 +1793,7 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { return res >= 0; } -static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams) +static void copy_chat_params(const common_chat_params & src, toolcall::sampling_updater * update_sparams) { GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab); @@ -1873,11 +1836,11 @@ std::string common_chat_apply_template( const std::vector & msgs, bool add_ass, bool use_jinja, - const common_params_tools & tools, - common_chat_sampling_updater * update_sparams) + toolcall::handler::ptr handler, + toolcall::sampling_updater * update_sparams) { const auto & tmpl_selected = - tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; + handler != nullptr && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; if (use_jinja) { common_chat_inputs inputs; @@ -1886,29 +1849,38 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } + inputs.messages = messages; + inputs.add_generation_prompt = add_ass; - if (tools.tools() != nullptr) { - inputs.tools = *tools.tools(); - } - - auto choice = tools.choice(); - if (std::holds_alternative(choice)) { - inputs.tool_choice = std::get(choice); + if (handler != nullptr) { + auto choice = handler->tool_choice(); + if (std::holds_alternative(choice)) { + inputs.tool_choice = std::get(choice); - } else { - auto choice_ptr = std::get(choice); - if (choice_ptr != nullptr) { - inputs.tool_choice = *choice_ptr; + } else { + auto choice_ptr = std::get(choice); + if (choice_ptr != nullptr) { + inputs.tool_choice = *choice_ptr; + } } + + inputs.tools = handler->tool_list(); } - inputs.messages = messages; - inputs.add_generation_prompt = add_ass; auto chat_params = common_chat_params_init(tmpl_selected, inputs); if (update_sparams) { copy_chat_params(chat_params, update_sparams); } - return chat_params.prompt; + + auto prompt = chat_params.prompt; + if (handler != nullptr) { + json response; + handler->call(prompt, response); + return response; // Caller will determine what to do based upon last_action + + } else { + return prompt; + } } int alloc_size = 0; @@ -1946,12 +1918,12 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - const common_params_tools & tools, - common_chat_sampling_updater * update_sparams) + toolcall::handler::ptr handler, + toolcall::sampling_updater * update_sparams) { std::ostringstream ss; auto fmt_past_msg = past_msg.empty() ? "" - : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams); + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, handler, update_sparams); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -1960,7 +1932,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, handler, update_sparams); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index a254f0d0c6818..c2aa4fad9f3b3 100644 --- a/common/common.h +++ b/common/common.h @@ -3,15 +3,12 @@ #pragma once #include "llama-cpp.h" - +#include "toolcall/handler.hpp" #include #include #include #include #include -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -206,31 +203,6 @@ struct common_params_vocoder { bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; -class common_params_tools { -public: - using json = nlohmann::ordered_json; - using json_ptr = std::shared_ptr; - using tool_choice_t = std::variant; - - common_params_tools(std::string tools = "", - std::string choice = "auto"); - - common_params_tools(const common_params_tools & other) = default; - common_params_tools(common_params_tools && other) noexcept = default; - common_params_tools & operator=(const common_params_tools & other) = default; - common_params_tools & operator=(common_params_tools && other) noexcept = default; - - void tools(std::string tools); - const json * tools() const { return tools_.get(); } - - void choice(std::string choice); - const tool_choice_t & choice() const { return tool_choice_; } - -private: - json_ptr tools_; - tool_choice_t tool_choice_; -}; - struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -375,7 +347,7 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; - common_params_tools jinja_tools; + toolcall::params jinja_tools; std::vector api_keys; @@ -671,10 +643,12 @@ struct common_chat_templates { std::unique_ptr template_tool_use; }; -struct common_chat_sampling_updater { - common_params_sampling * sparams; - const llama_vocab * vocab; -}; +namespace toolcall { + struct sampling_updater { + common_params_sampling * sparams; + const llama_vocab * vocab; + }; +} // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml @@ -684,8 +658,8 @@ std::string common_chat_apply_template( const std::vector & chat, bool add_ass, bool use_jinja, - const common_params_tools & tools = common_params_tools(), - common_chat_sampling_updater * update_sparams = nullptr); + toolcall::handler::ptr handler = nullptr, + toolcall::sampling_updater * update_sparams = nullptr); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -694,8 +668,8 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - const common_params_tools & tools = common_params_tools(), - common_chat_sampling_updater * update_sparams = nullptr); + toolcall::handler::ptr handler = nullptr, + toolcall::sampling_updater * update_sparams = nullptr); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp new file mode 100644 index 0000000000000..389a377b8c892 --- /dev/null +++ b/common/toolcall/handler.cpp @@ -0,0 +1,110 @@ + +#include "handler.hpp" + +using json = toolcall::json; + +toolcall::params::params(std::string tools, std::string choice) { + this->tools(tools); + this->choice(choice); +} + +static bool starts_with(const std::string & str, const std::string & prefix) { + return str.size() >= prefix.size() + && str.compare(0, prefix.size(), prefix) == 0; +} + +std::shared_ptr toolcall::create_handler(const toolcall::params & params) { + std::shared_ptr result; + + auto tools = params.tools(); + auto choice = params.choice(); + bool has_uri = std::holds_alternative(tools); + if (has_uri) { + auto tools_str = std::get(tools); + result.reset(new toolcall::handler(std::make_unique(tools_str, choice))); + + } else { + auto tools_ptr = std::get(tools); + if (tools_ptr != nullptr) { + result.reset(new toolcall::handler(std::make_unique(*tools_ptr, choice))); + } + } + + return result; +} + +void toolcall::params::tools(std::string tools) { + try { + if (tools.empty() || starts_with(tools, "mcp+http")) { + tools_ = std::move(tools); + + } else { + tools_ = std::make_shared(json::parse(tools)); + auto tools_ptr = std::get>(tools_); + if (! tools_ptr->is_array()) { + throw std::invalid_argument("tools must be a valid JSON array"); + } + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +void toolcall::params::choice(std::string choice) { + try { + if (choice == "auto" || choice == "required" || choice == "none") { + tool_choice_ = std::move(choice); + + } else { + auto choice_ptr = std::make_shared(json::parse(choice)); + tool_choice_ = choice_ptr; + if (! choice_ptr->is_object()) { + throw std::invalid_argument( + "tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\""); + } + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +toolcall::params::operator bool() const { + if (std::holds_alternative(tools_)) { + return ! std::get(tools_).empty(); + + } else { + return std::get(tools_) != nullptr; + } +} + +json toolcall::handler::tool_list() { + return impl_->tool_list(); +} + +toolcall::action toolcall::handler::call(const json & request, json & response) { + last_action_ = impl_->call(request, response); + return last_action_; +} + +const toolcall::tool_choice_t & toolcall::handler::tool_choice() const { + return impl_->tool_choice(); +} +toolcall::action toolcall::handler::last_action() const { + return last_action_; +} + +toolcall::mcp_impl::mcp_impl(std::string server_uri, tool_choice_t tool_choice) + : handler_impl(tool_choice) +{ + // TODO +} + +json toolcall::mcp_impl::tool_list() { + return json{};// TODO +} + +toolcall::action toolcall::mcp_impl::call(const json & request, json & response) { + return toolcall::ACCEPT; // TODO +} diff --git a/common/toolcall/handler.hpp b/common/toolcall/handler.hpp new file mode 100644 index 0000000000000..dc104d664ae20 --- /dev/null +++ b/common/toolcall/handler.hpp @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include + +#include "../json.hpp" + +namespace toolcall +{ + using json = nlohmann::ordered_json; + using json_ptr = std::shared_ptr; + using tools_t = std::variant; + using tool_choice_t = std::variant; + + enum action { + ACCEPT, + PENDING, + DEFER + }; + + class handler_impl; + class handler { + public: + using ptr = std::shared_ptr; + + handler(std::unique_ptr impl) : impl_(std::move(impl)) {} + + json tool_list(); + action call(const json & request, json & response); + const tool_choice_t & tool_choice() const; + action last_action() const; + + private: + std::unique_ptr impl_; + action last_action_; + }; + + class params { + public: + params(std::string tools = "", std::string choice = "auto"); + + params(const params & other) = default; + params(params && other) noexcept = default; + params & operator=(const params & other) = default; + params & operator=(params && other) noexcept = default; + + operator bool() const; + + void tools(std::string tools); + const tools_t tools() const { return tools_; } + + void choice(std::string choice); + const tool_choice_t & choice() const { return tool_choice_; } + + private: + tools_t tools_; + tool_choice_t tool_choice_; + }; + + std::shared_ptr create_handler(const toolcall::params & params); + + class handler_impl { + public: + handler_impl(tool_choice_t tool_choice) + : tool_choice_(std::move(tool_choice)) {} + + virtual ~handler_impl() = default; + virtual json tool_list() = 0; + virtual action call(const json & request, json & response) = 0; + + const tool_choice_t & tool_choice() const { return tool_choice_; } + + protected: + tool_choice_t tool_choice_; + }; + + class loopback_impl : public handler_impl { + public: + loopback_impl(json tools, tool_choice_t tool_choice) + : handler_impl(tool_choice), tools_(std::move(tools)) {} + + virtual json tool_list() override { + return tools_; + } + + virtual action call(const json & request, json & response) override { + response = request; + return toolcall::DEFER; + } + + private: + json tools_; + }; + + class mcp_impl : public handler_impl { + public: + mcp_impl(std::string server_uri, tool_choice_t tool_choice); + + virtual json tool_list() override; + virtual action call(const json & request, json & response) override; + }; +} diff --git a/common/toolcall/mcp_messages.cpp b/common/toolcall/mcp_messages.cpp new file mode 100644 index 0000000000000..3a2051a3a6f3f --- /dev/null +++ b/common/toolcall/mcp_messages.cpp @@ -0,0 +1,268 @@ +#include "mcp_messages.hpp" +#include + +using json = nlohmann::json; + +const std::string mcp::JsonRpcVersion = "2.0"; +const std::string mcp::McpVersion = "2024-11-05"; +const std::string mcp::ClientVersion = "1.0.0"; +const std::string mcp::ClientName = "llama.cpp"; + +mcp::message::message(std::optional id) : id_(std::move(id)) +{ +} + +void mcp::message::id(std::optional id) { + id_ = std::move(id); +} + +const std::optional & mcp::message::id() const { + return id_; +} + +mcp::request::request(std::optional id, + std::string method, + std::optional params) + + : message(id), method_(std::move(method)), params_(std::move(params)) +{ +} + +json mcp::request::toJson() const { + json j; + j["jsonrpc"] = JsonRpcVersion; + j["method"] = method(); + if (id()) { + j["id"] = id().value(); + } + if (params()) { + j["params"] = params().value(); + } + return j; +} + +void mcp::request::method(std::string method) { + method_ = std::move(method); +} + +const std::string & mcp::request::method() const { + return method_; +} + +void mcp::request::params(std::optional params) { + params_ = std::move(params); +} + +const std::optional & mcp::request::params() const { + return params_; +} + +mcp::response::response(std::optional id, + std::optional result, + std::optional error) + + : message(id), result_(result), error_(error) +{ +} + +json mcp::response::error::toJson() const { + json j; + j["code"] = code; + j["message"] = message; + if (data) { + j["data"] = data.value(); + } + return j; +} + +json mcp::response::toJson() const { + json j; + j["jsonrpc"] = JsonRpcVersion; + if (id()) { + j["id"] = id().value(); + } + if (result()) { + j["result"] = result().value(); + } else if (getError()) { + j["error"] = getError()->toJson(); + } + return j; +} + +void mcp::response::result(std::optional result) { + result_ = std::move(result); +} + +const std::optional & mcp::response::result() const { + return result_; +} + +void mcp::response::setError(std::optional error) { + error_ = std::move(error); +} + +const std::optional & mcp::response::getError() const { + return error_; +} + +mcp::notification::notification( + std::string method, std::optional params) + : message(), method_(method), params_(params) +{ +} + +json mcp::notification::toJson() const { + json j; + j["jsonrpc"] = JsonRpcVersion; + j["method"] = method(); + if (params()) { + j["params"] = params().value(); + } + return j; +} + +void mcp::notification::method(std::string method) { + method_ = std::move(method); +} + +const std::string & mcp::notification::method() const { + return method_; +} + +void mcp::notification::params(std::optional params) { + params_ = std::move(params); +} + +const std::optional & mcp::notification::params() const { + return params_; +} + +mcp::initialize_request::initialize_request(nlohmann::json id, mcp::capabilities caps) + : request(id, "initialize"), caps_(std::move(caps)) +{ + refreshParams(); +} + +void mcp::initialize_request::refreshParams() { + json params; + params["protocolVersion"] = protoVersion(); + params["clientInfo"]["name"] = name(); + params["clientInfo"]["version"] = version(); + params["capabilities"] = {}; + + for (auto cap = caps_.cbegin(); cap != caps_.cend(); ++cap) { + json cap_json; + + if (cap->subscribe) { + cap_json["subscribe"] = true; + } + if (cap->listChanged) { + cap_json["listChanged"] = true; + } + + params["capabilities"][cap->name] = cap_json; + } + + this->params(std::move(params)); +} + +void mcp::initialize_request::capabilities(mcp::capabilities caps) { + caps_ = std::move(caps); + refreshParams(); +} + +const mcp::capabilities & mcp::initialize_request::capabilities() const { + return caps_; +} + +mcp::initialize_response::initialize_response( + nlohmann::json id, std::string name, std::string version, std::string protoVersion, + mcp::capabilities caps) + : response(id), name_(std::move(name)), version_(std::move(version)), + protoVersion_(std::move(protoVersion)), caps_(std::move(caps)) +{ + refreshResult(); +} + +void mcp::initialize_response::refreshResult() { + json result; + result["protocolVersion"] = protoVersion(); + result["serverInfo"]["name"] = name(); + result["serverInfo"]["version"] = version(); + result["capabilities"] = {}; + + for (auto cap = caps_.cbegin(); cap != caps_.cend(); ++cap) { + json cap_json; + + if (cap->subscribe) { + cap_json["subscribe"] = true; + } + if (cap->listChanged) { + cap_json["listChanged"] = true; + } + + result["capabilities"][cap->name] = cap_json; + } + + this->result(std::move(result)); +} + +void mcp::initialize_response::name(std::string name) { + name_ = std::move(name); + refreshResult(); +} + +const std::string & mcp::initialize_response::name() const { + return name_; +} + +void mcp::initialize_response::version(std::string version) { + version_ = std::move(version); + refreshResult(); +} + +const std::string & mcp::initialize_response::version() const { + return version_; +} + +void mcp::initialize_response::protoVersion(std::string protoVersion) { + protoVersion_ = std::move(protoVersion); + refreshResult(); +} + +const std::string & mcp::initialize_response::protoVersion() const { + return protoVersion_; +} + +void mcp::initialize_response::capabilities(mcp::capabilities caps) { + caps_ = std::move(caps); + refreshResult(); +} + +const mcp::capabilities & mcp::initialize_response::capabilities() const { + return caps_; +} + +mcp::initialize_response mcp::initialize_response::fromJson(const nlohmann::json& j) { + std::string name = j["result"]["serverInfo"]["name"]; + std::string version = j["result"]["serverInfo"]["version"]; + std::string protoVersion = j["result"]["protocolVersion"]; + + mcp::capabilities caps; + if (j["result"].contains("capabilities")) { + for (const auto& [key, value] : j["result"]["capabilities"].items()) { + capability cap; + cap.name = key; + cap.subscribe = value.value("subscribe", false); + cap.listChanged = value.value("listChanged", false); + caps.push_back(cap); + } + } + + return initialize_response(j["id"], name, version, protoVersion, caps); +} + +mcp::initialized_notification::initialized_notification() + : notification("notifications/initialized") +{ +} diff --git a/common/toolcall/mcp_messages.hpp b/common/toolcall/mcp_messages.hpp new file mode 100644 index 0000000000000..eb60a9fda3d3d --- /dev/null +++ b/common/toolcall/mcp_messages.hpp @@ -0,0 +1,160 @@ +#include +#include +#include +#include "../json.hpp" + +namespace mcp +{ + extern const std::string JsonRpcVersion; + extern const std::string McpVersion; + extern const std::string ClientVersion; + extern const std::string ClientName; + + class message { + public: + message(std::optional id = std::nullopt); + + virtual ~message() = default; + virtual nlohmann::json toJson() const = 0; + + void id(std::optional id); + const std::optional & id() const; + + private: + std::optional id_; + }; + + + class request : public message { + public: + request(std::optional id, + std::string method, + std::optional params = std::nullopt); + + virtual ~request() = default; + nlohmann::json toJson() const override; + + void method(std::string method); + const std::string & method() const; + + void params(std::optional params); + const std::optional & params() const; + + private: + std::string method_; + std::optional params_; + }; + + + class response : public message { + public: + struct error { + int code; + std::string message; + std::optional data; + nlohmann::json toJson() const; + }; + + response(std::optional id, + std::optional result = std::nullopt, + std::optional error = std::nullopt); + + virtual ~response() = default; + virtual nlohmann::json toJson() const override; + + void result(std::optional result); + const std::optional & result() const; + + void setError(std::optional error); + const std::optional & getError() const; + + private: + std::optional result_; + std::optional error_; + }; + + + class notification : public message { + public: + notification(std::string method, + std::optional params = std::nullopt); + + virtual ~notification() = default; + virtual nlohmann::json toJson() const override; + + void method(std::string method); + const std::string & method() const; + + void params(std::optional params); + const std::optional & params() const; + + private: + std::string method_; + std::optional params_; + }; + + + struct capability { + std::string name; + bool subscribe = false; + bool listChanged = false; + }; + + using capabilities = std::vector; + + class initialize_request : public request { + public: + initialize_request(nlohmann::json id, mcp::capabilities caps); + + const std::string & name() const { return ClientName; } + const std::string & version() const { return ClientVersion; } + const std::string & protoVersion() const { return McpVersion; } + + void capabilities(mcp::capabilities capabilities); + const mcp::capabilities & capabilities() const; + + private: + void refreshParams(); + + mcp::capabilities caps_; + }; + + + class initialize_response : public response { + public: + initialize_response(nlohmann::json id, + std::string name, + std::string version, + std::string protoVersion, + mcp::capabilities caps); + + void name(std::string name); + const std::string & name() const; + + void version(std::string version); + const std::string & version() const; + + void protoVersion(std::string protoVersion); + const std::string & protoVersion() const; + + void capabilities(mcp::capabilities capabilities); + const mcp::capabilities & capabilities() const; + + static initialize_response fromJson(const nlohmann::json& j); + + private: + void refreshResult(); + + std::string name_; + std::string version_; + std::string protoVersion_; + mcp::capabilities caps_; + }; + + + class initialized_notification : public notification { + public: + initialized_notification(); + }; +} + diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 0ea614bcd5489..2dcf571e6ef17 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,18 +263,20 @@ int main(int argc, char ** argv) { std::vector embd_inp; + auto toolcall_handler = toolcall::create_handler(params.jinja_tools); + auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab]( const std::string & role, const std::string & content, - const common_params_tools & tools = common_params_tools()) + toolcall::handler::ptr handler = nullptr) { bool add_ass = (role == "user"); common_chat_msg new_msg{role, content, {}}; - common_chat_sampling_updater updater{&sparams, vocab}; + toolcall::sampling_updater updater{&sparams, vocab}; auto formatted = common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja, - tools, &updater); + handler, &updater); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); @@ -285,7 +287,7 @@ int main(int argc, char ** argv) { std::string system_prompt (params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt); bool use_conversation_prompt (params.conversation_mode && params.enable_chat_template); auto prompt = use_conversation_prompt ? - chat_add_and_format("system", system_prompt, params.jinja_tools) + chat_add_and_format("system", system_prompt, toolcall_handler) : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { @@ -791,10 +793,19 @@ int main(int argc, char ** argv) { } if (params.enable_chat_template) { - chat_add_and_format("assistant", assistant_ss.str()); + chat_add_and_format("assistant", assistant_ss.str(), toolcall_handler); + } + if (toolcall_handler != nullptr) { + auto action = toolcall_handler->last_action(); + if (action == toolcall::PENDING || action == toolcall::DEFER) { + is_interacting = true; + LOG("\n"); + } + + } else { + is_interacting = true; + LOG("\n"); } - is_interacting = true; - LOG("\n"); } } From b41f57cfa89039cd03f7d86b11181aa4156df093 Mon Sep 17 00:00:00 2001 From: Mason M Date: Thu, 13 Feb 2025 16:15:09 -0400 Subject: [PATCH 09/69] Comment out unused parameters --- common/toolcall/handler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp index 389a377b8c892..620901ffd67e7 100644 --- a/common/toolcall/handler.cpp +++ b/common/toolcall/handler.cpp @@ -95,7 +95,7 @@ toolcall::action toolcall::handler::last_action() const { return last_action_; } -toolcall::mcp_impl::mcp_impl(std::string server_uri, tool_choice_t tool_choice) +toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, tool_choice_t tool_choice) : handler_impl(tool_choice) { // TODO @@ -105,6 +105,6 @@ json toolcall::mcp_impl::tool_list() { return json{};// TODO } -toolcall::action toolcall::mcp_impl::call(const json & request, json & response) { +toolcall::action toolcall::mcp_impl::call(const json & /*request*/, json & /*response*/) { return toolcall::ACCEPT; // TODO } From e7efd7c4954ef5396e047a63e89c7705a79794e8 Mon Sep 17 00:00:00 2001 From: Mason M Date: Thu, 13 Feb 2025 16:17:30 -0400 Subject: [PATCH 10/69] Remove tabs --- common/common.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common.h b/common/common.h index 05c5d90468f65..d00c0f23d9bfd 100644 --- a/common/common.h +++ b/common/common.h @@ -653,8 +653,8 @@ struct common_chat_templates { namespace toolcall { struct sampling_updater { - common_params_sampling * sparams; - const llama_vocab * vocab; + common_params_sampling * sparams; + const llama_vocab * vocab; }; } From 2c07ce751a7c3e4303fbdffc08a2014d4d1e000b Mon Sep 17 00:00:00 2001 From: Mason M Date: Thu, 13 Feb 2025 16:47:51 -0400 Subject: [PATCH 11/69] Only use MCP handler with non-empty string --- common/common.cpp | 4 ++-- common/toolcall/handler.cpp | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 97418d78e52e4..eff848c03d1cc 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1839,8 +1839,8 @@ std::string common_chat_apply_template( toolcall::handler::ptr handler, toolcall::sampling_updater * update_sparams) { - const auto & tmpl_selected = - handler != nullptr && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; + bool use_tool_template = (use_jinja && handler != nullptr) && tmpl.template_tool_use; + const auto & tmpl_selected = use_tool_template ? *tmpl.template_tool_use : *tmpl.template_default; if (use_jinja) { common_chat_inputs inputs; diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp index 620901ffd67e7..e2c590d0f8783 100644 --- a/common/toolcall/handler.cpp +++ b/common/toolcall/handler.cpp @@ -21,7 +21,9 @@ std::shared_ptr toolcall::create_handler(const toolcall::para bool has_uri = std::holds_alternative(tools); if (has_uri) { auto tools_str = std::get(tools); - result.reset(new toolcall::handler(std::make_unique(tools_str, choice))); + if (! tools_str.empty()) { + result.reset(new toolcall::handler(std::make_unique(tools_str, choice))); + } } else { auto tools_ptr = std::get(tools); From b67a04cc91a4dcbb72cf4edbb0d4027deddae0d1 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 14 Feb 2025 10:20:02 -0400 Subject: [PATCH 12/69] Switch to compile-time polymorphic message types --- common/toolcall/mcp_messages.cpp | 89 +------------------------------- common/toolcall/mcp_messages.hpp | 84 +++++++++++++++++------------- 2 files changed, 48 insertions(+), 125 deletions(-) diff --git a/common/toolcall/mcp_messages.cpp b/common/toolcall/mcp_messages.cpp index 3a2051a3a6f3f..33c8f3bbcb0f4 100644 --- a/common/toolcall/mcp_messages.cpp +++ b/common/toolcall/mcp_messages.cpp @@ -8,63 +8,19 @@ const std::string mcp::McpVersion = "2024-11-05"; const std::string mcp::ClientVersion = "1.0.0"; const std::string mcp::ClientName = "llama.cpp"; -mcp::message::message(std::optional id) : id_(std::move(id)) -{ -} - -void mcp::message::id(std::optional id) { - id_ = std::move(id); -} - -const std::optional & mcp::message::id() const { - return id_; -} - -mcp::request::request(std::optional id, - std::string method, - std::optional params) - - : message(id), method_(std::move(method)), params_(std::move(params)) -{ -} - json mcp::request::toJson() const { json j; j["jsonrpc"] = JsonRpcVersion; - j["method"] = method(); if (id()) { j["id"] = id().value(); } + j["method"] = method(); if (params()) { j["params"] = params().value(); } return j; } -void mcp::request::method(std::string method) { - method_ = std::move(method); -} - -const std::string & mcp::request::method() const { - return method_; -} - -void mcp::request::params(std::optional params) { - params_ = std::move(params); -} - -const std::optional & mcp::request::params() const { - return params_; -} - -mcp::response::response(std::optional id, - std::optional result, - std::optional error) - - : message(id), result_(result), error_(error) -{ -} - json mcp::response::error::toJson() const { json j; j["code"] = code; @@ -89,28 +45,6 @@ json mcp::response::toJson() const { return j; } -void mcp::response::result(std::optional result) { - result_ = std::move(result); -} - -const std::optional & mcp::response::result() const { - return result_; -} - -void mcp::response::setError(std::optional error) { - error_ = std::move(error); -} - -const std::optional & mcp::response::getError() const { - return error_; -} - -mcp::notification::notification( - std::string method, std::optional params) - : message(), method_(method), params_(params) -{ -} - json mcp::notification::toJson() const { json j; j["jsonrpc"] = JsonRpcVersion; @@ -121,22 +55,6 @@ json mcp::notification::toJson() const { return j; } -void mcp::notification::method(std::string method) { - method_ = std::move(method); -} - -const std::string & mcp::notification::method() const { - return method_; -} - -void mcp::notification::params(std::optional params) { - params_ = std::move(params); -} - -const std::optional & mcp::notification::params() const { - return params_; -} - mcp::initialize_request::initialize_request(nlohmann::json id, mcp::capabilities caps) : request(id, "initialize"), caps_(std::move(caps)) { @@ -261,8 +179,3 @@ mcp::initialize_response mcp::initialize_response::fromJson(const nlohmann::json return initialize_response(j["id"], name, version, protoVersion, caps); } - -mcp::initialized_notification::initialized_notification() - : notification("notifications/initialized") -{ -} diff --git a/common/toolcall/mcp_messages.hpp b/common/toolcall/mcp_messages.hpp index eb60a9fda3d3d..bc41bdbda7cd3 100644 --- a/common/toolcall/mcp_messages.hpp +++ b/common/toolcall/mcp_messages.hpp @@ -10,43 +10,52 @@ namespace mcp extern const std::string ClientVersion; extern const std::string ClientName; + template class message { public: - message(std::optional id = std::nullopt); + message(std::optional id = std::nullopt) + : id_(std::move(id)) {} - virtual ~message() = default; - virtual nlohmann::json toJson() const = 0; + nlohmann::json toJson() const { + return static_cast(this)->toJson(); + } - void id(std::optional id); - const std::optional & id() const; + void id(std::optional id) { + id_ = std::move(id); + } + + const std::optional & id() const { + return id_; + } private: std::optional id_; }; - - class request : public message { + class request : public message { public: request(std::optional id, std::string method, - std::optional params = std::nullopt); + std::optional params = std::nullopt) + + : message(id), + method_(std::move(method)), + params_(std::move(params)) {} - virtual ~request() = default; - nlohmann::json toJson() const override; + void method(std::string method) { method_ = std::move(method); } + const std::string & method() const { return method_; } - void method(std::string method); - const std::string & method() const; + void params(std::optional params) { params_ = std::move(params); } + const std::optional & params() const { return params_; } - void params(std::optional params); - const std::optional & params() const; + nlohmann::json toJson() const; private: std::string method_; std::optional params_; }; - - class response : public message { + class response : public message { public: struct error { int code; @@ -57,43 +66,47 @@ namespace mcp response(std::optional id, std::optional result = std::nullopt, - std::optional error = std::nullopt); + std::optional error = std::nullopt) + + : message(id), + result_(std::move(result)), + error_(std::move(error)) {} - virtual ~response() = default; - virtual nlohmann::json toJson() const override; + void result(std::optional result) { result_ = std::move(result); } + const std::optional & result() const { return result_; } - void result(std::optional result); - const std::optional & result() const; + void setError(std::optional error) { error_ = std::move(error); } + const std::optional & getError() const { return error_; } - void setError(std::optional error); - const std::optional & getError() const; + nlohmann::json toJson() const; private: std::optional result_; std::optional error_; }; - - class notification : public message { + class notification : public message { public: notification(std::string method, - std::optional params = std::nullopt); + std::optional params = std::nullopt) - virtual ~notification() = default; - virtual nlohmann::json toJson() const override; + : message(), + method_(method), + params_(params) {} - void method(std::string method); - const std::string & method() const; + void method(std::string method) { method_ = std::move(method); } + const std::string & method() const { return method_; } - void params(std::optional params); - const std::optional & params() const; + void params(std::optional params) { params_ = std::move(params); } + const std::optional & params() const { return params_; } + + nlohmann::json toJson() const; private: std::string method_; std::optional params_; }; - struct capability { std::string name; bool subscribe = false; @@ -119,7 +132,6 @@ namespace mcp mcp::capabilities caps_; }; - class initialize_response : public response { public: initialize_response(nlohmann::json id, @@ -151,10 +163,8 @@ namespace mcp mcp::capabilities caps_; }; - class initialized_notification : public notification { public: - initialized_notification(); + initialized_notification() : notification("notifications/initialized") {} }; } - From 20a19f8a317ddf6efb526c56ea56ffdd13f3ddf7 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 14 Feb 2025 10:46:02 -0400 Subject: [PATCH 13/69] Add tools/list request --- common/toolcall/mcp_messages.cpp | 20 ++++++++++++++++++++ common/toolcall/mcp_messages.hpp | 15 ++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/common/toolcall/mcp_messages.cpp b/common/toolcall/mcp_messages.cpp index 33c8f3bbcb0f4..491dd804810a7 100644 --- a/common/toolcall/mcp_messages.cpp +++ b/common/toolcall/mcp_messages.cpp @@ -179,3 +179,23 @@ mcp::initialize_response mcp::initialize_response::fromJson(const nlohmann::json return initialize_response(j["id"], name, version, protoVersion, caps); } + +mcp::tools_list_request::tools_list_request(std::optional id, std::string cursor) + : request(id, "tools/list"), + cursor_(std::move(cursor)) +{ + refreshParams(); +} + +void mcp::tools_list_request::cursor(std::string cursor) { + cursor_ = std::move(cursor); + refreshParams(); +} + +void mcp::tools_list_request::refreshParams() { + if (! cursor_.empty()) { + json params; + params["cursor"] = cursor_; + this->params(params); + } +} diff --git a/common/toolcall/mcp_messages.hpp b/common/toolcall/mcp_messages.hpp index bc41bdbda7cd3..7e8dc7f7057ed 100644 --- a/common/toolcall/mcp_messages.hpp +++ b/common/toolcall/mcp_messages.hpp @@ -165,6 +165,19 @@ namespace mcp class initialized_notification : public notification { public: - initialized_notification() : notification("notifications/initialized") {} + initialized_notification() + : notification("notifications/initialized") {} + }; + + class tools_list_request : public request { + public: + tools_list_request(std::optional id, std::string cursor = ""); + + void cursor(std::string cursor); + const std::string & cursor() { return cursor_; } + + private: + void refreshParams(); + std::string cursor_; }; } From 9dbe42fe11a72fcb319334334e526f2f79c0784b Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 14 Feb 2025 12:00:44 -0400 Subject: [PATCH 14/69] Add tools/list response --- common/toolcall/mcp_messages.cpp | 57 ++++++++++++++++++++++++++++++++ common/toolcall/mcp_messages.hpp | 32 ++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/common/toolcall/mcp_messages.cpp b/common/toolcall/mcp_messages.cpp index 491dd804810a7..424c01bb7e558 100644 --- a/common/toolcall/mcp_messages.cpp +++ b/common/toolcall/mcp_messages.cpp @@ -199,3 +199,60 @@ void mcp::tools_list_request::refreshParams() { this->params(params); } } + +mcp::tools_list_response::tools_list_response(nlohmann::json id, + mcp::tools_list tools, + std::string next_cursor) + : response(id), + tools_(std::move(tools)), + next_cursor_(std::move(next_cursor)) +{ + refreshResult(); +} + +void mcp::tools_list_response::tools(mcp::tools_list tools) { + tools_ = std::move(tools); + refreshResult(); +} + +void mcp::tools_list_response::next_cursor(std::string next_cursor) { + next_cursor_ = std::move(next_cursor); + refreshResult(); +} + +void mcp::tools_list_response::refreshResult() { + json result; + + json tools = json::array(); + for (const auto & tool : tools_) { + json t; + + t["name"] = tool.tool_name; + t["description"] = tool.tool_description; + t["inputSchema"]["type"] = "object"; + + json props; + for (const auto & param : tool.params) { + props[param.name] = { + {"type"}, {param.type}, + {"description"}, {param.description} + }; + } + t["inputSchema"]["properties"] = props; + + json required = json::array(); + for (const auto & req_param : tool.required_params) { + required.push_back(req_param); + } + t["inputSchema"]["required"] = required; + + tools.push_back(t); + } + result["tools"] = tools; + + if (! next_cursor_.empty()) { + result["nextCursor"] = next_cursor_; + } + + this->result(result); +} diff --git a/common/toolcall/mcp_messages.hpp b/common/toolcall/mcp_messages.hpp index 7e8dc7f7057ed..fa4a45ef33f0c 100644 --- a/common/toolcall/mcp_messages.hpp +++ b/common/toolcall/mcp_messages.hpp @@ -180,4 +180,36 @@ namespace mcp void refreshParams(); std::string cursor_; }; + + struct tool { + struct param { + std::string name; + std::string type; + std::string description; + }; + std::string tool_name; + std::string tool_description; + std::vector params; + std::vector required_params; + }; + + using tools_list = std::vector; + + class tools_list_response : public response { + public: + tools_list_response(nlohmann::json id, + tools_list tools = tools_list(), + std::string next_cursor = ""); + + void tools(tools_list tools); + const tools_list & tools() const { return tools_; } + + void next_cursor(std::string next_cursor); + const std::string & next_cursor() { return next_cursor_; } + + private: + void refreshResult(); + tools_list tools_; + std::string next_cursor_; + }; } From 93b54e4687e9b31b0d288d07071b9d7d7ba9e5c4 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 14 Feb 2025 13:34:08 -0400 Subject: [PATCH 15/69] Tokenize output from toolcall response --- examples/main/main.cpp | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2dcf571e6ef17..54fec7f5e8447 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -793,18 +793,23 @@ int main(int argc, char ** argv) { } if (params.enable_chat_template) { - chat_add_and_format("assistant", assistant_ss.str(), toolcall_handler); - } - if (toolcall_handler != nullptr) { - auto action = toolcall_handler->last_action(); - if (action == toolcall::PENDING || action == toolcall::DEFER) { + auto output = chat_add_and_format("assistant", assistant_ss.str(), toolcall_handler); + if (toolcall_handler != nullptr) { + auto action = toolcall_handler->last_action(); + if (action == toolcall::ACCEPT) { + LOG_DBG("tokenizing toolcall response"); + auto response = common_tokenize(ctx, output, false, true); + embd_inp.insert(embd_inp.end(), response.begin(), response.end()); + + } else { + is_interacting = true; + LOG("\n"); + } + + } else { is_interacting = true; LOG("\n"); } - - } else { - is_interacting = true; - LOG("\n"); } } } From 99f2fe360ad548b5c346bff0ed716e4409e7b644 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 14 Feb 2025 16:09:32 -0400 Subject: [PATCH 16/69] Add MCP sse/stdio transport types --- common/CMakeLists.txt | 5 +++++ common/toolcall/handler.cpp | 18 +++++++++++++++--- common/toolcall/handler.hpp | 5 +++++ common/toolcall/mcp_messages.hpp | 5 +++++ common/toolcall/mcp_sse_transport.cpp | 17 +++++++++++++++++ common/toolcall/mcp_sse_transport.hpp | 18 ++++++++++++++++++ common/toolcall/mcp_stdio_transport.cpp | 17 +++++++++++++++++ common/toolcall/mcp_stdio_transport.hpp | 21 +++++++++++++++++++++ common/toolcall/mcp_transport.hpp | 22 ++++++++++++++++++++++ 9 files changed, 125 insertions(+), 3 deletions(-) create mode 100644 common/toolcall/mcp_sse_transport.cpp create mode 100644 common/toolcall/mcp_sse_transport.hpp create mode 100644 common/toolcall/mcp_stdio_transport.cpp create mode 100644 common/toolcall/mcp_stdio_transport.hpp create mode 100644 common/toolcall/mcp_transport.hpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 55628ce413977..69289bfced25a 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -79,6 +79,11 @@ add_library(${TARGET} STATIC ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.hpp ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.cpp ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_transport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_sse_transport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_sse_transport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_stdio_transport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_stdio_transport.hpp ) if (BUILD_SHARED_LIBS) diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp index e2c590d0f8783..dfea132b6031f 100644 --- a/common/toolcall/handler.cpp +++ b/common/toolcall/handler.cpp @@ -1,5 +1,7 @@ #include "handler.hpp" +#include "mcp_sse_transport.hpp" +#include "mcp_stdio_transport.hpp" using json = toolcall::json; @@ -97,16 +99,26 @@ toolcall::action toolcall::handler::last_action() const { return last_action_; } -toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, tool_choice_t tool_choice) - : handler_impl(tool_choice) +toolcall::mcp_impl::mcp_impl(std::string server_uri, tool_choice_t tool_choice) + : handler_impl(tool_choice), + transport_(new mcp_sse_transport(server_uri)) { - // TODO + transport_->start(); +} + +toolcall::mcp_impl::mcp_impl(std::vector argv, tool_choice_t tool_choice) + : handler_impl(tool_choice), + transport_(new mcp_stdio_transport(argv)) +{ + transport_->start(); } json toolcall::mcp_impl::tool_list() { + // Construct tools/list call and send to transport return json{};// TODO } toolcall::action toolcall::mcp_impl::call(const json & /*request*/, json & /*response*/) { + // Construct tool call and send to transport return toolcall::ACCEPT; // TODO } diff --git a/common/toolcall/handler.hpp b/common/toolcall/handler.hpp index dc104d664ae20..cf52f9e81d464 100644 --- a/common/toolcall/handler.hpp +++ b/common/toolcall/handler.hpp @@ -93,11 +93,16 @@ namespace toolcall json tools_; }; + class mcp_transport; class mcp_impl : public handler_impl { public: mcp_impl(std::string server_uri, tool_choice_t tool_choice); + mcp_impl(std::vector argv, tool_choice_t tool_choice); virtual json tool_list() override; virtual action call(const json & request, json & response) override; + + private: + std::unique_ptr transport_; }; } diff --git a/common/toolcall/mcp_messages.hpp b/common/toolcall/mcp_messages.hpp index fa4a45ef33f0c..a96e3b50668c9 100644 --- a/common/toolcall/mcp_messages.hpp +++ b/common/toolcall/mcp_messages.hpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "../json.hpp" namespace mcp @@ -212,4 +213,8 @@ namespace mcp tools_list tools_; std::string next_cursor_; }; + + using message_variant = std::variant< + initialize_request, initialize_response, initialized_notification, + tools_list_request, tools_list_response>; } diff --git a/common/toolcall/mcp_sse_transport.cpp b/common/toolcall/mcp_sse_transport.cpp new file mode 100644 index 0000000000000..0336175ab2aed --- /dev/null +++ b/common/toolcall/mcp_sse_transport.cpp @@ -0,0 +1,17 @@ + +#include "mcp_sse_transport.hpp" + +toolcall::mcp_sse_transport::mcp_sse_transport(std::string server_uri) + : server_uri_(std::move(server_uri)) +{ +} + +void toolcall::mcp_sse_transport::start() { +} + +void toolcall::mcp_sse_transport::stop() { +} + +bool toolcall::mcp_sse_transport::send(const mcp::message_variant & /*request*/) { + return false; +} diff --git a/common/toolcall/mcp_sse_transport.hpp b/common/toolcall/mcp_sse_transport.hpp new file mode 100644 index 0000000000000..cc44706d83451 --- /dev/null +++ b/common/toolcall/mcp_sse_transport.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "mcp_transport.hpp" + +namespace toolcall +{ + class mcp_sse_transport : public mcp_transport { + public: + mcp_sse_transport(std::string server_uri); + + virtual void start() override; + virtual void stop() override; + virtual bool send(const mcp::message_variant & request) override; + + private: + std::string server_uri_; + }; +} diff --git a/common/toolcall/mcp_stdio_transport.cpp b/common/toolcall/mcp_stdio_transport.cpp new file mode 100644 index 0000000000000..009ef81e13247 --- /dev/null +++ b/common/toolcall/mcp_stdio_transport.cpp @@ -0,0 +1,17 @@ + +#include "mcp_stdio_transport.hpp" + +toolcall::mcp_stdio_transport::mcp_stdio_transport(std::vector argv) + : argv_(std::move(argv)) +{ +} + +void toolcall::mcp_stdio_transport::start() { +} + +void toolcall::mcp_stdio_transport::stop() { +} + +bool toolcall::mcp_stdio_transport::send(const mcp::message_variant & /*request*/) { + return false; +} diff --git a/common/toolcall/mcp_stdio_transport.hpp b/common/toolcall/mcp_stdio_transport.hpp new file mode 100644 index 0000000000000..130efd0c869bc --- /dev/null +++ b/common/toolcall/mcp_stdio_transport.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "mcp_transport.hpp" + +#include +#include + +namespace toolcall +{ + class mcp_stdio_transport : public mcp_transport { + public: + mcp_stdio_transport(std::vector argv); + + virtual void start() override; + virtual void stop() override; + virtual bool send(const mcp::message_variant & request) override; + + private: + std::vector argv_; + }; +} diff --git a/common/toolcall/mcp_transport.hpp b/common/toolcall/mcp_transport.hpp new file mode 100644 index 0000000000000..ecb2bf82bb3e0 --- /dev/null +++ b/common/toolcall/mcp_transport.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "mcp_messages.hpp" + +namespace toolcall +{ + class mcp_transport { + public: + using on_message_callback = std::function; + + virtual ~mcp_transport() = default; + virtual void start() = 0; + virtual void stop() = 0; + virtual bool send(const mcp::message_variant & request) = 0; + + void on_received(on_message_callback callback) { callback_ = std::move(callback); } + const on_message_callback & on_received() const { return callback_; } + + protected: + on_message_callback callback_; + }; +} From 3309b585bcddfb21aca57402c9094de2f8571be4 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 14 Feb 2025 16:13:33 -0400 Subject: [PATCH 17/69] Fix indent --- common/toolcall/mcp_stdio_transport.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/toolcall/mcp_stdio_transport.hpp b/common/toolcall/mcp_stdio_transport.hpp index 130efd0c869bc..db40687a5360f 100644 --- a/common/toolcall/mcp_stdio_transport.hpp +++ b/common/toolcall/mcp_stdio_transport.hpp @@ -9,7 +9,7 @@ namespace toolcall { class mcp_stdio_transport : public mcp_transport { public: - mcp_stdio_transport(std::vector argv); + mcp_stdio_transport(std::vector argv); virtual void start() override; virtual void stop() override; From 376fbba01a7f904629b8de2bcdce96e8c1c02339 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 14 Feb 2025 16:29:10 -0400 Subject: [PATCH 18/69] throw exceptions in stdio transport for now --- common/toolcall/mcp_stdio_transport.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/common/toolcall/mcp_stdio_transport.cpp b/common/toolcall/mcp_stdio_transport.cpp index 009ef81e13247..c46002026c6bd 100644 --- a/common/toolcall/mcp_stdio_transport.cpp +++ b/common/toolcall/mcp_stdio_transport.cpp @@ -1,17 +1,21 @@ #include "mcp_stdio_transport.hpp" +#include + toolcall::mcp_stdio_transport::mcp_stdio_transport(std::vector argv) : argv_(std::move(argv)) { } void toolcall::mcp_stdio_transport::start() { + throw std::logic_error(std::string("Function not implemented: ") + __func__); } void toolcall::mcp_stdio_transport::stop() { + throw std::logic_error(std::string("Function not implemented: ") + __func__); } bool toolcall::mcp_stdio_transport::send(const mcp::message_variant & /*request*/) { - return false; + throw std::logic_error(std::string("Function not implemented: ") + __func__); } From 80e679094db2e9c879f41daa61b04a59facb3622 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 14 Feb 2025 17:28:43 -0400 Subject: [PATCH 19/69] Only include SSE transport when LLAMA_CURL is set --- common/toolcall/handler.cpp | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp index dfea132b6031f..690e6324aa4a6 100644 --- a/common/toolcall/handler.cpp +++ b/common/toolcall/handler.cpp @@ -1,6 +1,10 @@ #include "handler.hpp" -#include "mcp_sse_transport.hpp" + +#ifdef LLAMA_USE_CURL +# include "mcp_sse_transport.hpp" +#endif + #include "mcp_stdio_transport.hpp" using json = toolcall::json; @@ -22,11 +26,12 @@ std::shared_ptr toolcall::create_handler(const toolcall::para auto choice = params.choice(); bool has_uri = std::holds_alternative(tools); if (has_uri) { +#ifdef LLAMA_USE_CURL auto tools_str = std::get(tools); if (! tools_str.empty()) { result.reset(new toolcall::handler(std::make_unique(tools_str, choice))); } - +#endif } else { auto tools_ptr = std::get(tools); if (tools_ptr != nullptr) { @@ -39,14 +44,24 @@ std::shared_ptr toolcall::create_handler(const toolcall::para void toolcall::params::tools(std::string tools) { try { - if (tools.empty() || starts_with(tools, "mcp+http")) { + + if (tools.empty()) { tools_ = std::move(tools); + } else if (starts_with(tools, "mcp+http")) { +#ifdef LLAMA_USE_CURL + tools_ = std::move(tools); +#else + throw std::invalid_argument( + "Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl"); +#endif } else { tools_ = std::make_shared(json::parse(tools)); auto tools_ptr = std::get>(tools_); if (! tools_ptr->is_array()) { - throw std::invalid_argument("tools must be a valid JSON array"); + throw std::invalid_argument( + "tools must be a URL of the form \"mcp+http(s)://hostname[:port]/\"" + ", or a valid JSON array containing tool definitions"); } } @@ -99,12 +114,19 @@ toolcall::action toolcall::handler::last_action() const { return last_action_; } +#ifdef LLAMA_USE_CURL toolcall::mcp_impl::mcp_impl(std::string server_uri, tool_choice_t tool_choice) : handler_impl(tool_choice), transport_(new mcp_sse_transport(server_uri)) { transport_->start(); } +#else +toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, tool_choice_t tool_choice) + : handler_impl(tool_choice) +{ +} +#endif toolcall::mcp_impl::mcp_impl(std::vector argv, tool_choice_t tool_choice) : handler_impl(tool_choice), From ff447629c6cce198e17206bc9f2802d3164c5efe Mon Sep 17 00:00:00 2001 From: Mason M Date: Sat, 15 Feb 2025 13:41:05 -0400 Subject: [PATCH 20/69] Split toolcall params into separate files --- common/CMakeLists.txt | 2 ++ common/toolcall/handler.cpp | 68 ++--------------------------------- common/toolcall/handler.hpp | 26 ++------------ common/toolcall/params.cpp | 71 +++++++++++++++++++++++++++++++++++++ common/toolcall/params.hpp | 37 +++++++++++++++++++ 5 files changed, 114 insertions(+), 90 deletions(-) create mode 100644 common/toolcall/params.cpp create mode 100644 common/toolcall/params.hpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 69289bfced25a..e8a047315c71d 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -75,6 +75,8 @@ add_library(${TARGET} STATIC sampling.h speculative.cpp speculative.h + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/params.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/params.hpp ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.hpp ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.cpp diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp index 690e6324aa4a6..73d216aff68df 100644 --- a/common/toolcall/handler.cpp +++ b/common/toolcall/handler.cpp @@ -1,5 +1,7 @@ +#include "../json.hpp" #include "handler.hpp" +#include "params.hpp" #ifdef LLAMA_USE_CURL # include "mcp_sse_transport.hpp" @@ -9,16 +11,6 @@ using json = toolcall::json; -toolcall::params::params(std::string tools, std::string choice) { - this->tools(tools); - this->choice(choice); -} - -static bool starts_with(const std::string & str, const std::string & prefix) { - return str.size() >= prefix.size() - && str.compare(0, prefix.size(), prefix) == 0; -} - std::shared_ptr toolcall::create_handler(const toolcall::params & params) { std::shared_ptr result; @@ -42,62 +34,6 @@ std::shared_ptr toolcall::create_handler(const toolcall::para return result; } -void toolcall::params::tools(std::string tools) { - try { - - if (tools.empty()) { - tools_ = std::move(tools); - - } else if (starts_with(tools, "mcp+http")) { -#ifdef LLAMA_USE_CURL - tools_ = std::move(tools); -#else - throw std::invalid_argument( - "Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl"); -#endif - } else { - tools_ = std::make_shared(json::parse(tools)); - auto tools_ptr = std::get>(tools_); - if (! tools_ptr->is_array()) { - throw std::invalid_argument( - "tools must be a URL of the form \"mcp+http(s)://hostname[:port]/\"" - ", or a valid JSON array containing tool definitions"); - } - } - - } catch (const json::exception & err) { - throw std::invalid_argument(err.what()); - } -} - -void toolcall::params::choice(std::string choice) { - try { - if (choice == "auto" || choice == "required" || choice == "none") { - tool_choice_ = std::move(choice); - - } else { - auto choice_ptr = std::make_shared(json::parse(choice)); - tool_choice_ = choice_ptr; - if (! choice_ptr->is_object()) { - throw std::invalid_argument( - "tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\""); - } - } - - } catch (const json::exception & err) { - throw std::invalid_argument(err.what()); - } -} - -toolcall::params::operator bool() const { - if (std::holds_alternative(tools_)) { - return ! std::get(tools_).empty(); - - } else { - return std::get(tools_) != nullptr; - } -} - json toolcall::handler::tool_list() { return impl_->tool_list(); } diff --git a/common/toolcall/handler.hpp b/common/toolcall/handler.hpp index cf52f9e81d464..6f7fca8302734 100644 --- a/common/toolcall/handler.hpp +++ b/common/toolcall/handler.hpp @@ -1,11 +1,11 @@ #pragma once +#include "../json.hpp" +#include "params.hpp" // TODO: make foreward decl. #include #include #include -#include "../json.hpp" - namespace toolcall { using json = nlohmann::ordered_json; @@ -36,28 +36,6 @@ namespace toolcall action last_action_; }; - class params { - public: - params(std::string tools = "", std::string choice = "auto"); - - params(const params & other) = default; - params(params && other) noexcept = default; - params & operator=(const params & other) = default; - params & operator=(params && other) noexcept = default; - - operator bool() const; - - void tools(std::string tools); - const tools_t tools() const { return tools_; } - - void choice(std::string choice); - const tool_choice_t & choice() const { return tool_choice_; } - - private: - tools_t tools_; - tool_choice_t tool_choice_; - }; - std::shared_ptr create_handler(const toolcall::params & params); class handler_impl { diff --git a/common/toolcall/params.cpp b/common/toolcall/params.cpp new file mode 100644 index 0000000000000..bc30421085192 --- /dev/null +++ b/common/toolcall/params.cpp @@ -0,0 +1,71 @@ + +#include "../json.hpp" // Must come before params due to forward decl. +#include "params.hpp" +#include + +using json = nlohmann::ordered_json; + +static bool starts_with(const std::string & str, const std::string & prefix) { + return str.size() >= prefix.size() + && str.compare(0, prefix.size(), prefix) == 0; +} + +toolcall::params::params(std::string tools, std::string choice) { + this->tools(tools); + this->choice(choice); +} + +void toolcall::params::tools(std::string tools) { + try { + if (tools.empty()) { + tools_ = std::move(tools); + + } else if (starts_with(tools, "mcp+http")) { +#ifdef LLAMA_USE_CURL + tools_ = std::move(tools); +#else + throw std::invalid_argument( + "Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl"); +#endif + } else { + tools_ = std::make_shared(json::parse(tools)); + auto tools_ptr = std::get>(tools_); + if (! tools_ptr->is_array()) { + throw std::invalid_argument( + "tools must be a URL of the form \"mcp+http(s)://hostname[:port]/\"" + ", or a valid JSON array containing tool definitions"); + } + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +void toolcall::params::choice(std::string choice) { + try { + if (choice == "auto" || choice == "required" || choice == "none") { + tool_choice_ = std::move(choice); + + } else { + auto choice_ptr = std::make_shared(json::parse(choice)); + tool_choice_ = choice_ptr; + if (! choice_ptr->is_object()) { + throw std::invalid_argument( + "tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\""); + } + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +toolcall::params::operator bool() const { + if (std::holds_alternative(tools_)) { + return ! std::get(tools_).empty(); + + } else { + return std::get(tools_) != nullptr; + } +} diff --git a/common/toolcall/params.hpp b/common/toolcall/params.hpp new file mode 100644 index 0000000000000..fbe50aa51202a --- /dev/null +++ b/common/toolcall/params.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +#include "../json.hpp" // TODO: switch to foreward decl. +// namespace nlohmann { class ordered_json; } + +namespace toolcall +{ + class params { + public: + using json_ptr = std::shared_ptr; + using tools_t = std::variant; + using tool_choice_t = std::variant; + + params(std::string tools = "", std::string choice = "auto"); + + params(const params & other) = default; + params(params && other) noexcept = default; + params & operator=(const params & other) = default; + params & operator=(params && other) noexcept = default; + + operator bool() const; + + void tools(std::string tools); + const tools_t tools() const { return tools_; } + + void choice(std::string choice); + const tool_choice_t & choice() const { return tool_choice_; } + + private: + tools_t tools_; + tool_choice_t tool_choice_; + }; +} From a9e3404a4c886628303cffe92e0acf1280a03023 Mon Sep 17 00:00:00 2001 From: Mason M Date: Sat, 15 Feb 2025 16:19:14 -0400 Subject: [PATCH 21/69] Separate tool-call from template application --- common/common.cpp | 114 +++++++++++++++++------------------------ common/common.h | 22 ++++---- examples/main/main.cpp | 29 ++++++++++- 3 files changed, 84 insertions(+), 81 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index eff848c03d1cc..63d4b83682a63 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1768,42 +1768,18 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto return text; } -// -// Chat template utils -// - -bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { - if (use_jinja) { - try { - auto chat_template = common_chat_template(tmpl, "", ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - common_chat_params_init(chat_template, inputs); - return true; - } catch (const std::exception & e) { - LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); - return false; - } - } - llama_chat_message chat[] = {{"user", "test"}}; - const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); - return res >= 0; -} - -static void copy_chat_params(const common_chat_params & src, toolcall::sampling_updater * update_sparams) +void common_chat_grammar_to_sampler(const common_chat_params * src, + const llama_vocab * vocab, + common_params_sampling * sparams) { - GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab); + GGML_ASSERT(src && vocab && sparams); - auto & dst = *update_sparams->sparams; - auto vocab = update_sparams->vocab; + auto & dst = *sparams; - dst.grammar = src.grammar; - dst.grammar_lazy = src.grammar_lazy; + dst.grammar = src->grammar; + dst.grammar_lazy = src->grammar_lazy; - for (const auto & trigger : src.grammar_triggers) { + for (const auto & trigger : src->grammar_triggers) { auto ids = common_tokenize(vocab, trigger.word, false, true); if (ids.size() == 1) { @@ -1816,7 +1792,7 @@ static void copy_chat_params(const common_chat_params & src, toolcall::sampling_ dst.grammar_trigger_words.push_back(trigger); } - for (const auto & preserved : src.preserved_tokens) { + for (const auto & preserved : src->preserved_tokens) { auto ids = common_tokenize(vocab, preserved, false, true); if (ids.size() == 1) { LOG_DBG("Preserved token: %d\n", ids[0]); @@ -1831,19 +1807,45 @@ static void copy_chat_params(const common_chat_params & src, toolcall::sampling_ } } + +// +// Chat template utils +// + +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { + if (use_jinja) { + try { + auto chat_template = common_chat_template(tmpl, "", ""); + common_chat_inputs inputs; + inputs.messages = json::array({{ + {"role", "user"}, + {"content", "test"}, + }}); + common_chat_params_init(chat_template, inputs); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } + llama_chat_message chat[] = {{"user", "test"}}; + const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); + return res >= 0; +} + std::string common_chat_apply_template( const common_chat_templates & tmpl, const std::vector & msgs, bool add_ass, bool use_jinja, - toolcall::handler::ptr handler, - toolcall::sampling_updater * update_sparams) + const common_chat_inputs * inputs_, + common_chat_params * out_params) { - bool use_tool_template = (use_jinja && handler != nullptr) && tmpl.template_tool_use; + bool use_tool_template = use_jinja && tmpl.template_tool_use; const auto & tmpl_selected = use_tool_template ? *tmpl.template_tool_use : *tmpl.template_default; if (use_jinja) { - common_chat_inputs inputs; + common_chat_inputs inputs = inputs_ ? *inputs_ : common_chat_inputs(); auto messages = json::array(); for (const auto & msg : msgs) { @@ -1852,35 +1854,11 @@ std::string common_chat_apply_template( inputs.messages = messages; inputs.add_generation_prompt = add_ass; - if (handler != nullptr) { - auto choice = handler->tool_choice(); - if (std::holds_alternative(choice)) { - inputs.tool_choice = std::get(choice); - - } else { - auto choice_ptr = std::get(choice); - if (choice_ptr != nullptr) { - inputs.tool_choice = *choice_ptr; - } - } - - inputs.tools = handler->tool_list(); - } - auto chat_params = common_chat_params_init(tmpl_selected, inputs); - if (update_sparams) { - copy_chat_params(chat_params, update_sparams); - } - - auto prompt = chat_params.prompt; - if (handler != nullptr) { - json response; - handler->call(prompt, response); - return response; // Caller will determine what to do based upon last_action - - } else { - return prompt; + if (out_params != nullptr) { + *out_params = chat_params; } + return chat_params.prompt; } int alloc_size = 0; @@ -1918,12 +1896,12 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - toolcall::handler::ptr handler, - toolcall::sampling_updater * update_sparams) + const common_chat_inputs * inputs, + common_chat_params * out_params) { std::ostringstream ss; auto fmt_past_msg = past_msg.empty() ? "" - : common_chat_apply_template(tmpl, past_msg, false, use_jinja, handler, update_sparams); + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, inputs); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -1932,7 +1910,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, handler, update_sparams); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, inputs, out_params); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index d00c0f23d9bfd..d1c2443b5b5b6 100644 --- a/common/common.h +++ b/common/common.h @@ -618,6 +618,13 @@ std::string common_detokenize( const std::vector & tokens, bool special = true); +struct common_chat_params; +struct common_chat_inputs; +void common_chat_grammar_to_sampler(const common_chat_params * src, + const llama_vocab * vocab, + common_params_sampling * sparams); + + // // Chat template utils // @@ -651,13 +658,6 @@ struct common_chat_templates { std::unique_ptr template_tool_use; }; -namespace toolcall { - struct sampling_updater { - common_params_sampling * sparams; - const llama_vocab * vocab; - }; -} - // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error @@ -666,8 +666,8 @@ std::string common_chat_apply_template( const std::vector & chat, bool add_ass, bool use_jinja, - toolcall::handler::ptr handler = nullptr, - toolcall::sampling_updater * update_sparams = nullptr); + const common_chat_inputs * inputs = nullptr, + common_chat_params * out_params = nullptr); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -676,8 +676,8 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - toolcall::handler::ptr handler = nullptr, - toolcall::sampling_updater * update_sparams = nullptr); + const common_chat_inputs * inputs = nullptr, + common_chat_params * out_params = nullptr); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 54fec7f5e8447..55bd4b154feb4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -1,4 +1,5 @@ #include "arg.h" +#include "chat.hpp" #include "common.h" #include "console.h" #include "log.h" @@ -273,13 +274,37 @@ int main(int argc, char ** argv) { common_chat_msg new_msg{role, content, {}}; - toolcall::sampling_updater updater{&sparams, vocab}; + common_chat_inputs cinputs; + if (handler != nullptr) { + auto choice = handler->tool_choice(); + if (std::holds_alternative(choice)) { + cinputs.tool_choice = std::get(choice); + + } else { + auto choice_ptr = std::get(choice); + if (choice_ptr != nullptr) { + cinputs.tool_choice = *choice_ptr; + } + } + cinputs.tools = handler->tool_list(); + } + + common_chat_params cparams; auto formatted = common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja, - handler, &updater); + &cinputs, &cparams); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); + + if (g_params->use_jinja) { + common_chat_grammar_to_sampler(&cparams, vocab, &sparams); + if (handler != nullptr) { + json response; + handler->call(formatted, response); + return std::string(response); + } + } return formatted; }; From 7b93c318a86ce5df386e7b07ea5628db7d8ebb31 Mon Sep 17 00:00:00 2001 From: Mason M Date: Sat, 15 Feb 2025 16:53:41 -0400 Subject: [PATCH 22/69] Add noreturn to stdio transport methods --- common/toolcall/mcp_stdio_transport.cpp | 6 +++--- common/toolcall/mcp_stdio_transport.hpp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/common/toolcall/mcp_stdio_transport.cpp b/common/toolcall/mcp_stdio_transport.cpp index c46002026c6bd..74b02fc5d8fd9 100644 --- a/common/toolcall/mcp_stdio_transport.cpp +++ b/common/toolcall/mcp_stdio_transport.cpp @@ -8,14 +8,14 @@ toolcall::mcp_stdio_transport::mcp_stdio_transport(std::vector argv { } -void toolcall::mcp_stdio_transport::start() { +[[noreturn]] void toolcall::mcp_stdio_transport::start() { throw std::logic_error(std::string("Function not implemented: ") + __func__); } -void toolcall::mcp_stdio_transport::stop() { +[[noreturn]] void toolcall::mcp_stdio_transport::stop() { throw std::logic_error(std::string("Function not implemented: ") + __func__); } -bool toolcall::mcp_stdio_transport::send(const mcp::message_variant & /*request*/) { +[[noreturn]] bool toolcall::mcp_stdio_transport::send(const mcp::message_variant & /*request*/) { throw std::logic_error(std::string("Function not implemented: ") + __func__); } diff --git a/common/toolcall/mcp_stdio_transport.hpp b/common/toolcall/mcp_stdio_transport.hpp index db40687a5360f..94742323d02f6 100644 --- a/common/toolcall/mcp_stdio_transport.hpp +++ b/common/toolcall/mcp_stdio_transport.hpp @@ -11,9 +11,9 @@ namespace toolcall public: mcp_stdio_transport(std::vector argv); - virtual void start() override; - virtual void stop() override; - virtual bool send(const mcp::message_variant & request) override; + [[noreturn]] virtual void start() override; + [[noreturn]] virtual void stop() override; + [[noreturn]] virtual bool send(const mcp::message_variant & request) override; private: std::vector argv_; From 60bca9c25a8bde1e4eb58cdcd4357e3990274948 Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 19 Feb 2025 11:08:11 -0400 Subject: [PATCH 23/69] Squashed commit of the following: commit 98c4a8d7f056da5c16a875fd33c68ad5bfb16a71 Author: Mason M Date: Wed Feb 19 11:05:18 2025 -0400 Refactor MCP transport callback mechanism commit 2b3d1f6b994eec33fdd0a083032634a0f307804c Author: Mason M Date: Tue Feb 18 18:01:54 2025 -0400 add message_to_json function commit 3c7ae27bae00f572ef746042b4ffe21309f101ad Author: Mason M Date: Tue Feb 18 15:32:26 2025 -0400 Implement send routine commit 9cec1e01c4df6955958e874cdc6b66260dc4d231 Author: Mason M Date: Tue Feb 18 10:12:17 2025 -0400 Fix include paths commit b5642f0514edd84c210a8fe7936d2809bf05872f Author: Mason M Date: Tue Feb 18 09:56:52 2025 -0400 Use log API commit 7a83b2b563d908cc4ef4eda43ee237cebeb33c08 Author: Mason M Date: Mon Feb 17 19:32:48 2025 -0400 Fix build errors commit cc7fd66fa9b87181443dda6abc02f5d03b99cd92 Author: Mason M Date: Mon Feb 17 19:03:43 2025 -0400 Use condition variable to wait for endpoint event commit 73ccdd10f55a3d378f619074947562b9342900d3 Author: Mason M Date: Mon Feb 17 17:18:09 2025 -0400 Process SSE data asynchronously commit e9c37a36f539af3732e2655889b87f408eab8bdc Author: Mason M Date: Mon Feb 17 14:01:56 2025 -0400 Add keep-alive header to sse handler commit 57f84e653de77fff8be6ed0247fc11d7e3481700 Author: Mason M Date: Mon Feb 17 13:37:59 2025 -0400 Add methods for handling endpoint/message events commit 5c160f6badacece8106b23bde3ee94f157498050 Author: Mason M Date: Mon Feb 17 13:25:18 2025 -0400 Process sse values commit f51b493c34f2d6d174968fd8ac03eb6cad89d09f Author: Mason M Date: Mon Feb 17 13:07:38 2025 -0400 Clean up sse_read algorithm commit 29d6875585ec5bfed92eb21d5ec4714e02b6b9f8 Author: Mason M Date: Sun Feb 16 19:04:39 2025 -0400 WIP: implementing SSE protocol --- common/toolcall/handler.cpp | 2 +- common/toolcall/handler.hpp | 2 +- common/toolcall/mcp_messages.cpp | 16 ++ common/toolcall/mcp_messages.hpp | 14 +- common/toolcall/mcp_sse_transport.cpp | 249 +++++++++++++++++++++++- common/toolcall/mcp_sse_transport.hpp | 33 +++- common/toolcall/mcp_stdio_transport.cpp | 2 +- common/toolcall/mcp_stdio_transport.hpp | 2 +- common/toolcall/mcp_transport.hpp | 56 +++++- common/toolcall/params.hpp | 2 +- 10 files changed, 356 insertions(+), 22 deletions(-) diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp index 73d216aff68df..c0edeb5355379 100644 --- a/common/toolcall/handler.cpp +++ b/common/toolcall/handler.cpp @@ -1,5 +1,5 @@ -#include "../json.hpp" +#include #include "handler.hpp" #include "params.hpp" diff --git a/common/toolcall/handler.hpp b/common/toolcall/handler.hpp index 6f7fca8302734..438b9e5b06a3a 100644 --- a/common/toolcall/handler.hpp +++ b/common/toolcall/handler.hpp @@ -1,6 +1,6 @@ #pragma once -#include "../json.hpp" +#include // TODO: remove dependence on this #include "params.hpp" // TODO: make foreward decl. #include #include diff --git a/common/toolcall/mcp_messages.cpp b/common/toolcall/mcp_messages.cpp index 424c01bb7e558..0919e42d87f10 100644 --- a/common/toolcall/mcp_messages.cpp +++ b/common/toolcall/mcp_messages.cpp @@ -256,3 +256,19 @@ void mcp::tools_list_response::refreshResult() { this->result(result); } + +static bool has_initialized_response(const nlohmann::json & data) { + return data["result"].contains("serverInfo"); +} + +bool mcp::create_message(const std::string & data, mcp::message_variant & message) { + json j = json::parse(data); + + if (has_initialized_response(j)) { + message = mcp::initialize_response::fromJson(j); + + } else { + return false; + } + return true; +} diff --git a/common/toolcall/mcp_messages.hpp b/common/toolcall/mcp_messages.hpp index a96e3b50668c9..ab68972522e71 100644 --- a/common/toolcall/mcp_messages.hpp +++ b/common/toolcall/mcp_messages.hpp @@ -2,7 +2,7 @@ #include #include #include -#include "../json.hpp" +#include namespace mcp { @@ -214,7 +214,13 @@ namespace mcp std::string next_cursor_; }; - using message_variant = std::variant< - initialize_request, initialize_response, initialized_notification, - tools_list_request, tools_list_response>; + using message_variant = + std::variant; + + bool create_message(const std::string & data, message_variant & message); } diff --git a/common/toolcall/mcp_sse_transport.cpp b/common/toolcall/mcp_sse_transport.cpp index 0336175ab2aed..1751fb8f3d009 100644 --- a/common/toolcall/mcp_sse_transport.cpp +++ b/common/toolcall/mcp_sse_transport.cpp @@ -1,17 +1,260 @@ +#include +#include #include "mcp_sse_transport.hpp" +#include +#include + +toolcall::mcp_sse_transport::~mcp_sse_transport() { + if (endpoint_headers_) { + curl_slist_free_all(endpoint_headers_); + } + if (endpoint_) { + curl_easy_cleanup(endpoint_); + } +} toolcall::mcp_sse_transport::mcp_sse_transport(std::string server_uri) - : server_uri_(std::move(server_uri)) + : server_uri_(std::move(server_uri)), + running_(false), + sse_thread_(), + endpoint_(nullptr), + endpoint_headers_(nullptr), + endpoint_errbuf_(CURL_ERROR_SIZE), + event_{"", "", ""}, + sse_buffer_(""), + sse_cursor_(0), + sse_last_id_(""), + initializing_mutex_(), + initializing_() { + curl_global_init(CURL_GLOBAL_DEFAULT); } void toolcall::mcp_sse_transport::start() { + if (running_) return; + running_ = true; + + std::unique_lock lock(initializing_mutex_); + sse_thread_ = std::thread(&toolcall::mcp_sse_transport::sse_run, this); + initializing_.wait(lock); + + if (endpoint_ == nullptr) { + running_ = false; + LOG_ERR("SSE: Connection to \"%s\" failed", server_uri_.c_str()); + throw std::runtime_error("Connection to \"" + server_uri_ + "\" failed"); + } } void toolcall::mcp_sse_transport::stop() { + running_ = false; +} + +bool toolcall::mcp_sse_transport::send(const std::string & request_json) { + if (! running_ || endpoint_ == nullptr) { + return false; + } + + curl_easy_setopt(endpoint_, CURLOPT_POSTFIELDS, request_json.c_str()); + + CURLcode code = curl_easy_perform(endpoint_); + if (code != CURLE_OK) { + size_t len = strlen(&endpoint_errbuf_[0]); + LOG_ERR("%s", (len > 0 ? &endpoint_errbuf_[0] : curl_easy_strerror(code))); + return false; + } + return true; +} + +static size_t sse_callback(char * data, size_t size, size_t nmemb, void * clientp) { + auto transport = static_cast(clientp); + size_t len = size * nmemb; + return transport->sse_read(data, len); } -bool toolcall::mcp_sse_transport::send(const mcp::message_variant & /*request*/) { - return false; +void toolcall::mcp_sse_transport::parse_field_value(std::string field, std::string value) { + if (field == "event") { + // Set the event type buffer to field value. + event_.type = std::move(value); + + } else if (field == "data") { + // Append the field value to the data buffer, + // then append a single U+000A LINE FEED (LF) + // character to the data buffer. + value += '\n'; + event_.data.insert(event_.data.end(), value.begin(), value.end()); + + } else if (field == "id") { + // If the field value does not contain U+0000 NULL, + // then set the last event ID buffer to the field value. + // Otherwise, ignore the field. + if (! value.empty()) { + event_.id = std::move(value); + } + + } else if (field == "retry") { + // If the field value consists of only ASCII digits, + // then interpret the field value as an integer in base + // ten, and set the event stream's reconnection time to + // that integer. Otherwise, ignore the field. + + LOG_INF("SSE: Retry field is not currently implemented"); + + } else { + LOG_WRN("SSE: Unsupported field \"%s\" received", field.c_str()); + } +} + +void toolcall::mcp_sse_transport::on_endpoint_event() { + endpoint_ = curl_easy_init(); + if (! endpoint_) { + LOG_ERR("SSE: Failed to create endpoint handle"); + running_ = false; + return; + } + + curl_easy_setopt(endpoint_, CURLOPT_URL, event_.data.c_str()); + + endpoint_headers_ = + curl_slist_append(endpoint_headers_, "Content-Type: application/json"); + curl_slist_append(endpoint_headers_, "Connection: keep-alive"); + curl_easy_setopt(endpoint_, CURLOPT_HTTPHEADER, endpoint_headers_); + curl_easy_setopt(endpoint_, CURLOPT_ERRORBUFFER, &endpoint_errbuf_[0]); + + // Later calls to send will reuse the endpoint_ handle +} + +void toolcall::mcp_sse_transport::on_message_event() { + mcp::message_variant message; + if (mcp::create_message(event_.data, message)) { + notify_if(message); + notify_if(message); + } +} + +size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { + sse_buffer_.insert(sse_buffer_.end(), data, data + len); + + for (; sse_cursor_ < sse_buffer_.length(); ++sse_cursor_) { + if (sse_buffer_[sse_cursor_] == '\r' || sse_buffer_[sse_cursor_] == '\n') { + auto last = sse_buffer_.begin() + sse_cursor_; + + std::string line(sse_buffer_.begin(), last); + if (line.empty()) { // Dispatch event + if (event_.type == "endpoint") { + on_endpoint_event(); + + } else if(event_.type == "message") { + on_message_event(); + + } else { + LOG_WRN("SSE: Unsupported event \"%s\" received", event_.type.c_str()); + } + + sse_last_id_ = event_.id; + event_ = {"", "", ""}; + + } else if(line[0] != ':') { // : denotes a comment + // Set field/value + auto sep_index = line.find(':'); + if (sep_index != std::string::npos) { + auto sep_i = line.begin() + sep_index; + + std::string field (line.begin(), sep_i); + std::string value (sep_i + 1, line.end()); + + parse_field_value(std::move(field), std::move(value)); + } + } + + if (last++ != sse_buffer_.end()) { // Consume line-end + if (*last == '\n') { + last ++; // In the CRLF case consume one more + } + sse_buffer_ = std::string(last, sse_buffer_.end()); + + } else { + sse_buffer_.clear(); + } + sse_cursor_ = 0; // Prepare to scan for next line-end + } + } + return len; +} + +void toolcall::mcp_sse_transport::sse_run() { + std::unique_lock lock(initializing_mutex_); + char errbuf[CURL_ERROR_SIZE]; + size_t errlen; + CURLMcode mcode; + int num_handles; + struct CURLMsg * m; + int msgs_in_queue = 0; + CURLM * async_handle = nullptr; + struct curl_slist * headers = nullptr; + CURL * sse = nullptr; + + sse = curl_easy_init(); + if (! sse) { + LOG_ERR("SSE: Failed to initialize handle"); + goto cleanup; + } + + headers = curl_slist_append(headers, "Connection: keep-alive"); + + curl_easy_setopt(sse, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(sse, CURLOPT_ERRORBUFFER, errbuf); + curl_easy_setopt(sse, CURLOPT_URL, server_uri_.c_str()); + curl_easy_setopt(sse, CURLOPT_TCP_KEEPALIVE, 1L); + curl_easy_setopt(sse, CURLOPT_WRITEFUNCTION, sse_callback); + curl_easy_setopt(sse, CURLOPT_WRITEDATA, this); + + async_handle = curl_multi_init(); + if (! async_handle) { + LOG_ERR("SSE: Failed to initialize async handle"); + goto cleanup; + } + curl_multi_add_handle(async_handle, sse); + + do { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + mcode = curl_multi_perform(async_handle, &num_handles); + if (mcode != CURLM_OK) { + LOG_ERR("SSE: %s", curl_multi_strerror(mcode)); + break; + } + while ((m = curl_multi_info_read(async_handle, &msgs_in_queue)) != nullptr) { + if (m->msg == CURLMSG_DONE) { + if (m->data.result != CURLE_OK) { + errlen = strlen(errbuf); + if (errlen) { + LOG_ERR("SSE: %s", errbuf); + + } else { + LOG_ERR("SSE: %s", curl_easy_strerror(m->data.result)); + } + running_ = false; + break; + } + } + } + if (endpoint_ && lock.owns_lock()) { // TODO: timeout if endpoint not received + lock.unlock(); + initializing_.notify_one(); + } + + } while (running_); + + cleanup: + if (headers) { + curl_slist_free_all(headers); + } + if (async_handle) { + curl_multi_remove_handle(async_handle, sse); + curl_multi_cleanup(async_handle); + } + if (sse) { + curl_easy_cleanup(sse); + } } diff --git a/common/toolcall/mcp_sse_transport.hpp b/common/toolcall/mcp_sse_transport.hpp index cc44706d83451..a46963d44c2d7 100644 --- a/common/toolcall/mcp_sse_transport.hpp +++ b/common/toolcall/mcp_sse_transport.hpp @@ -1,18 +1,49 @@ #pragma once #include "mcp_transport.hpp" +#include +#include +#include +#include namespace toolcall { class mcp_sse_transport : public mcp_transport { public: + ~mcp_sse_transport(); + mcp_sse_transport(std::string server_uri); virtual void start() override; virtual void stop() override; - virtual bool send(const mcp::message_variant & request) override; + virtual bool send(const std::string & request_json) override; + + size_t sse_read(const char * data, size_t len); private: + void sse_run(); + void parse_field_value(std::string field, std::string value); + void on_endpoint_event(); + void on_message_event(); + std::string server_uri_; + bool running_; + std::thread sse_thread_; + CURL * endpoint_; + struct curl_slist * endpoint_headers_; + std::vector endpoint_errbuf_; + + struct sse_event { + std::string type; + std::string data; + std::string id; + } event_; + + std::string sse_buffer_; + size_t sse_cursor_; + std::string sse_last_id_; + + std::mutex initializing_mutex_; + std::condition_variable initializing_; }; } diff --git a/common/toolcall/mcp_stdio_transport.cpp b/common/toolcall/mcp_stdio_transport.cpp index 74b02fc5d8fd9..c90f8091dbe2b 100644 --- a/common/toolcall/mcp_stdio_transport.cpp +++ b/common/toolcall/mcp_stdio_transport.cpp @@ -16,6 +16,6 @@ toolcall::mcp_stdio_transport::mcp_stdio_transport(std::vector argv throw std::logic_error(std::string("Function not implemented: ") + __func__); } -[[noreturn]] bool toolcall::mcp_stdio_transport::send(const mcp::message_variant & /*request*/) { +[[noreturn]] bool toolcall::mcp_stdio_transport::send(const std::string & /*request_json*/) { throw std::logic_error(std::string("Function not implemented: ") + __func__); } diff --git a/common/toolcall/mcp_stdio_transport.hpp b/common/toolcall/mcp_stdio_transport.hpp index 94742323d02f6..ea16700be10b4 100644 --- a/common/toolcall/mcp_stdio_transport.hpp +++ b/common/toolcall/mcp_stdio_transport.hpp @@ -13,7 +13,7 @@ namespace toolcall [[noreturn]] virtual void start() override; [[noreturn]] virtual void stop() override; - [[noreturn]] virtual bool send(const mcp::message_variant & request) override; + [[noreturn]] virtual bool send(const std::string & request_json) override; private: std::vector argv_; diff --git a/common/toolcall/mcp_transport.hpp b/common/toolcall/mcp_transport.hpp index ecb2bf82bb3e0..84324de8a877a 100644 --- a/common/toolcall/mcp_transport.hpp +++ b/common/toolcall/mcp_transport.hpp @@ -1,22 +1,60 @@ #pragma once #include "mcp_messages.hpp" +#include +#include +#include namespace toolcall { - class mcp_transport { + template + using callback = std::function; + + template + class mcp_transport_t { public: - using on_message_callback = std::function; + template + void subscribe(callback callback) { + auto& vec = std::get>>(subscribers_); + vec.push_back(std::move(callback)); + } + + template + void notify(const T & message) const { + const auto& vec = std::get>>(subscribers_); + for (const auto& callback : vec) { + callback(message); + } + } + + template + void notify_if(const mcp::message_variant & message) { + if (std::holds_alternative(message)) { + notify(std::get(message)); + } + } + + template + bool send(const T & message) { + return static_cast(this)->send(message.toJson()); + } + + private: + std::tuple>...> subscribers_; + }; + + class mcp_transport : public mcp_transport_t + { + public: virtual ~mcp_transport() = default; virtual void start() = 0; virtual void stop() = 0; - virtual bool send(const mcp::message_variant & request) = 0; - - void on_received(on_message_callback callback) { callback_ = std::move(callback); } - const on_message_callback & on_received() const { return callback_; } - - protected: - on_message_callback callback_; + virtual bool send(const std::string & request_json) = 0; }; } diff --git a/common/toolcall/params.hpp b/common/toolcall/params.hpp index fbe50aa51202a..5224078398cb4 100644 --- a/common/toolcall/params.hpp +++ b/common/toolcall/params.hpp @@ -4,7 +4,7 @@ #include #include -#include "../json.hpp" // TODO: switch to foreward decl. +#include // TODO: switch to foreward decl. // namespace nlohmann { class ordered_json; } namespace toolcall From f2af859904455f37ef8a187e2c4a8b7d976948d7 Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 19 Feb 2025 14:00:14 -0400 Subject: [PATCH 24/69] Post-Merge refactoring --- common/chat.cpp | 12 ++++++--- common/chat.h | 4 ++- common/common.cpp | 1 + common/toolcall/handler.cpp | 36 ++++++++++---------------- common/toolcall/handler.hpp | 41 +++++++++++++---------------- common/toolcall/params.cpp | 51 +++++++++++++++---------------------- common/toolcall/params.hpp | 18 +++++-------- examples/main/main.cpp | 23 +++++++++-------- 8 files changed, 85 insertions(+), 101 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 9ebe4c5784cbc..60ab532b31117 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -291,9 +291,11 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja) { + bool use_jinja, + const struct common_chat_templates_inputs * input_extra, + struct common_chat_params * out_params) { - common_chat_templates_inputs inputs; + common_chat_templates_inputs inputs = input_extra ? *input_extra : common_chat_templates_inputs(); inputs.use_jinja = use_jinja; std::string fmt_past_msg; @@ -310,9 +312,13 @@ std::string common_chat_format_single( // format chat with new_msg inputs.messages.push_back(new_msg); inputs.add_generation_prompt = add_ass; - auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; + auto chat_params = common_chat_templates_apply(tmpls, inputs); + auto fmt_new_msg = chat_params.prompt; // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + if (out_params) { + *out_params = std::move(chat_params); + } return ss.str(); } diff --git a/common/chat.h b/common/chat.h index e77bef82b9edd..a18d83d09d240 100644 --- a/common/chat.h +++ b/common/chat.h @@ -111,7 +111,9 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja); + bool use_jinja, + const struct common_chat_templates_inputs * input_extra = nullptr, + struct common_chat_params * out_params = nullptr); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/common/common.cpp b/common/common.cpp index 727b12ec99d9b..5d6e867ee153b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,6 +12,7 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" +#include "chat.h" #include #include diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp index c0edeb5355379..ccf1b34003c1e 100644 --- a/common/toolcall/handler.cpp +++ b/common/toolcall/handler.cpp @@ -1,7 +1,6 @@ #include #include "handler.hpp" -#include "params.hpp" #ifdef LLAMA_USE_CURL # include "mcp_sse_transport.hpp" @@ -9,41 +8,34 @@ #include "mcp_stdio_transport.hpp" -using json = toolcall::json; +using json = nlohmann::json; std::shared_ptr toolcall::create_handler(const toolcall::params & params) { - std::shared_ptr result; + std::shared_ptr handler; auto tools = params.tools(); auto choice = params.choice(); - bool has_uri = std::holds_alternative(tools); - if (has_uri) { + if (params.has_uri()) { #ifdef LLAMA_USE_CURL - auto tools_str = std::get(tools); - if (! tools_str.empty()) { - result.reset(new toolcall::handler(std::make_unique(tools_str, choice))); - } + handler.reset(new toolcall::handler(std::make_unique(tools, choice))); #endif } else { - auto tools_ptr = std::get(tools); - if (tools_ptr != nullptr) { - result.reset(new toolcall::handler(std::make_unique(*tools_ptr, choice))); - } + handler.reset(new toolcall::handler(std::make_unique(tools, choice))); } - return result; + return handler; } -json toolcall::handler::tool_list() { +std::string toolcall::handler::tool_list() { return impl_->tool_list(); } -toolcall::action toolcall::handler::call(const json & request, json & response) { +toolcall::action toolcall::handler::call(const std::string & request, std::string & response) { last_action_ = impl_->call(request, response); return last_action_; } -const toolcall::tool_choice_t & toolcall::handler::tool_choice() const { +const std::string & toolcall::handler::tool_choice() const { return impl_->tool_choice(); } toolcall::action toolcall::handler::last_action() const { @@ -51,32 +43,32 @@ toolcall::action toolcall::handler::last_action() const { } #ifdef LLAMA_USE_CURL -toolcall::mcp_impl::mcp_impl(std::string server_uri, tool_choice_t tool_choice) +toolcall::mcp_impl::mcp_impl(std::string server_uri, std::string tool_choice) : handler_impl(tool_choice), transport_(new mcp_sse_transport(server_uri)) { transport_->start(); } #else -toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, tool_choice_t tool_choice) +toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, std::string tool_choice) : handler_impl(tool_choice) { } #endif -toolcall::mcp_impl::mcp_impl(std::vector argv, tool_choice_t tool_choice) +toolcall::mcp_impl::mcp_impl(std::vector argv, std::string tool_choice) : handler_impl(tool_choice), transport_(new mcp_stdio_transport(argv)) { transport_->start(); } -json toolcall::mcp_impl::tool_list() { +std::string toolcall::mcp_impl::tool_list() { // Construct tools/list call and send to transport return json{};// TODO } -toolcall::action toolcall::mcp_impl::call(const json & /*request*/, json & /*response*/) { +toolcall::action toolcall::mcp_impl::call(const std::string & /*request*/, std::string & /*response*/) { // Construct tool call and send to transport return toolcall::ACCEPT; // TODO } diff --git a/common/toolcall/handler.hpp b/common/toolcall/handler.hpp index 438b9e5b06a3a..8d36c79f3c5e4 100644 --- a/common/toolcall/handler.hpp +++ b/common/toolcall/handler.hpp @@ -1,18 +1,13 @@ #pragma once -#include // TODO: remove dependence on this -#include "params.hpp" // TODO: make foreward decl. +#include "params.hpp" #include #include #include +#include namespace toolcall { - using json = nlohmann::ordered_json; - using json_ptr = std::shared_ptr; - using tools_t = std::variant; - using tool_choice_t = std::variant; - enum action { ACCEPT, PENDING, @@ -26,9 +21,9 @@ namespace toolcall handler(std::unique_ptr impl) : impl_(std::move(impl)) {} - json tool_list(); - action call(const json & request, json & response); - const tool_choice_t & tool_choice() const; + std::string tool_list(); + action call(const std::string & request, std::string & response); + const std::string & tool_choice() const; action last_action() const; private: @@ -40,45 +35,45 @@ namespace toolcall class handler_impl { public: - handler_impl(tool_choice_t tool_choice) + handler_impl(std::string tool_choice) : tool_choice_(std::move(tool_choice)) {} virtual ~handler_impl() = default; - virtual json tool_list() = 0; - virtual action call(const json & request, json & response) = 0; + virtual std::string tool_list() = 0; + virtual action call(const std::string & request, std::string & response) = 0; - const tool_choice_t & tool_choice() const { return tool_choice_; } + const std::string & tool_choice() const { return tool_choice_; } protected: - tool_choice_t tool_choice_; + std::string tool_choice_; }; class loopback_impl : public handler_impl { public: - loopback_impl(json tools, tool_choice_t tool_choice) + loopback_impl(std::string tools, std::string tool_choice) : handler_impl(tool_choice), tools_(std::move(tools)) {} - virtual json tool_list() override { + virtual std::string tool_list() override { return tools_; } - virtual action call(const json & request, json & response) override { + virtual action call(const std::string & request, std::string & response) override { response = request; return toolcall::DEFER; } private: - json tools_; + std::string tools_; }; class mcp_transport; class mcp_impl : public handler_impl { public: - mcp_impl(std::string server_uri, tool_choice_t tool_choice); - mcp_impl(std::vector argv, tool_choice_t tool_choice); + mcp_impl(std::string server_uri, std::string tool_choice); + mcp_impl(std::vector argv, std::string tool_choice); - virtual json tool_list() override; - virtual action call(const json & request, json & response) override; + virtual std::string tool_list() override; + virtual action call(const std::string & request, std::string & response) override; private: std::unique_ptr transport_; diff --git a/common/toolcall/params.cpp b/common/toolcall/params.cpp index bc30421085192..654f4d4210a46 100644 --- a/common/toolcall/params.cpp +++ b/common/toolcall/params.cpp @@ -1,9 +1,9 @@ -#include "../json.hpp" // Must come before params due to forward decl. #include "params.hpp" #include +#include -using json = nlohmann::ordered_json; +using json = nlohmann::json; static bool starts_with(const std::string & str, const std::string & prefix) { return str.size() >= prefix.size() @@ -17,25 +17,25 @@ toolcall::params::params(std::string tools, std::string choice) { void toolcall::params::tools(std::string tools) { try { - if (tools.empty()) { - tools_ = std::move(tools); - - } else if (starts_with(tools, "mcp+http")) { -#ifdef LLAMA_USE_CURL - tools_ = std::move(tools); -#else - throw std::invalid_argument( - "Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl"); -#endif - } else { - tools_ = std::make_shared(json::parse(tools)); - auto tools_ptr = std::get>(tools_); - if (! tools_ptr->is_array()) { + if (! tools.empty()) { + if (starts_with(tools, "mcp+http")) { +#ifndef LLAMA_USE_CURL throw std::invalid_argument( - "tools must be a URL of the form \"mcp+http(s)://hostname[:port]/\"" - ", or a valid JSON array containing tool definitions"); + "Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl"); +#endif + has_uri_ = true; + + } else { + json j = json::parse(tools); // Just for early validation + if (! j.is_array()) { + throw std::invalid_argument( + "tools must be a URL of the form \"mcp+http(s)://hostname[:port]/\"" + ", or a valid JSON array containing tool definitions"); + } + has_uri_ = false; } } + tools_ = std::move(tools); } catch (const json::exception & err) { throw std::invalid_argument(err.what()); @@ -48,12 +48,8 @@ void toolcall::params::choice(std::string choice) { tool_choice_ = std::move(choice); } else { - auto choice_ptr = std::make_shared(json::parse(choice)); - tool_choice_ = choice_ptr; - if (! choice_ptr->is_object()) { - throw std::invalid_argument( - "tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\""); - } + throw std::invalid_argument( + "tool choice must be set to \"auto\", \"required\", or \"none\""); } } catch (const json::exception & err) { @@ -62,10 +58,5 @@ void toolcall::params::choice(std::string choice) { } toolcall::params::operator bool() const { - if (std::holds_alternative(tools_)) { - return ! std::get(tools_).empty(); - - } else { - return std::get(tools_) != nullptr; - } + return ! tools_.empty(); } diff --git a/common/toolcall/params.hpp b/common/toolcall/params.hpp index 5224078398cb4..d8ff16e3a0fcf 100644 --- a/common/toolcall/params.hpp +++ b/common/toolcall/params.hpp @@ -4,17 +4,10 @@ #include #include -#include // TODO: switch to foreward decl. -// namespace nlohmann { class ordered_json; } - namespace toolcall { class params { public: - using json_ptr = std::shared_ptr; - using tools_t = std::variant; - using tool_choice_t = std::variant; - params(std::string tools = "", std::string choice = "auto"); params(const params & other) = default; @@ -25,13 +18,16 @@ namespace toolcall operator bool() const; void tools(std::string tools); - const tools_t tools() const { return tools_; } + const std::string & tools() const { return tools_; } void choice(std::string choice); - const tool_choice_t & choice() const { return tool_choice_; } + const std::string & choice() const { return tool_choice_; } + + bool has_uri() const { return has_uri_; } private: - tools_t tools_; - tool_choice_t tool_choice_; + std::string tools_; + std::string tool_choice_; + bool has_uri_; }; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d01015c896fcd..3a4a5997f87eb 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -1,5 +1,4 @@ #include "arg.h" -#include "chat.hpp" #include "common.h" #include "console.h" #include "log.h" @@ -276,19 +275,21 @@ int main(int argc, char ** argv) { new_msg.role = role; new_msg.content = content; - common_chat_inputs cinputs; + common_chat_templates_inputs cinputs; if (handler != nullptr) { auto choice = handler->tool_choice(); - if (std::holds_alternative(choice)) { - cinputs.tool_choice = std::get(choice); + if (choice == "auto") { + cinputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; - } else { - auto choice_ptr = std::get(choice); - if (choice_ptr != nullptr) { - cinputs.tool_choice = *choice_ptr; - } + } else if (choice == "required") { + cinputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + } else if (choice == "none") { + cinputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; } - cinputs.tools = handler->tool_list(); + + // TODO + //cinputs.tools = handler->tool_list(); } common_chat_params cparams; @@ -302,7 +303,7 @@ int main(int argc, char ** argv) { if (g_params->use_jinja) { common_chat_grammar_to_sampler(&cparams, vocab, &sparams); if (handler != nullptr) { - json response; + std::string response; handler->call(formatted, response); return std::string(response); } From 90efb9095dc56554164bd966f0c5b5fae497bc18 Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 19 Feb 2025 16:03:49 -0400 Subject: [PATCH 25/69] Rearrange the furniture! --- common/CMakeLists.txt | 17 ++++++++++------- common/common.h | 4 +--- common/toolcall/handler.cpp | 6 +++--- common/toolcall/{handler.hpp => handler.h} | 2 +- common/toolcall/mcp_messages.cpp | 2 +- .../{mcp_messages.hpp => mcp_messages.h} | 0 common/toolcall/mcp_sse_transport.cpp | 2 +- ...cp_sse_transport.hpp => mcp_sse_transport.h} | 2 +- common/toolcall/mcp_stdio_transport.cpp | 2 +- ...tdio_transport.hpp => mcp_stdio_transport.h} | 2 +- .../{mcp_transport.hpp => mcp_transport.h} | 2 +- common/toolcall/params.cpp | 2 +- common/toolcall/{params.hpp => params.h} | 0 13 files changed, 22 insertions(+), 21 deletions(-) rename common/toolcall/{handler.hpp => handler.h} (98%) rename common/toolcall/{mcp_messages.hpp => mcp_messages.h} (100%) rename common/toolcall/{mcp_sse_transport.hpp => mcp_sse_transport.h} (97%) rename common/toolcall/{mcp_stdio_transport.hpp => mcp_stdio_transport.h} (93%) rename common/toolcall/{mcp_transport.hpp => mcp_transport.h} (98%) rename common/toolcall/{params.hpp => params.h} (100%) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index d335d91fff912..a2e2310157faf 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -76,16 +76,14 @@ add_library(${TARGET} STATIC speculative.cpp speculative.h ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/params.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/params.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/params.h ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.h ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_transport.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_sse_transport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_sse_transport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.h + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_transport.h ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_stdio_transport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_stdio_transport.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_stdio_transport.h ) if (BUILD_SHARED_LIBS) @@ -101,6 +99,11 @@ if (LLAMA_CURL) include_directories(${CURL_INCLUDE_DIRS}) find_library(CURL_LIBRARY curl REQUIRED) set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY}) + + target_sources(${TARGET} + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_sse_transport.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_sse_transport.h) + endif () if (LLAMA_LLGUIDANCE) diff --git a/common/common.h b/common/common.h index f67eb894ea1fc..c4a2bcd858d9b 100644 --- a/common/common.h +++ b/common/common.h @@ -3,12 +3,11 @@ #pragma once #include "llama-cpp.h" -#include "toolcall/handler.hpp" +#include "toolcall/handler.h" #include #include #include #include -#include #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -619,7 +618,6 @@ std::string common_detokenize( bool special = true); struct common_chat_params; -struct common_chat_inputs; void common_chat_grammar_to_sampler(const common_chat_params * src, const llama_vocab * vocab, common_params_sampling * sparams); diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp index ccf1b34003c1e..a003f41f124e5 100644 --- a/common/toolcall/handler.cpp +++ b/common/toolcall/handler.cpp @@ -1,12 +1,12 @@ #include -#include "handler.hpp" +#include "handler.h" #ifdef LLAMA_USE_CURL -# include "mcp_sse_transport.hpp" +# include "mcp_sse_transport.h" #endif -#include "mcp_stdio_transport.hpp" +#include "mcp_stdio_transport.h" using json = nlohmann::json; diff --git a/common/toolcall/handler.hpp b/common/toolcall/handler.h similarity index 98% rename from common/toolcall/handler.hpp rename to common/toolcall/handler.h index 8d36c79f3c5e4..3c3568dd0ee09 100644 --- a/common/toolcall/handler.hpp +++ b/common/toolcall/handler.h @@ -1,6 +1,6 @@ #pragma once -#include "params.hpp" +#include "params.h" #include #include #include diff --git a/common/toolcall/mcp_messages.cpp b/common/toolcall/mcp_messages.cpp index 0919e42d87f10..41d6b7ad1f6a9 100644 --- a/common/toolcall/mcp_messages.cpp +++ b/common/toolcall/mcp_messages.cpp @@ -1,4 +1,4 @@ -#include "mcp_messages.hpp" +#include "mcp_messages.h" #include using json = nlohmann::json; diff --git a/common/toolcall/mcp_messages.hpp b/common/toolcall/mcp_messages.h similarity index 100% rename from common/toolcall/mcp_messages.hpp rename to common/toolcall/mcp_messages.h diff --git a/common/toolcall/mcp_sse_transport.cpp b/common/toolcall/mcp_sse_transport.cpp index 1751fb8f3d009..bd8230c29ee33 100644 --- a/common/toolcall/mcp_sse_transport.cpp +++ b/common/toolcall/mcp_sse_transport.cpp @@ -1,7 +1,7 @@ #include #include -#include "mcp_sse_transport.hpp" +#include "mcp_sse_transport.h" #include #include diff --git a/common/toolcall/mcp_sse_transport.hpp b/common/toolcall/mcp_sse_transport.h similarity index 97% rename from common/toolcall/mcp_sse_transport.hpp rename to common/toolcall/mcp_sse_transport.h index a46963d44c2d7..5d424818a734f 100644 --- a/common/toolcall/mcp_sse_transport.hpp +++ b/common/toolcall/mcp_sse_transport.h @@ -1,6 +1,6 @@ #pragma once -#include "mcp_transport.hpp" +#include "mcp_transport.h" #include #include #include diff --git a/common/toolcall/mcp_stdio_transport.cpp b/common/toolcall/mcp_stdio_transport.cpp index c90f8091dbe2b..d7b98c5391ec8 100644 --- a/common/toolcall/mcp_stdio_transport.cpp +++ b/common/toolcall/mcp_stdio_transport.cpp @@ -1,5 +1,5 @@ -#include "mcp_stdio_transport.hpp" +#include "mcp_stdio_transport.h" #include diff --git a/common/toolcall/mcp_stdio_transport.hpp b/common/toolcall/mcp_stdio_transport.h similarity index 93% rename from common/toolcall/mcp_stdio_transport.hpp rename to common/toolcall/mcp_stdio_transport.h index ea16700be10b4..98e6e9295e0ec 100644 --- a/common/toolcall/mcp_stdio_transport.hpp +++ b/common/toolcall/mcp_stdio_transport.h @@ -1,6 +1,6 @@ #pragma once -#include "mcp_transport.hpp" +#include "mcp_transport.h" #include #include diff --git a/common/toolcall/mcp_transport.hpp b/common/toolcall/mcp_transport.h similarity index 98% rename from common/toolcall/mcp_transport.hpp rename to common/toolcall/mcp_transport.h index 84324de8a877a..5420bd6ff926f 100644 --- a/common/toolcall/mcp_transport.hpp +++ b/common/toolcall/mcp_transport.h @@ -1,6 +1,6 @@ #pragma once -#include "mcp_messages.hpp" +#include "mcp_messages.h" #include #include #include diff --git a/common/toolcall/params.cpp b/common/toolcall/params.cpp index 654f4d4210a46..4e1f81703ba15 100644 --- a/common/toolcall/params.cpp +++ b/common/toolcall/params.cpp @@ -1,5 +1,5 @@ -#include "params.hpp" +#include "params.h" #include #include diff --git a/common/toolcall/params.hpp b/common/toolcall/params.h similarity index 100% rename from common/toolcall/params.hpp rename to common/toolcall/params.h From 78a8d906d1e26a983421ba247ba4252db1fdabae Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 19 Feb 2025 16:40:33 -0400 Subject: [PATCH 26/69] Fix input processing --- common/toolcall/handler.cpp | 13 +++++++------ examples/main/main.cpp | 15 ++------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/common/toolcall/handler.cpp b/common/toolcall/handler.cpp index a003f41f124e5..f9e8d60514292 100644 --- a/common/toolcall/handler.cpp +++ b/common/toolcall/handler.cpp @@ -15,14 +15,15 @@ std::shared_ptr toolcall::create_handler(const toolcall::para auto tools = params.tools(); auto choice = params.choice(); - if (params.has_uri()) { + if (params) { + if (params.has_uri()) { #ifdef LLAMA_USE_CURL - handler.reset(new toolcall::handler(std::make_unique(tools, choice))); + handler.reset(new toolcall::handler(std::make_unique(tools, choice))); #endif - } else { - handler.reset(new toolcall::handler(std::make_unique(tools, choice))); + } else { + handler.reset(new toolcall::handler(std::make_unique(tools, choice))); + } } - return handler; } @@ -65,7 +66,7 @@ toolcall::mcp_impl::mcp_impl(std::vector argv, std::string tool_cho std::string toolcall::mcp_impl::tool_list() { // Construct tools/list call and send to transport - return json{};// TODO + return "[]";// TODO } toolcall::action toolcall::mcp_impl::call(const std::string & /*request*/, std::string & /*response*/) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3a4a5997f87eb..c462662e9c4b7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -277,19 +277,8 @@ int main(int argc, char ** argv) { common_chat_templates_inputs cinputs; if (handler != nullptr) { - auto choice = handler->tool_choice(); - if (choice == "auto") { - cinputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; - - } else if (choice == "required") { - cinputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - } else if (choice == "none") { - cinputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; - } - - // TODO - //cinputs.tools = handler->tool_list(); + cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(handler->tool_choice()); + cinputs.tools = common_chat_tools_parse_oaicompat(handler->tool_list()); } common_chat_params cparams; From 4d81086ef2467cfdcd9c0404a7e42bc76c50e1cd Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 19 Feb 2025 17:09:15 -0400 Subject: [PATCH 27/69] Clean up some header inclusions --- common/common.h | 2 +- common/toolcall/params.h | 2 -- examples/main/main.cpp | 1 + 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/common/common.h b/common/common.h index c4a2bcd858d9b..72a57e5a07c00 100644 --- a/common/common.h +++ b/common/common.h @@ -3,7 +3,7 @@ #pragma once #include "llama-cpp.h" -#include "toolcall/handler.h" +#include "toolcall/params.h" #include #include #include diff --git a/common/toolcall/params.h b/common/toolcall/params.h index d8ff16e3a0fcf..302880230461f 100644 --- a/common/toolcall/params.h +++ b/common/toolcall/params.h @@ -1,8 +1,6 @@ #pragma once #include -#include -#include namespace toolcall { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c462662e9c4b7..72ff5034e3b0d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -5,6 +5,7 @@ #include "sampling.h" #include "llama.h" #include "chat.h" +#include "toolcall/handler.h" #include #include From 3b0dd4e0784214c08f3fbbfc2b08835b68caf47f Mon Sep 17 00:00:00 2001 From: Mason M Date: Thu, 20 Feb 2025 19:09:52 -0400 Subject: [PATCH 28/69] Split toolcall into separate library --- CMakeLists.txt | 7 ++ common/CMakeLists.txt | 18 ++--- common/arg.cpp | 7 +- common/common.h | 10 ++- examples/main/main.cpp | 66 +++++++++++-------- toolcall/CMakeLists.txt | 36 ++++++++++ {common/toolcall => toolcall}/handler.cpp | 2 +- .../toolcall => toolcall}/mcp_messages.cpp | 0 {common/toolcall => toolcall}/mcp_messages.h | 0 .../mcp_sse_transport.cpp | 0 .../toolcall => toolcall}/mcp_sse_transport.h | 0 .../mcp_stdio_transport.cpp | 0 .../mcp_stdio_transport.h | 0 {common/toolcall => toolcall}/mcp_transport.h | 0 {common/toolcall => toolcall}/params.cpp | 2 +- .../handler.h => toolcall/toolcall-handler.h | 2 +- .../params.h => toolcall/toolcall-params.h | 0 17 files changed, 101 insertions(+), 49 deletions(-) create mode 100644 toolcall/CMakeLists.txt rename {common/toolcall => toolcall}/handler.cpp (98%) rename {common/toolcall => toolcall}/mcp_messages.cpp (100%) rename {common/toolcall => toolcall}/mcp_messages.h (100%) rename {common/toolcall => toolcall}/mcp_sse_transport.cpp (100%) rename {common/toolcall => toolcall}/mcp_sse_transport.h (100%) rename {common/toolcall => toolcall}/mcp_stdio_transport.cpp (100%) rename {common/toolcall => toolcall}/mcp_stdio_transport.h (100%) rename {common/toolcall => toolcall}/mcp_transport.h (100%) rename {common/toolcall => toolcall}/params.cpp (98%) rename common/toolcall/handler.h => toolcall/toolcall-handler.h (98%) rename common/toolcall/params.h => toolcall/toolcall-params.h (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b2a1845e5c7c..1af1a842e43c0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,6 +82,9 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) +# Toolcall support - needs LLAMA_CURL support to connect with SSE endpoints +option(LLAMA_TOOLCALL "llama: add toolcall support via Model Context Protocol" OFF) + # Required for relocatable CMake package include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake) @@ -160,6 +163,10 @@ add_subdirectory(src) # utils, programs, examples and tests # +if (LLAMA_TOOLCALL) + add_subdirectory(toolcall) +endif() + if (LLAMA_BUILD_COMMON) add_subdirectory(common) endif() diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index a2e2310157faf..8575f9e73085f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -75,15 +75,6 @@ add_library(${TARGET} STATIC sampling.h speculative.cpp speculative.h - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/params.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/params.h - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/handler.h - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_messages.h - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_transport.h - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_stdio_transport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_stdio_transport.h ) if (BUILD_SHARED_LIBS) @@ -99,11 +90,6 @@ if (LLAMA_CURL) include_directories(${CURL_INCLUDE_DIRS}) find_library(CURL_LIBRARY curl REQUIRED) set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY}) - - target_sources(${TARGET} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_sse_transport.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/toolcall/mcp_sse_transport.h) - endif () if (LLAMA_LLGUIDANCE) @@ -153,3 +139,7 @@ endif () target_include_directories(${TARGET} PUBLIC .) target_compile_features (${TARGET} PUBLIC cxx_std_17) target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) + +if (LLAMA_TOOLCALL) + target_link_libraries(${TARGET} PUBLIC toolcall) +endif() diff --git a/common/arg.cpp b/common/arg.cpp index d5e92a60fa21d..8685369d3651e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2143,9 +2143,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); +#ifdef LLAMA_USE_TOOLCALL add_opt(common_arg( {"--tools"}, "JINJA_TOOLS", - "set to JSON array of tool definitions used for assistant function-calling (requires --jinja)", + "set to URI of a Model Context Protocol server, or " + "a JSON array containing tool definitions (requires --jinja)", [](common_params ¶ms, const std::string & value) { params.jinja_tools.tools(value); @@ -2153,11 +2155,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"--tool-choice"}, "JINJA_TOOL_CHOICE", - "set to \"auto\", \"required\", \"none\" or a JSON object specifying a tool function (default: \"auto\")", + "set to \"auto\", \"required\", or \"none\" (default: \"auto\")", [](common_params ¶ms, const std::string & value) { params.jinja_tools.choice(value); }).set_examples({LLAMA_EXAMPLE_MAIN})); +#endif add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", diff --git a/common/common.h b/common/common.h index 72a57e5a07c00..d6413b3e6d692 100644 --- a/common/common.h +++ b/common/common.h @@ -3,7 +3,11 @@ #pragma once #include "llama-cpp.h" -#include "toolcall/params.h" + +#ifdef LLAMA_USE_TOOLCALL +# include "toolcall-params.h" +#endif + #include #include #include @@ -353,7 +357,11 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; + +#ifdef LLAMA_USE_TOOLCALL toolcall::params jinja_tools; +#endif + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; std::vector api_keys; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 72ff5034e3b0d..6aedb7bd723f2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -5,7 +5,6 @@ #include "sampling.h" #include "llama.h" #include "chat.h" -#include "toolcall/handler.h" #include #include @@ -16,6 +15,10 @@ #include #include +#ifdef LLAMA_USE_TOOLCALL +# include "toolcall-handler.h" +#endif + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include #include @@ -264,11 +267,14 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto toolcall_handler = toolcall::create_handler(params.jinja_tools); +#ifdef LLAMA_USE_TOOLCALL + auto tc_handler = toolcall::create_handler(params.jinja_tools); +#else + void * tc_handler = nullptr; // placeholder +#endif - auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab]( - const std::string & role, const std::string & content, - toolcall::handler::ptr handler = nullptr) + auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab, tc_handler]( + const std::string & role, const std::string & content, bool use_toolcalls = false) { bool add_ass = (role == "user"); @@ -277,11 +283,13 @@ int main(int argc, char ** argv) { new_msg.content = content; common_chat_templates_inputs cinputs; - if (handler != nullptr) { - cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(handler->tool_choice()); - cinputs.tools = common_chat_tools_parse_oaicompat(handler->tool_list()); - } +#ifdef LLAMA_USE_TOOLCALL + if (tc_handler != nullptr && use_toolcalls) { + cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_handler->tool_choice()); + cinputs.tools = common_chat_tools_parse_oaicompat(tc_handler->tool_list()); + } +#endif common_chat_params cparams; auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, add_ass, g_params->use_jinja, @@ -290,23 +298,24 @@ int main(int argc, char ** argv) { chat_msgs.push_back(new_msg); LOG_DBG("formatted: '%s'\n", formatted.c_str()); +#ifdef LLAMA_USE_TOOLCALL if (g_params->use_jinja) { common_chat_grammar_to_sampler(&cparams, vocab, &sparams); - if (handler != nullptr) { + if (tc_handler != nullptr) { std::string response; - handler->call(formatted, response); + tc_handler->call(formatted, response); return std::string(response); } } +#endif return formatted; }; { - std::string system_prompt (params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt); - bool use_conversation_prompt (params.conversation_mode && params.enable_chat_template); + std::string system_prompt = params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt; + bool use_conversation_prompt = params.conversation_mode && params.enable_chat_template; auto prompt = use_conversation_prompt ? - chat_add_and_format("system", system_prompt, toolcall_handler) - : params.prompt; + chat_add_and_format("system", system_prompt, true) : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); @@ -814,23 +823,22 @@ int main(int argc, char ** argv) { } if (params.enable_chat_template) { - auto output = chat_add_and_format("assistant", assistant_ss.str(), toolcall_handler); - if (toolcall_handler != nullptr) { - auto action = toolcall_handler->last_action(); - if (action == toolcall::ACCEPT) { - LOG_DBG("tokenizing toolcall response"); - auto response = common_tokenize(ctx, output, false, true); - embd_inp.insert(embd_inp.end(), response.begin(), response.end()); - - } else { - is_interacting = true; - LOG("\n"); - } - - } else { +#ifdef LLAMA_USE_TOOLCALL + auto output = chat_add_and_format("assistant", assistant_ss.str(), true); + if (tc_handler == nullptr || tc_handler->last_action() != toolcall::ACCEPT) { is_interacting = true; LOG("\n"); + + } else { + LOG_DBG("tokenizing toolcall response"); + auto response = common_tokenize(ctx, output, false, true); + embd_inp.insert(embd_inp.end(), response.begin(), response.end()); } +#else + chat_add_and_format("assistant", assistant_ss.str()); + is_interacting = true; + LOG("\n"); +#endif } } } diff --git a/toolcall/CMakeLists.txt b/toolcall/CMakeLists.txt new file mode 100644 index 0000000000000..b1d88a34d7953 --- /dev/null +++ b/toolcall/CMakeLists.txt @@ -0,0 +1,36 @@ + +set(TARGET toolcall) + +set(SOURCES + handler.cpp + mcp_messages.cpp + mcp_stdio_transport.cpp + params.cpp +) + +set(HEADERS + toolcall-params.h + toolcall-handler.h + mcp_transport.h + mcp_messages.h + mcp_stdio_transport.h +) + +add_library(${TARGET} STATIC ${SOURCES} ${HEADERS}) + +target_include_directories(${TARGET} # Right now only for "json.hpp" + PRIVATE $) + +if (LLAMA_CURL) + find_package(CURL REQUIRED) + target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL) + include_directories(${CURL_INCLUDE_DIRS}) + find_library(CURL_LIBRARY curl REQUIRED) + + target_link_libraries(${TARGET} PRIVATE ${CURL_LIBRARY}) + target_sources(${TARGET} PRIVATE mcp_sse_transport.cpp mcp_sse_transport.h) + +endif() + +target_compile_definitions(${TARGET} INTERFACE LLAMA_USE_TOOLCALL) +target_include_directories(${TARGET} PUBLIC $) diff --git a/common/toolcall/handler.cpp b/toolcall/handler.cpp similarity index 98% rename from common/toolcall/handler.cpp rename to toolcall/handler.cpp index f9e8d60514292..5d65079137611 100644 --- a/common/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -1,6 +1,6 @@ #include -#include "handler.h" +#include "toolcall-handler.h" #ifdef LLAMA_USE_CURL # include "mcp_sse_transport.h" diff --git a/common/toolcall/mcp_messages.cpp b/toolcall/mcp_messages.cpp similarity index 100% rename from common/toolcall/mcp_messages.cpp rename to toolcall/mcp_messages.cpp diff --git a/common/toolcall/mcp_messages.h b/toolcall/mcp_messages.h similarity index 100% rename from common/toolcall/mcp_messages.h rename to toolcall/mcp_messages.h diff --git a/common/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp similarity index 100% rename from common/toolcall/mcp_sse_transport.cpp rename to toolcall/mcp_sse_transport.cpp diff --git a/common/toolcall/mcp_sse_transport.h b/toolcall/mcp_sse_transport.h similarity index 100% rename from common/toolcall/mcp_sse_transport.h rename to toolcall/mcp_sse_transport.h diff --git a/common/toolcall/mcp_stdio_transport.cpp b/toolcall/mcp_stdio_transport.cpp similarity index 100% rename from common/toolcall/mcp_stdio_transport.cpp rename to toolcall/mcp_stdio_transport.cpp diff --git a/common/toolcall/mcp_stdio_transport.h b/toolcall/mcp_stdio_transport.h similarity index 100% rename from common/toolcall/mcp_stdio_transport.h rename to toolcall/mcp_stdio_transport.h diff --git a/common/toolcall/mcp_transport.h b/toolcall/mcp_transport.h similarity index 100% rename from common/toolcall/mcp_transport.h rename to toolcall/mcp_transport.h diff --git a/common/toolcall/params.cpp b/toolcall/params.cpp similarity index 98% rename from common/toolcall/params.cpp rename to toolcall/params.cpp index 4e1f81703ba15..473e6be5c4017 100644 --- a/common/toolcall/params.cpp +++ b/toolcall/params.cpp @@ -1,5 +1,5 @@ -#include "params.h" +#include "toolcall-params.h" #include #include diff --git a/common/toolcall/handler.h b/toolcall/toolcall-handler.h similarity index 98% rename from common/toolcall/handler.h rename to toolcall/toolcall-handler.h index 3c3568dd0ee09..46d6bec5f28a3 100644 --- a/common/toolcall/handler.h +++ b/toolcall/toolcall-handler.h @@ -1,6 +1,6 @@ #pragma once -#include "params.h" +#include "toolcall-params.h" #include #include #include diff --git a/common/toolcall/params.h b/toolcall/toolcall-params.h similarity index 100% rename from common/toolcall/params.h rename to toolcall/toolcall-params.h From 8668d89f2ad9bc1172b39064b50bea3dd9510260 Mon Sep 17 00:00:00 2001 From: Mason M Date: Thu, 20 Feb 2025 21:56:31 -0400 Subject: [PATCH 29/69] Convert chat_add_and_format to functor --- examples/main/main.cpp | 101 +++++++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 39 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6aedb7bd723f2..1b5fb41f8a929 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -89,6 +89,66 @@ static void sigint_handler(int signo) { } #endif +class chat_formatter { +public: + chat_formatter(common_params & params, std::vector & chat_msgs, struct common_chat_templates * chat_templates) + : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates) {} + +#ifdef LLAMA_USE_TOOLCALL + chat_formatter(common_params & params, + std::vector & chat_msgs, + struct common_chat_templates * chat_templates, + const llama_vocab * vocab, + toolcall::handler::ptr tc_handler) + + : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_handler_(tc_handler) {} +#endif + + std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false) { + common_chat_msg new_msg; + new_msg.role = role; + new_msg.content = content; + + common_chat_params cparams; + common_chat_templates_inputs cinputs; +#ifdef LLAMA_USE_TOOLCALL + if (tc_handler_ != nullptr && use_toolcalls) { + cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_handler_->tool_choice()); + cinputs.tools = common_chat_tools_parse_oaicompat(tc_handler_->tool_list()); + } +#endif + bool add_ass = role == "user"; + auto formatted = + common_chat_format_single(chat_templates_, chat_msgs_, new_msg, add_ass, params_.use_jinja, + &cinputs, &cparams); + + chat_msgs_.push_back(new_msg); + LOG_DBG("formatted: '%s'\n", formatted.c_str()); + +#ifdef LLAMA_USE_TOOLCALL + if (params_.use_jinja) { + common_chat_grammar_to_sampler(&cparams, vocab_, ¶ms_.sampling); + if (tc_handler_ != nullptr) { + std::string response; + tc_handler_->call(formatted, response); + return std::string(response); + } + } +#endif + return formatted; + } + +private: + common_params & params_; + std::vector & chat_msgs_; + struct common_chat_templates * chat_templates_; + +#ifdef LLAMA_USE_TOOLCALL + const llama_vocab * vocab_; + toolcall::handler::ptr tc_handler_; +#endif +}; + int main(int argc, char ** argv) { common_params params; g_params = ¶ms; @@ -269,47 +329,10 @@ int main(int argc, char ** argv) { #ifdef LLAMA_USE_TOOLCALL auto tc_handler = toolcall::create_handler(params.jinja_tools); + chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_handler); #else - void * tc_handler = nullptr; // placeholder -#endif - - auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab, tc_handler]( - const std::string & role, const std::string & content, bool use_toolcalls = false) - { - bool add_ass = (role == "user"); - - common_chat_msg new_msg; - new_msg.role = role; - new_msg.content = content; - - common_chat_templates_inputs cinputs; - -#ifdef LLAMA_USE_TOOLCALL - if (tc_handler != nullptr && use_toolcalls) { - cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_handler->tool_choice()); - cinputs.tools = common_chat_tools_parse_oaicompat(tc_handler->tool_list()); - } -#endif - common_chat_params cparams; - auto formatted = - common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, add_ass, g_params->use_jinja, - &cinputs, &cparams); - - chat_msgs.push_back(new_msg); - LOG_DBG("formatted: '%s'\n", formatted.c_str()); - -#ifdef LLAMA_USE_TOOLCALL - if (g_params->use_jinja) { - common_chat_grammar_to_sampler(&cparams, vocab, &sparams); - if (tc_handler != nullptr) { - std::string response; - tc_handler->call(formatted, response); - return std::string(response); - } - } + chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get()); #endif - return formatted; - }; { std::string system_prompt = params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt; From a19ed471ae9590d31f4ad51954e630333d0c1c31 Mon Sep 17 00:00:00 2001 From: Mason M Date: Thu, 20 Feb 2025 22:00:28 -0400 Subject: [PATCH 30/69] Enable LLAMA_TOOLCALL by default (for now) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1af1a842e43c0..613fc7e747b59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,7 +83,7 @@ option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) # Toolcall support - needs LLAMA_CURL support to connect with SSE endpoints -option(LLAMA_TOOLCALL "llama: add toolcall support via Model Context Protocol" OFF) +option(LLAMA_TOOLCALL "llama: add toolcall support via Model Context Protocol" ON) # Required for relocatable CMake package include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) From 5c0b0cb8749e7854d7b2ea478859b22c0cafe522 Mon Sep 17 00:00:00 2001 From: Mason M Date: Thu, 20 Feb 2025 22:13:01 -0400 Subject: [PATCH 31/69] Use cxx_std_17 --- toolcall/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/toolcall/CMakeLists.txt b/toolcall/CMakeLists.txt index b1d88a34d7953..953a0c6f20f43 100644 --- a/toolcall/CMakeLists.txt +++ b/toolcall/CMakeLists.txt @@ -34,3 +34,4 @@ endif() target_compile_definitions(${TARGET} INTERFACE LLAMA_USE_TOOLCALL) target_include_directories(${TARGET} PUBLIC $) +target_compile_features (${TARGET} PUBLIC cxx_std_17) From 3e46978e5e7503942c395521b65333eb0f12019a Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 21 Feb 2025 13:33:34 -0400 Subject: [PATCH 32/69] Impl. initialize and tool_list routines --- toolcall/handler.cpp | 116 ++++++++++++++++++++++++++++++--- toolcall/mcp_messages.h | 13 +++- toolcall/mcp_sse_transport.cpp | 4 +- toolcall/mcp_transport.h | 13 +++- toolcall/toolcall-handler.h | 27 +++++++- 5 files changed, 156 insertions(+), 17 deletions(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index 5d65079137611..f2ed6fab81509 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -1,9 +1,11 @@ #include #include "toolcall-handler.h" +#include +#include #ifdef LLAMA_USE_CURL -# include "mcp_sse_transport.h" +# include "mcp_sse_transport.h" #endif #include "mcp_stdio_transport.h" @@ -18,10 +20,12 @@ std::shared_ptr toolcall::create_handler(const toolcall::para if (params) { if (params.has_uri()) { #ifdef LLAMA_USE_CURL - handler.reset(new toolcall::handler(std::make_unique(tools, choice))); + handler.reset(new toolcall::handler( + std::make_unique(tools, choice))); #endif } else { - handler.reset(new toolcall::handler(std::make_unique(tools, choice))); + handler.reset(new toolcall::handler( + std::make_unique(tools, choice))); } } return handler; @@ -31,6 +35,10 @@ std::string toolcall::handler::tool_list() { return impl_->tool_list(); } +bool toolcall::handler::tool_list_dirty() const { + return impl_->tool_list_dirty(); +} + toolcall::action toolcall::handler::call(const std::string & request, std::string & response) { last_action_ = impl_->call(request, response); return last_action_; @@ -39,20 +47,33 @@ toolcall::action toolcall::handler::call(const std::string & request, std::strin const std::string & toolcall::handler::tool_choice() const { return impl_->tool_choice(); } + toolcall::action toolcall::handler::last_action() const { return last_action_; } +void toolcall::handler::initialize() { + impl_->initialize(); +} + #ifdef LLAMA_USE_CURL toolcall::mcp_impl::mcp_impl(std::string server_uri, std::string tool_choice) : handler_impl(tool_choice), - transport_(new mcp_sse_transport(server_uri)) + transport_(new mcp_sse_transport(server_uri)), + tools_("[]"), + tools_mutex_(), + tools_populating_(), + next_id_(1) { - transport_->start(); } #else toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, std::string tool_choice) - : handler_impl(tool_choice) + : handler_impl(tool_choice), + transport_(nullptr), + tools_("[]"), + tools_mutex_(), + tools_populating_(), + next_id_(1) { } #endif @@ -61,15 +82,94 @@ toolcall::mcp_impl::mcp_impl(std::vector argv, std::string tool_cho : handler_impl(tool_choice), transport_(new mcp_stdio_transport(argv)) { +} + +void toolcall::mcp_impl::initialize() { + using on_response = toolcall::callback; + using on_list_changed = toolcall::callback; + + if (transport_ == nullptr) return; + std::unique_lock lock(tools_mutex_); + transport_->start(); + + mcp::capabilities caps; + on_response set_caps = [this, &caps] (const mcp::initialize_response & resp) { + std::unique_lock lock(tools_mutex_); + caps = resp.capabilities(); + tools_populating_.notify_one(); + }; + + transport_->subscribe(set_caps); + + mcp::initialize_request req(next_id_++); + transport_->send(req.toJson()); + + tools_populating_.wait_for(lock, std::chrono::seconds(15)); + transport_->unsubscribe(set_caps); + + on_list_changed update_dirty = [this] (const mcp::tools_list_changed_notification &) { + tool_list_dirty_ = true; + }; + + bool has_tools = false; + for (const auto & cap : caps) { + if (cap.name == "tools") { + has_tools = true; + if (cap.listChanged) { + transport_->subscribe(update_dirty); + } + break; + } + } + if (! has_tools) { + throw std::runtime_error("MCP server does not support toolcalls!"); + } +} + +static std::string tools_list_to_oai_json(const mcp::tools_list & tools) { + return "[]"; // TODO } std::string toolcall::mcp_impl::tool_list() { - // Construct tools/list call and send to transport - return "[]";// TODO + using on_response = toolcall::callback; + + if (tool_list_dirty_) { + std::unique_lock lock(tools_mutex_); + + mcp::tools_list tools; + on_response set_tools = [this, &tools] (const mcp::tools_list_response & resp) { + std::unique_lock lock(tools_mutex_); + + tools.insert(tools.end(), resp.tools().begin(), resp.tools().end()); + auto cursor = resp.next_cursor(); + if (! cursor.empty()) { + mcp::tools_list_request req(std::to_string(next_id_++), cursor); + transport_->send(req.toJson()); + return; + } + tool_list_dirty_ = false; + lock.unlock(); + tools_populating_.notify_one(); + }; + + transport_->subscribe(set_tools); + + mcp::tools_list_request req(std::to_string(next_id_++)); + transport_->send(req.toJson()); + + tools_populating_.wait_for(lock, std::chrono::seconds(15)); + transport_->unsubscribe(set_tools); + + tools_ = tools_list_to_oai_json(tools); + } + return tools_; } toolcall::action toolcall::mcp_impl::call(const std::string & /*request*/, std::string & /*response*/) { + if (transport_ == nullptr) { + return toolcall::DEFER; + } // Construct tool call and send to transport return toolcall::ACCEPT; // TODO } diff --git a/toolcall/mcp_messages.h b/toolcall/mcp_messages.h index ab68972522e71..31d62e3469e19 100644 --- a/toolcall/mcp_messages.h +++ b/toolcall/mcp_messages.h @@ -118,7 +118,7 @@ namespace mcp class initialize_request : public request { public: - initialize_request(nlohmann::json id, mcp::capabilities caps); + initialize_request(nlohmann::json id, mcp::capabilities caps = mcp::capabilities{}); const std::string & name() const { return ClientName; } const std::string & version() const { return ClientVersion; } @@ -206,7 +206,7 @@ namespace mcp const tools_list & tools() const { return tools_; } void next_cursor(std::string next_cursor); - const std::string & next_cursor() { return next_cursor_; } + const std::string & next_cursor() const { return next_cursor_; } private: void refreshResult(); @@ -214,13 +214,20 @@ namespace mcp std::string next_cursor_; }; + class tools_list_changed_notification : public notification { + public: + tools_list_changed_notification() + : notification("notifications/tools/list_changed") {} + }; + using message_variant = std::variant; + tools_list_response, + tools_list_changed_notification>; bool create_message(const std::string & data, message_variant & message); } diff --git a/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp index bd8230c29ee33..4deda675409e8 100644 --- a/toolcall/mcp_sse_transport.cpp +++ b/toolcall/mcp_sse_transport.cpp @@ -1,6 +1,4 @@ -#include -#include #include "mcp_sse_transport.h" #include #include @@ -73,6 +71,8 @@ static size_t sse_callback(char * data, size_t size, size_t nmemb, void * client } void toolcall::mcp_sse_transport::parse_field_value(std::string field, std::string value) { + LOG_DBG("SSE: field \"%s\"; value \"%s\"", field.c_str(), value.c_str()); + if (field == "event") { // Set the event type buffer to field value. event_.type = std::move(value); diff --git a/toolcall/mcp_transport.h b/toolcall/mcp_transport.h index 5420bd6ff926f..bee13a6b63e70 100644 --- a/toolcall/mcp_transport.h +++ b/toolcall/mcp_transport.h @@ -20,6 +20,15 @@ namespace toolcall vec.push_back(std::move(callback)); } + template + void unsubscribe(callback callback) { + auto& vec = std::get>>(subscribers_); + auto found = std::find(vec.begin(), vec.end(), callback); + if (found != vec.end()) { + vec.erase(found); + } + } + template void notify(const T & message) const { const auto& vec = std::get>>(subscribers_); @@ -49,8 +58,8 @@ namespace toolcall mcp::initialize_response, mcp::initialized_notification, mcp::tools_list_request, - mcp::tools_list_response> - { + mcp::tools_list_response, + mcp::tools_list_changed_notification> { public: virtual ~mcp_transport() = default; virtual void start() = 0; diff --git a/toolcall/toolcall-handler.h b/toolcall/toolcall-handler.h index 46d6bec5f28a3..c3e97ae24a69a 100644 --- a/toolcall/toolcall-handler.h +++ b/toolcall/toolcall-handler.h @@ -5,6 +5,8 @@ #include #include #include +#include +#include namespace toolcall { @@ -21,11 +23,16 @@ namespace toolcall handler(std::unique_ptr impl) : impl_(std::move(impl)) {} - std::string tool_list(); action call(const std::string & request, std::string & response); + + std::string tool_list(); + bool tool_list_dirty() const; + const std::string & tool_choice() const; action last_action() const; + void initialize(); + private: std::unique_ptr impl_; action last_action_; @@ -36,16 +43,25 @@ namespace toolcall class handler_impl { public: handler_impl(std::string tool_choice) - : tool_choice_(std::move(tool_choice)) {} + : tool_choice_(std::move(tool_choice)), tool_list_dirty_(true) {} virtual ~handler_impl() = default; + virtual std::string tool_list() = 0; + + virtual bool tool_list_dirty() const { + return tool_list_dirty_; + } + virtual action call(const std::string & request, std::string & response) = 0; const std::string & tool_choice() const { return tool_choice_; } + virtual void initialize() {} + protected: std::string tool_choice_; + bool tool_list_dirty_; }; class loopback_impl : public handler_impl { @@ -54,6 +70,7 @@ namespace toolcall : handler_impl(tool_choice), tools_(std::move(tools)) {} virtual std::string tool_list() override { + tool_list_dirty_ = false; return tools_; } @@ -75,7 +92,13 @@ namespace toolcall virtual std::string tool_list() override; virtual action call(const std::string & request, std::string & response) override; + virtual void initialize() override; + private: std::unique_ptr transport_; + std::string tools_; + std::mutex tools_mutex_; + std::condition_variable tools_populating_; + int next_id_; }; } From 5d6a0580ff91df255b7b411dce12b58b564561b2 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 21 Feb 2025 14:05:42 -0400 Subject: [PATCH 33/69] Store callbacks in map --- toolcall/handler.cpp | 10 +++++----- toolcall/mcp_transport.h | 35 ++++++++++++++++++++--------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index f2ed6fab81509..e63e550625ff2 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -100,13 +100,13 @@ void toolcall::mcp_impl::initialize() { tools_populating_.notify_one(); }; - transport_->subscribe(set_caps); + transport_->subscribe("set_caps", set_caps); mcp::initialize_request req(next_id_++); transport_->send(req.toJson()); tools_populating_.wait_for(lock, std::chrono::seconds(15)); - transport_->unsubscribe(set_caps); + transport_->unsubscribe("set_caps"); on_list_changed update_dirty = [this] (const mcp::tools_list_changed_notification &) { tool_list_dirty_ = true; @@ -117,7 +117,7 @@ void toolcall::mcp_impl::initialize() { if (cap.name == "tools") { has_tools = true; if (cap.listChanged) { - transport_->subscribe(update_dirty); + transport_->subscribe("update_dirty", update_dirty); } break; } @@ -153,13 +153,13 @@ std::string toolcall::mcp_impl::tool_list() { tools_populating_.notify_one(); }; - transport_->subscribe(set_tools); + transport_->subscribe("set_tools", set_tools); mcp::tools_list_request req(std::to_string(next_id_++)); transport_->send(req.toJson()); tools_populating_.wait_for(lock, std::chrono::seconds(15)); - transport_->unsubscribe(set_tools); + transport_->unsubscribe("set_tools"); tools_ = tools_list_to_oai_json(tools); } diff --git a/toolcall/mcp_transport.h b/toolcall/mcp_transport.h index bee13a6b63e70..708b866ae27e5 100644 --- a/toolcall/mcp_transport.h +++ b/toolcall/mcp_transport.h @@ -3,7 +3,7 @@ #include "mcp_messages.h" #include #include -#include +#include namespace toolcall { @@ -13,27 +13,32 @@ namespace toolcall template class mcp_transport_t { public: - template - void subscribe(callback callback) { - auto& vec = std::get>>(subscribers_); - vec.push_back(std::move(callback)); + void subscribe(std::string key, callback callback) { + auto& map = + std::get>>( + subscribers_); + + map.insert({key, callback}); } template - void unsubscribe(callback callback) { - auto& vec = std::get>>(subscribers_); - auto found = std::find(vec.begin(), vec.end(), callback); - if (found != vec.end()) { - vec.erase(found); - } + void unsubscribe(std::string key) { + auto& map = + std::get>>( + subscribers_); + + map.erase(key); } template void notify(const T & message) const { - const auto& vec = std::get>>(subscribers_); - for (const auto& callback : vec) { - callback(message); + const auto& map = + std::get>>( + subscribers_); + + for (const auto & pair : map) { + pair.second(message); } } @@ -50,7 +55,7 @@ namespace toolcall } private: - std::tuple>...> subscribers_; + std::tuple>...> subscribers_; }; class mcp_transport : public mcp_transport_t Date: Fri, 21 Feb 2025 14:16:09 -0400 Subject: [PATCH 34/69] No need to explicitly convert int to string --- toolcall/handler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index e63e550625ff2..f5950d2b39fe9 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -144,7 +144,7 @@ std::string toolcall::mcp_impl::tool_list() { tools.insert(tools.end(), resp.tools().begin(), resp.tools().end()); auto cursor = resp.next_cursor(); if (! cursor.empty()) { - mcp::tools_list_request req(std::to_string(next_id_++), cursor); + mcp::tools_list_request req(next_id_++, cursor); transport_->send(req.toJson()); return; } @@ -155,7 +155,7 @@ std::string toolcall::mcp_impl::tool_list() { transport_->subscribe("set_tools", set_tools); - mcp::tools_list_request req(std::to_string(next_id_++)); + mcp::tools_list_request req(next_id_++); transport_->send(req.toJson()); tools_populating_.wait_for(lock, std::chrono::seconds(15)); From ba57885b32ffdffc076fa76124f746f969663329 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 21 Feb 2025 14:25:07 -0400 Subject: [PATCH 35/69] Initialize tc_handler --- examples/main/main.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1b5fb41f8a929..582a8ab474bff 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -329,6 +329,9 @@ int main(int argc, char ** argv) { #ifdef LLAMA_USE_TOOLCALL auto tc_handler = toolcall::create_handler(params.jinja_tools); + if (tc_handler) { + tc_handler->initialize(); + } chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_handler); #else chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get()); From 88bace32c1085db66ae28c4572088e6cf60b920c Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 21 Feb 2025 18:32:35 -0400 Subject: [PATCH 36/69] Impl. tools_list_to_oai_json --- toolcall/handler.cpp | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index f5950d2b39fe9..a027bb4595a29 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -128,7 +128,33 @@ void toolcall::mcp_impl::initialize() { } static std::string tools_list_to_oai_json(const mcp::tools_list & tools) { - return "[]"; // TODO + json tool_list = json::array(); + for (const auto & tool : tools) { + json props; + for (const auto & param : tool.params) { + props[param.name]["type"] = param.type; + props[param.name]["description"] = param.description; + } + json required = json::array(); + for (const auto & name : tool.required_params) { + required.push_back(name); + } + tool_list.push_back({ + {"type", "function"}, + {"function", { + {"name", tool.tool_name}, + {"description", tool.tool_description}, + {"parameters", { + {"type", "object"}, + {"properties", props} + } + }, + {"required", required} + } + } + }); + } + return tool_list; } std::string toolcall::mcp_impl::tool_list() { From b2c340dae677197eed21daa536372a4ffa1773c1 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 21 Feb 2025 18:40:11 -0400 Subject: [PATCH 37/69] Remove mcp+ URI prefix --- toolcall/params.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/toolcall/params.cpp b/toolcall/params.cpp index 473e6be5c4017..3507e523cbeae 100644 --- a/toolcall/params.cpp +++ b/toolcall/params.cpp @@ -18,7 +18,7 @@ toolcall::params::params(std::string tools, std::string choice) { void toolcall::params::tools(std::string tools) { try { if (! tools.empty()) { - if (starts_with(tools, "mcp+http")) { + if (starts_with(tools, "http")) { #ifndef LLAMA_USE_CURL throw std::invalid_argument( "Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl"); From 7d58e9bf83123162ddb5b37a21c8cc418f61e739 Mon Sep 17 00:00:00 2001 From: Mason M Date: Sat, 22 Feb 2025 14:16:34 -0400 Subject: [PATCH 38/69] WIP: fixing SSE issues --- toolcall/handler.cpp | 13 ++++++++----- toolcall/mcp_sse_transport.cpp | 3 ++- toolcall/mcp_transport.h | 26 +++++++++++++------------- toolcall/params.cpp | 3 +-- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index a027bb4595a29..d76eccdcae001 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -93,6 +93,7 @@ void toolcall::mcp_impl::initialize() { transport_->start(); + bool caps_received = false; mcp::capabilities caps; on_response set_caps = [this, &caps] (const mcp::initialize_response & resp) { std::unique_lock lock(tools_mutex_); @@ -103,9 +104,9 @@ void toolcall::mcp_impl::initialize() { transport_->subscribe("set_caps", set_caps); mcp::initialize_request req(next_id_++); - transport_->send(req.toJson()); + transport_->send(req); - tools_populating_.wait_for(lock, std::chrono::seconds(15)); + tools_populating_.wait_for(lock, std::chrono::seconds(15), [&caps_received] { return caps_received; }); transport_->unsubscribe("set_caps"); on_list_changed update_dirty = [this] (const mcp::tools_list_changed_notification &) { @@ -125,6 +126,8 @@ void toolcall::mcp_impl::initialize() { if (! has_tools) { throw std::runtime_error("MCP server does not support toolcalls!"); } + + transport_->send(mcp::initialized_notification()); } static std::string tools_list_to_oai_json(const mcp::tools_list & tools) { @@ -171,7 +174,7 @@ std::string toolcall::mcp_impl::tool_list() { auto cursor = resp.next_cursor(); if (! cursor.empty()) { mcp::tools_list_request req(next_id_++, cursor); - transport_->send(req.toJson()); + transport_->send(req); return; } tool_list_dirty_ = false; @@ -182,9 +185,9 @@ std::string toolcall::mcp_impl::tool_list() { transport_->subscribe("set_tools", set_tools); mcp::tools_list_request req(next_id_++); - transport_->send(req.toJson()); + transport_->send(req); - tools_populating_.wait_for(lock, std::chrono::seconds(15)); + tools_populating_.wait_for(lock, std::chrono::seconds(15), [this] { return ! tool_list_dirty_; }); transport_->unsubscribe("set_tools"); tools_ = tools_list_to_oai_json(tools); diff --git a/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp index 4deda675409e8..d3e4c0c91fcfc 100644 --- a/toolcall/mcp_sse_transport.cpp +++ b/toolcall/mcp_sse_transport.cpp @@ -35,7 +35,8 @@ void toolcall::mcp_sse_transport::start() { std::unique_lock lock(initializing_mutex_); sse_thread_ = std::thread(&toolcall::mcp_sse_transport::sse_run, this); - initializing_.wait(lock); + initializing_.wait_for( + lock, std::chrono::seconds(15), [this] { return endpoint_ != nullptr; }); if (endpoint_ == nullptr) { running_ = false; diff --git a/toolcall/mcp_transport.h b/toolcall/mcp_transport.h index 708b866ae27e5..434fff70f074f 100644 --- a/toolcall/mcp_transport.h +++ b/toolcall/mcp_transport.h @@ -10,7 +10,7 @@ namespace toolcall template using callback = std::function; - template + template class mcp_transport_t { public: template @@ -49,24 +49,24 @@ namespace toolcall } } - template - bool send(const T & message) { - return static_cast(this)->send(message.toJson()); - } - private: std::tuple>...> subscribers_; }; - class mcp_transport : public mcp_transport_t { + class mcp_transport : public mcp_transport_t { public: virtual ~mcp_transport() = default; + + template + bool send(const T & message) { + return send(std::string(message.toJson())); + } + virtual void start() = 0; virtual void stop() = 0; virtual bool send(const std::string & request_json) = 0; diff --git a/toolcall/params.cpp b/toolcall/params.cpp index 3507e523cbeae..1ad64149aed68 100644 --- a/toolcall/params.cpp +++ b/toolcall/params.cpp @@ -29,8 +29,7 @@ void toolcall::params::tools(std::string tools) { json j = json::parse(tools); // Just for early validation if (! j.is_array()) { throw std::invalid_argument( - "tools must be a URL of the form \"mcp+http(s)://hostname[:port]/\"" - ", or a valid JSON array containing tool definitions"); + "tools must be a valid URL or a JSON array containing tool definitions"); } has_uri_ = false; } From ea4cc2fa8c36d3c19bf0a17ca0f1575b654507f9 Mon Sep 17 00:00:00 2001 From: Mason M Date: Sat, 22 Feb 2025 19:00:09 -0400 Subject: [PATCH 39/69] Add timeout to initialize routine --- toolcall/mcp_sse_transport.cpp | 26 +++++++++++++++++++++----- toolcall/mcp_sse_transport.h | 3 +++ toolcall/mcp_transport.h | 14 +++++++------- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp index d3e4c0c91fcfc..b9dd08ae2fe13 100644 --- a/toolcall/mcp_sse_transport.cpp +++ b/toolcall/mcp_sse_transport.cpp @@ -3,6 +3,9 @@ #include #include +const int toolcall::mcp_sse_transport::EndpointReceivedTimoutSeconds = 5; +const int toolcall::mcp_sse_transport::StartTimeoutSeconds = 8; + toolcall::mcp_sse_transport::~mcp_sse_transport() { if (endpoint_headers_) { curl_slist_free_all(endpoint_headers_); @@ -36,7 +39,7 @@ void toolcall::mcp_sse_transport::start() { std::unique_lock lock(initializing_mutex_); sse_thread_ = std::thread(&toolcall::mcp_sse_transport::sse_run, this); initializing_.wait_for( - lock, std::chrono::seconds(15), [this] { return endpoint_ != nullptr; }); + lock, std::chrono::seconds(StartTimeoutSeconds), [this] { return endpoint_ != nullptr; }); if (endpoint_ == nullptr) { running_ = false; @@ -184,7 +187,10 @@ size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { } void toolcall::mcp_sse_transport::sse_run() { + using namespace std::chrono; + std::unique_lock lock(initializing_mutex_); + char errbuf[CURL_ERROR_SIZE]; size_t errlen; CURLMcode mcode; @@ -194,6 +200,7 @@ void toolcall::mcp_sse_transport::sse_run() { CURLM * async_handle = nullptr; struct curl_slist * headers = nullptr; CURL * sse = nullptr; + steady_clock::time_point start; sse = curl_easy_init(); if (! sse) { @@ -217,8 +224,9 @@ void toolcall::mcp_sse_transport::sse_run() { } curl_multi_add_handle(async_handle, sse); + start = steady_clock::now(); do { - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(milliseconds(50)); mcode = curl_multi_perform(async_handle, &num_handles); if (mcode != CURLM_OK) { @@ -240,9 +248,17 @@ void toolcall::mcp_sse_transport::sse_run() { } } } - if (endpoint_ && lock.owns_lock()) { // TODO: timeout if endpoint not received - lock.unlock(); - initializing_.notify_one(); + + if (endpoint_) { + if (lock.owns_lock()) { + lock.unlock(); + initializing_.notify_one(); + } + + } else { + if (steady_clock::now() - start >= seconds(EndpointReceivedTimoutSeconds)) { + running_ = false; + } } } while (running_); diff --git a/toolcall/mcp_sse_transport.h b/toolcall/mcp_sse_transport.h index 5d424818a734f..8d6680de8d205 100644 --- a/toolcall/mcp_sse_transport.h +++ b/toolcall/mcp_sse_transport.h @@ -21,6 +21,9 @@ namespace toolcall size_t sse_read(const char * data, size_t len); private: + static const int EndpointReceivedTimoutSeconds; + static const int StartTimeoutSeconds; + void sse_run(); void parse_field_value(std::string field, std::string value); void on_endpoint_event(); diff --git a/toolcall/mcp_transport.h b/toolcall/mcp_transport.h index 434fff70f074f..49f144320241f 100644 --- a/toolcall/mcp_transport.h +++ b/toolcall/mcp_transport.h @@ -11,7 +11,7 @@ namespace toolcall using callback = std::function; template - class mcp_transport_t { + class mcp_message_observer { public: template void subscribe(std::string key, callback callback) { @@ -53,12 +53,12 @@ namespace toolcall std::tuple>...> subscribers_; }; - class mcp_transport : public mcp_transport_t { + class mcp_transport : public mcp_message_observer { public: virtual ~mcp_transport() = default; From 86f83f3947b9be16035dc9ff0953efbca113ddfa Mon Sep 17 00:00:00 2001 From: Mason M Date: Sat, 22 Feb 2025 19:15:57 -0400 Subject: [PATCH 40/69] Allow send routine to lock --- toolcall/mcp_sse_transport.cpp | 22 ++++++++++++++++------ toolcall/mcp_sse_transport.h | 4 ++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp index b9dd08ae2fe13..1a543e58de29c 100644 --- a/toolcall/mcp_sse_transport.cpp +++ b/toolcall/mcp_sse_transport.cpp @@ -26,8 +26,8 @@ toolcall::mcp_sse_transport::mcp_sse_transport(std::string server_uri) sse_buffer_(""), sse_cursor_(0), sse_last_id_(""), - initializing_mutex_(), - initializing_() + mutex_(), + cv_() { curl_global_init(CURL_GLOBAL_DEFAULT); } @@ -36,9 +36,9 @@ void toolcall::mcp_sse_transport::start() { if (running_) return; running_ = true; - std::unique_lock lock(initializing_mutex_); + std::unique_lock lock(mutex_); sse_thread_ = std::thread(&toolcall::mcp_sse_transport::sse_run, this); - initializing_.wait_for( + cv_.wait_for( lock, std::chrono::seconds(StartTimeoutSeconds), [this] { return endpoint_ != nullptr; }); if (endpoint_ == nullptr) { @@ -53,6 +53,7 @@ void toolcall::mcp_sse_transport::stop() { } bool toolcall::mcp_sse_transport::send(const std::string & request_json) { + std::lock_guard lock(mutex_); if (! running_ || endpoint_ == nullptr) { return false; } @@ -189,7 +190,7 @@ size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { void toolcall::mcp_sse_transport::sse_run() { using namespace std::chrono; - std::unique_lock lock(initializing_mutex_); + std::unique_lock lock(mutex_); char errbuf[CURL_ERROR_SIZE]; size_t errlen; @@ -252,7 +253,7 @@ void toolcall::mcp_sse_transport::sse_run() { if (endpoint_) { if (lock.owns_lock()) { lock.unlock(); - initializing_.notify_one(); + cv_.notify_one(); } } else { @@ -274,4 +275,13 @@ void toolcall::mcp_sse_transport::sse_run() { if (sse) { curl_easy_cleanup(sse); } + + lock.lock(); // Wait for pending send calls to complete + + if (endpoint_headers_) { + curl_slist_free_all(endpoint_headers_); + } + if (endpoint_) { + curl_easy_cleanup(endpoint_); + } } diff --git a/toolcall/mcp_sse_transport.h b/toolcall/mcp_sse_transport.h index 8d6680de8d205..ebe40919a3f63 100644 --- a/toolcall/mcp_sse_transport.h +++ b/toolcall/mcp_sse_transport.h @@ -46,7 +46,7 @@ namespace toolcall size_t sse_cursor_; std::string sse_last_id_; - std::mutex initializing_mutex_; - std::condition_variable initializing_; + std::mutex mutex_; + std::condition_variable cv_; }; } From c07a45260ffebec1fcefbde46dcc65353a2dd1fd Mon Sep 17 00:00:00 2001 From: Mason M Date: Sat, 22 Feb 2025 20:00:52 -0400 Subject: [PATCH 41/69] Handle relative URI returned from SSE --- toolcall/mcp_sse_transport.cpp | 43 ++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp index 1a543e58de29c..75a8912afa960 100644 --- a/toolcall/mcp_sse_transport.cpp +++ b/toolcall/mcp_sse_transport.cpp @@ -2,10 +2,16 @@ #include "mcp_sse_transport.h" #include #include +#include const int toolcall::mcp_sse_transport::EndpointReceivedTimoutSeconds = 5; const int toolcall::mcp_sse_transport::StartTimeoutSeconds = 8; +static bool starts_with(const std::string & str, const std::string & prefix) { + return str.size() >= prefix.size() + && str.compare(0, prefix.size(), prefix) == 0; +} + toolcall::mcp_sse_transport::~mcp_sse_transport() { if (endpoint_headers_) { curl_slist_free_all(endpoint_headers_); @@ -110,7 +116,17 @@ void toolcall::mcp_sse_transport::parse_field_value(std::string field, std::stri } } +static std::string uri_base (const std::string & uri) { + std::regex uri_regex(R"(^(https?:\/\/[^\/?#]+))"); + std::smatch match; + if (std::regex_search(uri, match, uri_regex)) { + return match[1]; + } + return uri; +} + void toolcall::mcp_sse_transport::on_endpoint_event() { + LOG_DBG("on_endpoint_event"); endpoint_ = curl_easy_init(); if (! endpoint_) { LOG_ERR("SSE: Failed to create endpoint handle"); @@ -118,7 +134,20 @@ void toolcall::mcp_sse_transport::on_endpoint_event() { return; } - curl_easy_setopt(endpoint_, CURLOPT_URL, event_.data.c_str()); + std::string endpoint_uri; + bool is_absolute = starts_with(event_.data, "http"); + if (is_absolute) { + endpoint_uri = event_.data; + + } else { + auto endpoint_uri = uri_base(server_uri_); + if (event_.data[0] != '/') { + endpoint_uri += '/'; + } + endpoint_uri += event_.data; + } + + curl_easy_setopt(endpoint_, CURLOPT_URL, endpoint_uri.c_str()); endpoint_headers_ = curl_slist_append(endpoint_headers_, "Content-Type: application/json"); @@ -147,7 +176,9 @@ size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { std::string line(sse_buffer_.begin(), last); if (line.empty()) { // Dispatch event if (event_.type == "endpoint") { - on_endpoint_event(); + if (! endpoint_) { + on_endpoint_event(); + } } else if(event_.type == "message") { on_message_event(); @@ -164,9 +195,10 @@ size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { auto sep_index = line.find(':'); if (sep_index != std::string::npos) { auto sep_i = line.begin() + sep_index; - + auto val_i = sep_i + (*(sep_i + 1) == ' ' ? 2 : 1); // If value starts with a U+0020 SPACE + // character, remove it from value. std::string field (line.begin(), sep_i); - std::string value (sep_i + 1, line.end()); + std::string value (val_i, line.end()); parse_field_value(std::move(field), std::move(value)); } @@ -276,7 +308,8 @@ void toolcall::mcp_sse_transport::sse_run() { curl_easy_cleanup(sse); } - lock.lock(); // Wait for pending send calls to complete + if (! lock.owns_lock()) + lock.lock(); // Wait for pending send calls to complete if (endpoint_headers_) { curl_slist_free_all(endpoint_headers_); From 3a11fa252235be02b9b73bd799443d10a6e9ef8b Mon Sep 17 00:00:00 2001 From: Mason M Date: Sun, 23 Feb 2025 11:18:21 -0400 Subject: [PATCH 42/69] Fix CRLF case --- toolcall/mcp_sse_transport.cpp | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp index 75a8912afa960..f5a7e097edbe9 100644 --- a/toolcall/mcp_sse_transport.cpp +++ b/toolcall/mcp_sse_transport.cpp @@ -27,7 +27,7 @@ toolcall::mcp_sse_transport::mcp_sse_transport(std::string server_uri) sse_thread_(), endpoint_(nullptr), endpoint_headers_(nullptr), - endpoint_errbuf_(CURL_ERROR_SIZE), + endpoint_errbuf_(CURL_ERROR_SIZE, '\0'), event_{"", "", ""}, sse_buffer_(""), sse_cursor_(0), @@ -169,8 +169,11 @@ void toolcall::mcp_sse_transport::on_message_event() { size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { sse_buffer_.insert(sse_buffer_.end(), data, data + len); - for (; sse_cursor_ < sse_buffer_.length(); ++sse_cursor_) { - if (sse_buffer_[sse_cursor_] == '\r' || sse_buffer_[sse_cursor_] == '\n') { + while (sse_cursor_ < sse_buffer_.length()) { + bool last_was_cr = sse_buffer_[sse_cursor_] == '\r'; + bool last_was_lf = sse_buffer_[sse_cursor_] == '\n'; + + if (last_was_cr || last_was_lf) { auto last = sse_buffer_.begin() + sse_cursor_; std::string line(sse_buffer_.begin(), last); @@ -190,8 +193,10 @@ size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { sse_last_id_ = event_.id; event_ = {"", "", ""}; - } else if(line[0] != ':') { // : denotes a comment - // Set field/value + } else if(line[0] != ':') { + // Comments begin with ":" and + // Field/Value pairs are delimited by ":" + auto sep_index = line.find(':'); if (sep_index != std::string::npos) { auto sep_i = line.begin() + sep_index; @@ -205,8 +210,8 @@ size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { } if (last++ != sse_buffer_.end()) { // Consume line-end - if (*last == '\n') { - last ++; // In the CRLF case consume one more + if (last_was_cr && *last == '\n') { + last ++; } sse_buffer_ = std::string(last, sse_buffer_.end()); @@ -214,6 +219,9 @@ size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { sse_buffer_.clear(); } sse_cursor_ = 0; // Prepare to scan for next line-end + + } else { + sse_cursor_ ++; } } return len; From 3418d3784ff34778749380427673c0f49f3f356a Mon Sep 17 00:00:00 2001 From: Mason M Date: Sun, 23 Feb 2025 11:45:15 -0400 Subject: [PATCH 43/69] Strip trailing NL from endpoint event URI --- toolcall/mcp_sse_transport.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp index f5a7e097edbe9..ab28c6dae4d93 100644 --- a/toolcall/mcp_sse_transport.cpp +++ b/toolcall/mcp_sse_transport.cpp @@ -82,7 +82,7 @@ static size_t sse_callback(char * data, size_t size, size_t nmemb, void * client } void toolcall::mcp_sse_transport::parse_field_value(std::string field, std::string value) { - LOG_DBG("SSE: field \"%s\"; value \"%s\"", field.c_str(), value.c_str()); + LOG_DBG("SSE: field \"%s\"; value \"%s\"\n", field.c_str(), value.c_str()); if (field == "event") { // Set the event type buffer to field value. @@ -126,7 +126,9 @@ static std::string uri_base (const std::string & uri) { } void toolcall::mcp_sse_transport::on_endpoint_event() { - LOG_DBG("on_endpoint_event"); + LOG_DBG("on_endpoint_event\n"); + event_.data.erase(event_.data.end() - 1); // Event data has trailing newline that will impact URI + endpoint_ = curl_easy_init(); if (! endpoint_) { LOG_ERR("SSE: Failed to create endpoint handle"); @@ -140,13 +142,14 @@ void toolcall::mcp_sse_transport::on_endpoint_event() { endpoint_uri = event_.data; } else { - auto endpoint_uri = uri_base(server_uri_); + endpoint_uri = uri_base(server_uri_); if (event_.data[0] != '/') { endpoint_uri += '/'; } endpoint_uri += event_.data; } + LOG_INF("SSE: using endpoint \"%s\"\n", endpoint_uri.c_str()); curl_easy_setopt(endpoint_, CURLOPT_URL, endpoint_uri.c_str()); endpoint_headers_ = From 7d6c29b2e0650e4897cc94826c2964a2b5ec8df4 Mon Sep 17 00:00:00 2001 From: Mason M Date: Sun, 23 Feb 2025 12:05:12 -0400 Subject: [PATCH 44/69] Explicitly create empty capability object --- toolcall/mcp_messages.cpp | 7 +++---- toolcall/mcp_transport.h | 3 ++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/toolcall/mcp_messages.cpp b/toolcall/mcp_messages.cpp index 41d6b7ad1f6a9..c8398324d2602 100644 --- a/toolcall/mcp_messages.cpp +++ b/toolcall/mcp_messages.cpp @@ -66,20 +66,19 @@ void mcp::initialize_request::refreshParams() { params["protocolVersion"] = protoVersion(); params["clientInfo"]["name"] = name(); params["clientInfo"]["version"] = version(); - params["capabilities"] = {}; + json capabilities = json::object(); for (auto cap = caps_.cbegin(); cap != caps_.cend(); ++cap) { json cap_json; - if (cap->subscribe) { cap_json["subscribe"] = true; } if (cap->listChanged) { cap_json["listChanged"] = true; } - - params["capabilities"][cap->name] = cap_json; + capabilities[cap->name] = cap_json; } + params["capabilities"] = capabilities; this->params(std::move(params)); } diff --git a/toolcall/mcp_transport.h b/toolcall/mcp_transport.h index 49f144320241f..084c6bef28ef5 100644 --- a/toolcall/mcp_transport.h +++ b/toolcall/mcp_transport.h @@ -64,7 +64,8 @@ namespace toolcall template bool send(const T & message) { - return send(std::string(message.toJson())); + nlohmann::json json = message.toJson(); + return send(json.dump(-1)); } virtual void start() = 0; From 40156fff7cb031641a6fb4c4eef018b996ee72ea Mon Sep 17 00:00:00 2001 From: Mason M Date: Sun, 23 Feb 2025 18:50:04 -0400 Subject: [PATCH 45/69] Add tools_list_response::fromJson --- toolcall/handler.cpp | 3 ++- toolcall/mcp_messages.cpp | 42 ++++++++++++++++++++++++++++++++++----- toolcall/mcp_messages.h | 2 ++ 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index d76eccdcae001..50e38a2ed5761 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -95,9 +95,10 @@ void toolcall::mcp_impl::initialize() { bool caps_received = false; mcp::capabilities caps; - on_response set_caps = [this, &caps] (const mcp::initialize_response & resp) { + on_response set_caps = [this, &caps, &caps_received] (const mcp::initialize_response & resp) { std::unique_lock lock(tools_mutex_); caps = resp.capabilities(); + caps_received = true; tools_populating_.notify_one(); }; diff --git a/toolcall/mcp_messages.cpp b/toolcall/mcp_messages.cpp index c8398324d2602..3ad196e53ba1c 100644 --- a/toolcall/mcp_messages.cpp +++ b/toolcall/mcp_messages.cpp @@ -1,5 +1,6 @@ #include "mcp_messages.h" #include +#include using json = nlohmann::json; @@ -106,20 +107,19 @@ void mcp::initialize_response::refreshResult() { result["protocolVersion"] = protoVersion(); result["serverInfo"]["name"] = name(); result["serverInfo"]["version"] = version(); - result["capabilities"] = {}; + json capabilities = json::object(); for (auto cap = caps_.cbegin(); cap != caps_.cend(); ++cap) { json cap_json; - if (cap->subscribe) { cap_json["subscribe"] = true; } if (cap->listChanged) { cap_json["listChanged"] = true; } - - result["capabilities"][cap->name] = cap_json; + capabilities[cap->name] = cap_json; } + result["capabilities"] = capabilities; this->result(std::move(result)); } @@ -256,8 +256,36 @@ void mcp::tools_list_response::refreshResult() { this->result(result); } +mcp::tools_list_response mcp::tools_list_response::fromJson(const nlohmann::json & j) { + mcp::tools_list tools; + for (const auto & t : j["result"]["tools"]) { + mcp::tool tool; + tool.tool_name = t["name"]; + tool.tool_description = t["description"]; + for (const auto & [key, value] : t["inputSchema"]["properties"].items()) { + mcp::tool::param param; + param.name = key; + param.type = value["type"]; + param.description = value["description"]; + tool.params.push_back(param); + } + if (t["inputSchema"].contains("required") && t["inputSchema"]["required"].is_array()) { + for (const auto & required : t["inputSchema"]["required"]) { + tool.required_params.push_back(required); + } + } + tools.push_back(std::move(tool)); + } + std::string next_cursor = j["result"].value("nextCursor", ""); + return tools_list_response(j["id"], std::move(tools), next_cursor); +} + static bool has_initialized_response(const nlohmann::json & data) { - return data["result"].contains("serverInfo"); + return data["result"].contains("capabilities"); +} + +static bool has_tools_list_response(const nlohmann::json & data) { + return data["result"].contains("tools"); } bool mcp::create_message(const std::string & data, mcp::message_variant & message) { @@ -266,7 +294,11 @@ bool mcp::create_message(const std::string & data, mcp::message_variant & messag if (has_initialized_response(j)) { message = mcp::initialize_response::fromJson(j); + } else if (has_tools_list_response(j)) { + message = mcp::tools_list_response::fromJson(j); + } else { + message = std::monostate(); return false; } return true; diff --git a/toolcall/mcp_messages.h b/toolcall/mcp_messages.h index 31d62e3469e19..b5fef1783af86 100644 --- a/toolcall/mcp_messages.h +++ b/toolcall/mcp_messages.h @@ -208,6 +208,8 @@ namespace mcp void next_cursor(std::string next_cursor); const std::string & next_cursor() const { return next_cursor_; } + static tools_list_response fromJson(const nlohmann::json & j); + private: void refreshResult(); tools_list tools_; From 8a3497b4de05e2c9ffa2ca0191e1bf6120401e9f Mon Sep 17 00:00:00 2001 From: Mason M Date: Sun, 23 Feb 2025 19:28:47 -0400 Subject: [PATCH 46/69] Convert tool list string --- toolcall/handler.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index 50e38a2ed5761..1b8d2deaec765 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -158,7 +158,7 @@ static std::string tools_list_to_oai_json(const mcp::tools_list & tools) { } }); } - return tool_list; + return tool_list.dump(-1); } std::string toolcall::mcp_impl::tool_list() { From 1209b953848f71e3e49b759346151d9c35986555 Mon Sep 17 00:00:00 2001 From: Mason M Date: Sun, 23 Feb 2025 20:00:58 -0400 Subject: [PATCH 47/69] Only invoke toolcall with valid JSON --- examples/main/main.cpp | 14 ++++++++++---- toolcall/handler.cpp | 29 ++++++++++++++--------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 582a8ab474bff..7c23cd00c223b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -5,6 +5,7 @@ #include "sampling.h" #include "llama.h" #include "chat.h" +#include #include #include @@ -126,12 +127,17 @@ class chat_formatter { LOG_DBG("formatted: '%s'\n", formatted.c_str()); #ifdef LLAMA_USE_TOOLCALL - if (params_.use_jinja) { + if (params_.use_jinja && use_toolcalls) { common_chat_grammar_to_sampler(&cparams, vocab_, ¶ms_.sampling); if (tc_handler_ != nullptr) { - std::string response; - tc_handler_->call(formatted, response); - return std::string(response); + if (nlohmann::json::accept(formatted)) { // May need a better way to ensure + std::string response; // this is intended for a tool-call. + tc_handler_->call(formatted, response); + return std::string(response); + + } else { + return formatted; + } } } #endif diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index 1b8d2deaec765..1285f8f06b525 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -134,30 +134,29 @@ void toolcall::mcp_impl::initialize() { static std::string tools_list_to_oai_json(const mcp::tools_list & tools) { json tool_list = json::array(); for (const auto & tool : tools) { - json props; + json t = json::object(); + + t["type"] = "function"; + t["function"]["name"] = tool.tool_name; + t["function"]["description"] = tool.tool_description; + + json props = json::object(); for (const auto & param : tool.params) { props[param.name]["type"] = param.type; props[param.name]["description"] = param.description; } + t["function"]["parameters"]["type"] = "object"; + t["function"]["parameters"]["properties"] = props; + json required = json::array(); for (const auto & name : tool.required_params) { required.push_back(name); } - tool_list.push_back({ - {"type", "function"}, - {"function", { - {"name", tool.tool_name}, - {"description", tool.tool_description}, - {"parameters", { - {"type", "object"}, - {"properties", props} - } - }, - {"required", required} - } - } - }); + t["function"]["required"] = required; + + tool_list.push_back(t); } + return tool_list.dump(-1); } From 52b98d8aab3d2228937ad5f12781b2bdb64285e0 Mon Sep 17 00:00:00 2001 From: Mason M Date: Mon, 24 Feb 2025 11:34:11 -0400 Subject: [PATCH 48/69] Add tool-call request/response types --- toolcall/mcp_messages.cpp | 76 +++++++++++++++++++++++++++++++++++++++ toolcall/mcp_messages.h | 52 ++++++++++++++++++++++++++- 2 files changed, 127 insertions(+), 1 deletion(-) diff --git a/toolcall/mcp_messages.cpp b/toolcall/mcp_messages.cpp index 3ad196e53ba1c..b49d3a2ff3ccd 100644 --- a/toolcall/mcp_messages.cpp +++ b/toolcall/mcp_messages.cpp @@ -280,6 +280,82 @@ mcp::tools_list_response mcp::tools_list_response::fromJson(const nlohmann::json return tools_list_response(j["id"], std::move(tools), next_cursor); } +mcp::tools_call_request::tools_call_request(nlohmann::json id, std::string name, tool_arg_list args) + : request(id, "tools/call"), name_(std::move(name)), args_(std::move(args)) +{ + refreshParams(); +} + +void mcp::tools_call_request::name(std::string name) { + name_ = std::move(name); + refreshParams(); +} + +void mcp::tools_call_request::args(mcp::tool_arg_list args) { + args_ = std::move(args); + refreshParams(); +} + +void mcp::tools_call_request::refreshParams() { + json params = json::object(); + params["name"] = name_; + if (! args_.empty()) { + json args = json::object(); + for (const auto & arg : args_) { + args[arg.name] = arg.value; + } + params["arguments"] = args; + } + this->params(params); +} + +mcp::tools_call_response::tools_call_response(nlohmann::json id, tool_result_list result, bool error) + : response(id), tool_result_(std::move(result)), error_(error) +{ + refreshResult(); +} + +void mcp::tools_call_response::tool_result(mcp::tool_result_list result) { + tool_result_ = std::move(result); + refreshResult(); +} + +void mcp::tools_call_response::tool_error(bool error) { + error_ = error; + refreshResult(); +} + +void mcp::tools_call_response::refreshResult() { + json result = json::object(); + result["isError"] = error_; + json content = json::array(); + for (const auto & res : tool_result_) { + json r; + r["type"] = res.type; + if (res.type == "text") { + r["text"] = res.value; + + } else if (res.type == "image" || res.type == "audio") { + r["data"] = res.value; + r["mimeType"] = res.mime_type.value(); // throws + + } else if (res.type == "resource") { + json rr; + rr["uri"] = res.uri.value(); // throws + rr["mimeType"] = res.mime_type.value(); //throws + rr["text"] = res.value; + + r["resource"] = rr; + + } else { + // throw + } + content.push_back(r); + } + result["content"] = content; + this->result(std::move(result)); +} + static bool has_initialized_response(const nlohmann::json & data) { return data["result"].contains("capabilities"); } diff --git a/toolcall/mcp_messages.h b/toolcall/mcp_messages.h index b5fef1783af86..0caa643e28bfe 100644 --- a/toolcall/mcp_messages.h +++ b/toolcall/mcp_messages.h @@ -222,6 +222,54 @@ namespace mcp : notification("notifications/tools/list_changed") {} }; + struct tool_arg { + std::string name; + std::string value; + }; + + using tool_arg_list = std::vector; + + class tools_call_request : public request { + public: + tools_call_request(nlohmann::json id, std::string name, tool_arg_list args = tool_arg_list()); + + void name(std::string name); + const std::string & name() const { return name_; } + + void args(tool_arg_list args); + const tool_arg_list args() const { return args_; } + + private: + void refreshParams(); + std::string name_; + tool_arg_list args_; + }; + + struct tool_result { + std::string type; // text, image, audio, or resource + std::string value; + std::optional mime_type; // Only for: image, audio, and resource + std::optional uri; // Only for: resource + }; + + using tool_result_list = std::vector; + + class tools_call_response : public response { + public: + tools_call_response(nlohmann::json id, tool_result_list result = tool_result_list(), bool error = false); + + void tool_result(tool_result_list result); + const tool_result_list & tool_result() const { return tool_result_; } + + void tool_error(bool error); + bool tool_error() const { return error_; } + + private: + void refreshResult(); + tool_result_list tool_result_; + bool error_; + }; + using message_variant = std::variant; + tools_list_changed_notification, + tools_call_request, + tools_call_response>; bool create_message(const std::string & data, message_variant & message); } From 606993d6dc5a66e88d9a3542047998a21a578788 Mon Sep 17 00:00:00 2001 From: Mason M Date: Mon, 24 Feb 2025 16:29:52 -0400 Subject: [PATCH 49/69] Move transport message-dispatch to base type --- toolcall/handler.cpp | 24 +++++----------- toolcall/mcp_messages.cpp | 33 ++++++---------------- toolcall/mcp_messages.h | 14 ++-------- toolcall/mcp_sse_transport.cpp | 10 ++++--- toolcall/mcp_transport.h | 51 +++++++++++++++++++++------------- 5 files changed, 56 insertions(+), 76 deletions(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index 1285f8f06b525..250031a45b434 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -102,16 +102,12 @@ void toolcall::mcp_impl::initialize() { tools_populating_.notify_one(); }; - transport_->subscribe("set_caps", set_caps); - - mcp::initialize_request req(next_id_++); - transport_->send(req); - + transport_->send(mcp::initialize_request(next_id_++), set_caps); tools_populating_.wait_for(lock, std::chrono::seconds(15), [&caps_received] { return caps_received; }); - transport_->unsubscribe("set_caps"); - on_list_changed update_dirty = [this] (const mcp::tools_list_changed_notification &) { + on_list_changed update_dirty = [&update_dirty, this] (const mcp::tools_list_changed_notification &) { tool_list_dirty_ = true; + transport_->subscribe("notifications/tools/list_changed", update_dirty); }; bool has_tools = false; @@ -119,7 +115,7 @@ void toolcall::mcp_impl::initialize() { if (cap.name == "tools") { has_tools = true; if (cap.listChanged) { - transport_->subscribe("update_dirty", update_dirty); + transport_->subscribe("notifications/tools/list_changed", update_dirty); } break; } @@ -167,14 +163,13 @@ std::string toolcall::mcp_impl::tool_list() { std::unique_lock lock(tools_mutex_); mcp::tools_list tools; - on_response set_tools = [this, &tools] (const mcp::tools_list_response & resp) { + on_response set_tools = [this, &tools, &set_tools] (const mcp::tools_list_response & resp) { std::unique_lock lock(tools_mutex_); tools.insert(tools.end(), resp.tools().begin(), resp.tools().end()); auto cursor = resp.next_cursor(); if (! cursor.empty()) { - mcp::tools_list_request req(next_id_++, cursor); - transport_->send(req); + transport_->send(mcp::tools_list_request(next_id_++, cursor), set_tools); return; } tool_list_dirty_ = false; @@ -182,13 +177,8 @@ std::string toolcall::mcp_impl::tool_list() { tools_populating_.notify_one(); }; - transport_->subscribe("set_tools", set_tools); - - mcp::tools_list_request req(next_id_++); - transport_->send(req); - + transport_->send(mcp::tools_list_request(next_id_++), set_tools); tools_populating_.wait_for(lock, std::chrono::seconds(15), [this] { return ! tool_list_dirty_; }); - transport_->unsubscribe("set_tools"); tools_ = tools_list_to_oai_json(tools); } diff --git a/toolcall/mcp_messages.cpp b/toolcall/mcp_messages.cpp index b49d3a2ff3ccd..bc468b10dc520 100644 --- a/toolcall/mcp_messages.cpp +++ b/toolcall/mcp_messages.cpp @@ -1,6 +1,7 @@ #include "mcp_messages.h" #include #include +#include using json = nlohmann::json; @@ -280,6 +281,14 @@ mcp::tools_list_response mcp::tools_list_response::fromJson(const nlohmann::json return tools_list_response(j["id"], std::move(tools), next_cursor); } +mcp::tools_list_changed_notification mcp::tools_list_changed_notification::fromJson(const nlohmann::json & j) { + if (! (j.is_object() && j.contains("method") && + j["method"] == "notifications/tools/list_changed")) { + throw std::invalid_argument("Invalid tools_list_changed message"); + } + return tools_list_changed_notification(); +} + mcp::tools_call_request::tools_call_request(nlohmann::json id, std::string name, tool_arg_list args) : request(id, "tools/call"), name_(std::move(name)), args_(std::move(args)) { @@ -355,27 +364,3 @@ void mcp::tools_call_response::refreshResult() { result["content"] = content; this->result(std::move(result)); } - -static bool has_initialized_response(const nlohmann::json & data) { - return data["result"].contains("capabilities"); -} - -static bool has_tools_list_response(const nlohmann::json & data) { - return data["result"].contains("tools"); -} - -bool mcp::create_message(const std::string & data, mcp::message_variant & message) { - json j = json::parse(data); - - if (has_initialized_response(j)) { - message = mcp::initialize_response::fromJson(j); - - } else if (has_tools_list_response(j)) { - message = mcp::tools_list_response::fromJson(j); - - } else { - message = std::monostate(); - return false; - } - return true; -} diff --git a/toolcall/mcp_messages.h b/toolcall/mcp_messages.h index 0caa643e28bfe..2e32cdc8424b1 100644 --- a/toolcall/mcp_messages.h +++ b/toolcall/mcp_messages.h @@ -220,6 +220,8 @@ namespace mcp public: tools_list_changed_notification() : notification("notifications/tools/list_changed") {} + + static tools_list_changed_notification fromJson(const nlohmann::json & j); }; struct tool_arg { @@ -270,16 +272,4 @@ namespace mcp bool error_; }; - using message_variant = - std::variant; - - bool create_message(const std::string & data, message_variant & message); } diff --git a/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp index ab28c6dae4d93..3dfbffc49f780 100644 --- a/toolcall/mcp_sse_transport.cpp +++ b/toolcall/mcp_sse_transport.cpp @@ -162,10 +162,12 @@ void toolcall::mcp_sse_transport::on_endpoint_event() { } void toolcall::mcp_sse_transport::on_message_event() { - mcp::message_variant message; - if (mcp::create_message(event_.data, message)) { - notify_if(message); - notify_if(message); + try { + nlohmann::json message = nlohmann::json::parse(event_.data); + notify(message); + + } catch (const nlohmann::json::exception & err) { + LOG_WRN("SSE: Invalid message \"%s\" received: \"%s\"\n", event_.data.c_str(), err.what()); } } diff --git a/toolcall/mcp_transport.h b/toolcall/mcp_transport.h index 084c6bef28ef5..e7d0aeae6b346 100644 --- a/toolcall/mcp_transport.h +++ b/toolcall/mcp_transport.h @@ -31,39 +31,52 @@ namespace toolcall map.erase(key); } - template - void notify(const T & message) const { - const auto& map = - std::get>>( - subscribers_); + void notify(const nlohmann::json & message) { + std::string key; + if (message.contains("id")) { + key = message["id"].dump(); - for (const auto & pair : map) { - pair.second(message); - } - } + } else if (message.contains("method")) { + key = message["method"].dump(); - template - void notify_if(const mcp::message_variant & message) { - if (std::holds_alternative(message)) { - notify(std::get(message)); + } else { + return; } + std::apply([&key, &message, this](auto&... maps) { + (..., [&] { + auto it = maps.find(key); + if (it != maps.end()) { + using callback_type = decltype(it->second); + using T = typename std::decay::type; + + it->second(T::fromJson(message)); + maps.erase(it); + } + }()); + }, subscribers_); } private: std::tuple>...> subscribers_; }; - class mcp_transport : public mcp_message_observer { public: virtual ~mcp_transport() = default; - template - bool send(const T & message) { + template + bool send(const Req & message, callback on_response) { + if (message.id().has_value()) { + std::string id = message.id().value().dump(); + subscribe(id, on_response); + } + return send(message); + } + + template + bool send(const Req & message) { nlohmann::json json = message.toJson(); return send(json.dump(-1)); } From 67438a3d4b510a8d6e158b3f7a1767c2d61fa3c3 Mon Sep 17 00:00:00 2001 From: Mason M Date: Mon, 24 Feb 2025 18:25:46 -0400 Subject: [PATCH 50/69] Add tools_call_response fromJson --- toolcall/handler.cpp | 4 ++-- toolcall/mcp_messages.cpp | 34 ++++++++++++++++++++++++++++++---- toolcall/mcp_messages.h | 16 ++++++++++++---- toolcall/mcp_transport.h | 12 +++++++++++- 4 files changed, 55 insertions(+), 11 deletions(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index 250031a45b434..1864dc19980c5 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -107,7 +107,7 @@ void toolcall::mcp_impl::initialize() { on_list_changed update_dirty = [&update_dirty, this] (const mcp::tools_list_changed_notification &) { tool_list_dirty_ = true; - transport_->subscribe("notifications/tools/list_changed", update_dirty); + transport_->subscribe(update_dirty); }; bool has_tools = false; @@ -115,7 +115,7 @@ void toolcall::mcp_impl::initialize() { if (cap.name == "tools") { has_tools = true; if (cap.listChanged) { - transport_->subscribe("notifications/tools/list_changed", update_dirty); + transport_->subscribe(update_dirty); } break; } diff --git a/toolcall/mcp_messages.cpp b/toolcall/mcp_messages.cpp index bc468b10dc520..3dc9c9a14759c 100644 --- a/toolcall/mcp_messages.cpp +++ b/toolcall/mcp_messages.cpp @@ -181,7 +181,7 @@ mcp::initialize_response mcp::initialize_response::fromJson(const nlohmann::json } mcp::tools_list_request::tools_list_request(std::optional id, std::string cursor) - : request(id, "tools/list"), + : request(id, Method), cursor_(std::move(cursor)) { refreshParams(); @@ -282,15 +282,14 @@ mcp::tools_list_response mcp::tools_list_response::fromJson(const nlohmann::json } mcp::tools_list_changed_notification mcp::tools_list_changed_notification::fromJson(const nlohmann::json & j) { - if (! (j.is_object() && j.contains("method") && - j["method"] == "notifications/tools/list_changed")) { + if (! (j.is_object() && j.contains("method") && j["method"] == Method)) { throw std::invalid_argument("Invalid tools_list_changed message"); } return tools_list_changed_notification(); } mcp::tools_call_request::tools_call_request(nlohmann::json id, std::string name, tool_arg_list args) - : request(id, "tools/call"), name_(std::move(name)), args_(std::move(args)) + : request(id, Method), name_(std::move(name)), args_(std::move(args)) { refreshParams(); } @@ -364,3 +363,30 @@ void mcp::tools_call_response::refreshResult() { result["content"] = content; this->result(std::move(result)); } + +mcp::tools_call_response mcp::tools_call_response::fromJson(const nlohmann::json & j) { + mcp::tool_result_list result_list; + for (const auto & content : j["result"]["content"]) { + mcp::tool_result result; + + result.type = content["type"]; + if (content["type"] == "text") { + result.value = content["text"]; + + } else if (content["type"] == "image" || content["type"] == "audio") { + result.value = content["data"]; + result.mime_type = content["mimeType"]; + + } else if (content["type"] == "resource") { + result.value = content["resource"]["text"]; + result.mime_type = content["resource"]["mimeType"]; + result.uri = content["resource"]["uri"]; + } + + result_list.push_back(std::move(result)); + } + + bool error = j["result"].value("isError", false); + + return mcp::tools_call_response(j["id"], std::move(result_list), error); +} diff --git a/toolcall/mcp_messages.h b/toolcall/mcp_messages.h index 2e32cdc8424b1..b11f17aa883fc 100644 --- a/toolcall/mcp_messages.h +++ b/toolcall/mcp_messages.h @@ -166,12 +166,15 @@ namespace mcp class initialized_notification : public notification { public: - initialized_notification() - : notification("notifications/initialized") {} + static inline const std::string Method = "notifications/initialized"; + + initialized_notification() : notification(Method) {} }; class tools_list_request : public request { public: + static inline const std::string Method = "tools/list"; + tools_list_request(std::optional id, std::string cursor = ""); void cursor(std::string cursor); @@ -218,8 +221,9 @@ namespace mcp class tools_list_changed_notification : public notification { public: - tools_list_changed_notification() - : notification("notifications/tools/list_changed") {} + static inline const std::string Method = "notifications/tools/list_changed"; + + tools_list_changed_notification() : notification(Method) {} static tools_list_changed_notification fromJson(const nlohmann::json & j); }; @@ -233,6 +237,8 @@ namespace mcp class tools_call_request : public request { public: + static inline const std::string Method = "tools/call"; + tools_call_request(nlohmann::json id, std::string name, tool_arg_list args = tool_arg_list()); void name(std::string name); @@ -266,6 +272,8 @@ namespace mcp void tool_error(bool error); bool tool_error() const { return error_; } + static tools_call_response fromJson(const nlohmann::json & j); + private: void refreshResult(); tool_result_list tool_result_; diff --git a/toolcall/mcp_transport.h b/toolcall/mcp_transport.h index e7d0aeae6b346..e4383434da541 100644 --- a/toolcall/mcp_transport.h +++ b/toolcall/mcp_transport.h @@ -22,6 +22,15 @@ namespace toolcall map.insert({key, callback}); } + template + void subscribe(callback callback) { + auto& map = + std::get>>( + subscribers_); + + map.insert({T::Method, callback}); + } + template void unsubscribe(std::string key) { auto& map = @@ -62,7 +71,8 @@ namespace toolcall class mcp_transport : public mcp_message_observer { + mcp::tools_list_changed_notification, + mcp::tools_call_response> { public: virtual ~mcp_transport() = default; From 7a23d06758fee519c696688e0381d2f5a97243a1 Mon Sep 17 00:00:00 2001 From: Mason M Date: Mon, 24 Feb 2025 19:19:24 -0400 Subject: [PATCH 51/69] Implement call routine --- toolcall/handler.cpp | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index 1864dc19980c5..c5526fc00edf6 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -185,10 +185,35 @@ std::string toolcall::mcp_impl::tool_list() { return tools_; } -toolcall::action toolcall::mcp_impl::call(const std::string & /*request*/, std::string & /*response*/) { +static mcp::tools_call_request tools_call_request_from_local_json(nlohmann::json id, const std::string & local_json) { + nlohmann::json j = json::parse(local_json); + mcp::tool_arg_list args; + for (const auto & [key, val] : j["parameters"].items()) { + args.push_back({key, val.dump()}); + } + return mcp::tools_call_request(id, j["name"], args); +} + +static std::string tools_call_response_to_local_json(const mcp::tools_call_response & resp) { + return resp.toJson().dump(-1); // The AI will figure it out? +} + +toolcall::action toolcall::mcp_impl::call(const std::string & request, std::string & response) { + using on_response = toolcall::callback; + if (transport_ == nullptr) { return toolcall::DEFER; } - // Construct tool call and send to transport - return toolcall::ACCEPT; // TODO + std::unique_lock lock(tools_mutex_); + + response.clear(); + on_response set_response = [this, &response] (const mcp::tools_call_response & resp) { + std::unique_lock lock(tools_mutex_); + response = tools_call_response_to_local_json(resp); + tools_populating_.notify_one(); + }; + transport_->send(tools_call_request_from_local_json(next_id_++, request), set_response); + tools_populating_.wait_for(lock, std::chrono::seconds(15), [&response] { return ! response.empty(); }); + + return toolcall::ACCEPT; } From c2e531abf96df91019124301a90f19b06cc80260 Mon Sep 17 00:00:00 2001 From: Mason M Date: Mon, 24 Feb 2025 22:32:18 -0400 Subject: [PATCH 52/69] Fix whitespace --- toolcall/mcp_transport.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/toolcall/mcp_transport.h b/toolcall/mcp_transport.h index e4383434da541..f2a2fba24d880 100644 --- a/toolcall/mcp_transport.h +++ b/toolcall/mcp_transport.h @@ -22,14 +22,14 @@ namespace toolcall map.insert({key, callback}); } - template - void subscribe(callback callback) { - auto& map = + template + void subscribe(callback callback) { + auto& map = std::get>>( subscribers_); - map.insert({T::Method, callback}); - } + map.insert({T::Method, callback}); + } template void unsubscribe(std::string key) { From ce5c46ce46dc0f1033815e4c11491838a71772e7 Mon Sep 17 00:00:00 2001 From: Mason M Date: Mon, 24 Feb 2025 22:43:20 -0400 Subject: [PATCH 53/69] Preserver argument value --- toolcall/handler.cpp | 2 +- toolcall/mcp_messages.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index c5526fc00edf6..4c3940929d7ab 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -189,7 +189,7 @@ static mcp::tools_call_request tools_call_request_from_local_json(nlohmann::json nlohmann::json j = json::parse(local_json); mcp::tool_arg_list args; for (const auto & [key, val] : j["parameters"].items()) { - args.push_back({key, val.dump()}); + args.push_back({key, val}); } return mcp::tools_call_request(id, j["name"], args); } diff --git a/toolcall/mcp_messages.h b/toolcall/mcp_messages.h index b11f17aa883fc..ad20b45fef5f9 100644 --- a/toolcall/mcp_messages.h +++ b/toolcall/mcp_messages.h @@ -230,7 +230,7 @@ namespace mcp struct tool_arg { std::string name; - std::string value; + nlohmann::json value; }; using tool_arg_list = std::vector; From 850e043382c0f59abfc941dfe64dd5e02943979f Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 25 Feb 2025 07:31:17 -0400 Subject: [PATCH 54/69] Refactor tool/call response --- examples/main/main.cpp | 46 +++++++++++++++---------------------- toolcall/handler.cpp | 31 +++++++++++++------------ toolcall/toolcall-handler.h | 27 ++++++++++++---------- 3 files changed, 49 insertions(+), 55 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7c23cd00c223b..2f14ab1eb1eaf 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -106,10 +106,26 @@ class chat_formatter { #endif std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false) { + common_chat_msg new_msg; new_msg.role = role; new_msg.content = content; +#ifdef LLAMA_USE_TOOLCALL + if (params_.use_jinja && use_toolcalls) { + if (tc_handler_ != nullptr) { + if (nlohmann::json::accept(content)) { // Need a better way to know this is for a toolcall + toolcall::result_set res = tc_handler_->call(content); + std::string new_content; + for (const auto & r : res) { + new_content += (r.data + "\n"); + } + new_msg.content = new_content; // TODO: this is not wiring correctly into the prompt + } + } + } +#endif + common_chat_params cparams; common_chat_templates_inputs cinputs; #ifdef LLAMA_USE_TOOLCALL @@ -126,21 +142,8 @@ class chat_formatter { chat_msgs_.push_back(new_msg); LOG_DBG("formatted: '%s'\n", formatted.c_str()); -#ifdef LLAMA_USE_TOOLCALL - if (params_.use_jinja && use_toolcalls) { - common_chat_grammar_to_sampler(&cparams, vocab_, ¶ms_.sampling); - if (tc_handler_ != nullptr) { - if (nlohmann::json::accept(formatted)) { // May need a better way to ensure - std::string response; // this is intended for a tool-call. - tc_handler_->call(formatted, response); - return std::string(response); + common_chat_grammar_to_sampler(&cparams, vocab_, ¶ms_.sampling); - } else { - return formatted; - } - } - } -#endif return formatted; } @@ -855,22 +858,9 @@ int main(int argc, char ** argv) { } if (params.enable_chat_template) { -#ifdef LLAMA_USE_TOOLCALL - auto output = chat_add_and_format("assistant", assistant_ss.str(), true); - if (tc_handler == nullptr || tc_handler->last_action() != toolcall::ACCEPT) { - is_interacting = true; - LOG("\n"); - - } else { - LOG_DBG("tokenizing toolcall response"); - auto response = common_tokenize(ctx, output, false, true); - embd_inp.insert(embd_inp.end(), response.begin(), response.end()); - } -#else - chat_add_and_format("assistant", assistant_ss.str()); + chat_add_and_format("assistant", assistant_ss.str(), true); is_interacting = true; LOG("\n"); -#endif } } } diff --git a/toolcall/handler.cpp b/toolcall/handler.cpp index 4c3940929d7ab..e557eece89be8 100644 --- a/toolcall/handler.cpp +++ b/toolcall/handler.cpp @@ -5,7 +5,7 @@ #include #ifdef LLAMA_USE_CURL -# include "mcp_sse_transport.h" +# include "mcp_sse_transport.h" #endif #include "mcp_stdio_transport.h" @@ -39,19 +39,14 @@ bool toolcall::handler::tool_list_dirty() const { return impl_->tool_list_dirty(); } -toolcall::action toolcall::handler::call(const std::string & request, std::string & response) { - last_action_ = impl_->call(request, response); - return last_action_; +toolcall::result_set toolcall::handler::call(const std::string & request) { + return impl_->call(request); } const std::string & toolcall::handler::tool_choice() const { return impl_->tool_choice(); } -toolcall::action toolcall::handler::last_action() const { - return last_action_; -} - void toolcall::handler::initialize() { impl_->initialize(); } @@ -194,26 +189,32 @@ static mcp::tools_call_request tools_call_request_from_local_json(nlohmann::json return mcp::tools_call_request(id, j["name"], args); } -static std::string tools_call_response_to_local_json(const mcp::tools_call_response & resp) { - return resp.toJson().dump(-1); // The AI will figure it out? +static toolcall::result_set tools_call_response_to_result(const mcp::tools_call_response & resp) { + toolcall::result_set result; + for (const auto & res : resp.tool_result()) { + result.push_back(toolcall::result{ + res.type, res.value, res.mime_type.value_or("text/plain"), res.uri, resp.tool_error() + }); + } + return std::move(result); } -toolcall::action toolcall::mcp_impl::call(const std::string & request, std::string & response) { +toolcall::result_set toolcall::mcp_impl::call(const std::string & request) { using on_response = toolcall::callback; if (transport_ == nullptr) { - return toolcall::DEFER; + return toolcall::result_set(); } std::unique_lock lock(tools_mutex_); - response.clear(); + toolcall::result_set response; on_response set_response = [this, &response] (const mcp::tools_call_response & resp) { std::unique_lock lock(tools_mutex_); - response = tools_call_response_to_local_json(resp); + response = tools_call_response_to_result(resp); tools_populating_.notify_one(); }; transport_->send(tools_call_request_from_local_json(next_id_++, request), set_response); tools_populating_.wait_for(lock, std::chrono::seconds(15), [&response] { return ! response.empty(); }); - return toolcall::ACCEPT; + return response; } diff --git a/toolcall/toolcall-handler.h b/toolcall/toolcall-handler.h index c3e97ae24a69a..0da5820a855b8 100644 --- a/toolcall/toolcall-handler.h +++ b/toolcall/toolcall-handler.h @@ -10,12 +10,16 @@ namespace toolcall { - enum action { - ACCEPT, - PENDING, - DEFER + struct result { + std::string type; + std::string data; + std::string mime_type; + std::optional uri; + bool error; }; + using result_set = std::vector; + class handler_impl; class handler { public: @@ -23,19 +27,17 @@ namespace toolcall handler(std::unique_ptr impl) : impl_(std::move(impl)) {} - action call(const std::string & request, std::string & response); + result_set call(const std::string & request); std::string tool_list(); bool tool_list_dirty() const; const std::string & tool_choice() const; - action last_action() const; void initialize(); private: std::unique_ptr impl_; - action last_action_; }; std::shared_ptr create_handler(const toolcall::params & params); @@ -53,7 +55,7 @@ namespace toolcall return tool_list_dirty_; } - virtual action call(const std::string & request, std::string & response) = 0; + virtual result_set call(const std::string & request) = 0; const std::string & tool_choice() const { return tool_choice_; } @@ -74,9 +76,10 @@ namespace toolcall return tools_; } - virtual action call(const std::string & request, std::string & response) override { - response = request; - return toolcall::DEFER; + virtual result_set call(const std::string & request) override { + return { + {"text", request, "text/plain", std::nullopt, false} + }; } private: @@ -90,7 +93,7 @@ namespace toolcall mcp_impl(std::vector argv, std::string tool_choice); virtual std::string tool_list() override; - virtual action call(const std::string & request, std::string & response) override; + virtual result_set call(const std::string & request) override; virtual void initialize() override; From 0b52627ae8c949439ff79f5e41b0c242303d7c9c Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 25 Feb 2025 07:41:04 -0400 Subject: [PATCH 55/69] Add missing header --- toolcall/toolcall-handler.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/toolcall/toolcall-handler.h b/toolcall/toolcall-handler.h index 0da5820a855b8..b6337297aedcc 100644 --- a/toolcall/toolcall-handler.h +++ b/toolcall/toolcall-handler.h @@ -2,7 +2,7 @@ #include "toolcall-params.h" #include -#include +#include #include #include #include @@ -77,9 +77,9 @@ namespace toolcall } virtual result_set call(const std::string & request) override { - return { - {"text", request, "text/plain", std::nullopt, false} - }; + return result_set { + {"text", request, "text/plain", std::nullopt, false} + }; } private: From 66eff76b90dad2ec5d25f7bcfdb2983ce0a887f4 Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 25 Feb 2025 09:45:34 -0400 Subject: [PATCH 56/69] Move tool-call invocation into main loop --- examples/main/main.cpp | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2f14ab1eb1eaf..b626c4fef6b75 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -111,21 +111,6 @@ class chat_formatter { new_msg.role = role; new_msg.content = content; -#ifdef LLAMA_USE_TOOLCALL - if (params_.use_jinja && use_toolcalls) { - if (tc_handler_ != nullptr) { - if (nlohmann::json::accept(content)) { // Need a better way to know this is for a toolcall - toolcall::result_set res = tc_handler_->call(content); - std::string new_content; - for (const auto & r : res) { - new_content += (r.data + "\n"); - } - new_msg.content = new_content; // TODO: this is not wiring correctly into the prompt - } - } - } -#endif - common_chat_params cparams; common_chat_templates_inputs cinputs; #ifdef LLAMA_USE_TOOLCALL @@ -142,8 +127,9 @@ class chat_formatter { chat_msgs_.push_back(new_msg); LOG_DBG("formatted: '%s'\n", formatted.c_str()); +#ifdef LLAMA_USE_TOOLCALL common_chat_grammar_to_sampler(&cparams, vocab_, ¶ms_.sampling); - +#endif return formatted; } @@ -859,8 +845,27 @@ int main(int argc, char ** argv) { if (params.enable_chat_template) { chat_add_and_format("assistant", assistant_ss.str(), true); +#ifdef LLAMA_USE_TOOLCALL + if (! params.use_jinja || tc_handler == nullptr || ! nlohmann::json::accept(assistant_ss.str())) { + is_interacting = true; + LOG("\n"); + + } else { + toolcall::result_set res = tc_handler->call(assistant_ss.str()); + if (! res.empty()) { + std::string toolcall_result_str; + for (const auto & r : res) { + toolcall_result_str += (r.data + "\n"); + } + auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true); + embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end()); + } + } +#else + is_interacting = true; LOG("\n"); +#endif } } } From a097b4f1914cb7e69c02f63cc64e8c675dee1d45 Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 25 Feb 2025 10:41:31 -0400 Subject: [PATCH 57/69] Add tighter check before running toolcalls --- examples/main/main.cpp | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b626c4fef6b75..9a14212c49e64 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -846,20 +846,33 @@ int main(int argc, char ** argv) { if (params.enable_chat_template) { chat_add_and_format("assistant", assistant_ss.str(), true); #ifdef LLAMA_USE_TOOLCALL - if (! params.use_jinja || tc_handler == nullptr || ! nlohmann::json::accept(assistant_ss.str())) { - is_interacting = true; - LOG("\n"); + auto should_use_toolcall = [¶ms, tc_handler] (const std::string & asst_msg) { + if (! params.use_jinja || tc_handler == nullptr) { + return false; + } + try { + nlohmann::json j = nlohmann::json::parse(asst_msg); + return (j.contains("name") && j.contains("parameters")); - } else { + } catch (const nlohmann::json::exception & err) { + return false; + } + }; + + if (should_use_toolcall(assistant_ss.str())) { toolcall::result_set res = tc_handler->call(assistant_ss.str()); if (! res.empty()) { std::string toolcall_result_str; for (const auto & r : res) { - toolcall_result_str += (r.data + "\n"); + toolcall_result_str += ("\n" + r.data); } auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true); embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end()); } + + } else { + is_interacting = true; + LOG("\n"); } #else From 9db9686140c92a0f52a140fdbf1d6542c2976842 Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 25 Feb 2025 14:27:00 -0400 Subject: [PATCH 58/69] Remove toolcall dependency from common --- common/CMakeLists.txt | 4 ---- common/arg.cpp | 6 ++---- common/common.h | 13 ++++++------- examples/main/CMakeLists.txt | 5 +++++ examples/main/main.cpp | 7 ++++++- 5 files changed, 19 insertions(+), 16 deletions(-) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 8575f9e73085f..17146fffc1168 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -139,7 +139,3 @@ endif () target_include_directories(${TARGET} PUBLIC .) target_compile_features (${TARGET} PUBLIC cxx_std_17) target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) - -if (LLAMA_TOOLCALL) - target_link_libraries(${TARGET} PUBLIC toolcall) -endif() diff --git a/common/arg.cpp b/common/arg.cpp index 8685369d3651e..f1daa91e4c084 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2143,13 +2143,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); -#ifdef LLAMA_USE_TOOLCALL add_opt(common_arg( {"--tools"}, "JINJA_TOOLS", "set to URI of a Model Context Protocol server, or " "a JSON array containing tool definitions (requires --jinja)", [](common_params ¶ms, const std::string & value) { - params.jinja_tools.tools(value); + params.toolcall.tools = value; }).set_examples({LLAMA_EXAMPLE_MAIN})); @@ -2157,10 +2156,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--tool-choice"}, "JINJA_TOOL_CHOICE", "set to \"auto\", \"required\", or \"none\" (default: \"auto\")", [](common_params ¶ms, const std::string & value) { - params.jinja_tools.choice(value); + params.toolcall.choice = value; }).set_examples({LLAMA_EXAMPLE_MAIN})); -#endif add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", diff --git a/common/common.h b/common/common.h index d6413b3e6d692..011c40417eccb 100644 --- a/common/common.h +++ b/common/common.h @@ -4,10 +4,6 @@ #include "llama-cpp.h" -#ifdef LLAMA_USE_TOOLCALL -# include "toolcall-params.h" -#endif - #include #include #include @@ -212,6 +208,11 @@ enum common_reasoning_format { COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` }; +struct common_toolcall_params { + std::string tools = ""; + std::string choice = "auto"; +}; + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -358,9 +359,7 @@ struct common_params { bool use_jinja = false; // NOLINT bool enable_chat_template = true; -#ifdef LLAMA_USE_TOOLCALL - toolcall::params jinja_tools; -#endif + struct common_toolcall_params toolcall; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; diff --git a/examples/main/CMakeLists.txt b/examples/main/CMakeLists.txt index af3d9150f8640..6cc0b23b11029 100644 --- a/examples/main/CMakeLists.txt +++ b/examples/main/CMakeLists.txt @@ -2,4 +2,9 @@ set(TARGET llama-cli) add_executable(${TARGET} main.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) + +if (LLAMA_TOOLCALL) + target_link_libraries(${TARGET} PRIVATE toolcall) +endif() + target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9a14212c49e64..6ece1206186f3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -155,6 +155,11 @@ int main(int argc, char ** argv) { auto & sparams = params.sampling; +#ifdef LLAMA_USE_TOOLCALL + // Ensure parameters are validated before the model loads + toolcall::params tc_params(params.toolcall.tools, params.toolcall.choice); +#endif + // save choice to use color for later // (note for later: this is a slightly awkward choice) console::init(params.simple_io, params.use_color); @@ -323,7 +328,7 @@ int main(int argc, char ** argv) { std::vector embd_inp; #ifdef LLAMA_USE_TOOLCALL - auto tc_handler = toolcall::create_handler(params.jinja_tools); + auto tc_handler = toolcall::create_handler(tc_params); if (tc_handler) { tc_handler->initialize(); } From e8dd857bd96131d4ef522fcfba8a0e6ca5406f06 Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 25 Feb 2025 14:55:32 -0400 Subject: [PATCH 59/69] Rename handler -> client to reflect MCP terminology --- examples/main/main.cpp | 28 +++++++++---------- toolcall/CMakeLists.txt | 4 +-- toolcall/{handler.cpp => client.cpp} | 28 +++++++++---------- .../{toolcall-handler.h => toolcall-client.h} | 24 ++++++++-------- 4 files changed, 42 insertions(+), 42 deletions(-) rename toolcall/{handler.cpp => client.cpp} (90%) rename toolcall/{toolcall-handler.h => toolcall-client.h} (78%) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6ece1206186f3..9627a80964cb3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -17,7 +17,7 @@ #include #ifdef LLAMA_USE_TOOLCALL -# include "toolcall-handler.h" +# include "toolcall-client.h" #endif #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) @@ -100,9 +100,9 @@ class chat_formatter { std::vector & chat_msgs, struct common_chat_templates * chat_templates, const llama_vocab * vocab, - toolcall::handler::ptr tc_handler) + toolcall::client::ptr tc_client) - : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_handler_(tc_handler) {} + : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client) {} #endif std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false) { @@ -114,9 +114,9 @@ class chat_formatter { common_chat_params cparams; common_chat_templates_inputs cinputs; #ifdef LLAMA_USE_TOOLCALL - if (tc_handler_ != nullptr && use_toolcalls) { - cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_handler_->tool_choice()); - cinputs.tools = common_chat_tools_parse_oaicompat(tc_handler_->tool_list()); + if (tc_client_ != nullptr && use_toolcalls) { + cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_client_->tool_choice()); + cinputs.tools = common_chat_tools_parse_oaicompat(tc_client_->tool_list()); } #endif bool add_ass = role == "user"; @@ -140,7 +140,7 @@ class chat_formatter { #ifdef LLAMA_USE_TOOLCALL const llama_vocab * vocab_; - toolcall::handler::ptr tc_handler_; + toolcall::client::ptr tc_client_; #endif }; @@ -328,11 +328,11 @@ int main(int argc, char ** argv) { std::vector embd_inp; #ifdef LLAMA_USE_TOOLCALL - auto tc_handler = toolcall::create_handler(tc_params); - if (tc_handler) { - tc_handler->initialize(); + auto tc_client = toolcall::create_client(tc_params); + if (tc_client) { + tc_client->initialize(); } - chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_handler); + chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client); #else chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get()); #endif @@ -851,8 +851,8 @@ int main(int argc, char ** argv) { if (params.enable_chat_template) { chat_add_and_format("assistant", assistant_ss.str(), true); #ifdef LLAMA_USE_TOOLCALL - auto should_use_toolcall = [¶ms, tc_handler] (const std::string & asst_msg) { - if (! params.use_jinja || tc_handler == nullptr) { + auto should_use_toolcall = [¶ms, tc_client] (const std::string & asst_msg) { + if (! params.use_jinja || tc_client == nullptr) { return false; } try { @@ -865,7 +865,7 @@ int main(int argc, char ** argv) { }; if (should_use_toolcall(assistant_ss.str())) { - toolcall::result_set res = tc_handler->call(assistant_ss.str()); + toolcall::result_set res = tc_client->call(assistant_ss.str()); if (! res.empty()) { std::string toolcall_result_str; for (const auto & r : res) { diff --git a/toolcall/CMakeLists.txt b/toolcall/CMakeLists.txt index 953a0c6f20f43..5fca0aee94fde 100644 --- a/toolcall/CMakeLists.txt +++ b/toolcall/CMakeLists.txt @@ -2,7 +2,7 @@ set(TARGET toolcall) set(SOURCES - handler.cpp + client.cpp mcp_messages.cpp mcp_stdio_transport.cpp params.cpp @@ -10,7 +10,7 @@ set(SOURCES set(HEADERS toolcall-params.h - toolcall-handler.h + toolcall-client.h mcp_transport.h mcp_messages.h mcp_stdio_transport.h diff --git a/toolcall/handler.cpp b/toolcall/client.cpp similarity index 90% rename from toolcall/handler.cpp rename to toolcall/client.cpp index e557eece89be8..d078f04794421 100644 --- a/toolcall/handler.cpp +++ b/toolcall/client.cpp @@ -1,6 +1,6 @@ #include -#include "toolcall-handler.h" +#include "toolcall-client.h" #include #include @@ -12,48 +12,48 @@ using json = nlohmann::json; -std::shared_ptr toolcall::create_handler(const toolcall::params & params) { - std::shared_ptr handler; +std::shared_ptr toolcall::create_client(const toolcall::params & params) { + std::shared_ptr client; auto tools = params.tools(); auto choice = params.choice(); if (params) { if (params.has_uri()) { #ifdef LLAMA_USE_CURL - handler.reset(new toolcall::handler( + client.reset(new toolcall::client( std::make_unique(tools, choice))); #endif } else { - handler.reset(new toolcall::handler( + client.reset(new toolcall::client( std::make_unique(tools, choice))); } } - return handler; + return client; } -std::string toolcall::handler::tool_list() { +std::string toolcall::client::tool_list() { return impl_->tool_list(); } -bool toolcall::handler::tool_list_dirty() const { +bool toolcall::client::tool_list_dirty() const { return impl_->tool_list_dirty(); } -toolcall::result_set toolcall::handler::call(const std::string & request) { +toolcall::result_set toolcall::client::call(const std::string & request) { return impl_->call(request); } -const std::string & toolcall::handler::tool_choice() const { +const std::string & toolcall::client::tool_choice() const { return impl_->tool_choice(); } -void toolcall::handler::initialize() { +void toolcall::client::initialize() { impl_->initialize(); } #ifdef LLAMA_USE_CURL toolcall::mcp_impl::mcp_impl(std::string server_uri, std::string tool_choice) - : handler_impl(tool_choice), + : client_impl(tool_choice), transport_(new mcp_sse_transport(server_uri)), tools_("[]"), tools_mutex_(), @@ -63,7 +63,7 @@ toolcall::mcp_impl::mcp_impl(std::string server_uri, std::string tool_choice) } #else toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, std::string tool_choice) - : handler_impl(tool_choice), + : client_impl(tool_choice), transport_(nullptr), tools_("[]"), tools_mutex_(), @@ -74,7 +74,7 @@ toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, std::string tool_choice #endif toolcall::mcp_impl::mcp_impl(std::vector argv, std::string tool_choice) - : handler_impl(tool_choice), + : client_impl(tool_choice), transport_(new mcp_stdio_transport(argv)) { } diff --git a/toolcall/toolcall-handler.h b/toolcall/toolcall-client.h similarity index 78% rename from toolcall/toolcall-handler.h rename to toolcall/toolcall-client.h index b6337297aedcc..936352ab4efe5 100644 --- a/toolcall/toolcall-handler.h +++ b/toolcall/toolcall-client.h @@ -20,12 +20,12 @@ namespace toolcall using result_set = std::vector; - class handler_impl; - class handler { + class client_impl; + class client { public: - using ptr = std::shared_ptr; + using ptr = std::shared_ptr; - handler(std::unique_ptr impl) : impl_(std::move(impl)) {} + client(std::unique_ptr impl) : impl_(std::move(impl)) {} result_set call(const std::string & request); @@ -37,17 +37,17 @@ namespace toolcall void initialize(); private: - std::unique_ptr impl_; + std::unique_ptr impl_; }; - std::shared_ptr create_handler(const toolcall::params & params); + std::shared_ptr create_client(const toolcall::params & params); - class handler_impl { + class client_impl { public: - handler_impl(std::string tool_choice) + client_impl(std::string tool_choice) : tool_choice_(std::move(tool_choice)), tool_list_dirty_(true) {} - virtual ~handler_impl() = default; + virtual ~client_impl() = default; virtual std::string tool_list() = 0; @@ -66,10 +66,10 @@ namespace toolcall bool tool_list_dirty_; }; - class loopback_impl : public handler_impl { + class loopback_impl : public client_impl { public: loopback_impl(std::string tools, std::string tool_choice) - : handler_impl(tool_choice), tools_(std::move(tools)) {} + : client_impl(tool_choice), tools_(std::move(tools)) {} virtual std::string tool_list() override { tool_list_dirty_ = false; @@ -87,7 +87,7 @@ namespace toolcall }; class mcp_transport; - class mcp_impl : public handler_impl { + class mcp_impl : public client_impl { public: mcp_impl(std::string server_uri, std::string tool_choice); mcp_impl(std::vector argv, std::string tool_choice); From f354ff92a3fcc452cbdb39c13ce7053687fa32dd Mon Sep 17 00:00:00 2001 From: Mason M Date: Sun, 2 Mar 2025 18:14:35 -0400 Subject: [PATCH 60/69] Ensure toolcalls are registered when no -sys provided --- examples/main/main.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8747d7e8c0894..1c206f2cf92ac 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -346,11 +346,8 @@ int main(int argc, char ** argv) { if (params.conversation_mode && params.enable_chat_template) { // format the system prompt in conversation mode (will use template default if empty) - prompt = params.system_prompt; + prompt = chat_add_and_format("system", params.system_prompt, true); - if (!prompt.empty()) { - prompt = chat_add_and_format("system", prompt, true); - } } else { // otherwise use the prompt as is prompt = params.prompt; From 8871c8d0b93155180c022073cba679f924403b43 Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 4 Mar 2025 14:10:22 -0400 Subject: [PATCH 61/69] Add toolcall output after single-turn run --- examples/main/main.cpp | 84 +++++++++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 29 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index f49248220e11d..2f3e9d95bbbe7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -142,6 +142,43 @@ class chat_formatter { #endif }; +#ifdef LLAMA_USE_TOOLCALL +static bool call_tool(common_params & params, const std::string & assistant_msg, + llama_context * ctx, toolcall::client::ptr tc_client, std::vector & embd_inp) +{ + auto should_use_toolcall = [¶ms, tc_client] (const std::string & asst_msg) { + if (! params.use_jinja || tc_client == nullptr) { + return false; + } + try { + nlohmann::json j = nlohmann::json::parse(asst_msg); + return (j.contains("name") && j.contains("parameters")); + + } catch (const nlohmann::json::exception & err) { + return false; + } + }; + + if (should_use_toolcall(assistant_msg)) { + toolcall::result_set res = tc_client->call(assistant_msg); + if (! res.empty()) { + std::string toolcall_result_str; + for (const auto & r : res) { + toolcall_result_str += ("\n" + r.data); // Although more complex results can be + // returned (resources, images, etc.), + // for now simply append the data. Later + // on support for specific models may + // allow for unpacking Base64 data. + } + auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true); + embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end()); + } + return true; + } + return false; +} +#endif + int main(int argc, char ** argv) { common_params params; g_params = ¶ms; @@ -361,6 +398,12 @@ int main(int argc, char ** argv) { inputs.messages = chat_msgs; inputs.add_generation_prompt = !params.prompt.empty(); +#ifdef LLAMA_USE_TOOLCALL + if (tc_client != nullptr) { + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_client->tool_choice()); + inputs.tools = common_chat_tools_parse_oaicompat(tc_client->tool_list()); + } +#endif prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt; } } else { @@ -881,36 +924,9 @@ int main(int argc, char ** argv) { if (params.enable_chat_template) { chat_add_and_format("assistant", assistant_ss.str(), true); #ifdef LLAMA_USE_TOOLCALL - auto should_use_toolcall = [¶ms, tc_client] (const std::string & asst_msg) { - if (! params.use_jinja || tc_client == nullptr) { - return false; - } - try { - nlohmann::json j = nlohmann::json::parse(asst_msg); - return (j.contains("name") && j.contains("parameters")); - - } catch (const nlohmann::json::exception & err) { - return false; - } - }; - - if (should_use_toolcall(assistant_ss.str())) { - toolcall::result_set res = tc_client->call(assistant_ss.str()); - if (! res.empty()) { - std::string toolcall_result_str; - for (const auto & r : res) { - toolcall_result_str += ("\n" + r.data); - } - auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true); - embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end()); - } - - } else { - is_interacting = true; - LOG("\n"); - } + is_interacting = ! call_tool(params, assistant_ss.str(), ctx, tc_client, embd_inp); + LOG("\n"); #else - is_interacting = true; LOG("\n"); #endif @@ -1033,6 +1049,16 @@ int main(int argc, char ** argv) { } } +#ifdef LLAMA_USE_TOOLCALL + if (params.single_turn) { + size_t last_len = embd_inp.size(); + bool was_toolcall = call_tool(params, assistant_ss.str(), ctx, tc_client, embd_inp); + if (was_toolcall && last_len < embd_inp.size()) { + LOG("%s", common_token_to_piece(ctx, embd_inp[last_len]).c_str()); + } + } +#endif + // end of generation if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) { LOG(" [end of text]\n"); From 46766c1134796c36d0df49c6baa2d80e6e3057aa Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 5 Mar 2025 10:27:21 -0400 Subject: [PATCH 62/69] Update grammar_trigger processing --- common/common.cpp | 46 +++++++++++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8649662a74a83..6f5738e580f21 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1782,19 +1782,6 @@ void common_chat_grammar_to_sampler(const common_chat_params * src, dst.grammar = src->grammar; dst.grammar_lazy = src->grammar_lazy; - for (const auto & trigger : src->grammar_triggers) { - auto ids = common_tokenize(vocab, trigger.word, false, true); - - if (ids.size() == 1) { - LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); - dst.grammar_trigger_tokens.push_back(ids[0]); - dst.preserved_tokens.insert(ids[0]); - continue; - } - LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); - dst.grammar_trigger_words.push_back(trigger); - } - for (const auto & preserved : src->preserved_tokens) { auto ids = common_tokenize(vocab, preserved, false, true); if (ids.size() == 1) { @@ -1808,6 +1795,39 @@ void common_chat_grammar_to_sampler(const common_chat_params * src, preserved.c_str()); } } + + for (const auto & trigger : src->grammar_triggers) { + if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = trigger.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + + if (ids.size() == 1) { + auto token = ids[0]; + auto found = std::find(dst.preserved_tokens.begin(), dst.preserved_tokens.end(), + (llama_token) token); + + if (found == dst.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + + LOG_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = (llama_token) token; + dst.grammar_triggers.push_back(trigger); + + } else { + LOG_DBG("Grammar trigger word: `%s`\n", word.c_str()); + dst.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + + } else { + dst.grammar_triggers.push_back(trigger); + } + } + if (dst.grammar_lazy && dst.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } } From ac1fc3182081c497d76c15e2690ffc67033c1d14 Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 5 Mar 2025 12:54:13 -0400 Subject: [PATCH 63/69] WIP: use common_chat_parse for toolcall --- examples/main/main.cpp | 81 +++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2f3e9d95bbbe7..6b4b205bffbdd 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -143,39 +143,37 @@ class chat_formatter { }; #ifdef LLAMA_USE_TOOLCALL -static bool call_tool(common_params & params, const std::string & assistant_msg, - llama_context * ctx, toolcall::client::ptr tc_client, std::vector & embd_inp) +static bool call_tool(const std::string & assistant_msg, llama_context * ctx, toolcall::client::ptr tc_client, std::vector & embd_inp) { - auto should_use_toolcall = [¶ms, tc_client] (const std::string & asst_msg) { - if (! params.use_jinja || tc_client == nullptr) { - return false; - } - try { - nlohmann::json j = nlohmann::json::parse(asst_msg); - return (j.contains("name") && j.contains("parameters")); - - } catch (const nlohmann::json::exception & err) { - return false; - } - }; - - if (should_use_toolcall(assistant_msg)) { - toolcall::result_set res = tc_client->call(assistant_msg); - if (! res.empty()) { - std::string toolcall_result_str; - for (const auto & r : res) { - toolcall_result_str += ("\n" + r.data); // Although more complex results can be - // returned (resources, images, etc.), - // for now simply append the data. Later - // on support for specific models may - // allow for unpacking Base64 data. + bool tool_was_called = false; + common_chat_msg msg = common_chat_parse(assistant_msg, COMMON_CHAT_FORMAT_GENERIC); + if (! msg.tool_calls.empty()) { + for (const auto & tc : msg.tool_calls) { + nlohmann::json tc_oai_json { + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + {"id", tc.id}, + }; + toolcall::result_set res = tc_client->call(tc_oai_json); + if (! res.empty()) { + std::string toolcall_result_str; + for (const auto & r : res) { + toolcall_result_str += ("\n" + r.data); // Although more complex results can be + // returned (resources, images, etc.), + // for now simply append the data. Later + // on support for specific models may + // allow for unpacking Base64 data. + } + auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true); + embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end()); } - auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true); - embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end()); + tool_was_called = true; } - return true; } - return false; + return tool_was_called; } #endif @@ -923,13 +921,8 @@ int main(int argc, char ** argv) { if (params.enable_chat_template) { chat_add_and_format("assistant", assistant_ss.str(), true); -#ifdef LLAMA_USE_TOOLCALL - is_interacting = ! call_tool(params, assistant_ss.str(), ctx, tc_client, embd_inp); - LOG("\n"); -#else is_interacting = true; LOG("\n"); -#endif } } } @@ -945,6 +938,16 @@ int main(int argc, char ** argv) { } } +#ifdef LLAMA_USE_TOOLCALL + if ((tc_client && n_past > 0) && (waiting_for_first_input || is_interacting)) { + size_t last_len = embd_inp.size(); + bool was_toolcall = call_tool(assistant_ss.str(), ctx, tc_client, embd_inp); + if (was_toolcall && last_len < embd_inp.size()) { + LOG("%s", common_token_to_piece(ctx, embd_inp[last_len]).c_str()); + } + } +#endif + if ((n_past > 0 || waiting_for_first_input) && is_interacting) { LOG_DBG("waiting for user input\n"); @@ -1049,16 +1052,6 @@ int main(int argc, char ** argv) { } } -#ifdef LLAMA_USE_TOOLCALL - if (params.single_turn) { - size_t last_len = embd_inp.size(); - bool was_toolcall = call_tool(params, assistant_ss.str(), ctx, tc_client, embd_inp); - if (was_toolcall && last_len < embd_inp.size()) { - LOG("%s", common_token_to_piece(ctx, embd_inp[last_len]).c_str()); - } - } -#endif - // end of generation if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) { LOG(" [end of text]\n"); From ba098afb6cd79b156e6c1d4249b7a664e2c1cf01 Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 5 Mar 2025 14:25:08 -0400 Subject: [PATCH 64/69] Extract toolcall format from model --- examples/main/main.cpp | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6b4b205bffbdd..456f9ed8e3205 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -98,9 +98,10 @@ class chat_formatter { std::vector & chat_msgs, struct common_chat_templates * chat_templates, const llama_vocab * vocab, - toolcall::client::ptr tc_client) + toolcall::client::ptr tc_client, + common_chat_format chat_format) - : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client) {} + : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client), chat_format_(chat_format) {} #endif std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false) { @@ -126,6 +127,7 @@ class chat_formatter { LOG_DBG("formatted: '%s'\n", formatted.c_str()); #ifdef LLAMA_USE_TOOLCALL + if (chat_format_) *chat_format_ = cparams.format; common_chat_grammar_to_sampler(&cparams, vocab_, ¶ms_.sampling); #endif return formatted; @@ -139,14 +141,16 @@ class chat_formatter { #ifdef LLAMA_USE_TOOLCALL const llama_vocab * vocab_; toolcall::client::ptr tc_client_; + common_chat_format * chat_format_; #endif }; #ifdef LLAMA_USE_TOOLCALL -static bool call_tool(const std::string & assistant_msg, llama_context * ctx, toolcall::client::ptr tc_client, std::vector & embd_inp) +static bool call_tool(common_chat_format chat_format, const std::string & assistant_msg, llama_context * ctx, + toolcall::client::ptr tc_client, std::vector & embd_inp) { bool tool_was_called = false; - common_chat_msg msg = common_chat_parse(assistant_msg, COMMON_CHAT_FORMAT_GENERIC); + common_chat_msg msg = common_chat_parse(assistant_msg, chat_format); if (! msg.tool_calls.empty()) { for (const auto & tc : msg.tool_calls) { nlohmann::json tc_oai_json { @@ -371,7 +375,8 @@ int main(int argc, char ** argv) { if (tc_client) { tc_client->initialize(); } - chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client); + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client, &chat_format); #else chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get()); #endif @@ -941,7 +946,7 @@ int main(int argc, char ** argv) { #ifdef LLAMA_USE_TOOLCALL if ((tc_client && n_past > 0) && (waiting_for_first_input || is_interacting)) { size_t last_len = embd_inp.size(); - bool was_toolcall = call_tool(assistant_ss.str(), ctx, tc_client, embd_inp); + bool was_toolcall = call_tool(chat_format, assistant_ss.str(), ctx, tc_client, embd_inp); if (was_toolcall && last_len < embd_inp.size()) { LOG("%s", common_token_to_piece(ctx, embd_inp[last_len]).c_str()); } From 787fa89ed1f5d1c8c77a0e57937f5e52e2f23a05 Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 5 Mar 2025 14:30:47 -0400 Subject: [PATCH 65/69] Oops --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 456f9ed8e3205..7ab4fd0e1943b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -99,7 +99,7 @@ class chat_formatter { struct common_chat_templates * chat_templates, const llama_vocab * vocab, toolcall::client::ptr tc_client, - common_chat_format chat_format) + common_chat_format * chat_format) : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client), chat_format_(chat_format) {} #endif From c36c7e6d027ad16c2358a98ef1d2ee88c248429f Mon Sep 17 00:00:00 2001 From: Mason M Date: Thu, 6 Mar 2025 17:21:04 -0400 Subject: [PATCH 66/69] Squashed commit of the following: commit 7adfa186455d8a38f4f7e28dba151e6873b42ed5 Author: Mason M Date: Thu Mar 6 17:19:09 2025 -0400 Re-Prompt after toolcall commit c8843da4439740656b2a5cf3ea6e6e592b407cdc Author: Mason M Date: Thu Mar 6 13:41:45 2025 -0400 Use format to extract toolcalls --- examples/main/main.cpp | 137 ++++++++++++++++++------------------- toolcall/client.cpp | 29 ++++---- toolcall/toolcall-client.h | 19 +++-- 3 files changed, 96 insertions(+), 89 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7ab4fd0e1943b..0bc7145e4ce64 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -90,7 +90,16 @@ static void sigint_handler(int signo) { class chat_formatter { public: - chat_formatter(common_params & params, std::vector & chat_msgs, struct common_chat_templates * chat_templates) + + struct result { + std::string formatted; + bool tool_was_called; + }; + + chat_formatter(common_params & params, + std::vector & chat_msgs, + struct common_chat_templates * chat_templates) + : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates) {} #ifdef LLAMA_USE_TOOLCALL @@ -98,39 +107,65 @@ class chat_formatter { std::vector & chat_msgs, struct common_chat_templates * chat_templates, const llama_vocab * vocab, - toolcall::client::ptr tc_client, - common_chat_format * chat_format) + toolcall::client::ptr tc_client) - : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client), chat_format_(chat_format) {} + : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), + vocab_(vocab), tc_client_(tc_client), + chat_format_(COMMON_CHAT_FORMAT_CONTENT_ONLY), + formatted_() {} #endif - std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false) { + chat_formatter::result operator() (const std::string & role, const std::string & content) { - common_chat_msg new_msg; + common_chat_msg new_msg = common_chat_parse(content, chat_format_); new_msg.role = role; - new_msg.content = content; - common_chat_params cparams; common_chat_templates_inputs cinputs; + cinputs.use_jinja = params_.use_jinja; + cinputs.add_generation_prompt = (role == "user"); #ifdef LLAMA_USE_TOOLCALL - if (tc_client_ != nullptr && use_toolcalls) { + if (tc_client_ != nullptr) { cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_client_->tool_choice()); cinputs.tools = common_chat_tools_parse_oaicompat(tc_client_->tool_list()); } #endif - bool add_ass = role == "user"; - auto formatted = - common_chat_format_single(chat_templates_, chat_msgs_, new_msg, add_ass, params_.use_jinja, - &cinputs, &cparams); - + cinputs.messages.assign(chat_msgs_.cbegin(), chat_msgs_.cend()); + cinputs.messages.push_back(new_msg); chat_msgs_.push_back(new_msg); + + bool tool_was_called = false; + if (! new_msg.tool_calls.empty()) { // Call tool and re-prompt + nlohmann::json result_array = nlohmann::json::array(); + for (const auto & tc : new_msg.tool_calls) { + toolcall::result_set res = tc_client_->call(tc.name, tc.arguments, tc.id); + if (! res.empty()) { + for (const auto & r : res) { + result_array.push_back(r.data); + } + } + } + common_chat_msg toolcall_msg; + toolcall_msg.role = "tool"; + toolcall_msg.content = result_array.dump(-1); + + cinputs.add_generation_prompt = true; + cinputs.messages.push_back(toolcall_msg); + chat_msgs_.push_back(toolcall_msg); + + tool_was_called = true; + } + + common_chat_params cparams = common_chat_templates_apply(chat_templates_, cinputs); + std::string formatted = cparams.prompt.substr(formatted_.size(), cparams.prompt.size()); + formatted_ = cparams.prompt; + LOG_DBG("formatted: '%s'\n", formatted.c_str()); #ifdef LLAMA_USE_TOOLCALL - if (chat_format_) *chat_format_ = cparams.format; + chat_format_ = cparams.format; common_chat_grammar_to_sampler(&cparams, vocab_, ¶ms_.sampling); #endif - return formatted; + return chat_formatter::result{std::move(formatted), tool_was_called}; } private: @@ -141,46 +176,11 @@ class chat_formatter { #ifdef LLAMA_USE_TOOLCALL const llama_vocab * vocab_; toolcall::client::ptr tc_client_; - common_chat_format * chat_format_; + common_chat_format chat_format_; + std::string formatted_; #endif }; -#ifdef LLAMA_USE_TOOLCALL -static bool call_tool(common_chat_format chat_format, const std::string & assistant_msg, llama_context * ctx, - toolcall::client::ptr tc_client, std::vector & embd_inp) -{ - bool tool_was_called = false; - common_chat_msg msg = common_chat_parse(assistant_msg, chat_format); - if (! msg.tool_calls.empty()) { - for (const auto & tc : msg.tool_calls) { - nlohmann::json tc_oai_json { - {"type", "function"}, - {"function", { - {"name", tc.name}, - {"arguments", tc.arguments}, - }}, - {"id", tc.id}, - }; - toolcall::result_set res = tc_client->call(tc_oai_json); - if (! res.empty()) { - std::string toolcall_result_str; - for (const auto & r : res) { - toolcall_result_str += ("\n" + r.data); // Although more complex results can be - // returned (resources, images, etc.), - // for now simply append the data. Later - // on support for specific models may - // allow for unpacking Base64 data. - } - auto toolcall_result_tok = common_tokenize(ctx, toolcall_result_str, false, true); - embd_inp.insert(embd_inp.end(), toolcall_result_tok.begin(), toolcall_result_tok.end()); - } - tool_was_called = true; - } - } - return tool_was_called; -} -#endif - int main(int argc, char ** argv) { common_params params; g_params = ¶ms; @@ -375,8 +375,7 @@ int main(int argc, char ** argv) { if (tc_client) { tc_client->initialize(); } - common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client, &chat_format); + chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client); #else chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get()); #endif @@ -386,12 +385,12 @@ int main(int argc, char ** argv) { if (params.conversation_mode && params.enable_chat_template) { if (!params.system_prompt.empty()) { // format the system prompt (will use template default if empty) - chat_add_and_format("system", params.system_prompt, true); + chat_add_and_format("system", params.system_prompt); } if (!params.prompt.empty()) { // format and append the user prompt - chat_add_and_format("user", params.prompt, true); + chat_add_and_format("user", params.prompt); } else { waiting_for_first_input = true; } @@ -925,9 +924,15 @@ int main(int argc, char ** argv) { } if (params.enable_chat_template) { - chat_add_and_format("assistant", assistant_ss.str(), true); - is_interacting = true; - LOG("\n"); + auto format_res = chat_add_and_format("assistant", assistant_ss.str()); + if (format_res.tool_was_called) { + auto format_res_tok = common_tokenize(ctx, format_res.formatted, false, true); + embd_inp.insert(embd_inp.end(), format_res_tok.begin(), format_res_tok.end()); + + } else { + is_interacting = true; + LOG("\n"); + } } } } @@ -943,16 +948,6 @@ int main(int argc, char ** argv) { } } -#ifdef LLAMA_USE_TOOLCALL - if ((tc_client && n_past > 0) && (waiting_for_first_input || is_interacting)) { - size_t last_len = embd_inp.size(); - bool was_toolcall = call_tool(chat_format, assistant_ss.str(), ctx, tc_client, embd_inp); - if (was_toolcall && last_len < embd_inp.size()) { - LOG("%s", common_token_to_piece(ctx, embd_inp[last_len]).c_str()); - } - } -#endif - if ((n_past > 0 || waiting_for_first_input) && is_interacting) { LOG_DBG("waiting for user input\n"); @@ -1005,7 +1000,7 @@ int main(int argc, char ** argv) { bool format_chat = params.conversation_mode && params.enable_chat_template; std::string user_inp = format_chat - ? chat_add_and_format("user", std::move(buffer)) + ? chat_add_and_format("user", std::move(buffer)).formatted : std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); diff --git a/toolcall/client.cpp b/toolcall/client.cpp index d078f04794421..74394d16c7218 100644 --- a/toolcall/client.cpp +++ b/toolcall/client.cpp @@ -39,8 +39,10 @@ bool toolcall::client::tool_list_dirty() const { return impl_->tool_list_dirty(); } -toolcall::result_set toolcall::client::call(const std::string & request) { - return impl_->call(request); +toolcall::result_set toolcall::client::call(const std::string & name, + const std::string & arguments, + const std::string & id) { + return impl_->call(name, arguments, id); } const std::string & toolcall::client::tool_choice() const { @@ -180,15 +182,6 @@ std::string toolcall::mcp_impl::tool_list() { return tools_; } -static mcp::tools_call_request tools_call_request_from_local_json(nlohmann::json id, const std::string & local_json) { - nlohmann::json j = json::parse(local_json); - mcp::tool_arg_list args; - for (const auto & [key, val] : j["parameters"].items()) { - args.push_back({key, val}); - } - return mcp::tools_call_request(id, j["name"], args); -} - static toolcall::result_set tools_call_response_to_result(const mcp::tools_call_response & resp) { toolcall::result_set result; for (const auto & res : resp.tool_result()) { @@ -199,7 +192,10 @@ static toolcall::result_set tools_call_response_to_result(const mcp::tools_call_ return std::move(result); } -toolcall::result_set toolcall::mcp_impl::call(const std::string & request) { +toolcall::result_set toolcall::mcp_impl::call(const std::string & name, + const std::string & arguments, + const std::string & id) +{ using on_response = toolcall::callback; if (transport_ == nullptr) { @@ -213,7 +209,14 @@ toolcall::result_set toolcall::mcp_impl::call(const std::string & request) { response = tools_call_response_to_result(resp); tools_populating_.notify_one(); }; - transport_->send(tools_call_request_from_local_json(next_id_++, request), set_response); + std::string req_id = id.empty() ? std::to_string(next_id_++) : id; + mcp::tool_arg_list req_args; + auto json_args = json::parse(arguments); // TODO check errors + for (const auto & [key, val] : json_args.items()) { + req_args.push_back({key, val}); + } + + transport_->send(mcp::tools_call_request(req_id, name, req_args), set_response); tools_populating_.wait_for(lock, std::chrono::seconds(15), [&response] { return ! response.empty(); }); return response; diff --git a/toolcall/toolcall-client.h b/toolcall/toolcall-client.h index 936352ab4efe5..d2259335851d9 100644 --- a/toolcall/toolcall-client.h +++ b/toolcall/toolcall-client.h @@ -27,7 +27,9 @@ namespace toolcall client(std::unique_ptr impl) : impl_(std::move(impl)) {} - result_set call(const std::string & request); + result_set call(const std::string & name, + const std::string & arguments, + const std::string & id = ""); std::string tool_list(); bool tool_list_dirty() const; @@ -55,7 +57,9 @@ namespace toolcall return tool_list_dirty_; } - virtual result_set call(const std::string & request) = 0; + virtual result_set call(const std::string & name, + const std::string & arguments, + const std::string & id = "") = 0; const std::string & tool_choice() const { return tool_choice_; } @@ -76,9 +80,11 @@ namespace toolcall return tools_; } - virtual result_set call(const std::string & request) override { + virtual result_set call(const std::string & /* name */, + const std::string & /* arguments */, + const std::string & /* id = "" */) override { return result_set { - {"text", request, "text/plain", std::nullopt, false} + {"text", "", "text/plain", std::nullopt, false} }; } @@ -93,7 +99,10 @@ namespace toolcall mcp_impl(std::vector argv, std::string tool_choice); virtual std::string tool_list() override; - virtual result_set call(const std::string & request) override; + + virtual result_set call(const std::string & name, + const std::string & arguments, + const std::string & id = "") override; virtual void initialize() override; From f5c209ff3081779f9c490a826b150b3c945fa255 Mon Sep 17 00:00:00 2001 From: Mason M Date: Mon, 10 Mar 2025 10:44:29 -0300 Subject: [PATCH 67/69] Sync trigger-token fix ggml-org#12291 --- common/common.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6f5738e580f21..dd58d49235cf8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1813,8 +1813,9 @@ void common_chat_grammar_to_sampler(const common_chat_params * src, LOG_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); common_grammar_trigger trigger; trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; - trigger.value = (llama_token) token; - dst.grammar_triggers.push_back(trigger); + trigger.value = word; + trigger.token = token; + dst.grammar_triggers.push_back(std::move(trigger)); } else { LOG_DBG("Grammar trigger word: `%s`\n", word.c_str()); From 4e378fbf96a3bdb603fa963f6499df2207e21160 Mon Sep 17 00:00:00 2001 From: Mason M Date: Mon, 10 Mar 2025 11:46:01 -0300 Subject: [PATCH 68/69] Clear assistant_ss before returning control to loop --- examples/main/main.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 0bc7145e4ce64..f876514911ac9 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -928,6 +928,7 @@ int main(int argc, char ** argv) { if (format_res.tool_was_called) { auto format_res_tok = common_tokenize(ctx, format_res.formatted, false, true); embd_inp.insert(embd_inp.end(), format_res_tok.begin(), format_res_tok.end()); + assistant_ss.str(""); } else { is_interacting = true; From ff18e245da8d2781a899b05ba52f65f449ae8b33 Mon Sep 17 00:00:00 2001 From: Mason M Date: Mon, 10 Mar 2025 11:59:05 -0300 Subject: [PATCH 69/69] Revert changes to common_chat_format_single --- common/chat.cpp | 12 +++--------- common/chat.h | 4 +--- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 341c83f3dd098..62ca26ad7609c 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -294,11 +294,9 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja, - const struct common_chat_templates_inputs * input_extra, - struct common_chat_params * out_params) { + bool use_jinja) { - common_chat_templates_inputs inputs = input_extra ? *input_extra : common_chat_templates_inputs(); + common_chat_templates_inputs inputs; inputs.use_jinja = use_jinja; std::string fmt_past_msg; @@ -315,13 +313,9 @@ std::string common_chat_format_single( // format chat with new_msg inputs.messages.push_back(new_msg); inputs.add_generation_prompt = add_ass; - auto chat_params = common_chat_templates_apply(tmpls, inputs); - auto fmt_new_msg = chat_params.prompt; + auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); - if (out_params) { - *out_params = std::move(chat_params); - } return ss.str(); } diff --git a/common/chat.h b/common/chat.h index 895d6cfb65d1a..9aad84e880448 100644 --- a/common/chat.h +++ b/common/chat.h @@ -112,9 +112,7 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja, - const struct common_chat_templates_inputs * input_extra = nullptr, - struct common_chat_params * out_params = nullptr); + bool use_jinja); // Returns an example of formatted chat std::string common_chat_format_example(