diff --git a/CMakeLists.txt b/CMakeLists.txt index ac3e9090336d9..c533285987ebc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,9 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON) option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) +# Toolcall support - needs LLAMA_CURL support to connect with SSE endpoints +option(LLAMA_TOOLCALL "llama: add toolcall support via Model Context Protocol" ON) + # Required for relocatable CMake package include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake) @@ -172,6 +175,11 @@ add_subdirectory(src) if (NOT LLAMA_BUILD_COMMON) message(STATUS "LLAMA_BUILD_COMMON is OFF, disabling LLAMA_CURL") set(LLAMA_CURL OFF) + set(LLAMA_TOOLCALL OFF) +endif() + +if (LLAMA_TOOLCALL) + add_subdirectory(toolcall) endif() if (LLAMA_BUILD_COMMON) diff --git a/common/arg.cpp b/common/arg.cpp index 5ed5a23903332..c1e4ff9dbd007 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2890,6 +2890,24 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.chat_template = read_file(value); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + + add_opt(common_arg( + {"--tools"}, "JINJA_TOOLS", + "set to URI of a Model Context Protocol server, or " + "a JSON array containing tool definitions (requires --jinja)", + [](common_params ¶ms, const std::string & value) { + params.toolcall.tools = value; + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + + add_opt(common_arg( + {"--tool-choice"}, "JINJA_TOOL_CHOICE", + "set to \"auto\", \"required\", or \"none\" (default: \"auto\")", + [](common_params ¶ms, const std::string & value) { + params.toolcall.choice = value; + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( {"--no-prefill-assistant"}, string_format( diff --git a/common/common.cpp b/common/common.cpp index 2afa9b2d641d4..f8b63676e82da 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "chat.h" #include #include @@ -1329,6 +1330,67 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto return text; } +void common_chat_grammar_to_sampler(const common_chat_params * src, + const llama_vocab * vocab, + common_params_sampling * sparams) +{ + GGML_ASSERT(src && vocab && sparams); + + auto & dst = *sparams; + + dst.grammar = src->grammar; + dst.grammar_lazy = src->grammar_lazy; + + for (const auto & preserved : src->preserved_tokens) { + auto ids = common_tokenize(vocab, preserved, false, true); + if (ids.size() == 1) { + LOG_DBG("Preserved token: %d\n", ids[0]); + dst.preserved_tokens.insert(ids[0]); + + } else { + // This may happen when using a tool call style meant for a model + // with special tokens to preserve on a model without said tokens. + LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", + preserved.c_str()); + } + } + + for (const auto & trigger : src->grammar_triggers) { + if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + const auto & word = trigger.value; + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + + if (ids.size() == 1) { + auto token = ids[0]; + auto found = std::find(dst.preserved_tokens.begin(), dst.preserved_tokens.end(), + (llama_token) token); + + if (found == dst.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + + LOG_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + common_grammar_trigger trigger; + trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; + trigger.value = word; + trigger.token = token; + dst.grammar_triggers.push_back(std::move(trigger)); + + } else { + LOG_DBG("Grammar trigger word: `%s`\n", word.c_str()); + dst.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + + } else { + dst.grammar_triggers.push_back(trigger); + } + } + if (dst.grammar_lazy && dst.grammar_triggers.empty()) { + throw std::runtime_error("Error: no triggers set for lazy grammar!"); + } +} + + // // Embedding utils // diff --git a/common/common.h b/common/common.h index 92b9533fc2948..e1a442394066c 100644 --- a/common/common.h +++ b/common/common.h @@ -218,6 +218,11 @@ enum common_reasoning_format { COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` }; +struct common_toolcall_params { + std::string tools = ""; + std::string choice = "auto"; +}; + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -367,6 +372,9 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; + + struct common_toolcall_params toolcall; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int reasoning_budget = -1; bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response @@ -622,6 +630,12 @@ std::string common_detokenize( const std::vector & tokens, bool special = true); +struct common_chat_params; +void common_chat_grammar_to_sampler(const common_chat_params * src, + const llama_vocab * vocab, + common_params_sampling * sparams); + + // // Embedding utils // diff --git a/toolcall/CMakeLists.txt b/toolcall/CMakeLists.txt new file mode 100644 index 0000000000000..5fca0aee94fde --- /dev/null +++ b/toolcall/CMakeLists.txt @@ -0,0 +1,37 @@ + +set(TARGET toolcall) + +set(SOURCES + client.cpp + mcp_messages.cpp + mcp_stdio_transport.cpp + params.cpp +) + +set(HEADERS + toolcall-params.h + toolcall-client.h + mcp_transport.h + mcp_messages.h + mcp_stdio_transport.h +) + +add_library(${TARGET} STATIC ${SOURCES} ${HEADERS}) + +target_include_directories(${TARGET} # Right now only for "json.hpp" + PRIVATE $) + +if (LLAMA_CURL) + find_package(CURL REQUIRED) + target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL) + include_directories(${CURL_INCLUDE_DIRS}) + find_library(CURL_LIBRARY curl REQUIRED) + + target_link_libraries(${TARGET} PRIVATE ${CURL_LIBRARY}) + target_sources(${TARGET} PRIVATE mcp_sse_transport.cpp mcp_sse_transport.h) + +endif() + +target_compile_definitions(${TARGET} INTERFACE LLAMA_USE_TOOLCALL) +target_include_directories(${TARGET} PUBLIC $) +target_compile_features (${TARGET} PUBLIC cxx_std_17) diff --git a/toolcall/client.cpp b/toolcall/client.cpp new file mode 100644 index 0000000000000..74394d16c7218 --- /dev/null +++ b/toolcall/client.cpp @@ -0,0 +1,223 @@ + +#include +#include "toolcall-client.h" +#include +#include + +#ifdef LLAMA_USE_CURL +# include "mcp_sse_transport.h" +#endif + +#include "mcp_stdio_transport.h" + +using json = nlohmann::json; + +std::shared_ptr toolcall::create_client(const toolcall::params & params) { + std::shared_ptr client; + + auto tools = params.tools(); + auto choice = params.choice(); + if (params) { + if (params.has_uri()) { +#ifdef LLAMA_USE_CURL + client.reset(new toolcall::client( + std::make_unique(tools, choice))); +#endif + } else { + client.reset(new toolcall::client( + std::make_unique(tools, choice))); + } + } + return client; +} + +std::string toolcall::client::tool_list() { + return impl_->tool_list(); +} + +bool toolcall::client::tool_list_dirty() const { + return impl_->tool_list_dirty(); +} + +toolcall::result_set toolcall::client::call(const std::string & name, + const std::string & arguments, + const std::string & id) { + return impl_->call(name, arguments, id); +} + +const std::string & toolcall::client::tool_choice() const { + return impl_->tool_choice(); +} + +void toolcall::client::initialize() { + impl_->initialize(); +} + +#ifdef LLAMA_USE_CURL +toolcall::mcp_impl::mcp_impl(std::string server_uri, std::string tool_choice) + : client_impl(tool_choice), + transport_(new mcp_sse_transport(server_uri)), + tools_("[]"), + tools_mutex_(), + tools_populating_(), + next_id_(1) +{ +} +#else +toolcall::mcp_impl::mcp_impl(std::string /*server_uri*/, std::string tool_choice) + : client_impl(tool_choice), + transport_(nullptr), + tools_("[]"), + tools_mutex_(), + tools_populating_(), + next_id_(1) +{ +} +#endif + +toolcall::mcp_impl::mcp_impl(std::vector argv, std::string tool_choice) + : client_impl(tool_choice), + transport_(new mcp_stdio_transport(argv)) +{ +} + +void toolcall::mcp_impl::initialize() { + using on_response = toolcall::callback; + using on_list_changed = toolcall::callback; + + if (transport_ == nullptr) return; + std::unique_lock lock(tools_mutex_); + + transport_->start(); + + bool caps_received = false; + mcp::capabilities caps; + on_response set_caps = [this, &caps, &caps_received] (const mcp::initialize_response & resp) { + std::unique_lock lock(tools_mutex_); + caps = resp.capabilities(); + caps_received = true; + tools_populating_.notify_one(); + }; + + transport_->send(mcp::initialize_request(next_id_++), set_caps); + tools_populating_.wait_for(lock, std::chrono::seconds(15), [&caps_received] { return caps_received; }); + + on_list_changed update_dirty = [&update_dirty, this] (const mcp::tools_list_changed_notification &) { + tool_list_dirty_ = true; + transport_->subscribe(update_dirty); + }; + + bool has_tools = false; + for (const auto & cap : caps) { + if (cap.name == "tools") { + has_tools = true; + if (cap.listChanged) { + transport_->subscribe(update_dirty); + } + break; + } + } + if (! has_tools) { + throw std::runtime_error("MCP server does not support toolcalls!"); + } + + transport_->send(mcp::initialized_notification()); +} + +static std::string tools_list_to_oai_json(const mcp::tools_list & tools) { + json tool_list = json::array(); + for (const auto & tool : tools) { + json t = json::object(); + + t["type"] = "function"; + t["function"]["name"] = tool.tool_name; + t["function"]["description"] = tool.tool_description; + + json props = json::object(); + for (const auto & param : tool.params) { + props[param.name]["type"] = param.type; + props[param.name]["description"] = param.description; + } + t["function"]["parameters"]["type"] = "object"; + t["function"]["parameters"]["properties"] = props; + + json required = json::array(); + for (const auto & name : tool.required_params) { + required.push_back(name); + } + t["function"]["required"] = required; + + tool_list.push_back(t); + } + + return tool_list.dump(-1); +} + +std::string toolcall::mcp_impl::tool_list() { + using on_response = toolcall::callback; + + if (tool_list_dirty_) { + std::unique_lock lock(tools_mutex_); + + mcp::tools_list tools; + on_response set_tools = [this, &tools, &set_tools] (const mcp::tools_list_response & resp) { + std::unique_lock lock(tools_mutex_); + + tools.insert(tools.end(), resp.tools().begin(), resp.tools().end()); + auto cursor = resp.next_cursor(); + if (! cursor.empty()) { + transport_->send(mcp::tools_list_request(next_id_++, cursor), set_tools); + return; + } + tool_list_dirty_ = false; + lock.unlock(); + tools_populating_.notify_one(); + }; + + transport_->send(mcp::tools_list_request(next_id_++), set_tools); + tools_populating_.wait_for(lock, std::chrono::seconds(15), [this] { return ! tool_list_dirty_; }); + + tools_ = tools_list_to_oai_json(tools); + } + return tools_; +} + +static toolcall::result_set tools_call_response_to_result(const mcp::tools_call_response & resp) { + toolcall::result_set result; + for (const auto & res : resp.tool_result()) { + result.push_back(toolcall::result{ + res.type, res.value, res.mime_type.value_or("text/plain"), res.uri, resp.tool_error() + }); + } + return std::move(result); +} + +toolcall::result_set toolcall::mcp_impl::call(const std::string & name, + const std::string & arguments, + const std::string & id) +{ + using on_response = toolcall::callback; + + if (transport_ == nullptr) { + return toolcall::result_set(); + } + std::unique_lock lock(tools_mutex_); + + toolcall::result_set response; + on_response set_response = [this, &response] (const mcp::tools_call_response & resp) { + std::unique_lock lock(tools_mutex_); + response = tools_call_response_to_result(resp); + tools_populating_.notify_one(); + }; + std::string req_id = id.empty() ? std::to_string(next_id_++) : id; + mcp::tool_arg_list req_args; + auto json_args = json::parse(arguments); // TODO check errors + for (const auto & [key, val] : json_args.items()) { + req_args.push_back({key, val}); + } + + transport_->send(mcp::tools_call_request(req_id, name, req_args), set_response); + tools_populating_.wait_for(lock, std::chrono::seconds(15), [&response] { return ! response.empty(); }); + + return response; +} diff --git a/toolcall/mcp_messages.cpp b/toolcall/mcp_messages.cpp new file mode 100644 index 0000000000000..3dc9c9a14759c --- /dev/null +++ b/toolcall/mcp_messages.cpp @@ -0,0 +1,392 @@ +#include "mcp_messages.h" +#include +#include +#include + +using json = nlohmann::json; + +const std::string mcp::JsonRpcVersion = "2.0"; +const std::string mcp::McpVersion = "2024-11-05"; +const std::string mcp::ClientVersion = "1.0.0"; +const std::string mcp::ClientName = "llama.cpp"; + +json mcp::request::toJson() const { + json j; + j["jsonrpc"] = JsonRpcVersion; + if (id()) { + j["id"] = id().value(); + } + j["method"] = method(); + if (params()) { + j["params"] = params().value(); + } + return j; +} + +json mcp::response::error::toJson() const { + json j; + j["code"] = code; + j["message"] = message; + if (data) { + j["data"] = data.value(); + } + return j; +} + +json mcp::response::toJson() const { + json j; + j["jsonrpc"] = JsonRpcVersion; + if (id()) { + j["id"] = id().value(); + } + if (result()) { + j["result"] = result().value(); + } else if (getError()) { + j["error"] = getError()->toJson(); + } + return j; +} + +json mcp::notification::toJson() const { + json j; + j["jsonrpc"] = JsonRpcVersion; + j["method"] = method(); + if (params()) { + j["params"] = params().value(); + } + return j; +} + +mcp::initialize_request::initialize_request(nlohmann::json id, mcp::capabilities caps) + : request(id, "initialize"), caps_(std::move(caps)) +{ + refreshParams(); +} + +void mcp::initialize_request::refreshParams() { + json params; + params["protocolVersion"] = protoVersion(); + params["clientInfo"]["name"] = name(); + params["clientInfo"]["version"] = version(); + + json capabilities = json::object(); + for (auto cap = caps_.cbegin(); cap != caps_.cend(); ++cap) { + json cap_json; + if (cap->subscribe) { + cap_json["subscribe"] = true; + } + if (cap->listChanged) { + cap_json["listChanged"] = true; + } + capabilities[cap->name] = cap_json; + } + params["capabilities"] = capabilities; + + this->params(std::move(params)); +} + +void mcp::initialize_request::capabilities(mcp::capabilities caps) { + caps_ = std::move(caps); + refreshParams(); +} + +const mcp::capabilities & mcp::initialize_request::capabilities() const { + return caps_; +} + +mcp::initialize_response::initialize_response( + nlohmann::json id, std::string name, std::string version, std::string protoVersion, + mcp::capabilities caps) + : response(id), name_(std::move(name)), version_(std::move(version)), + protoVersion_(std::move(protoVersion)), caps_(std::move(caps)) +{ + refreshResult(); +} + +void mcp::initialize_response::refreshResult() { + json result; + result["protocolVersion"] = protoVersion(); + result["serverInfo"]["name"] = name(); + result["serverInfo"]["version"] = version(); + + json capabilities = json::object(); + for (auto cap = caps_.cbegin(); cap != caps_.cend(); ++cap) { + json cap_json; + if (cap->subscribe) { + cap_json["subscribe"] = true; + } + if (cap->listChanged) { + cap_json["listChanged"] = true; + } + capabilities[cap->name] = cap_json; + } + result["capabilities"] = capabilities; + + this->result(std::move(result)); +} + +void mcp::initialize_response::name(std::string name) { + name_ = std::move(name); + refreshResult(); +} + +const std::string & mcp::initialize_response::name() const { + return name_; +} + +void mcp::initialize_response::version(std::string version) { + version_ = std::move(version); + refreshResult(); +} + +const std::string & mcp::initialize_response::version() const { + return version_; +} + +void mcp::initialize_response::protoVersion(std::string protoVersion) { + protoVersion_ = std::move(protoVersion); + refreshResult(); +} + +const std::string & mcp::initialize_response::protoVersion() const { + return protoVersion_; +} + +void mcp::initialize_response::capabilities(mcp::capabilities caps) { + caps_ = std::move(caps); + refreshResult(); +} + +const mcp::capabilities & mcp::initialize_response::capabilities() const { + return caps_; +} + +mcp::initialize_response mcp::initialize_response::fromJson(const nlohmann::json& j) { + std::string name = j["result"]["serverInfo"]["name"]; + std::string version = j["result"]["serverInfo"]["version"]; + std::string protoVersion = j["result"]["protocolVersion"]; + + mcp::capabilities caps; + if (j["result"].contains("capabilities")) { + for (const auto& [key, value] : j["result"]["capabilities"].items()) { + capability cap; + cap.name = key; + cap.subscribe = value.value("subscribe", false); + cap.listChanged = value.value("listChanged", false); + caps.push_back(cap); + } + } + + return initialize_response(j["id"], name, version, protoVersion, caps); +} + +mcp::tools_list_request::tools_list_request(std::optional id, std::string cursor) + : request(id, Method), + cursor_(std::move(cursor)) +{ + refreshParams(); +} + +void mcp::tools_list_request::cursor(std::string cursor) { + cursor_ = std::move(cursor); + refreshParams(); +} + +void mcp::tools_list_request::refreshParams() { + if (! cursor_.empty()) { + json params; + params["cursor"] = cursor_; + this->params(params); + } +} + +mcp::tools_list_response::tools_list_response(nlohmann::json id, + mcp::tools_list tools, + std::string next_cursor) + : response(id), + tools_(std::move(tools)), + next_cursor_(std::move(next_cursor)) +{ + refreshResult(); +} + +void mcp::tools_list_response::tools(mcp::tools_list tools) { + tools_ = std::move(tools); + refreshResult(); +} + +void mcp::tools_list_response::next_cursor(std::string next_cursor) { + next_cursor_ = std::move(next_cursor); + refreshResult(); +} + +void mcp::tools_list_response::refreshResult() { + json result; + + json tools = json::array(); + for (const auto & tool : tools_) { + json t; + + t["name"] = tool.tool_name; + t["description"] = tool.tool_description; + t["inputSchema"]["type"] = "object"; + + json props; + for (const auto & param : tool.params) { + props[param.name] = { + {"type"}, {param.type}, + {"description"}, {param.description} + }; + } + t["inputSchema"]["properties"] = props; + + json required = json::array(); + for (const auto & req_param : tool.required_params) { + required.push_back(req_param); + } + t["inputSchema"]["required"] = required; + + tools.push_back(t); + } + result["tools"] = tools; + + if (! next_cursor_.empty()) { + result["nextCursor"] = next_cursor_; + } + + this->result(result); +} + +mcp::tools_list_response mcp::tools_list_response::fromJson(const nlohmann::json & j) { + mcp::tools_list tools; + for (const auto & t : j["result"]["tools"]) { + mcp::tool tool; + tool.tool_name = t["name"]; + tool.tool_description = t["description"]; + for (const auto & [key, value] : t["inputSchema"]["properties"].items()) { + mcp::tool::param param; + param.name = key; + param.type = value["type"]; + param.description = value["description"]; + tool.params.push_back(param); + } + if (t["inputSchema"].contains("required") && t["inputSchema"]["required"].is_array()) { + for (const auto & required : t["inputSchema"]["required"]) { + tool.required_params.push_back(required); + } + } + tools.push_back(std::move(tool)); + } + std::string next_cursor = j["result"].value("nextCursor", ""); + return tools_list_response(j["id"], std::move(tools), next_cursor); +} + +mcp::tools_list_changed_notification mcp::tools_list_changed_notification::fromJson(const nlohmann::json & j) { + if (! (j.is_object() && j.contains("method") && j["method"] == Method)) { + throw std::invalid_argument("Invalid tools_list_changed message"); + } + return tools_list_changed_notification(); +} + +mcp::tools_call_request::tools_call_request(nlohmann::json id, std::string name, tool_arg_list args) + : request(id, Method), name_(std::move(name)), args_(std::move(args)) +{ + refreshParams(); +} + +void mcp::tools_call_request::name(std::string name) { + name_ = std::move(name); + refreshParams(); +} + +void mcp::tools_call_request::args(mcp::tool_arg_list args) { + args_ = std::move(args); + refreshParams(); +} + +void mcp::tools_call_request::refreshParams() { + json params = json::object(); + params["name"] = name_; + if (! args_.empty()) { + json args = json::object(); + for (const auto & arg : args_) { + args[arg.name] = arg.value; + } + params["arguments"] = args; + } + this->params(params); +} + +mcp::tools_call_response::tools_call_response(nlohmann::json id, tool_result_list result, bool error) + : response(id), tool_result_(std::move(result)), error_(error) +{ + refreshResult(); +} + +void mcp::tools_call_response::tool_result(mcp::tool_result_list result) { + tool_result_ = std::move(result); + refreshResult(); +} + +void mcp::tools_call_response::tool_error(bool error) { + error_ = error; + refreshResult(); +} + +void mcp::tools_call_response::refreshResult() { + json result = json::object(); + result["isError"] = error_; + json content = json::array(); + for (const auto & res : tool_result_) { + json r; + r["type"] = res.type; + if (res.type == "text") { + r["text"] = res.value; + + } else if (res.type == "image" || res.type == "audio") { + r["data"] = res.value; + r["mimeType"] = res.mime_type.value(); // throws + + } else if (res.type == "resource") { + json rr; + rr["uri"] = res.uri.value(); // throws + rr["mimeType"] = res.mime_type.value(); //throws + rr["text"] = res.value; + + r["resource"] = rr; + + } else { + // throw + } + content.push_back(r); + } + result["content"] = content; + this->result(std::move(result)); +} + +mcp::tools_call_response mcp::tools_call_response::fromJson(const nlohmann::json & j) { + mcp::tool_result_list result_list; + for (const auto & content : j["result"]["content"]) { + mcp::tool_result result; + + result.type = content["type"]; + if (content["type"] == "text") { + result.value = content["text"]; + + } else if (content["type"] == "image" || content["type"] == "audio") { + result.value = content["data"]; + result.mime_type = content["mimeType"]; + + } else if (content["type"] == "resource") { + result.value = content["resource"]["text"]; + result.mime_type = content["resource"]["mimeType"]; + result.uri = content["resource"]["uri"]; + } + + result_list.push_back(std::move(result)); + } + + bool error = j["result"].value("isError", false); + + return mcp::tools_call_response(j["id"], std::move(result_list), error); +} diff --git a/toolcall/mcp_messages.h b/toolcall/mcp_messages.h new file mode 100644 index 0000000000000..ad20b45fef5f9 --- /dev/null +++ b/toolcall/mcp_messages.h @@ -0,0 +1,283 @@ +#include +#include +#include +#include +#include + +namespace mcp +{ + extern const std::string JsonRpcVersion; + extern const std::string McpVersion; + extern const std::string ClientVersion; + extern const std::string ClientName; + + template + class message { + public: + message(std::optional id = std::nullopt) + : id_(std::move(id)) {} + + nlohmann::json toJson() const { + return static_cast(this)->toJson(); + } + + void id(std::optional id) { + id_ = std::move(id); + } + + const std::optional & id() const { + return id_; + } + + private: + std::optional id_; + }; + + class request : public message { + public: + request(std::optional id, + std::string method, + std::optional params = std::nullopt) + + : message(id), + method_(std::move(method)), + params_(std::move(params)) {} + + void method(std::string method) { method_ = std::move(method); } + const std::string & method() const { return method_; } + + void params(std::optional params) { params_ = std::move(params); } + const std::optional & params() const { return params_; } + + nlohmann::json toJson() const; + + private: + std::string method_; + std::optional params_; + }; + + class response : public message { + public: + struct error { + int code; + std::string message; + std::optional data; + nlohmann::json toJson() const; + }; + + response(std::optional id, + std::optional result = std::nullopt, + std::optional error = std::nullopt) + + : message(id), + result_(std::move(result)), + error_(std::move(error)) {} + + void result(std::optional result) { result_ = std::move(result); } + const std::optional & result() const { return result_; } + + void setError(std::optional error) { error_ = std::move(error); } + const std::optional & getError() const { return error_; } + + nlohmann::json toJson() const; + + private: + std::optional result_; + std::optional error_; + }; + + class notification : public message { + public: + notification(std::string method, + std::optional params = std::nullopt) + + : message(), + method_(method), + params_(params) {} + + void method(std::string method) { method_ = std::move(method); } + const std::string & method() const { return method_; } + + void params(std::optional params) { params_ = std::move(params); } + const std::optional & params() const { return params_; } + + nlohmann::json toJson() const; + + private: + std::string method_; + std::optional params_; + }; + + struct capability { + std::string name; + bool subscribe = false; + bool listChanged = false; + }; + + using capabilities = std::vector; + + class initialize_request : public request { + public: + initialize_request(nlohmann::json id, mcp::capabilities caps = mcp::capabilities{}); + + const std::string & name() const { return ClientName; } + const std::string & version() const { return ClientVersion; } + const std::string & protoVersion() const { return McpVersion; } + + void capabilities(mcp::capabilities capabilities); + const mcp::capabilities & capabilities() const; + + private: + void refreshParams(); + + mcp::capabilities caps_; + }; + + class initialize_response : public response { + public: + initialize_response(nlohmann::json id, + std::string name, + std::string version, + std::string protoVersion, + mcp::capabilities caps); + + void name(std::string name); + const std::string & name() const; + + void version(std::string version); + const std::string & version() const; + + void protoVersion(std::string protoVersion); + const std::string & protoVersion() const; + + void capabilities(mcp::capabilities capabilities); + const mcp::capabilities & capabilities() const; + + static initialize_response fromJson(const nlohmann::json& j); + + private: + void refreshResult(); + + std::string name_; + std::string version_; + std::string protoVersion_; + mcp::capabilities caps_; + }; + + class initialized_notification : public notification { + public: + static inline const std::string Method = "notifications/initialized"; + + initialized_notification() : notification(Method) {} + }; + + class tools_list_request : public request { + public: + static inline const std::string Method = "tools/list"; + + tools_list_request(std::optional id, std::string cursor = ""); + + void cursor(std::string cursor); + const std::string & cursor() { return cursor_; } + + private: + void refreshParams(); + std::string cursor_; + }; + + struct tool { + struct param { + std::string name; + std::string type; + std::string description; + }; + std::string tool_name; + std::string tool_description; + std::vector params; + std::vector required_params; + }; + + using tools_list = std::vector; + + class tools_list_response : public response { + public: + tools_list_response(nlohmann::json id, + tools_list tools = tools_list(), + std::string next_cursor = ""); + + void tools(tools_list tools); + const tools_list & tools() const { return tools_; } + + void next_cursor(std::string next_cursor); + const std::string & next_cursor() const { return next_cursor_; } + + static tools_list_response fromJson(const nlohmann::json & j); + + private: + void refreshResult(); + tools_list tools_; + std::string next_cursor_; + }; + + class tools_list_changed_notification : public notification { + public: + static inline const std::string Method = "notifications/tools/list_changed"; + + tools_list_changed_notification() : notification(Method) {} + + static tools_list_changed_notification fromJson(const nlohmann::json & j); + }; + + struct tool_arg { + std::string name; + nlohmann::json value; + }; + + using tool_arg_list = std::vector; + + class tools_call_request : public request { + public: + static inline const std::string Method = "tools/call"; + + tools_call_request(nlohmann::json id, std::string name, tool_arg_list args = tool_arg_list()); + + void name(std::string name); + const std::string & name() const { return name_; } + + void args(tool_arg_list args); + const tool_arg_list args() const { return args_; } + + private: + void refreshParams(); + std::string name_; + tool_arg_list args_; + }; + + struct tool_result { + std::string type; // text, image, audio, or resource + std::string value; + std::optional mime_type; // Only for: image, audio, and resource + std::optional uri; // Only for: resource + }; + + using tool_result_list = std::vector; + + class tools_call_response : public response { + public: + tools_call_response(nlohmann::json id, tool_result_list result = tool_result_list(), bool error = false); + + void tool_result(tool_result_list result); + const tool_result_list & tool_result() const { return tool_result_; } + + void tool_error(bool error); + bool tool_error() const { return error_; } + + static tools_call_response fromJson(const nlohmann::json & j); + + private: + void refreshResult(); + tool_result_list tool_result_; + bool error_; + }; + +} diff --git a/toolcall/mcp_sse_transport.cpp b/toolcall/mcp_sse_transport.cpp new file mode 100644 index 0000000000000..3dfbffc49f780 --- /dev/null +++ b/toolcall/mcp_sse_transport.cpp @@ -0,0 +1,333 @@ + +#include "mcp_sse_transport.h" +#include +#include +#include + +const int toolcall::mcp_sse_transport::EndpointReceivedTimoutSeconds = 5; +const int toolcall::mcp_sse_transport::StartTimeoutSeconds = 8; + +static bool starts_with(const std::string & str, const std::string & prefix) { + return str.size() >= prefix.size() + && str.compare(0, prefix.size(), prefix) == 0; +} + +toolcall::mcp_sse_transport::~mcp_sse_transport() { + if (endpoint_headers_) { + curl_slist_free_all(endpoint_headers_); + } + if (endpoint_) { + curl_easy_cleanup(endpoint_); + } +} + +toolcall::mcp_sse_transport::mcp_sse_transport(std::string server_uri) + : server_uri_(std::move(server_uri)), + running_(false), + sse_thread_(), + endpoint_(nullptr), + endpoint_headers_(nullptr), + endpoint_errbuf_(CURL_ERROR_SIZE, '\0'), + event_{"", "", ""}, + sse_buffer_(""), + sse_cursor_(0), + sse_last_id_(""), + mutex_(), + cv_() +{ + curl_global_init(CURL_GLOBAL_DEFAULT); +} + +void toolcall::mcp_sse_transport::start() { + if (running_) return; + running_ = true; + + std::unique_lock lock(mutex_); + sse_thread_ = std::thread(&toolcall::mcp_sse_transport::sse_run, this); + cv_.wait_for( + lock, std::chrono::seconds(StartTimeoutSeconds), [this] { return endpoint_ != nullptr; }); + + if (endpoint_ == nullptr) { + running_ = false; + LOG_ERR("SSE: Connection to \"%s\" failed", server_uri_.c_str()); + throw std::runtime_error("Connection to \"" + server_uri_ + "\" failed"); + } +} + +void toolcall::mcp_sse_transport::stop() { + running_ = false; +} + +bool toolcall::mcp_sse_transport::send(const std::string & request_json) { + std::lock_guard lock(mutex_); + if (! running_ || endpoint_ == nullptr) { + return false; + } + + curl_easy_setopt(endpoint_, CURLOPT_POSTFIELDS, request_json.c_str()); + + CURLcode code = curl_easy_perform(endpoint_); + if (code != CURLE_OK) { + size_t len = strlen(&endpoint_errbuf_[0]); + LOG_ERR("%s", (len > 0 ? &endpoint_errbuf_[0] : curl_easy_strerror(code))); + return false; + } + return true; +} + +static size_t sse_callback(char * data, size_t size, size_t nmemb, void * clientp) { + auto transport = static_cast(clientp); + size_t len = size * nmemb; + return transport->sse_read(data, len); +} + +void toolcall::mcp_sse_transport::parse_field_value(std::string field, std::string value) { + LOG_DBG("SSE: field \"%s\"; value \"%s\"\n", field.c_str(), value.c_str()); + + if (field == "event") { + // Set the event type buffer to field value. + event_.type = std::move(value); + + } else if (field == "data") { + // Append the field value to the data buffer, + // then append a single U+000A LINE FEED (LF) + // character to the data buffer. + value += '\n'; + event_.data.insert(event_.data.end(), value.begin(), value.end()); + + } else if (field == "id") { + // If the field value does not contain U+0000 NULL, + // then set the last event ID buffer to the field value. + // Otherwise, ignore the field. + if (! value.empty()) { + event_.id = std::move(value); + } + + } else if (field == "retry") { + // If the field value consists of only ASCII digits, + // then interpret the field value as an integer in base + // ten, and set the event stream's reconnection time to + // that integer. Otherwise, ignore the field. + + LOG_INF("SSE: Retry field is not currently implemented"); + + } else { + LOG_WRN("SSE: Unsupported field \"%s\" received", field.c_str()); + } +} + +static std::string uri_base (const std::string & uri) { + std::regex uri_regex(R"(^(https?:\/\/[^\/?#]+))"); + std::smatch match; + if (std::regex_search(uri, match, uri_regex)) { + return match[1]; + } + return uri; +} + +void toolcall::mcp_sse_transport::on_endpoint_event() { + LOG_DBG("on_endpoint_event\n"); + event_.data.erase(event_.data.end() - 1); // Event data has trailing newline that will impact URI + + endpoint_ = curl_easy_init(); + if (! endpoint_) { + LOG_ERR("SSE: Failed to create endpoint handle"); + running_ = false; + return; + } + + std::string endpoint_uri; + bool is_absolute = starts_with(event_.data, "http"); + if (is_absolute) { + endpoint_uri = event_.data; + + } else { + endpoint_uri = uri_base(server_uri_); + if (event_.data[0] != '/') { + endpoint_uri += '/'; + } + endpoint_uri += event_.data; + } + + LOG_INF("SSE: using endpoint \"%s\"\n", endpoint_uri.c_str()); + curl_easy_setopt(endpoint_, CURLOPT_URL, endpoint_uri.c_str()); + + endpoint_headers_ = + curl_slist_append(endpoint_headers_, "Content-Type: application/json"); + curl_slist_append(endpoint_headers_, "Connection: keep-alive"); + curl_easy_setopt(endpoint_, CURLOPT_HTTPHEADER, endpoint_headers_); + curl_easy_setopt(endpoint_, CURLOPT_ERRORBUFFER, &endpoint_errbuf_[0]); + + // Later calls to send will reuse the endpoint_ handle +} + +void toolcall::mcp_sse_transport::on_message_event() { + try { + nlohmann::json message = nlohmann::json::parse(event_.data); + notify(message); + + } catch (const nlohmann::json::exception & err) { + LOG_WRN("SSE: Invalid message \"%s\" received: \"%s\"\n", event_.data.c_str(), err.what()); + } +} + +size_t toolcall::mcp_sse_transport::sse_read(const char * data, size_t len) { + sse_buffer_.insert(sse_buffer_.end(), data, data + len); + + while (sse_cursor_ < sse_buffer_.length()) { + bool last_was_cr = sse_buffer_[sse_cursor_] == '\r'; + bool last_was_lf = sse_buffer_[sse_cursor_] == '\n'; + + if (last_was_cr || last_was_lf) { + auto last = sse_buffer_.begin() + sse_cursor_; + + std::string line(sse_buffer_.begin(), last); + if (line.empty()) { // Dispatch event + if (event_.type == "endpoint") { + if (! endpoint_) { + on_endpoint_event(); + } + + } else if(event_.type == "message") { + on_message_event(); + + } else { + LOG_WRN("SSE: Unsupported event \"%s\" received", event_.type.c_str()); + } + + sse_last_id_ = event_.id; + event_ = {"", "", ""}; + + } else if(line[0] != ':') { + // Comments begin with ":" and + // Field/Value pairs are delimited by ":" + + auto sep_index = line.find(':'); + if (sep_index != std::string::npos) { + auto sep_i = line.begin() + sep_index; + auto val_i = sep_i + (*(sep_i + 1) == ' ' ? 2 : 1); // If value starts with a U+0020 SPACE + // character, remove it from value. + std::string field (line.begin(), sep_i); + std::string value (val_i, line.end()); + + parse_field_value(std::move(field), std::move(value)); + } + } + + if (last++ != sse_buffer_.end()) { // Consume line-end + if (last_was_cr && *last == '\n') { + last ++; + } + sse_buffer_ = std::string(last, sse_buffer_.end()); + + } else { + sse_buffer_.clear(); + } + sse_cursor_ = 0; // Prepare to scan for next line-end + + } else { + sse_cursor_ ++; + } + } + return len; +} + +void toolcall::mcp_sse_transport::sse_run() { + using namespace std::chrono; + + std::unique_lock lock(mutex_); + + char errbuf[CURL_ERROR_SIZE]; + size_t errlen; + CURLMcode mcode; + int num_handles; + struct CURLMsg * m; + int msgs_in_queue = 0; + CURLM * async_handle = nullptr; + struct curl_slist * headers = nullptr; + CURL * sse = nullptr; + steady_clock::time_point start; + + sse = curl_easy_init(); + if (! sse) { + LOG_ERR("SSE: Failed to initialize handle"); + goto cleanup; + } + + headers = curl_slist_append(headers, "Connection: keep-alive"); + + curl_easy_setopt(sse, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(sse, CURLOPT_ERRORBUFFER, errbuf); + curl_easy_setopt(sse, CURLOPT_URL, server_uri_.c_str()); + curl_easy_setopt(sse, CURLOPT_TCP_KEEPALIVE, 1L); + curl_easy_setopt(sse, CURLOPT_WRITEFUNCTION, sse_callback); + curl_easy_setopt(sse, CURLOPT_WRITEDATA, this); + + async_handle = curl_multi_init(); + if (! async_handle) { + LOG_ERR("SSE: Failed to initialize async handle"); + goto cleanup; + } + curl_multi_add_handle(async_handle, sse); + + start = steady_clock::now(); + do { + std::this_thread::sleep_for(milliseconds(50)); + + mcode = curl_multi_perform(async_handle, &num_handles); + if (mcode != CURLM_OK) { + LOG_ERR("SSE: %s", curl_multi_strerror(mcode)); + break; + } + while ((m = curl_multi_info_read(async_handle, &msgs_in_queue)) != nullptr) { + if (m->msg == CURLMSG_DONE) { + if (m->data.result != CURLE_OK) { + errlen = strlen(errbuf); + if (errlen) { + LOG_ERR("SSE: %s", errbuf); + + } else { + LOG_ERR("SSE: %s", curl_easy_strerror(m->data.result)); + } + running_ = false; + break; + } + } + } + + if (endpoint_) { + if (lock.owns_lock()) { + lock.unlock(); + cv_.notify_one(); + } + + } else { + if (steady_clock::now() - start >= seconds(EndpointReceivedTimoutSeconds)) { + running_ = false; + } + } + + } while (running_); + + cleanup: + if (headers) { + curl_slist_free_all(headers); + } + if (async_handle) { + curl_multi_remove_handle(async_handle, sse); + curl_multi_cleanup(async_handle); + } + if (sse) { + curl_easy_cleanup(sse); + } + + if (! lock.owns_lock()) + lock.lock(); // Wait for pending send calls to complete + + if (endpoint_headers_) { + curl_slist_free_all(endpoint_headers_); + } + if (endpoint_) { + curl_easy_cleanup(endpoint_); + } +} diff --git a/toolcall/mcp_sse_transport.h b/toolcall/mcp_sse_transport.h new file mode 100644 index 0000000000000..ebe40919a3f63 --- /dev/null +++ b/toolcall/mcp_sse_transport.h @@ -0,0 +1,52 @@ +#pragma once + +#include "mcp_transport.h" +#include +#include +#include +#include + +namespace toolcall +{ + class mcp_sse_transport : public mcp_transport { + public: + ~mcp_sse_transport(); + + mcp_sse_transport(std::string server_uri); + + virtual void start() override; + virtual void stop() override; + virtual bool send(const std::string & request_json) override; + + size_t sse_read(const char * data, size_t len); + + private: + static const int EndpointReceivedTimoutSeconds; + static const int StartTimeoutSeconds; + + void sse_run(); + void parse_field_value(std::string field, std::string value); + void on_endpoint_event(); + void on_message_event(); + + std::string server_uri_; + bool running_; + std::thread sse_thread_; + CURL * endpoint_; + struct curl_slist * endpoint_headers_; + std::vector endpoint_errbuf_; + + struct sse_event { + std::string type; + std::string data; + std::string id; + } event_; + + std::string sse_buffer_; + size_t sse_cursor_; + std::string sse_last_id_; + + std::mutex mutex_; + std::condition_variable cv_; + }; +} diff --git a/toolcall/mcp_stdio_transport.cpp b/toolcall/mcp_stdio_transport.cpp new file mode 100644 index 0000000000000..d7b98c5391ec8 --- /dev/null +++ b/toolcall/mcp_stdio_transport.cpp @@ -0,0 +1,21 @@ + +#include "mcp_stdio_transport.h" + +#include + +toolcall::mcp_stdio_transport::mcp_stdio_transport(std::vector argv) + : argv_(std::move(argv)) +{ +} + +[[noreturn]] void toolcall::mcp_stdio_transport::start() { + throw std::logic_error(std::string("Function not implemented: ") + __func__); +} + +[[noreturn]] void toolcall::mcp_stdio_transport::stop() { + throw std::logic_error(std::string("Function not implemented: ") + __func__); +} + +[[noreturn]] bool toolcall::mcp_stdio_transport::send(const std::string & /*request_json*/) { + throw std::logic_error(std::string("Function not implemented: ") + __func__); +} diff --git a/toolcall/mcp_stdio_transport.h b/toolcall/mcp_stdio_transport.h new file mode 100644 index 0000000000000..98e6e9295e0ec --- /dev/null +++ b/toolcall/mcp_stdio_transport.h @@ -0,0 +1,21 @@ +#pragma once + +#include "mcp_transport.h" + +#include +#include + +namespace toolcall +{ + class mcp_stdio_transport : public mcp_transport { + public: + mcp_stdio_transport(std::vector argv); + + [[noreturn]] virtual void start() override; + [[noreturn]] virtual void stop() override; + [[noreturn]] virtual bool send(const std::string & request_json) override; + + private: + std::vector argv_; + }; +} diff --git a/toolcall/mcp_transport.h b/toolcall/mcp_transport.h new file mode 100644 index 0000000000000..f2a2fba24d880 --- /dev/null +++ b/toolcall/mcp_transport.h @@ -0,0 +1,98 @@ +#pragma once + +#include "mcp_messages.h" +#include +#include +#include + +namespace toolcall +{ + template + using callback = std::function; + + template + class mcp_message_observer { + public: + template + void subscribe(std::string key, callback callback) { + auto& map = + std::get>>( + subscribers_); + + map.insert({key, callback}); + } + + template + void subscribe(callback callback) { + auto& map = + std::get>>( + subscribers_); + + map.insert({T::Method, callback}); + } + + template + void unsubscribe(std::string key) { + auto& map = + std::get>>( + subscribers_); + + map.erase(key); + } + + void notify(const nlohmann::json & message) { + std::string key; + if (message.contains("id")) { + key = message["id"].dump(); + + } else if (message.contains("method")) { + key = message["method"].dump(); + + } else { + return; + } + std::apply([&key, &message, this](auto&... maps) { + (..., [&] { + auto it = maps.find(key); + if (it != maps.end()) { + using callback_type = decltype(it->second); + using T = typename std::decay::type; + + it->second(T::fromJson(message)); + maps.erase(it); + } + }()); + }, subscribers_); + } + + private: + std::tuple>...> subscribers_; + }; + + class mcp_transport : public mcp_message_observer { + public: + virtual ~mcp_transport() = default; + + template + bool send(const Req & message, callback on_response) { + if (message.id().has_value()) { + std::string id = message.id().value().dump(); + subscribe(id, on_response); + } + return send(message); + } + + template + bool send(const Req & message) { + nlohmann::json json = message.toJson(); + return send(json.dump(-1)); + } + + virtual void start() = 0; + virtual void stop() = 0; + virtual bool send(const std::string & request_json) = 0; + }; +} diff --git a/toolcall/params.cpp b/toolcall/params.cpp new file mode 100644 index 0000000000000..1ad64149aed68 --- /dev/null +++ b/toolcall/params.cpp @@ -0,0 +1,61 @@ + +#include "toolcall-params.h" +#include +#include + +using json = nlohmann::json; + +static bool starts_with(const std::string & str, const std::string & prefix) { + return str.size() >= prefix.size() + && str.compare(0, prefix.size(), prefix) == 0; +} + +toolcall::params::params(std::string tools, std::string choice) { + this->tools(tools); + this->choice(choice); +} + +void toolcall::params::tools(std::string tools) { + try { + if (! tools.empty()) { + if (starts_with(tools, "http")) { +#ifndef LLAMA_USE_CURL + throw std::invalid_argument( + "Model Context Protocol (MCP) only works when llama.cpp is compiled with libcurl"); +#endif + has_uri_ = true; + + } else { + json j = json::parse(tools); // Just for early validation + if (! j.is_array()) { + throw std::invalid_argument( + "tools must be a valid URL or a JSON array containing tool definitions"); + } + has_uri_ = false; + } + } + tools_ = std::move(tools); + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +void toolcall::params::choice(std::string choice) { + try { + if (choice == "auto" || choice == "required" || choice == "none") { + tool_choice_ = std::move(choice); + + } else { + throw std::invalid_argument( + "tool choice must be set to \"auto\", \"required\", or \"none\""); + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +toolcall::params::operator bool() const { + return ! tools_.empty(); +} diff --git a/toolcall/toolcall-client.h b/toolcall/toolcall-client.h new file mode 100644 index 0000000000000..d2259335851d9 --- /dev/null +++ b/toolcall/toolcall-client.h @@ -0,0 +1,116 @@ +#pragma once + +#include "toolcall-params.h" +#include +#include +#include +#include +#include +#include + +namespace toolcall +{ + struct result { + std::string type; + std::string data; + std::string mime_type; + std::optional uri; + bool error; + }; + + using result_set = std::vector; + + class client_impl; + class client { + public: + using ptr = std::shared_ptr; + + client(std::unique_ptr impl) : impl_(std::move(impl)) {} + + result_set call(const std::string & name, + const std::string & arguments, + const std::string & id = ""); + + std::string tool_list(); + bool tool_list_dirty() const; + + const std::string & tool_choice() const; + + void initialize(); + + private: + std::unique_ptr impl_; + }; + + std::shared_ptr create_client(const toolcall::params & params); + + class client_impl { + public: + client_impl(std::string tool_choice) + : tool_choice_(std::move(tool_choice)), tool_list_dirty_(true) {} + + virtual ~client_impl() = default; + + virtual std::string tool_list() = 0; + + virtual bool tool_list_dirty() const { + return tool_list_dirty_; + } + + virtual result_set call(const std::string & name, + const std::string & arguments, + const std::string & id = "") = 0; + + const std::string & tool_choice() const { return tool_choice_; } + + virtual void initialize() {} + + protected: + std::string tool_choice_; + bool tool_list_dirty_; + }; + + class loopback_impl : public client_impl { + public: + loopback_impl(std::string tools, std::string tool_choice) + : client_impl(tool_choice), tools_(std::move(tools)) {} + + virtual std::string tool_list() override { + tool_list_dirty_ = false; + return tools_; + } + + virtual result_set call(const std::string & /* name */, + const std::string & /* arguments */, + const std::string & /* id = "" */) override { + return result_set { + {"text", "", "text/plain", std::nullopt, false} + }; + } + + private: + std::string tools_; + }; + + class mcp_transport; + class mcp_impl : public client_impl { + public: + mcp_impl(std::string server_uri, std::string tool_choice); + mcp_impl(std::vector argv, std::string tool_choice); + + virtual std::string tool_list() override; + + virtual result_set call(const std::string & name, + const std::string & arguments, + const std::string & id = "") override; + + virtual void initialize() override; + + private: + std::unique_ptr transport_; + std::string tools_; + std::mutex tools_mutex_; + std::condition_variable tools_populating_; + int next_id_; + }; +} diff --git a/toolcall/toolcall-params.h b/toolcall/toolcall-params.h new file mode 100644 index 0000000000000..302880230461f --- /dev/null +++ b/toolcall/toolcall-params.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace toolcall +{ + class params { + public: + params(std::string tools = "", std::string choice = "auto"); + + params(const params & other) = default; + params(params && other) noexcept = default; + params & operator=(const params & other) = default; + params & operator=(params && other) noexcept = default; + + operator bool() const; + + void tools(std::string tools); + const std::string & tools() const { return tools_; } + + void choice(std::string choice); + const std::string & choice() const { return tool_choice_; } + + bool has_uri() const { return has_uri_; } + + private: + std::string tools_; + std::string tool_choice_; + bool has_uri_; + }; +} diff --git a/tools/main/CMakeLists.txt b/tools/main/CMakeLists.txt index af3d9150f8640..6cc0b23b11029 100644 --- a/tools/main/CMakeLists.txt +++ b/tools/main/CMakeLists.txt @@ -2,4 +2,9 @@ set(TARGET llama-cli) add_executable(${TARGET} main.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) + +if (LLAMA_TOOLCALL) + target_link_libraries(${TARGET} PRIVATE toolcall) +endif() + target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 1bd2be2d94f51..473397fa4fddf 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -5,6 +5,7 @@ #include "sampling.h" #include "llama.h" #include "chat.h" +#include #include #include @@ -15,6 +16,10 @@ #include #include +#ifdef LLAMA_USE_TOOLCALL +# include "toolcall-client.h" +#endif + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include #include @@ -83,6 +88,99 @@ static void sigint_handler(int signo) { } #endif +class chat_formatter { +public: + + struct result { + std::string formatted; + bool tool_was_called; + }; + + chat_formatter(common_params & params, + std::vector & chat_msgs, + struct common_chat_templates * chat_templates) + + : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates) {} + +#ifdef LLAMA_USE_TOOLCALL + chat_formatter(common_params & params, + std::vector & chat_msgs, + struct common_chat_templates * chat_templates, + const llama_vocab * vocab, + toolcall::client::ptr tc_client) + + : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), + vocab_(vocab), tc_client_(tc_client), + chat_format_(COMMON_CHAT_FORMAT_CONTENT_ONLY), + formatted_() {} +#endif + + chat_formatter::result operator() (const std::string & role, const std::string & content) { + + common_chat_msg new_msg = common_chat_parse(content, chat_format_); + new_msg.role = role; + + common_chat_templates_inputs cinputs; + cinputs.use_jinja = params_.use_jinja; + cinputs.add_generation_prompt = (role == "user"); +#ifdef LLAMA_USE_TOOLCALL + if (tc_client_ != nullptr) { + cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_client_->tool_choice()); + cinputs.tools = common_chat_tools_parse_oaicompat(tc_client_->tool_list()); + } +#endif + cinputs.messages.assign(chat_msgs_.cbegin(), chat_msgs_.cend()); + cinputs.messages.push_back(new_msg); + chat_msgs_.push_back(new_msg); + + bool tool_was_called = false; + if (! new_msg.tool_calls.empty()) { // Call tool and re-prompt + nlohmann::json result_array = nlohmann::json::array(); + for (const auto & tc : new_msg.tool_calls) { + toolcall::result_set res = tc_client_->call(tc.name, tc.arguments, tc.id); + if (! res.empty()) { + for (const auto & r : res) { + result_array.push_back(r.data); + } + } + } + common_chat_msg toolcall_msg; + toolcall_msg.role = "tool"; + toolcall_msg.content = result_array.dump(-1); + + cinputs.add_generation_prompt = true; + cinputs.messages.push_back(toolcall_msg); + chat_msgs_.push_back(toolcall_msg); + + tool_was_called = true; + } + + common_chat_params cparams = common_chat_templates_apply(chat_templates_, cinputs); + std::string formatted = cparams.prompt.substr(formatted_.size(), cparams.prompt.size()); + formatted_ = cparams.prompt; + + LOG_DBG("formatted: '%s'\n", formatted.c_str()); + +#ifdef LLAMA_USE_TOOLCALL + chat_format_ = cparams.format; + common_chat_grammar_to_sampler(&cparams, vocab_, ¶ms_.sampling); +#endif + return chat_formatter::result{std::move(formatted), tool_was_called}; + } + +private: + common_params & params_; + std::vector & chat_msgs_; + struct common_chat_templates * chat_templates_; + +#ifdef LLAMA_USE_TOOLCALL + const llama_vocab * vocab_; + toolcall::client::ptr tc_client_; + common_chat_format chat_format_; + std::string formatted_; +#endif +}; + int main(int argc, char ** argv) { common_params params; g_params = ¶ms; @@ -94,6 +192,11 @@ int main(int argc, char ** argv) { auto & sparams = params.sampling; +#ifdef LLAMA_USE_TOOLCALL + // Ensure parameters are validated before the model loads + toolcall::params tc_params(params.toolcall.tools, params.toolcall.choice); +#endif + // save choice to use color for later // (note for later: this is a slightly awkward choice) console::init(params.simple_io, params.use_color); @@ -263,15 +366,16 @@ int main(int argc, char ** argv) { std::vector embd_inp; bool waiting_for_first_input = false; - auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { - common_chat_msg new_msg; - new_msg.role = role; - new_msg.content = content; - auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja); - chat_msgs.push_back(new_msg); - LOG_DBG("formatted: '%s'\n", formatted.c_str()); - return formatted; - }; + +#ifdef LLAMA_USE_TOOLCALL + auto tc_client = toolcall::create_client(tc_params); + if (tc_client) { + tc_client->initialize(); + } + chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client); +#else + chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get()); +#endif std::string prompt; { @@ -293,6 +397,12 @@ int main(int argc, char ** argv) { inputs.messages = chat_msgs; inputs.add_generation_prompt = !params.prompt.empty(); +#ifdef LLAMA_USE_TOOLCALL + if (tc_client != nullptr) { + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_client->tool_choice()); + inputs.tools = common_chat_tools_parse_oaicompat(tc_client->tool_list()); + } +#endif prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt; } } else { @@ -811,10 +921,17 @@ int main(int argc, char ** argv) { } if (params.enable_chat_template) { - chat_add_and_format("assistant", assistant_ss.str()); + auto format_res = chat_add_and_format("assistant", assistant_ss.str()); + if (format_res.tool_was_called) { + auto format_res_tok = common_tokenize(ctx, format_res.formatted, false, true); + embd_inp.insert(embd_inp.end(), format_res_tok.begin(), format_res_tok.end()); + assistant_ss.str(""); + + } else { + is_interacting = true; + LOG("\n"); + } } - is_interacting = true; - LOG("\n"); } } @@ -894,7 +1011,7 @@ int main(int argc, char ** argv) { bool format_chat = params.conversation_mode && params.enable_chat_template; std::string user_inp = format_chat - ? chat_add_and_format("user", std::move(buffer)) + ? chat_add_and_format("user", std::move(buffer)).formatted : std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);