From 51af21a683e475b27c064a40894516430db13ddb Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 13 Sep 2025 12:12:34 -0300 Subject: [PATCH 01/17] Add Streamable HTTP Client and multiple refactoring and improvements --- .release-manifest.json | 20 +- Cargo.lock | 94 +- Cargo.toml | 17 +- README.md | 27 +- crates/rust-mcp-sdk/Cargo.toml | 41 +- crates/rust-mcp-sdk/README.md | 28 +- crates/rust-mcp-sdk/src/error.rs | 37 +- .../src/hyper_servers/app_state.rs | 11 +- .../src/hyper_servers/routes/hyper_utils.rs | 100 ++- .../src/hyper_servers/routes/sse_routes.rs | 22 +- .../routes/streamable_http_routes.rs | 7 +- .../rust-mcp-sdk/src/hyper_servers/server.rs | 13 +- .../src/hyper_servers/session_store.rs | 24 - crates/rust-mcp-sdk/src/id_generator.rs | 5 + .../src/id_generator/fast_id_generator.rs | 53 ++ .../src/id_generator/uuid_generator.rs | 18 + crates/rust-mcp-sdk/src/lib.rs | 9 +- .../src/mcp_handlers/mcp_server_handler.rs | 51 +- .../mcp_handlers/mcp_server_handler_core.rs | 17 +- .../src/mcp_runtimes/client_runtime.rs | 531 ++++++++--- .../client_runtime/mcp_client_runtime.rs | 24 +- .../client_runtime/mcp_client_runtime_core.rs | 33 +- .../src/mcp_runtimes/server_runtime.rs | 283 ++++-- .../server_runtime/mcp_server_runtime.rs | 21 +- .../server_runtime/mcp_server_runtime_core.rs | 15 +- crates/rust-mcp-sdk/src/mcp_traits.rs | 2 + .../src/mcp_traits/id_generator.rs | 12 + .../rust-mcp-sdk/src/mcp_traits/mcp_client.rs | 49 +- .../src/mcp_traits/mcp_handler.rs | 11 +- .../rust-mcp-sdk/src/mcp_traits/mcp_server.rs | 7 +- crates/rust-mcp-sdk/src/utils.rs | 43 +- crates/rust-mcp-sdk/tests/check_imports.rs | 5 +- crates/rust-mcp-sdk/tests/common/common.rs | 57 +- .../rust-mcp-sdk/tests/common/mock_server.rs | 528 +++++++++++ .../rust-mcp-sdk/tests/common/test_client.rs | 163 ++++ .../rust-mcp-sdk/tests/common/test_server.rs | 31 +- .../tests/test_protocol_compatibility.rs | 2 +- .../tests/test_streamable_http_client.rs | 823 ++++++++++++++++++ ...http.rs => test_streamable_http_server.rs} | 7 +- crates/rust-mcp-transport/Cargo.toml | 4 +- crates/rust-mcp-transport/README.md | 4 +- crates/rust-mcp-transport/src/client_sse.rs | 101 ++- .../src/client_streamable_http.rs | 515 +++++++++++ crates/rust-mcp-transport/src/constants.rs | 3 + crates/rust-mcp-transport/src/error.rs | 71 +- crates/rust-mcp-transport/src/lib.rs | 17 +- crates/rust-mcp-transport/src/mcp_stream.rs | 37 + .../src/message_dispatcher.rs | 82 +- crates/rust-mcp-transport/src/sse.rs | 4 +- crates/rust-mcp-transport/src/stdio.rs | 74 +- crates/rust-mcp-transport/src/transport.rs | 35 +- crates/rust-mcp-transport/src/utils.rs | 28 +- .../src/utils/http_utils.rs | 125 ++- .../src/utils/sse_parser.rs | 320 +++++++ .../src/utils/streamable_http_stream.rs | 374 ++++++++ .../rust-mcp-transport/tests/check_imports.rs | 5 +- development.md | 6 +- doc/getting-started-mcp-server.md | 4 +- .../.gitignore | 0 .../Cargo.toml | 5 +- .../README.md | 8 +- .../src/handler.rs | 8 +- .../src/main.rs | 0 .../src/tools.rs | 0 .../Cargo.toml | 7 +- .../README.md | 8 +- .../src/handler.rs | 5 +- .../src/main.rs | 5 +- .../src/tools.rs | 0 .../.gitignore | 0 .../Cargo.toml | 5 +- .../README.md | 4 +- .../src/handler.rs | 8 +- .../src/main.rs | 0 .../src/tools.rs | 0 .../Cargo.toml | 1 + .../README.md | 2 +- .../src/handler.rs | 11 +- .../Cargo.toml | 5 +- .../README.md | 2 +- .../src/handler.rs | 0 .../src/inquiry_utils.rs | 0 .../src/main.rs | 1 + examples/simple-mcp-client-sse/Cargo.toml | 2 + examples/simple-mcp-client-sse/src/main.rs | 13 +- .../Cargo.toml | 5 +- .../README.md | 2 +- .../src/handler.rs | 0 .../src/inquiry_utils.rs | 0 .../src/main.rs | 0 .../Cargo.toml | 5 +- .../README.md | 2 +- .../src/handler.rs | 0 .../src/inquiry_utils.rs | 0 .../src/main.rs | 0 .../Cargo.toml | 29 + .../README.md | 40 + .../src/handler.rs | 72 ++ .../src/inquiry_utils.rs | 222 +++++ .../src/main.rs | 95 ++ .../Cargo.toml | 29 + .../README.md | 40 + .../src/handler.rs | 10 + .../src/inquiry_utils.rs | 222 +++++ .../src/main.rs | 99 +++ 105 files changed, 5330 insertions(+), 692 deletions(-) create mode 100644 crates/rust-mcp-sdk/src/id_generator.rs create mode 100644 crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs create mode 100644 crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs create mode 100644 crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs create mode 100644 crates/rust-mcp-sdk/tests/common/mock_server.rs create mode 100644 crates/rust-mcp-sdk/tests/common/test_client.rs create mode 100644 crates/rust-mcp-sdk/tests/test_streamable_http_client.rs rename crates/rust-mcp-sdk/tests/{test_streamable_http.rs => test_streamable_http_server.rs} (99%) create mode 100644 crates/rust-mcp-transport/src/client_streamable_http.rs create mode 100644 crates/rust-mcp-transport/src/constants.rs create mode 100644 crates/rust-mcp-transport/src/utils/sse_parser.rs create mode 100644 crates/rust-mcp-transport/src/utils/streamable_http_stream.rs rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/.gitignore (100%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/Cargo.toml (83%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/README.md (81%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/src/handler.rs (97%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/src/main.rs (100%) rename examples/{hello-world-mcp-server-core => hello-world-mcp-server-stdio-core}/src/tools.rs (100%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/Cargo.toml (85%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/README.md (84%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/src/handler.rs (94%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/src/main.rs (92%) rename examples/{hello-world-mcp-server => hello-world-mcp-server-stdio}/src/tools.rs (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/.gitignore (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/Cargo.toml (84%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/README.md (95%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/src/handler.rs (97%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/src/main.rs (100%) rename examples/{hello-world-server-core-streamable-http => hello-world-server-streamable-http-core}/src/tools.rs (100%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/Cargo.toml (88%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/README.md (97%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/src/handler.rs (100%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/src/inquiry_utils.rs (100%) rename examples/{simple-mcp-client-core-sse => simple-mcp-client-sse-core}/src/main.rs (99%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/Cargo.toml (86%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/README.md (97%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/src/handler.rs (100%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/src/inquiry_utils.rs (100%) rename examples/{simple-mcp-client-core => simple-mcp-client-stdio-core}/src/main.rs (100%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/Cargo.toml (87%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/README.md (97%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/src/handler.rs (100%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/src/inquiry_utils.rs (100%) rename examples/{simple-mcp-client => simple-mcp-client-stdio}/src/main.rs (100%) create mode 100644 examples/simple-mcp-client-streamable-http-core/Cargo.toml create mode 100644 examples/simple-mcp-client-streamable-http-core/README.md create mode 100644 examples/simple-mcp-client-streamable-http-core/src/handler.rs create mode 100644 examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs create mode 100644 examples/simple-mcp-client-streamable-http-core/src/main.rs create mode 100644 examples/simple-mcp-client-streamable-http/Cargo.toml create mode 100644 examples/simple-mcp-client-streamable-http/README.md create mode 100644 examples/simple-mcp-client-streamable-http/src/handler.rs create mode 100644 examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs create mode 100644 examples/simple-mcp-client-streamable-http/src/main.rs diff --git a/.release-manifest.json b/.release-manifest.json index 97a0f63..a645da6 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,15 @@ { "crates/rust-mcp-sdk": "0.6.3", "crates/rust-mcp-macros": "0.5.1", - "crates/rust-mcp-transport": "0.5.1", - "examples/hello-world-mcp-server": "0.1.31", - "examples/hello-world-mcp-server-core": "0.1.22", - "examples/simple-mcp-client": "0.1.31", - "examples/simple-mcp-client-core": "0.1.31", - "examples/hello-world-server-core-streamable-http": "0.1.22", - "examples/hello-world-server-streamable-http": "0.1.31", - "examples/simple-mcp-client-core-sse": "0.1.22", - "examples/simple-mcp-client-sse": "0.1.22" + "crates/rust-mcp-transport": "0.5.0", + "examples/hello-world-mcp-server-stdio": "0.1.28", + "examples/hello-world-mcp-server-stdio-core": "0.1.19", + "examples/simple-mcp-client-stdio": "0.1.28", + "examples/simple-mcp-client-stdio-core": "0.1.28", + "examples/hello-world-server-streamable-http-core": "0.1.19", + "examples/hello-world-server-streamable-http": "0.1.28", + "examples/simple-mcp-client-sse-core": "0.1.19", + "examples/simple-mcp-client-sse": "0.1.19", + "examples/simple-mcp-client-streamable-http": "0.1.0", + "examples/simple-mcp-client-streamable-http-core": "0.1.0" } diff --git a/Cargo.lock b/Cargo.lock index c10e354..c3c4462 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -257,10 +257,11 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.34" +version = "1.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42bc4aea80032b7bf409b0bc7ccad88853858911b7713a8062fdc0623867bedc" +checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", @@ -381,9 +382,9 @@ checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" [[package]] name = "deranged" -version = "0.4.0" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" dependencies = [ "powerfmt", ] @@ -451,6 +452,12 @@ dependencies = [ "instant", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e178e4fba8a2726903f6ba98a6d221e76f9c12c650d5dc0e6afdc50677b49650" + [[package]] name = "fnv" version = "1.0.7" @@ -687,8 +694,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] -name = "hello-world-mcp-server" -version = "0.1.31" +name = "hello-world-mcp-server-stdio" +version = "0.1.28" dependencies = [ "async-trait", "futures", @@ -701,8 +708,8 @@ dependencies = [ ] [[package]] -name = "hello-world-mcp-server-core" -version = "0.1.22" +name = "hello-world-mcp-server-stdio-core" +version = "0.1.19" dependencies = [ "async-trait", "futures", @@ -713,8 +720,8 @@ dependencies = [ ] [[package]] -name = "hello-world-server-core-streamable-http" -version = "0.1.22" +name = "hello-world-server-streamable-http" +version = "0.1.31" dependencies = [ "async-trait", "futures", @@ -727,8 +734,8 @@ dependencies = [ ] [[package]] -name = "hello-world-server-streamable-http" -version = "0.1.31" +name = "hello-world-server-streamable-http-core" +version = "0.1.19" dependencies = [ "async-trait", "futures", @@ -1684,6 +1691,7 @@ dependencies = [ "async-trait", "axum", "axum-server", + "base64 0.22.1", "futures", "hyper 1.7.0", "reqwest", @@ -1698,6 +1706,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "wiremock", ] [[package]] @@ -1903,8 +1912,8 @@ dependencies = [ ] [[package]] -name = "simple-mcp-client" -version = "0.1.31" +name = "simple-mcp-client-sse" +version = "0.1.22" dependencies = [ "async-trait", "colored", @@ -1914,11 +1923,13 @@ dependencies = [ "serde_json", "thiserror 2.0.16", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] -name = "simple-mcp-client-core" -version = "0.1.31" +name = "simple-mcp-client-sse-core" +version = "0.1.19" dependencies = [ "async-trait", "colored", @@ -1928,11 +1939,41 @@ dependencies = [ "serde_json", "thiserror 2.0.16", "tokio", + "tracing", + "tracing-subscriber", ] [[package]] -name = "simple-mcp-client-core-sse" -version = "0.1.22" +name = "simple-mcp-client-stdio" +version = "0.1.28" +dependencies = [ + "async-trait", + "colored", + "futures", + "rust-mcp-sdk", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", +] + +[[package]] +name = "simple-mcp-client-stdio-core" +version = "0.1.28" +dependencies = [ + "async-trait", + "colored", + "futures", + "rust-mcp-sdk", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", +] + +[[package]] +name = "simple-mcp-client-streamable-http" +version = "0.1.0" dependencies = [ "async-trait", "colored", @@ -1947,8 +1988,8 @@ dependencies = [ ] [[package]] -name = "simple-mcp-client-sse" -version = "0.1.22" +name = "simple-mcp-client-streamable-http-core" +version = "0.1.0" dependencies = [ "async-trait", "colored", @@ -2088,12 +2129,11 @@ dependencies = [ [[package]] name = "time" -version = "0.3.41" +version = "0.3.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +checksum = "8ca967379f9d8eb8058d86ed467d81d03e81acd45757e4ca341c24affbe8e8e3" dependencies = [ "deranged", - "itoa", "num-conv", "powerfmt", "serde", @@ -2103,15 +2143,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" +checksum = "a9108bb380861b07264b950ded55a44a14a4adc68b9f5efd85aafc3aa4d40a68" [[package]] name = "time-macros" -version = "0.2.22" +version = "0.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +checksum = "7182799245a7264ce590b349d90338f1c1affad93d2639aed5f8f69c090b334c" dependencies = [ "num-conv", "time-core", diff --git a/Cargo.toml b/Cargo.toml index b4f7cca..711204d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,14 +4,17 @@ members = [ "crates/rust-mcp-macros", "crates/rust-mcp-sdk", "crates/rust-mcp-transport", - "examples/simple-mcp-client", - "examples/simple-mcp-client-core", - "examples/hello-world-mcp-server", - "examples/hello-world-mcp-server-core", + "examples/simple-mcp-client-stdio", + "examples/simple-mcp-client-stdio-core", + "examples/hello-world-mcp-server-stdio", + "examples/hello-world-mcp-server-stdio-core", "examples/hello-world-server-streamable-http", - "examples/hello-world-server-core-streamable-http", + "examples/hello-world-server-streamable-http-core", "examples/simple-mcp-client-sse", - "examples/simple-mcp-client-core-sse", + "examples/simple-mcp-client-sse-core", + "examples/simple-mcp-client-streamable-http", + "examples/simple-mcp-client-streamable-http-core", + ] [workspace.dependencies] @@ -39,7 +42,7 @@ tracing-subscriber = { version = "0.3", features = [ "std", "fmt", ] } - +base64 = "0.22" axum = "0.8" rustls = "0.23" tokio-rustls = "0.26" diff --git a/README.md b/README.md index 1581d1d..b1af670 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [build status ](https://github.com/rust-mcp-stack/rust-mcp-sdk/actions/workflows/ci.yml) [Hello World MCP Server -](examples/hello-world-mcp-server) +](examples/hello-world-mcp-server-stdio) A high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! @@ -42,6 +42,17 @@ This project supports following transports: - ⬜ Resumability - ⬜ Authentication / Oauth + + +**MCP Streamable HTTP Support** +- [x] Streamable HTTP Support for MCP Servers +- [x] DNS Rebinding Protection +- [x] Batch Messages +- [x] Streaming & non-streaming JSON response +- [ ] Streamable HTTP Support for MCP Clients +- [ ] Resumability +- [ ] Authentication / Oauth + **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents @@ -110,7 +121,7 @@ async fn main() -> SdkResult<()> { } ``` -See hello-world-mcp-server example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) @@ -180,7 +191,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +202,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc, ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) @@ -205,7 +216,7 @@ impl ServerHandler for MyServerHandler { --- -👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** +👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : @@ -477,10 +488,10 @@ Learn when to use the `mcp_*_handler` traits versus the lower-level `mcp_*_hand [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) provides two type of handler traits that you can chose from: - **ServerHandler**: This is the recommended trait for your MCP project, offering a default implementation for all types of MCP messages. It includes predefined implementations within the trait, such as handling initialization or responding to ping requests, so you only need to override and customize the handler functions relevant to your specific needs. - Refer to [examples/hello-world-mcp-server/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio/src/handler.rs) for an example. - **ServerHandlerCore**: If you need more control over MCP messages, consider using `ServerHandlerCore`. It offers three primary methods to manage the three MCP message types: `request`, `notification`, and `error`. While still providing type-safe objects in these methods, it allows you to determine how to handle each message based on its type and parameters. - Refer to [examples/hello-world-mcp-server-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core/src/handler.rs) for an example. --- @@ -509,7 +520,7 @@ Both functions create an MCP client instance. -Check out the corresponding examples at: [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) and [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core). +Check out the corresponding examples at: [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) and [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core). ## Projects using Rust MCP SDK diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 48ea665..3fd9ec2 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -24,15 +24,17 @@ futures = { workspace = true } thiserror = { workspace = true } axum = { workspace = true, optional = true } -uuid = { workspace = true, features = ["v4"], optional = true } +uuid = { workspace = true, features = ["v4"] } tokio-stream = { workspace = true, optional = true } axum-server = { version = "0.7", features = [], optional = true } tracing.workspace = true +base64.workspace = true # rustls = { workspace = true, optional = true } hyper = { version = "1.6.0", optional = true } [dev-dependencies] +wiremock = "0.5" reqwest = { workspace = true, default-features = false, features = [ "stream", "rustls-tls", @@ -51,47 +53,54 @@ default = [ "client", "server", "macros", + "stdio", + "sse", + "streamable-http", "hyper-server", "ssl", "2025_06_18", ] # All features enabled by default -server = ["rust-mcp-transport/stdio"] # Server feature -client = ["rust-mcp-transport/stdio", "rust-mcp-transport/sse"] # Client feature -hyper-server = [ - "axum", - "axum-server", - "hyper", - "server", - "uuid", - "tokio-stream", - "rust-mcp-transport/sse", -] + +sse = ["rust-mcp-transport/sse"] +streamable-http = ["rust-mcp-transport/streamable-http"] +stdio = ["rust-mcp-transport/stdio"] + +server = [] # Server feature +client = [] # Client feature +hyper-server = ["axum", "axum-server", "hyper", "server", "tokio-stream"] ssl = ["axum-server/tls-rustls"] macros = ["rust-mcp-macros/sdk"] -# enables mcp protocol version 2025_06_18 -2025_06_18 = [ +# enables mcp protocol version 2025-06-18 +2025-06-18 = [ "rust-mcp-schema/2025_06_18", "rust-mcp-macros/2025_06_18", "rust-mcp-transport/2025_06_18", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2025_06_18 = ["2025-06-18"] # enables mcp protocol version 2025_03_26 -2025_03_26 = [ +2025-03-26 = [ "rust-mcp-schema/2025_03_26", "rust-mcp-macros/2025_03_26", "rust-mcp-transport/2025_03_26", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2025_03_26 = ["2025-03-26"] + # enables mcp protocol version 2024_11_05 -2024_11_05 = [ +2024-11-05 = [ "rust-mcp-schema/2024_11_05", "rust-mcp-macros/2024_11_05", "rust-mcp-transport/2024_11_05", "rust-mcp-schema/schema_utils", ] +# Alias: allow users to use underscores instead of hyphens +2024_11_05 = ["2024-11-05"] [lints] workspace = true diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 1581d1d..9df027d 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -9,7 +9,7 @@ [build status ](https://github.com/rust-mcp-stack/rust-mcp-sdk/actions/workflows/ci.yml) [Hello World MCP Server -](examples/hello-world-mcp-server) +](examples/hello-world-mcp-server-stdio) A high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while **rust-mcp-sdk** takes care of the rest! @@ -32,7 +32,6 @@ This project supports following transports: 🚀 The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - **MCP Streamable HTTP Support** - ✅ Streamable HTTP Support for MCP Servers - ✅ DNS Rebinding Protection @@ -42,6 +41,17 @@ This project supports following transports: - ⬜ Resumability - ⬜ Authentication / Oauth + + +**MCP Streamable HTTP Support** +- [x] Streamable HTTP Support for MCP Servers +- [x] DNS Rebinding Protection +- [x] Batch Messages +- [x] Streaming & non-streaming JSON response +- [ ] Streamable HTTP Support for MCP Clients +- [ ] Resumability +- [ ] Authentication / Oauth + **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents @@ -110,7 +120,7 @@ async fn main() -> SdkResult<()> { } ``` -See hello-world-mcp-server example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-mcp-server-stdio example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) @@ -180,7 +190,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +201,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) @@ -205,7 +215,7 @@ impl ServerHandler for MyServerHandler { --- -👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** +👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : @@ -477,10 +487,10 @@ Learn when to use the `mcp_*_handler` traits versus the lower-level `mcp_*_hand [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) provides two type of handler traits that you can chose from: - **ServerHandler**: This is the recommended trait for your MCP project, offering a default implementation for all types of MCP messages. It includes predefined implementations within the trait, such as handling initialization or responding to ping requests, so you only need to override and customize the handler functions relevant to your specific needs. - Refer to [examples/hello-world-mcp-server/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio/src/handler.rs) for an example. - **ServerHandlerCore**: If you need more control over MCP messages, consider using `ServerHandlerCore`. It offers three primary methods to manage the three MCP message types: `request`, `notification`, and `error`. While still providing type-safe objects in these methods, it allows you to determine how to handle each message based on its type and parameters. - Refer to [examples/hello-world-mcp-server-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core/src/handler.rs) for an example. + Refer to [examples/hello-world-mcp-server-stdio-core/src/handler.rs](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core/src/handler.rs) for an example. --- @@ -509,7 +519,7 @@ Both functions create an MCP client instance. -Check out the corresponding examples at: [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) and [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core). +Check out the corresponding examples at: [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) and [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core). ## Projects using Rust MCP SDK diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 3de8d98..3879526 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -11,25 +11,36 @@ pub type SdkResult = core::result::Result; #[derive(Debug, Error)] pub enum McpSdkError { + #[error("Transport error: {0}")] + Transport(#[from] TransportError), + + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("{0}")] RpcError(#[from] RpcError), + #[error("{0}")] - IoError(#[from] std::io::Error), - #[error("{0}")] - TransportError(#[from] TransportError), - #[error("{0}")] - JoinError(#[from] JoinError), - #[error("{0}")] - AnyError(Box<(dyn std::error::Error + Send + Sync)>), - #[error("{0}")] - SdkError(#[from] crate::schema::schema_utils::SdkError), + Join(#[from] JoinError), + #[cfg(feature = "hyper-server")] #[error("{0}")] - TransportServerError(#[from] TransportServerError), - #[error("Incompatible mcp protocol version: requested:{0} current:{1}")] - IncompatibleProtocolVersion(String, String), + HyperServer(#[from] TransportServerError), + #[error("{0}")] - ParseProtocolVersionError(#[from] ParseProtocolVersionError), + SdkError(#[from] crate::schema::schema_utils::SdkError), + + #[error("Protocol error: {kind}")] + Protocol { kind: ProtocolErrorKind }, +} + +// Sub-enum for protocol-related errors +#[derive(Debug, Error)] +pub enum ProtocolErrorKind { + #[error("Incompatible protocol version: requested {requested}, current {current}")] + IncompatibleVersion { requested: String, current: String }, + #[error("Failed to parse protocol version: {0}")] + ParseError(#[from] ParseProtocolVersionError), } impl McpSdkError { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index 0c1dcf3..ff6d5b2 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -1,11 +1,9 @@ use std::{sync::Arc, time::Duration}; -use crate::schema::InitializeResult; -use rust_mcp_transport::TransportOptions; - +use super::session_store::SessionStore; use crate::mcp_traits::mcp_handler::McpServerHandler; - -use super::{session_store::SessionStore, IdGenerator}; +use crate::{id_generator::FastIdGenerator, mcp_traits::IdGenerator, schema::InitializeResult}; +use rust_mcp_transport::{SessionId, TransportOptions}; /// Application state struct for the Hyper server /// @@ -14,7 +12,8 @@ use super::{session_store::SessionStore, IdGenerator}; #[derive(Clone)] pub struct AppState { pub session_store: Arc, - pub id_generator: Arc, + pub id_generator: Arc>, + pub stream_id_gen: Arc, pub server_details: Arc, pub handler: Arc, pub ping_interval: Duration, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index 0a77913..da69c67 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -6,7 +6,7 @@ use crate::{ }, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, - mcp_traits::mcp_handler::McpServerHandler, + mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, utils::validate_mcp_protocol_version, }; @@ -22,13 +22,12 @@ use axum::{ }; use futures::stream; use hyper::{header, HeaderMap, StatusCode}; -use rust_mcp_transport::{SessionId, SseTransport}; +use rust_mcp_transport::{ + SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, +}; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; -pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; -pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; - const DUPLEX_BUFFER_SIZE: usize = 8192; async fn create_sse_stream( @@ -41,11 +40,11 @@ async fn create_sse_stream( let payload_string = payload.map(|p| p.to_string()); // TODO: this logic should be moved out after refactoing the mcp_stream.rs - let result = payload_string + let payload_contains_request = payload_string .as_ref() .map(|json_str| contains_request(json_str)) .unwrap_or(Ok(false)); - let Ok(payload_contains_request) = result else { + let Ok(payload_contains_request) = payload_contains_request else { return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); }; @@ -54,18 +53,20 @@ async fn create_sse_stream( // writable stream to deliver message to the client let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - let transport = SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + let transport = Arc::new( + SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?, + ); - let stream_id = if standalone { + let stream_id: StreamId = if standalone { DEFAULT_STREAM_ID.to_string() } else { - state.id_generator.generate() + state.stream_id_gen.generate() }; let ping_interval = state.ping_interval; let runtime_clone = Arc::clone(&runtime); @@ -85,6 +86,7 @@ async fn create_sse_stream( // Construct SSE stream let reader = BufReader::new(write_rx); + // outgoing messages from server to the client let message_stream = stream::unfold(reader, |mut reader| async move { let mut line = String::new(); @@ -117,12 +119,12 @@ async fn create_sse_stream( // TODO: this function will be removed after refactoring the readable stream of the transports // so we would deserialize the string syncronousely and have more control over the flow -// this function could potentially add a 20-250 ns overhead which could be avoided +// this function may incur a slight runtime cost which could be avoided after refactoring fn contains_request(json_str: &str) -> Result { let value: serde_json::Value = serde_json::from_str(json_str)?; match value { serde_json::Value::Object(obj) => Ok(obj.contains_key("id") && obj.contains_key("method")), - serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + serde_json::Value::Array(arr) => Ok(arr.iter().any(|item| { item.as_object() .map(|obj| obj.contains_key("id") && obj.contains_key("method")) .unwrap_or(false) @@ -131,6 +133,19 @@ fn contains_request(json_str: &str) -> Result { } } +fn is_result(json_str: &str) -> Result { + let value: serde_json::Value = serde_json::from_str(json_str)?; + match value { + serde_json::Value::Object(obj) => Ok(obj.contains_key("result")), + serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + item.as_object() + .map(|obj| obj.contains_key("result")) + .unwrap_or(false) + })), + _ => Ok(false), + } +} + pub async fn create_standalone_stream( session_id: SessionId, state: Arc, @@ -166,11 +181,11 @@ pub async fn start_new_session( let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let runtime: Arc = Arc::new(server_runtime::create_server_instance( + let runtime: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); + ); tracing::info!("a new client joined : {}", &session_id); @@ -224,7 +239,12 @@ async fn single_shot_stream( tokio::spawn(async move { match runtime_clone - .start_stream(transport, &stream_id, ping_interval, payload_string) + .start_stream( + Arc::new(transport), + &stream_id, + ping_interval, + payload_string, + ) .await { Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id), @@ -233,7 +253,6 @@ async fn single_shot_stream( let _ = runtime.remove_transport(&stream_id).await; }); - // Construct SSE stream let mut reader = BufReader::new(write_rx); let mut line = String::new(); let response = match reader.read_line(&mut line).await { @@ -310,15 +329,34 @@ pub async fn process_incoming_message( match state.session_store.get(&session_id).await { Some(runtime) => { let runtime = runtime.lock().await.to_owned(); - - create_sse_stream( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - ) - .await + // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport + // it should be processed by the same transport , therefore no need to call create_sse_stream + let Ok(is_result) = is_result(payload) else { + return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); + }; + + if is_result { + match runtime + .consume_payload_string(DEFAULT_STREAM_ID, payload) + .await + { + Ok(()) => Ok((StatusCode::ACCEPTED, Json(())).into_response()), + Err(err) => Ok(( + StatusCode::BAD_REQUEST, + Json(SdkError::internal_error().with_message(err.to_string().as_ref())), + ) + .into_response()), + } + } else { + create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + ) + .await + } } None => { let error = SdkError::session_not_found(); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index e1c00f8..27a16b2 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,3 +1,4 @@ +use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::ClientMessage; use crate::{ hyper_servers::{ @@ -90,20 +91,24 @@ pub async fn handle_sse( let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); // create a transport for sending/receiving messages - let transport = SseTransport::new( + let Ok(transport) = SseTransport::new( read_rx, write_tx, read_tx, Arc::clone(&state.transport_options), - ) - .unwrap(); + ) else { + return Err(TransportServerError::TransportError( + "Failed to create SSE transport".to_string(), + )); + }; + let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let server: Arc = Arc::new(server_runtime::create_server_instance( + let server: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); + ); state .session_store @@ -115,7 +120,12 @@ pub async fn handle_sse( // Start the server tokio::spawn(async move { match server - .start_stream(transport, DEFAULT_STREAM_ID, state.ping_interval, None) + .start_stream( + Arc::new(transport), + DEFAULT_STREAM_ID, + state.ping_interval, + None, + ) .await { Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()), diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 83cc372..00d46c0 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,4 +1,4 @@ -use super::hyper_utils::{start_new_session, MCP_SESSION_ID_HEADER}; +use super::hyper_utils::start_new_session; use crate::schema::schema_utils::SdkError; use crate::{ error::McpSdkError, @@ -14,6 +14,7 @@ use crate::{ }, utils::valid_initialize_method, }; +use axum::routing::get; use axum::{ extract::{Query, State}, middleware, @@ -22,11 +23,9 @@ use axum::{ Json, Router, }; use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::SessionId; +use rust_mcp_transport::{SessionId, MCP_SESSION_ID_HEADER}; use std::{collections::HashMap, sync::Arc}; -use axum::routing::get; - pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { Router::new() .route(streamable_http_endpoint, get(handle_streamable_http_get)) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index f093da3..1c3b3cf 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,6 +1,8 @@ use crate::{ - error::SdkResult, mcp_server::hyper_runtime::HyperRuntime, - mcp_traits::mcp_handler::McpServerHandler, + error::SdkResult, + id_generator::{FastIdGenerator, UuidGenerator}, + mcp_server::hyper_runtime::HyperRuntime, + mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, }; #[cfg(feature = "ssl")] use axum_server::tls_rustls::RustlsConfig; @@ -17,11 +19,11 @@ use super::{ app_state::AppState, error::{TransportServerError, TransportServerResult}, routes::app_routes, - IdGenerator, InMemorySessionStore, UuidGenerator, + InMemorySessionStore, }; use crate::schema::InitializeResult; use axum::Router; -use rust_mcp_transport::TransportOptions; +use rust_mcp_transport::{SessionId, TransportOptions}; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); @@ -43,7 +45,7 @@ pub struct HyperServerOptions { pub port: u16, /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, + pub session_id_generator: Option>>, /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, @@ -258,6 +260,7 @@ impl HyperServer { .session_id_generator .take() .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)), + stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))), server_details: Arc::new(server_details), handler, ping_interval: server_options.ping_interval, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs index 95b2158..4384b1a 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs @@ -5,7 +5,6 @@ use async_trait::async_trait; pub use in_memory::*; use rust_mcp_transport::SessionId; use tokio::sync::Mutex; -use uuid::Uuid; use crate::mcp_server::ServerRuntime; @@ -46,26 +45,3 @@ pub trait SessionStore: Send + Sync { async fn has(&self, session: &SessionId) -> bool; } - -/// Trait for generating session identifiers -/// -/// Implementors must be Send and Sync to support concurrent access. -pub trait IdGenerator: Send + Sync { - fn generate(&self) -> SessionId; -} - -/// Struct implementing the IdGenerator trait using UUID v4 -/// -/// This is a simple wrapper around the uuid crate's Uuid::new_v4 function -/// to generate unique session identifiers. -pub struct UuidGenerator {} - -impl IdGenerator for UuidGenerator { - /// Generates a new UUID v4-based session identifier - /// - /// # Returns - /// * `SessionId` - A new UUID-based session identifier as a String - fn generate(&self) -> SessionId { - Uuid::new_v4().to_string() - } -} diff --git a/crates/rust-mcp-sdk/src/id_generator.rs b/crates/rust-mcp-sdk/src/id_generator.rs new file mode 100644 index 0000000..54f0e72 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator.rs @@ -0,0 +1,5 @@ +mod fast_id_generator; +mod uuid_generator; +pub use crate::mcp_traits::IdGenerator; +pub use fast_id_generator::*; +pub use uuid_generator::*; diff --git a/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs b/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs new file mode 100644 index 0000000..fc2e976 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator/fast_id_generator.rs @@ -0,0 +1,53 @@ +use crate::mcp_traits::IdGenerator; +use base64::Engine; +use std::sync::atomic::{AtomicU64, Ordering}; + +/// An [`IdGenerator`] implementation optimized for lightweight, locally-scoped identifiers. +/// +/// This generator produces short, incrementing identifiers that are Base64-encoded. +/// This makes it well-suited for cases such as `StreamId` generation, where: +/// - IDs only need to be unique within a single process or session +/// - Predictability is acceptable +/// - Shorter, more human-readable identifiers are desirable +/// +pub struct FastIdGenerator { + counter: AtomicU64, + ///Optional prefix for readability + prefix: &'static str, +} + +impl FastIdGenerator { + /// Creates a new ID generator with an optional prefix. + /// + /// # Arguments + /// * `prefix` - A static string to prepend to IDs (e.g., "sid_"). + pub fn new(prefix: Option<&'static str>) -> Self { + FastIdGenerator { + counter: AtomicU64::new(0), + prefix: prefix.unwrap_or_default(), + } + } +} + +impl IdGenerator for FastIdGenerator +where + T: From, +{ + /// Generates a new session ID as a short Base64-encoded string. + /// + /// Increments an internal counter atomically and encodes it in Base64 URL-safe format. + /// The resulting ID is prefixed (if provided) and typically 8–12 characters long. + /// + /// # Returns + /// * `SessionId` - A short, unique session ID (e.g., "sid_BBBB" or "BBBB"). + fn generate(&self) -> T { + let id = self.counter.fetch_add(1, Ordering::Relaxed); + let bytes = id.to_le_bytes(); + let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + if self.prefix.is_empty() { + T::from(encoded) + } else { + T::from(format!("{}{}", self.prefix, encoded)) + } + } +} diff --git a/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs b/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs new file mode 100644 index 0000000..2f0dc21 --- /dev/null +++ b/crates/rust-mcp-sdk/src/id_generator/uuid_generator.rs @@ -0,0 +1,18 @@ +use crate::mcp_traits::IdGenerator; +use uuid::Uuid; + +/// An [`IdGenerator`] implementation that uses UUID v4 to create unique identifiers. +/// +/// This generator produces random UUIDs (version 4), which are highly unlikely +/// to collide and difficult to predict. It is therefore well-suited for +/// generating identifiers such as `SessionId` or other values where uniqueness is important. +pub struct UuidGenerator; + +impl IdGenerator for UuidGenerator +where + T: From, +{ + fn generate(&self) -> T { + T::from(Uuid::new_v4().to_string()) + } +} diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index 1ea23df..a33f889 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -21,7 +21,7 @@ pub mod mcp_client { //! responding to ping requests, so you only need to override and customize the handler //! functions relevant to your specific needs. //! - //! Refer to [examples/simple-mcp-client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) for an example. + //! Refer to [examples/simple-mcp-client-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) for an example. //! //! //! - **client_runtime_core**: If you need more control over MCP messages, consider using @@ -30,7 +30,7 @@ pub mod mcp_client { //! While still providing type-safe objects in these methods, it allows you to determine how to //! handle each message based on its type and parameters. //! - //! Refer to [examples/simple-mcp-client-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core) for an example. + //! Refer to [examples/simple-mcp-client-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core) for an example. pub use super::mcp_handlers::mcp_client_handler::ClientHandler; pub use super::mcp_handlers::mcp_client_handler_core::ClientHandlerCore; pub use super::mcp_runtimes::client_runtime::mcp_client_runtime as client_runtime; @@ -53,7 +53,7 @@ pub mod mcp_server { //! responding to ping requests, so you only need to override and customize the handler //! functions relevant to your specific needs. //! - //! Refer to [examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) for an example. + //! Refer to [examples/hello-world-mcp-server-stdio](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) for an example. //! //! //! - **server_runtime_core**: If you need more control over MCP messages, consider using @@ -62,7 +62,7 @@ pub mod mcp_server { //! While still providing type-safe objects in these methods, it allows you to determine how to //! handle each message based on its type and parameters. //! - //! Refer to [examples/hello-world-mcp-server-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core) for an example. + //! Refer to [examples/hello-world-mcp-server-stdio-core](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core) for an example. pub use super::mcp_handlers::mcp_server_handler::ServerHandler; pub use super::mcp_handlers::mcp_server_handler_core::ServerHandlerCore; @@ -93,4 +93,5 @@ pub mod macros { pub use rust_mcp_macros::*; } +pub mod id_generator; pub mod schema; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index 89aebf5..9b9577e 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -1,6 +1,7 @@ use crate::schema::{schema_utils::CallToolError, *}; use async_trait::async_trait; use serde_json::Value; +use std::sync::Arc; use crate::{mcp_traits::mcp_server::McpServer, utils::enforce_compatible_protocol_version}; @@ -15,7 +16,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, runtime: &dyn McpServer) {} + async fn on_initialized(&self, runtime: Arc) {} /// Handles the InitializeRequest from a client. /// @@ -29,7 +30,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialize_request( &self, initialize_request: InitializeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let mut server_info = runtime.server_info().to_owned(); // Provide compatibility for clients using older MCP protocol versions. @@ -65,7 +66,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_ping_request( &self, _: PingRequest, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result { Ok(Result::default()) } @@ -77,7 +78,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resources_request( &self, request: ListResourcesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -93,7 +94,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resource_templates_request( &self, request: ListResourceTemplatesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -109,7 +110,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_read_resource_request( &self, request: ReadResourceRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -125,7 +126,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_subscribe_request( &self, request: SubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -141,7 +142,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_unsubscribe_request( &self, request: UnsubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -157,7 +158,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_prompts_request( &self, request: ListPromptsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -173,7 +174,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_get_prompt_request( &self, request: GetPromptRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -189,7 +190,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -205,7 +206,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) @@ -220,7 +221,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_set_level_request( &self, request: SetLevelRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -236,7 +237,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_complete_request( &self, request: CompleteRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -252,7 +253,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_custom_request( &self, request: Value, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Err(RpcError::method_not_found() .with_message("No handler is implemented for custom requests.".to_string())) @@ -265,7 +266,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialized_notification( &self, notification: InitializedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -275,7 +276,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_cancelled_notification( &self, notification: CancelledNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -285,7 +286,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_progress_notification( &self, notification: ProgressNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -295,7 +296,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_roots_list_changed_notification( &self, notification: RootsListChangedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -320,18 +321,8 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_error( &self, error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } - - /// Called when the server has successfully started. - /// - /// Sends a "Server started successfully" message to stderr. - /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index e7b0e6d..9275da7 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -1,8 +1,8 @@ +use crate::mcp_traits::mcp_server::McpServer; use crate::schema::schema_utils::*; use crate::schema::*; use async_trait::async_trait; - -use crate::mcp_traits::mcp_server::McpServer; +use std::sync::Arc; /// Defines the `ServerHandlerCore` trait for handling Model Context Protocol (MCP) server operations. /// Unlike `ServerHandler`, this trait offers no default implementations, providing full control over MCP message handling @@ -14,7 +14,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, _runtime: &dyn McpServer) {} + async fn on_initialized(&self, _runtime: Arc) {} /// Asynchronously handles an incoming request from the client. /// @@ -26,7 +26,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; /// Asynchronously handles an incoming notification from the client. @@ -36,7 +36,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_notification( &self, notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError>; /// Asynchronously handles an error received from the client. @@ -46,11 +46,6 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_error( &self, error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError>; - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 7ee0815..9961b84 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -1,12 +1,17 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; - +use crate::error::{McpSdkError, SdkResult}; +use crate::id_generator::FastIdGenerator; +use crate::mcp_traits::mcp_client::McpClient; +use crate::mcp_traits::mcp_handler::McpClientHandler; +use crate::mcp_traits::IdGenerator; +use crate::utils::ensure_server_protocole_compatibility; use crate::{ mcp_traits::{RequestIdGen, RequestIdGenNumeric}, schema::{ schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, - ServerMessages, + self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, + ServerMessage, ServerMessages, }, InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, RequestId, RpcError, ServerResult, @@ -16,63 +21,100 @@ use async_trait::async_trait; use futures::future::{join_all, try_join_all}; use futures::StreamExt; -use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; -use std::{ - sync::{Arc, RwLock}, - time::Duration, -}; +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::{ClientStreamableTransport, StreamableTransportOptions}; +use rust_mcp_transport::{IoStream, SessionId, StreamId, Transport, TransportDispatcher}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::sync::Mutex; +use tokio::sync::{watch, Mutex}; -use crate::error::{McpSdkError, SdkResult}; -use crate::mcp_traits::mcp_client::McpClient; -use crate::mcp_traits::mcp_handler::McpClientHandler; -use crate::utils::ensure_server_protocole_compatibility; +pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; + +// Define a type alias for the TransportDispatcher trait object +type TransportDispatcherType = dyn TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, +>; +type TransportType = Arc; pub struct ClientRuntime { - // The transport interface for handling messages between client and server - transport: Arc< - dyn Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, - >, + // A thread-safe map storing transport types + transport_map: tokio::sync::RwLock>, // The handler for processing MCP messages handler: Box, - // // Information about the server + // Information about the server client_details: InitializeRequestParams, - // Details about the connected server - server_details: Arc>>, handlers: Mutex>>>, + // Generator for unique request IDs request_id_gen: Box, + // Generator for stream IDs + stream_id_gen: FastIdGenerator, + #[cfg(feature = "streamable-http")] + // Optional configuration for streamable transport + transport_options: Option, + // Flag indicating whether the client has been shut down + is_shut_down: Mutex, + // Session ID + session_id: tokio::sync::RwLock>, + // Details about the connected server + server_details_tx: watch::Sender>, + server_details_rx: watch::Receiver>, } impl ClientRuntime { pub(crate) fn new( client_details: InitializeRequestParams, - transport: impl Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, + transport: TransportType, handler: Box, ) -> Self { + let mut map: HashMap = HashMap::new(); + map.insert(DEFAULT_STREAM_ID.to_string(), transport); + let (server_details_tx, server_details_rx) = + watch::channel::>(None); Self { - transport: Arc::new(transport), + transport_map: tokio::sync::RwLock::new(map), handler, client_details, - server_details: Arc::new(RwLock::new(None)), handlers: Mutex::new(vec![]), request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + #[cfg(feature = "streamable-http")] + transport_options: None, + is_shut_down: Mutex::new(false), + session_id: tokio::sync::RwLock::new(None), + stream_id_gen: FastIdGenerator::new(Some("s_")), + server_details_tx, + server_details_rx, } } - async fn initialize_request(&self) -> SdkResult<()> { + #[cfg(feature = "streamable-http")] + pub(crate) fn new_instance( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: Box, + ) -> Self { + let map: HashMap = HashMap::new(); + let (server_details_tx, server_details_rx) = + watch::channel::>(None); + Self { + transport_map: tokio::sync::RwLock::new(map), + handler, + client_details, + handlers: Mutex::new(vec![]), + transport_options: Some(transport_options), + is_shut_down: Mutex::new(false), + session_id: tokio::sync::RwLock::new(None), + request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + stream_id_gen: FastIdGenerator::new(Some("s_")), + server_details_tx, + server_details_rx, + } + } + + async fn initialize_request(self: Arc) -> SdkResult<()> { let request = InitializeRequest::new(self.client_details.clone()); let result: ServerResult = self.request(request.into(), None).await?.try_into()?; @@ -81,9 +123,15 @@ impl ClientRuntime { &self.client_details.protocol_version, &initialize_result.protocol_version, )?; - // store server details self.set_server_details(initialize_result)?; + + #[cfg(feature = "streamable-http")] + // try to create a sse stream for server initiated messages , if supported by the server + if let Err(error) = self.clone().create_sse_stream().await { + tracing::warn!("{error}"); + } + // send a InitializedNotification to the server self.send_notification(InitializedNotification::new(None).into()) .await?; @@ -92,21 +140,14 @@ impl ClientRuntime { .with_message("Incorrect response to InitializeRequest!".into()) .into()); } + Ok(()) } pub(crate) async fn handle_message( &self, message: ServerMessage, - transport: &Arc< - dyn Transport< - ServerMessages, - MessageFromClient, - ServerMessage, - ClientMessages, - ClientMessage, - >, - >, + transport: &TransportType, ) -> SdkResult> { let response = match message { ServerMessage::Request(jsonrpc_request) => { @@ -162,28 +203,26 @@ impl ClientRuntime { }; Ok(response) } -} -#[async_trait] -impl McpClient for ClientRuntime { - fn sender(&self) -> Arc>>> - where - MessageDispatcher: - McpDispatch, - { - (self.transport.message_sender().clone()) as _ - } + async fn start_standalone(self: Arc) -> SdkResult<()> { + let self_clone = self.clone(); + let transport_map = self_clone.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; - async fn start(self: Arc) -> SdkResult<()> { //TODO: improve the flow - let mut stream = self.transport.start().await?; - let transport = self.transport.clone(); + let mut stream = transport.start().await?; + + let transport_clone = transport.clone(); let mut error_io_stream = transport.error_stream().write().await; let error_io_stream = error_io_stream.take(); let self_clone = Arc::clone(&self); let self_clone_err = Arc::clone(&self); + // task reading from the error stream let err_task = tokio::spawn(async move { let self_ref = &*self_clone_err; @@ -191,7 +230,7 @@ impl McpClient for ClientRuntime { let mut reader = BufReader::new(error_input).lines(); loop { tokio::select! { - should_break = self_ref.transport.is_shut_down() =>{ + should_break = transport_clone.is_shut_down() =>{ if should_break { break; } @@ -221,14 +260,10 @@ impl McpClient for ClientRuntime { Ok::<(), McpSdkError>(()) }); - let transport = self.transport.clone(); + let transport = transport.clone(); + // main task reading from mcp_message stream let main_task = tokio::spawn(async move { - let sender = self_clone.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; while let Some(mcp_messages) = stream.next().await { let self_ref = &*self_clone; @@ -239,7 +274,7 @@ impl McpClient for ClientRuntime { match result { Ok(result) => { if let Some(result) = result { - sender + transport .send_message(ClientMessages::Single(result), None) .await?; } @@ -260,7 +295,7 @@ impl McpClient for ClientRuntime { let results: Vec<_> = results.into_iter().flatten().collect(); if !results.is_empty() { - sender + transport .send_message(ClientMessages::Batch(results), None) .await?; } @@ -271,71 +306,349 @@ impl McpClient for ClientRuntime { }); // send initialize request to the MCP server - self.initialize_request().await?; + self.clone().initialize_request().await?; let mut lock = self.handlers.lock().await; lock.push(main_task); lock.push(err_task); + Ok(()) + } + pub(crate) async fn store_transport( + &self, + stream_id: &str, + transport: TransportType, + ) -> SdkResult<()> { + let mut transport_map = self.transport_map.write().await; + tracing::trace!("save transport for stream id : {}", stream_id); + transport_map.insert(stream_id.to_string(), transport); Ok(()) } - fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> { - match self.server_details.write() { - Ok(mut details) => { - *details = Some(server_details); - Ok(()) - } - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - Err(_) => Err(RpcError::internal_error() - .with_message("Internal Error: Failed to acquire write lock.".to_string()) - .into()), - } + pub(crate) async fn transport_by_stream(&self, stream_id: &str) -> SdkResult { + let transport_map = self.transport_map.read().await; + transport_map.get(stream_id).cloned().ok_or_else(|| { + RpcError::internal_error() + .with_message(format!("Transport for key {stream_id} not found")) + .into() + }) } - fn client_info(&self) -> &InitializeRequestParams { - &self.client_details + + #[cfg(feature = "streamable-http")] + pub(crate) async fn new_transport( + &self, + session_id: Option, + standalone: bool, + ) -> SdkResult< + impl TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + > { + let options = self + .transport_options + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + let transport = ClientStreamableTransport::new(options, session_id, standalone)?; + + Ok(transport) } - fn server_info(&self) -> Option { - if let Ok(details) = self.server_details.read() { - details.clone() - } else { - // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None. - None + + #[cfg(feature = "streamable-http")] + pub(crate) async fn create_sse_stream(self: Arc) -> SdkResult<()> { + let stream_id: StreamId = DEFAULT_STREAM_ID.into(); + let session_id = self.session_id.read().await.clone(); + let transport: Arc< + dyn TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + > = Arc::new(self.new_transport(session_id, true).await?); + let mut stream = transport.start().await?; + self.store_transport(&stream_id, transport.clone()).await?; + + let self_clone = Arc::clone(&self); + + let main_task = tokio::spawn(async move { + loop { + if let Some(mcp_messages) = stream.next().await { + match mcp_messages { + ServerMessages::Single(server_message) => { + let result = self.handle_message(server_message, &transport).await?; + + if let Some(result) = result { + transport + .send_message(ClientMessages::Single(result), None) + .await?; + } + } + ServerMessages::Batch(server_messages) => { + let handling_tasks: Vec<_> = server_messages + .into_iter() + .map(|server_message| { + self.handle_message(server_message, &transport) + }) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + transport + .send_message(ClientMessages::Batch(results), None) + .await?; + } + } + } + // close the stream after all messages are sent, unless it is a standalone stream + if !stream_id.eq(DEFAULT_STREAM_ID) { + return Ok::<_, McpSdkError>(()); + } + } else { + // end of stream + return Ok::<_, McpSdkError>(()); + } + } + }); + + let mut lock = self_clone.handlers.lock().await; + lock.push(main_task); + + Ok(()) + } + + #[cfg(feature = "streamable-http")] + pub(crate) async fn start_stream( + &self, + messages: ClientMessages, + timeout: Option, + ) -> SdkResult> { + use futures::stream::{AbortHandle, Abortable}; + let stream_id: StreamId = self.stream_id_gen.generate(); + let session_id = self.session_id.read().await.clone(); + let no_session_id = session_id.is_none(); + + let has_request = match &messages { + ClientMessages::Single(client_message) => client_message.is_request(), + ClientMessages::Batch(client_messages) => { + client_messages.iter().any(|m| m.is_request()) + } + }; + + let transport = Arc::new(self.new_transport(session_id, false).await?); + + let mut stream = transport.start().await?; + + self.store_transport(&stream_id, transport).await?; + + let transport = self.transport_by_stream(&stream_id).await?; //TODO: remove + + let send_task = async { + let result = transport.send_message(messages, timeout).await?; + + if no_session_id { + if let Some(resquest_id) = transport.session_id().await.clone() { + let mut guard = self.session_id.write().await; + *guard = Some(resquest_id) + } + } + + Ok::<_, McpSdkError>(result) + }; + + if !has_request { + return send_task.await; } + + let (abort_recv_handle, abort_recv_reg) = AbortHandle::new_pair(); + + let receive_task = async { + loop { + tokio::select! { + Some(mcp_messages) = stream.next() =>{ + + match mcp_messages { + ServerMessages::Single(server_message) => { + let result = self.handle_message(server_message, &transport).await?; + if let Some(result) = result { + transport.send_message(ClientMessages::Single(result), None).await?; + } + } + ServerMessages::Batch(server_messages) => { + + let handling_tasks: Vec<_> = server_messages + .into_iter() + .map(|server_message| self.handle_message(server_message, &transport)) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + transport.send_message(ClientMessages::Batch(results), None).await?; + } + } + } + // close the stream after all messages are sent, unless it is a standalone stream + if !stream_id.eq(DEFAULT_STREAM_ID){ + return Ok::<_, McpSdkError>(()); + } + } + } + } + }; + + let receive_task = Abortable::new(receive_task, abort_recv_reg); + + // Pin the tasks to ensure they are not moved + tokio::pin!(send_task); + tokio::pin!(receive_task); + + // Run both tasks with cancellation logic + let (send_res, _) = tokio::select! { + res = &mut send_task => { + // cancel the receive_task task, to cover the case where sned_task returns with error + abort_recv_handle.abort(); + (res, receive_task.await) // Wait for receive_task to finish (it should exit due to cancellation) + } + res = &mut receive_task => { + (send_task.await, res) + } + }; + send_res } +} +#[async_trait] +impl McpClient for ClientRuntime { async fn send( &self, message: MessageFromClient, request_id: Option, - timeout: Option, + request_timeout: Option, ) -> SdkResult> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + let outgoing_request_id = self + .request_id_gen + .request_id_for_message(&message, request_id); + let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + + let response = self + .start_stream(ClientMessages::Single(mcp_message), request_timeout) + .await?; + return response + .map(|r| r.as_single()) + .transpose() + .map_err(|err| err.into()); + } + } + + let transport_map = self.transport_map.read().await; + + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; let outgoing_request_id = self .request_id_gen .request_id_for_message(&message, request_id); let mcp_message = ClientMessage::from_message(message, outgoing_request_id)?; + let response = transport + .send_message(ClientMessages::Single(mcp_message), request_timeout) + .await?; + response + .map(|r| r.as_single()) + .transpose() + .map_err(|err| err.into()) + } - let response = sender - .send_message(ClientMessages::Single(mcp_message), timeout) - .await? - .map(|res| res.as_single()) - .transpose()?; + async fn send_batch( + &self, + messages: Vec, + timeout: Option, + ) -> SdkResult>> { + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + let result = self + .start_stream(ClientMessages::Batch(messages), timeout) + .await?; + // let response = self.start_stream(&stream_id, request_id, message).await?; + return result + .map(|r| r.as_batch()) + .transpose() + .map_err(|err| err.into()); + } + } - Ok(response) + let transport_map = self.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; + transport + .send_batch(messages, timeout) + .await + .map_err(|err| err.into()) + } + + async fn start(self: Arc) -> SdkResult<()> { + #[cfg(feature = "streamable-http")] + { + if self.transport_options.is_some() { + self.initialize_request().await?; + return Ok(()); + } + } + + self.start_standalone().await + } + + fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()> { + self.server_details_tx + .send(Some(server_details)) + .map_err(|_| { + RpcError::internal_error() + .with_message("Failed to set server details".to_string()) + .into() + }) + } + + fn client_info(&self) -> &InitializeRequestParams { + &self.client_details + } + + fn server_info(&self) -> Option { + self.server_details_rx.borrow().clone() } async fn is_shut_down(&self) -> bool { - self.transport.is_shut_down().await + let result = self.is_shut_down.lock().await; + *result } + async fn shut_down(&self) -> SdkResult<()> { - self.transport.shut_down().await?; + let mut is_shut_down_lock = self.is_shut_down.lock().await; + *is_shut_down_lock = true; + + let mut transport_map = self.transport_map.write().await; + let transports: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); + drop(transport_map); + for transport in transports { + let _ = transport.shut_down().await; + } // wait for tasks let mut tasks_lock = self.handlers.lock().await; @@ -344,4 +657,18 @@ impl McpClient for ClientRuntime { Ok(()) } + + async fn terminate_session(&self) { + #[cfg(feature = "streamable-http")] + { + if let Some(transport_options) = self.transport_options.as_ref() { + let session_id = self.session_id.read().await.clone(); + transport_options + .terminate_session(session_id.as_ref()) + .await; + let _ = self.shut_down().await; + } + } + let _ = self.shut_down().await; + } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 7925f07..43a7079 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -8,7 +8,10 @@ use crate::schema::{ InitializeRequestParams, RpcError, ServerNotification, ServerRequest, }; use async_trait::async_trait; -use rust_mcp_transport::Transport; + +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::StreamableTransportOptions; +use rust_mcp_transport::TransportDispatcher; use crate::{ error::SdkResult, mcp_client::ClientHandler, mcp_traits::mcp_handler::McpClientHandler, @@ -37,10 +40,10 @@ use super::ClientRuntime; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport< + transport: impl TransportDispatcher< ServerMessages, MessageFromClient, ServerMessage, @@ -51,7 +54,20 @@ pub fn create_client( ) -> Arc { Arc::new(ClientRuntime::new( client_details, - transport, + Arc::new(transport), + Box::new(ClientInternalHandler::new(Box::new(handler))), + )) +} + +#[cfg(feature = "streamable-http")] +pub fn with_transport_options( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: impl ClientHandler, +) -> Arc { + Arc::new(ClientRuntime::new_instance( + client_details, + transport_options, Box::new(ClientInternalHandler::new(Box::new(handler))), )) } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index 8cb8cff..884de9d 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -1,5 +1,4 @@ -use std::sync::Arc; - +use super::ClientRuntime; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, MessageFromClient, NotificationFromServer, @@ -7,17 +6,16 @@ use crate::schema::{ }, InitializeRequestParams, RpcError, }; -use async_trait::async_trait; - -use rust_mcp_transport::Transport; - use crate::{ error::SdkResult, mcp_handlers::mcp_client_handler_core::ClientHandlerCore, mcp_traits::{mcp_client::McpClient, mcp_handler::McpClientHandler}, }; - -use super::ClientRuntime; +use async_trait::async_trait; +#[cfg(feature = "streamable-http")] +use rust_mcp_transport::StreamableTransportOptions; +use rust_mcp_transport::TransportDispatcher; +use std::sync::Arc; /// Creates a new MCP client runtime with the specified configuration. /// @@ -39,10 +37,10 @@ use super::ClientRuntime; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio-core) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport< + transport: impl TransportDispatcher< ServerMessages, MessageFromClient, ServerMessage, @@ -53,7 +51,20 @@ pub fn create_client( ) -> Arc { Arc::new(ClientRuntime::new( client_details, - transport, + Arc::new(transport), + Box::new(ClientCoreInternalHandler::new(Box::new(handler))), + )) +} + +#[cfg(feature = "streamable-http")] +pub fn with_transport_options( + client_details: InitializeRequestParams, + transport_options: StreamableTransportOptions, + handler: impl ClientHandlerCore, +) -> Arc { + Arc::new(ClientRuntime::new_instance( + client_details, + transport_options, Box::new(ClientCoreInternalHandler::new(Box::new(handler))), )) } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 44f3e53..1b24b57 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -19,12 +19,15 @@ use futures::{StreamExt, TryFutureExt}; use rust_mcp_transport::SessionId; use rust_mcp_transport::{IoStream, TransportDispatcher}; use std::collections::HashMap; +use std::panic; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; -use tokio::sync::{oneshot, watch}; + +use tokio::sync::{mpsc, oneshot, watch}; pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; +const TASK_CHANNEL_CAPACITY: usize = 500; // Define a type alias for the TransportDispatcher trait object type TransportType = Arc< @@ -45,7 +48,7 @@ pub struct ServerRuntime { server_details: Arc, #[cfg(feature = "hyper-server")] session_id: Option, - transport_map: tokio::sync::RwLock>, + transport_map: tokio::sync::RwLock>, //TODO: remove the transport_map, we do not need a hashmap for it request_id_gen: Box, client_details_tx: watch::Sender>, client_details_rx: watch::Receiver>, @@ -55,8 +58,6 @@ pub struct ServerRuntime { impl McpServer for ServerRuntime { /// Set the client details, storing them in client_details async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> { - self.handler.on_server_started(self).await; - self.client_details_tx .send(Some(client_details)) .map_err(|_| { @@ -132,8 +133,9 @@ impl McpServer for ServerRuntime { } /// Main runtime loop, processes incoming messages and handles requests - async fn start(&self) -> SdkResult<()> { - let transport_map = self.transport_map.read().await; + async fn start(self: Arc) -> SdkResult<()> { + let self_clone = self.clone(); + let transport_map = self_clone.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() @@ -142,43 +144,88 @@ impl McpServer for ServerRuntime { let mut stream = transport.start().await?; + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); + // Process incoming messages from the client while let Some(mcp_messages) = stream.next().await { match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, transport).await; - - match result { - Ok(result) => { - if let Some(result) = result { - transport - .send_message(ServerMessages::Single(result), None) - .await?; + let transport = transport.clone(); + let self = self.clone(); + let tx = tx.clone(); + + // Handle incoming messages in a separate task to avoid blocking the stream. + tokio::spawn(async move { + let result = self.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + // Send result to the main loop + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send result to channel: {}", error); } - Err(error) => { - tracing::error!("Error handling message : {}", error) - } - } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); + let transport = transport.clone(); + let self = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport + .send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => Err(error), + }; - if !results.is_empty() { - transport - .send_message(ServerMessages::Batch(results), None) - .await?; - } + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } } + + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } + return Ok(()); } @@ -223,7 +270,7 @@ impl ServerRuntime { } pub(crate) async fn handle_message( - &self, + self: &Arc, message: ClientMessage, transport: &Arc< dyn TransportDispatcher< @@ -240,7 +287,7 @@ impl ServerRuntime { ClientMessage::Request(client_jsonrpc_request) => { let result = self .handler - .handle_request(client_jsonrpc_request.request, self) + .handle_request(client_jsonrpc_request.request, self.clone()) .await; // create a response to send back to the client let response: MessageFromServer = match result { @@ -262,13 +309,13 @@ impl ServerRuntime { } ClientMessage::Notification(client_jsonrpc_notification) => { self.handler - .handle_notification(client_jsonrpc_notification.notification, self) + .handle_notification(client_jsonrpc_notification.notification, self.clone()) .await?; None } ClientMessage::Error(jsonrpc_error) => { self.handler - .handle_error(&jsonrpc_error.error, self) + .handle_error(&jsonrpc_error.error, self.clone()) .await?; if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { tx_response @@ -282,7 +329,6 @@ impl ServerRuntime { } None } - // The response is the result of a request, it is processed at the transport level. ClientMessage::Response(response) => { if let Some(tx_response) = transport.pending_request_tx(&response.id).await { tx_response @@ -313,6 +359,9 @@ impl ServerRuntime { >, >, ) -> SdkResult<()> { + if stream_id != DEFAULT_STREAM_ID { + return Ok(()); + } let mut transport_map = self.transport_map.write().await; tracing::trace!("save transport for stream id : {}", stream_id); transport_map.insert(stream_id.to_string(), transport); @@ -320,34 +369,18 @@ impl ServerRuntime { } pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> { + if stream_id != DEFAULT_STREAM_ID { + return Ok(()); + } let mut transport_map = self.transport_map.write().await; tracing::trace!("removing transport for stream id : {}", stream_id); + if let Some(transport) = transport_map.get(stream_id) { + transport.shut_down().await?; + } transport_map.remove(stream_id); Ok(()) } - pub(crate) async fn transport_by_stream( - &self, - stream_id: &str, - ) -> SdkResult< - Arc< - dyn TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, - >, - >, - > { - let transport_map = self.transport_map.read().await; - transport_map.get(stream_id).cloned().ok_or_else(|| { - RpcError::internal_error() - .with_message(format!("Transport for key {stream_id} not found")) - .into() - }) - } - pub(crate) async fn shutdown(&self) { let mut transport_map = self.transport_map.write().await; let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); @@ -359,17 +392,24 @@ impl ServerRuntime { pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool { let transport_map = self.transport_map.read().await; - transport_map.contains_key(stream_id) + let live_transport = if let Some(t) = transport_map.get(stream_id) { + !t.is_shut_down().await + } else { + false + }; + live_transport } pub(crate) async fn start_stream( self: Arc, - transport: impl TransportDispatcher< - ClientMessages, - MessageFromServer, - ClientMessage, - ServerMessages, - ServerMessage, + transport: Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, >, stream_id: &str, ping_interval: Duration, @@ -377,9 +417,11 @@ impl ServerRuntime { ) -> SdkResult<()> { let mut stream = transport.start().await?; - self.store_transport(stream_id, Arc::new(transport)).await?; + if stream_id == DEFAULT_STREAM_ID { + self.store_transport(stream_id, transport.clone()).await?; + } - let transport = self.transport_by_stream(stream_id).await?; + let self_clone = self.clone(); let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>(); let abort_alive_task = transport @@ -394,43 +436,102 @@ impl ServerRuntime { // in case there is a payload, we consume it by transport to get processed if let Some(payload) = payload { - transport.consume_string_payload(&payload).await?; + if let Err(err) = transport.consume_string_payload(&payload).await { + let _ = self.remove_transport(stream_id).await; + return Err(err.into()); + } } + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); + loop { tokio::select! { Some(mcp_messages) = stream.next() =>{ match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, &transport).await?; - if let Some(result) = result { - transport.send_message(ServerMessages::Single(result), None).await?; - } + let transport = transport.clone(); + let self_clone = self.clone(); + let tx = tx.clone(); + tokio::spawn(async move { + + let result = self_clone.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, &transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); - - - if !results.is_empty() { - transport.send_message(ServerMessages::Batch(results), None).await?; - } + let transport = transport.clone(); + let self_clone = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self_clone.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport.send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + }else { + Ok(None) + } + }, + Err(error) => Err(error), + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } + // close the stream after all messages are sent, unless it is a standalone stream if !stream_id.eq(DEFAULT_STREAM_ID){ + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } return Ok(()); } } _ = &mut disconnect_rx => { + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } self.remove_transport(stream_id).await?; // Disconnection detected by keep-alive task return Err(SdkError::connection_closed().into()); @@ -445,10 +546,10 @@ impl ServerRuntime { server_details: Arc, handler: Arc, session_id: SessionId, - ) -> Self { + ) -> Arc { let (client_details_tx, client_details_rx) = watch::channel::>(None); - Self { + Arc::new(Self { server_details, handler, session_id: Some(session_id), @@ -456,7 +557,7 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), - } + }) } pub(crate) fn new( @@ -469,12 +570,12 @@ impl ServerRuntime { ServerMessage, >, handler: Arc, - ) -> Self { + ) -> Arc { let mut map: HashMap = HashMap::new(); map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport)); let (client_details_tx, client_details_rx) = watch::channel::>(None); - Self { + Arc::new(Self { server_details: Arc::new(server_details), handler, #[cfg(feature = "hyper-server")] @@ -483,6 +584,6 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), - } + }) } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index ea19e19..62fd31f 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -38,7 +38,7 @@ use crate::{ /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) pub fn create_server( server_details: InitializeResult, transport: impl TransportDispatcher< @@ -49,7 +49,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandler, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -62,7 +62,7 @@ pub(crate) fn create_server_instance( server_details: Arc, handler: Arc, session_id: SessionId, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new_instance(server_details, handler, session_id) } @@ -80,7 +80,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { match client_jsonrpc_request { schema_utils::RequestFromClient::ClientRequest(client_request) => { @@ -178,7 +178,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -187,7 +187,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { match client_jsonrpc_notification { schema_utils::NotificationFromClient::ClientNotification(client_notification) => { @@ -199,7 +199,10 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } ClientNotification::InitializedNotification(initialized_notification) => { self.handler - .handle_initialized_notification(initialized_notification, runtime) + .handle_initialized_notification( + initialized_notification, + runtime.clone(), + ) .await?; self.handler.on_initialized(runtime).await; } @@ -226,8 +229,4 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } Ok(()) } - - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index e0e7108..110b20b 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -32,7 +32,7 @@ use std::sync::Arc; /// # Examples /// You can find a detailed example of how to use this function in the repository: /// -/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core) +/// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio-core) pub fn create_server( server_details: InitializeResult, transport: impl TransportDispatcher< @@ -43,7 +43,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandlerCore, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -66,7 +66,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // store the client details if the request is a client initialization request if let schema_utils::RequestFromClient::ClientRequest(ClientRequest::InitializeRequest( @@ -88,7 +88,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -96,11 +96,11 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { // Trigger the `on_initialized()` callback if an `initialized_notification` is received from the client. if client_jsonrpc_notification.is_initialized_notification() { - self.handler.on_initialized(runtime).await; + self.handler.on_initialized(runtime.clone()).await; } // handle notification @@ -109,7 +109,4 @@ impl McpServerHandler for RuntimeCoreInternalHandler> .await?; Ok(()) } - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_traits.rs b/crates/rust-mcp-sdk/src/mcp_traits.rs index 2b155fa..b66ba93 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits.rs @@ -1,3 +1,4 @@ +pub(super) mod id_generator; #[cfg(feature = "client")] pub mod mcp_client; pub mod mcp_handler; @@ -5,4 +6,5 @@ pub mod mcp_handler; pub mod mcp_server; mod request_id_gen; +pub use id_generator::*; pub use request_id_gen::*; diff --git a/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs b/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs new file mode 100644 index 0000000..e7cb8d3 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_traits/id_generator.rs @@ -0,0 +1,12 @@ +/// Trait for generating unique identifiers. +/// +/// This trait is generic over the target ID type, allowing it to be used for +/// generating different kinds of identifiers such as `SessionId` or +/// transport-scoped `StreamId`. +/// +pub trait IdGenerator: Send + Sync +where + T: From, +{ + fn generate(&self) -> T; +} diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 1883581..5fe3fba 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -1,9 +1,7 @@ -use std::{sync::Arc, time::Duration}; - use crate::schema::{ schema_utils::{ - self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, - NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages, + ClientMessage, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient, + ResultFromServer, ServerMessage, }, CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams, CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation, @@ -17,21 +15,18 @@ use crate::schema::{ }; use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; -use rust_mcp_transport::{McpDispatch, MessageDispatcher}; +use std::{sync::Arc, time::Duration}; #[async_trait] pub trait McpClient: Sync + Send { async fn start(self: Arc) -> SdkResult<()>; fn set_server_details(&self, server_details: InitializeResult) -> SdkResult<()>; + async fn terminate_session(&self); + async fn shut_down(&self) -> SdkResult<()>; async fn is_shut_down(&self) -> bool; - fn sender(&self) -> Arc>>> - where - MessageDispatcher: - McpDispatch; - fn client_info(&self) -> &InitializeRequestParams; fn server_info(&self) -> Option; @@ -170,48 +165,20 @@ pub trait McpClient: Sync + Send { &self, message: MessageFromClient, request_id: Option, - timeout: Option, + request_timeout: Option, ) -> SdkResult>; async fn send_batch( &self, messages: Vec, timeout: Option, - ) -> SdkResult>> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let response = sender - .send_message(ClientMessages::Batch(messages), timeout) - .await?; - - match response { - Some(res) => { - let server_results = res.as_batch()?; - Ok(Some(server_results)) - } - None => Ok(None), - } - } + ) -> SdkResult>>; /// Sends a notification. This is a one-way message that is not expected /// to return any response. The method asynchronously sends the notification using /// the transport layer and does not wait for any acknowledgement or result. async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> { - let sender = self.sender(); - let sender = sender.read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; - - let mcp_message = ClientMessage::from_message(MessageFromClient::from(notification), None)?; - - sender - .send_message(ClientMessages::Single(mcp_message), None) - .await?; + self.send(notification.into(), None, None).await?; Ok(()) } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index 2974bfc..cb37f2a 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -6,9 +6,9 @@ use crate::schema::schema_utils::{NotificationFromClient, RequestFromClient, Res #[cfg(feature = "client")] use crate::schema::schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}; -use crate::schema::RpcError; - use crate::error::SdkResult; +use crate::schema::RpcError; +use std::sync::Arc; #[cfg(feature = "client")] use super::mcp_client::McpClient; @@ -18,21 +18,20 @@ use super::mcp_server::McpServer; #[cfg(feature = "server")] #[async_trait] pub trait McpServerHandler: Send + Sync { - async fn on_server_started(&self, runtime: &dyn McpServer); async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()>; async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()>; } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index 2eab9db..dc860b6 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -13,16 +13,15 @@ use crate::schema::{ ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams, }; +use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; use rust_mcp_transport::SessionId; -use std::time::Duration; - -use crate::{error::SdkResult, utils::format_assertion_message}; +use std::{sync::Arc, time::Duration}; //TODO: support options , such as enforceStrictCapabilities #[async_trait] pub trait McpServer: Sync + Send { - async fn start(&self) -> SdkResult<()>; + async fn start(self: Arc) -> SdkResult<()>; async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>; fn server_info(&self) -> &InitializeResult; fn client_info(&self) -> Option; diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index e98a1ed..16fe7c7 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -1,6 +1,6 @@ use crate::schema::schema_utils::{ClientMessages, SdkError}; -use crate::error::{McpSdkError, SdkResult}; +use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; @@ -71,20 +71,20 @@ pub fn format_assertion_message(entity: &str, capability: &str, method_name: &st /// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05"); /// assert!(result.is_ok()); /// -/// // Incompatible versions (client < server) +/// // Incompatible versions (requested < current) /// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2024_11_05" && server == "2025_03_26" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2024_11_05" && current == "2025_03_26" /// )); /// -/// // Incompatible versions (client > server) +/// // Incompatible versions (requested > current) /// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2025_03_26" && server == "2024_11_05" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2025_03_26" && current == "2024_11_05" /// )); /// ``` #[allow(unused)] @@ -93,10 +93,12 @@ pub fn ensure_server_protocole_compatibility( server_protocol_version: &str, ) -> SdkResult<()> { match client_protocol_version.cmp(server_protocol_version) { - Ordering::Less | Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion( - client_protocol_version.to_string(), - server_protocol_version.to_string(), - )), + Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol { + kind: ProtocolErrorKind::IncompatibleVersion { + requested: client_protocol_version.to_string(), + current: server_protocol_version.to_string(), + }, + }), Ordering::Equal => Ok(()), } } @@ -140,8 +142,8 @@ pub fn ensure_server_protocole_compatibility( /// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05"); /// assert!(matches!( /// result, -/// Err(McpSdkError::IncompatibleProtocolVersion(client, server)) -/// if client == "2025_03_26" && server == "2024_11_05" +/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}}) +/// if requested == "2025_03_26" && current == "2024_11_05" /// )); /// ``` #[allow(unused)] @@ -151,10 +153,12 @@ pub fn enforce_compatible_protocol_version( ) -> SdkResult> { match client_protocol_version.cmp(server_protocol_version) { // if client protocol version is higher - Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion( - client_protocol_version.to_string(), - server_protocol_version.to_string(), - )), + Ordering::Greater => Err(McpSdkError::Protocol { + kind: ProtocolErrorKind::IncompatibleVersion { + requested: client_protocol_version.to_string(), + current: server_protocol_version.to_string(), + }, + }), Ordering::Equal => Ok(None), Ordering::Less => { // return the same version that was received from the client @@ -164,7 +168,10 @@ pub fn enforce_compatible_protocol_version( } pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> { - let _mcp_protocol_version = ProtocolVersion::try_from(mcp_protocol_version)?; + let _mcp_protocol_version = + ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol { + kind: ProtocolErrorKind::ParseError(err), + })?; Ok(()) } diff --git a/crates/rust-mcp-sdk/tests/check_imports.rs b/crates/rust-mcp-sdk/tests/check_imports.rs index cda7d0c..207644e 100644 --- a/crates/rust-mcp-sdk/tests/check_imports.rs +++ b/crates/rust-mcp-sdk/tests/check_imports.rs @@ -37,13 +37,12 @@ mod tests { // Check for `use rust_mcp_schema` if content.contains("use rust_mcp_schema") { errors.push(format!( - "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", - abs_path + "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." )); } } Err(e) => { - errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + errors.push(format!("Failed to read file `{path_str}`: {e}")); } } } diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 564db0d..f330dda 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -1,5 +1,8 @@ +mod mock_server; +mod test_client; mod test_server; use async_trait::async_trait; +pub use mock_server::*; use reqwest::{Client, Response, Url}; use rust_mcp_macros::{mcp_tool, JsonSchema}; use rust_mcp_schema::ProtocolVersion; @@ -8,9 +11,12 @@ use rust_mcp_sdk::mcp_client::ClientHandler; use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; use std::collections::HashMap; use std::process; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::time::timeout; use tokio_stream::StreamExt; +use wiremock::{MockServer, Request, ResponseTemplate}; +pub use test_client::*; pub use test_server::*; pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything"; @@ -337,3 +343,52 @@ pub mod sample_tools { } } } + +pub async fn wiremock_request(mock_server: &MockServer, index: usize) -> Request { + let requests = mock_server.received_requests().await.unwrap(); + requests[index].clone() +} + +pub async fn debug_wiremock(mock_server: &MockServer) { + let requests = mock_server.received_requests().await.unwrap(); + let len = requests.len(); + println!(">>> {len} request(s) received <<<"); + + for (index, request) in requests.iter().enumerate() { + println!("\n--- #{index} of {len} ---"); + println!("Method: {}", request.method); + println!("Path: {}", request.url.path()); + // println!("Headers: {:#?}", request.headers); + println!("---- headers ----"); + for (key, values) in &request.headers { + println!("{key}: {values:?}"); + } + + let body_str = String::from_utf8_lossy(&request.body); + println!("Body: {body_str}\n"); + } +} + +pub fn create_sse_response(payload: &str) -> ResponseTemplate { + let sse_body = format!(r#"data: {}{}"#, payload, "\n\n"); + ResponseTemplate::new(200).set_body_raw(sse_body.into_bytes(), "text/event-stream") +} + +pub async fn wait_for_n_requests( + mock_server: &MockServer, + num_requests: usize, + duration: Option, +) { + let duration = duration.unwrap_or(Duration::from_secs(1)); + timeout(duration, async { + loop { + let requests = mock_server.received_requests().await.unwrap(); + if requests.len() >= num_requests { + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + }) + .await + .unwrap(); +} diff --git a/crates/rust-mcp-sdk/tests/common/mock_server.rs b/crates/rust-mcp-sdk/tests/common/mock_server.rs new file mode 100644 index 0000000..f5b533a --- /dev/null +++ b/crates/rust-mcp-sdk/tests/common/mock_server.rs @@ -0,0 +1,528 @@ +use axum::{ + body::Body, + extract::Request, + http::{header::CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, Method, StatusCode}, + response::{ + sse::{Event, KeepAlive}, + IntoResponse, Response, Sse, + }, + routing::any, + Router, +}; +use core::fmt; +use futures::stream; +use std::collections::VecDeque; +use std::{future::Future, net::SocketAddr, pin::Pin}; +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; +use tokio::net::TcpListener; + +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, +} + +impl ToString for SseEvent { + fn to_string(&self) -> String { + let mut s = String::new(); + + if let Some(id) = &self.id { + s.push_str("id: "); + s.push_str(id); + s.push('\n'); + } + + if let Some(event) = &self.event { + s.push_str("event: "); + s.push_str(event); + s.push('\n'); + } + + if let Some(data) = &self.data { + // Convert bytes to string safely, fallback if invalid UTF-8 + for line in data.lines() { + s.push_str("data: "); + s.push_str(line); + s.push('\n'); + } + } + + s.push('\n'); // End of event + s + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self.data.as_ref(); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .finish() + } +} + +// RequestRecord stores the history of incoming requests +#[derive(Clone, Debug)] +pub struct RequestRecord { + pub method: Method, + pub path: String, + pub headers: HeaderMap, + pub body: String, +} + +#[derive(Clone, Debug)] +pub struct ResponseRecord { + pub status: StatusCode, + pub headers: HeaderMap, + pub body: String, +} + +// pub type BoxedStream = +// Pin> + Send>>; +// pub type BoxedSseResponse = Sse; + +// pub type AsyncResponseFn = +// Box Pin + Send>> + Send + Sync>; + +type AsyncResponseFn = + Box Pin + Send>> + Send + Sync>; + +// Mock defines a single mock response configuration +// #[derive(Clone)] +pub struct Mock { + method: Method, + path: String, + response: String, + response_func: Option, + header_map: HeaderMap, + matcher: Option bool + Send + Sync>>, + remaining_calls: Option>>, + status: StatusCode, +} + +// MockBuilder is a factory for creating Mock instances +pub struct MockBuilder { + method: Method, + path: String, + response: String, + header_map: HeaderMap, + response_func: Option, + matcher: Option bool + Send + Sync>>, + remaining_calls: Option>>, + status: StatusCode, +} + +impl MockBuilder { + fn new(method: Method, path: String, response: String, header_map: HeaderMap) -> Self { + Self { + method, + path, + response, + response_func: None, + header_map, + matcher: None, + status: StatusCode::OK, + remaining_calls: None, // Default to unlimited calls + } + } + + fn new_with_func( + method: Method, + path: String, + response_func: AsyncResponseFn, + header_map: HeaderMap, + ) -> Self { + Self { + method, + path, + response: String::new(), + response_func: Some(response_func), + header_map, + matcher: None, + status: StatusCode::OK, + remaining_calls: None, // Default to unlimited calls + } + } + + pub fn new_breakable_sse( + method: Method, + path: String, + repeating_message: SseEvent, + interval: Duration, + repeat: usize, + ) -> Self { + let message = Arc::new(repeating_message); + let interval = interval; + let max_repeats = repeat; + + let response_fn: AsyncResponseFn = Box::new({ + let message = Arc::clone(&message); + move || { + let message = Arc::clone(&message); + + Box::pin(async move { + // Construct SSE stream with 10 static messages using unfold + let message_stream = stream::unfold(0, move |count| { + let message = Arc::clone(&message); + + async move { + if count >= max_repeats { + return Some(( + Err(std::io::Error::other("Message limit reached")), + count, + )); + } + tokio::time::sleep(interval).await; + + Some(( + Ok(Event::default() + .data(message.data.clone().unwrap_or("".into())) + .id(message.id.clone().unwrap_or(format!("msg-id_{count}"))) + .event(message.event.clone().unwrap_or("message".into()))), + count + 1, + )) + } + }); + + let sse_stream = Sse::new(message_stream) + .keep_alive(KeepAlive::new().interval(Duration::from_secs(10))); + + sse_stream.into_response() + }) + } + }); + + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + Self::new_with_func(method, path, response_fn, header_map) + } + + pub fn with_matcher(mut self, matcher: F) -> Self + where + F: Fn(&str, &HeaderMap) -> bool + Send + Sync + 'static, + { + self.matcher = Some(Arc::new(matcher)); + self + } + + pub fn add_header(mut self, key: HeaderName, val: HeaderValue) -> Self { + self.header_map.insert(key, val); + self + } + + pub fn without_matcher(mut self) -> Self { + self.matcher = None; + self + } + + pub fn expect(mut self, num_calls: usize) -> Self { + self.remaining_calls = Some(Arc::new(Mutex::new(num_calls))); + self + } + + pub fn unlimited_calls(mut self) -> Self { + self.remaining_calls = None; + self + } + + pub fn with_status(mut self, status: StatusCode) -> Self { + self.status = status; + self + } + + pub fn build(self) -> Mock { + Mock { + method: self.method, + path: self.path, + response: self.response, + header_map: self.header_map, + matcher: self.matcher, + remaining_calls: self.remaining_calls, + status: self.status, + response_func: self.response_func, + } + } + + // add_string with text/plain + pub fn new_text(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + + Self::new(method, path, response.into(), header_map) + } + + /** + MockBuilder::new_response( + Method::GET, + "/mcp".to_string(), + Box::new(|| { + // tokio::time::sleep(Duration::from_secs(1)).await; + let json_response = Json(json!({ + "status": "ok", + "data": [1, 2, 3] + })) + .into_response(); + Box::pin(async move { json_response }) + }), + ) + .build(), + */ + pub fn new_response(method: Method, path: String, response_func: AsyncResponseFn) -> Self { + Self::new_with_func(method, path, response_func, HeaderMap::new()) + } + + // new_json with application/json + pub fn new_json(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + Self::new(method, path, response.into(), header_map) + } + + // new_sse with text/event-stream + pub fn new_sse(method: Method, path: String, response: impl Into) -> Self { + let response = format!(r#"data: {}{}"#, response.into(), '\n'); + + let mut header_map = HeaderMap::new(); + header_map.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + // ensure message ends with a \n\n , if needed + let cr = if response.ends_with("\n\n") { + "" + } else { + "\n\n" + }; + Self::new(method, path, format!("{response}{cr}"), header_map) + } + + // new_raw with application/octet-stream + pub fn new_raw(method: Method, path: String, response: impl Into) -> Self { + let mut header_map = HeaderMap::new(); + header_map.insert( + CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + Self::new(method, path, response.into(), header_map) + } +} + +// MockServerHandle provides access to the request history after the server starts +pub struct MockServerHandle { + history: Arc>>, +} + +impl MockServerHandle { + pub async fn get_history(&self) -> Vec<(RequestRecord, ResponseRecord)> { + let history = self.history.lock().unwrap(); + history.iter().cloned().collect() + } + + pub async fn print(&self) { + let requests = self.get_history().await; + + let len = requests.len(); + println!("\n>>> {len} request(s) received <<<"); + + for (index, (request, response)) in requests.iter().enumerate() { + println!( + "\n--- Request {} of {len} ------------------------------------", + index + 1 + ); + println!("Method: {}", request.method); + println!("Path: {}", request.path); + // println!("Headers: {:#?}", request.headers); + println!("> headers "); + for (key, values) in &request.headers { + println!("{key}: {values:?}"); + } + + println!("\n> Body"); + println!("{}\n", &request.body); + + println!(">>>>> Response <<<<<"); + println!("> status: {}", response.status); + println!("> headers"); + for (key, values) in &response.headers { + println!("{key}: {values:?}"); + } + println!("> Body"); + println!("{}", &response.body); + } + } +} + +// MockServer is the main struct for configuring and starting the mock server +pub struct SimpleMockServer { + mocks: Vec, + history: Arc>>, +} + +impl Default for SimpleMockServer { + fn default() -> Self { + Self::new() + } +} + +impl SimpleMockServer { + pub fn new() -> Self { + Self { + mocks: Vec::new(), + history: Arc::new(Mutex::new(VecDeque::new())), + } + } + + pub async fn start_with_mocks(mocks: Vec) -> (String, MockServerHandle) { + let mut server = SimpleMockServer::new(); + server.add_mocks(mocks); + server.start().await + } + + // Generic add function + pub fn add_mock_builder(&mut self, builder: MockBuilder) -> &mut Self { + self.mocks.push(builder.build()); + self + } + + pub fn add_mock(&mut self, mock: Mock) -> &mut Self { + self.mocks.push(mock); + self + } + + pub fn add_mocks(&mut self, mock: Vec) -> &mut Self { + mock.into_iter().for_each(|m| self.mocks.push(m)); + self + } + + pub async fn start(self) -> (String, MockServerHandle) { + let mocks = Arc::new(self.mocks); + let history = Arc::clone(&self.history); + + async fn handler( + mocks: Arc>, + history: Arc>>, + mut req: Request, + ) -> impl IntoResponse { + // Take ownership of the body using std::mem::take + let body = std::mem::take(req.body_mut()); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let body_str = String::from_utf8_lossy(&body_bytes).to_string(); + + let request_record = RequestRecord { + method: req.method().clone(), + path: req.uri().path().to_string(), + headers: req.headers().clone(), + body: body_str.clone(), + }; + + for m in mocks.iter() { + if m.method != *req.method() || m.path != req.uri().path() { + continue; + } + + if let Some(matcher) = &m.matcher { + if !(matcher)(&body_str, req.headers()) { + continue; + } + } + + if let Some(remaining) = &m.remaining_calls { + let mut rem = remaining.lock().unwrap(); + if *rem == 0 { + continue; + } + *rem -= 1; + } + + let mut resp = match m.response_func.as_ref() { + Some(get_response) => get_response().await.into_response(), + None => Response::new(Body::from(m.response.clone())), + }; + + // if let Some(resp_box) = &mut m.response_func.take() { + // let response = resp_box.into_response(); + // // *response.status_mut() = m.status; + // // m.response_func = Some(Box::new(response)); + // } + + // let mut resp = m.response_func.as_ref().unwrap().clone().to_owned(); + // let resp = *resp; + // *resp.into_response().status_mut() = m.status; + + // let mut response = m.response_func.as_ref().unwrap().clone(); + // let mut response = m.response_func.as_ref().unwrap().clone().to_owned(); + // let mut m = *response; + // *response.status_mut() = m.status; + // let resp = &*m.response_func.as_ref().unwrap().to_owned().clone().deref(); + + // let response = boxed_response.into_response(); + + // let mut resp = Response::new(Body::from(m.response.clone())); + *resp.status_mut() = m.status; + m.header_map.iter().for_each(|(k, v)| { + resp.headers_mut().insert(k, v.clone()); + }); + + let response_record = ResponseRecord { + status: resp.status(), + headers: resp.headers().clone(), + body: m.response.clone(), + }; + + { + let mut hist = history.lock().unwrap(); + hist.push_back((request_record, response_record)); + } + + return resp; + } + + let resp = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap(); + + let response_record = ResponseRecord { + status: resp.status(), + headers: resp.headers().clone(), + body: "".into(), + }; + + { + let mut hist = history.lock().unwrap(); + hist.push_back((request_record, response_record)); + } + + resp + } + + let app = Router::new().route( + "/{*path}", + any(move |req: Request| handler(Arc::clone(&mocks), Arc::clone(&history), req)), + ); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let url = format!("http://{local_addr}"); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + ( + url, + MockServerHandle { + history: self.history, + }, + ) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/test_client.rs b/crates/rust-mcp-sdk/tests/common/test_client.rs new file mode 100644 index 0000000..21678c7 --- /dev/null +++ b/crates/rust-mcp-sdk/tests/common/test_client.rs @@ -0,0 +1,163 @@ +use async_trait::async_trait; +use rust_mcp_schema::{schema_utils::MessageFromServer, PingRequest, RpcError}; +use rust_mcp_sdk::{mcp_client::ClientHandler, McpClient}; +use serde_json::json; +use std::sync::Arc; +use tokio::sync::RwLock; + +#[cfg(feature = "hyper-server")] +pub mod test_client_common { + use rust_mcp_schema::{ + schema_utils::MessageFromServer, ClientCapabilities, Implementation, + InitializeRequestParams, LATEST_PROTOCOL_VERSION, + }; + use rust_mcp_sdk::{ + mcp_client::{client_runtime, ClientRuntime}, + McpClient, RequestOptions, SessionId, StreamableTransportOptions, + }; + use std::{collections::HashMap, sync::Arc, time::Duration}; + use tokio::sync::RwLock; + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + use wiremock::{ + matchers::{body_json_string, method, path}, + Mock, MockServer, ResponseTemplate, + }; + + use crate::common::{ + create_sse_response, test_server_common::INITIALIZE_RESPONSE, wait_for_n_requests, + }; + + pub struct InitializedClient { + pub client: Arc, + pub mcp_url: String, + pub mock_server: MockServer, + } + + pub const TEST_SESSION_ID: &str = "test-session-id"; + pub const INITIALIZE_REQUEST: &str = r#"{"id":0,"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{},"clientInfo":{"name":"simple-rust-mcp-client-sse","title":"Simple Rust MCP Client (SSE)","version":"0.1.0"},"protocolVersion":"2025-06-18"}}"#; + + pub fn test_client_details() -> InitializeRequestParams { + InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + } + } + + pub async fn create_client( + mcp_url: &str, + custom_headers: Option>, + ) -> (Arc, Arc>>) { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let client_details: InitializeRequestParams = test_client_details(); + + let transport_options = StreamableTransportOptions { + mcp_url: mcp_url.to_string(), + request_options: RequestOptions { + request_timeout: Duration::from_secs(2), + custom_headers, + ..RequestOptions::default() + }, + }; + + let message_history = Arc::new(RwLock::new(vec![])); + let handler = super::TestClientHandler { + message_history: message_history.clone(), + }; + + let client = + client_runtime::with_transport_options(client_details, transport_options, handler); + + // client.clone().start().await.unwrap(); + (client, message_history) + } + + pub async fn initialize_client( + session_id: Option, + custom_headers: Option>, + ) -> InitializedClient { + let mock_server = MockServer::start().await; + + // intialize response + let mut response = create_sse_response(INITIALIZE_RESPONSE); + + if let Some(session_id) = session_id { + response = response.append_header("mcp-session-id", session_id.as_str()); + } + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, custom_headers).await; + + client.clone().start().await.unwrap(); + + wait_for_n_requests(&mock_server, 2, None).await; + + InitializedClient { + client, + mcp_url, + mock_server, + } + } +} + +// Custom responder for SSE with 10 ping messages +struct SsePingResponder; + +// Test handler +pub struct TestClientHandler { + message_history: Arc>>, +} + +impl TestClientHandler { + async fn register_message(&self, message: &MessageFromServer) { + let mut lock = self.message_history.write().await; + lock.push(message.clone()); + } +} + +#[async_trait] +impl ClientHandler for TestClientHandler { + async fn handle_ping_request( + &self, + request: PingRequest, + runtime: &dyn McpClient, + ) -> std::result::Result { + self.register_message(&request.into()).await; + + Ok(rust_mcp_schema::Result { + meta: Some(json!({"meta_number":1515}).as_object().unwrap().to_owned()), + extra: None, + }) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index aa8e2fb..769f8c6 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -1,30 +1,30 @@ #[cfg(feature = "hyper-server")] pub mod test_server_common { + use crate::common::sample_tools::SayHelloTool; use async_trait::async_trait; use rust_mcp_schema::schema_utils::CallToolError; use rust_mcp_schema::{ CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, ProtocolVersion, RpcError, }; + use rust_mcp_sdk::id_generator::IdGenerator; use rust_mcp_sdk::mcp_server::hyper_runtime::HyperRuntime; - use tokio_stream::StreamExt; - use rust_mcp_sdk::schema::{ ClientCapabilities, Implementation, InitializeRequest, InitializeRequestParams, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, }; use rust_mcp_sdk::{ - mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, + mcp_server::{hyper_server, HyperServer, HyperServerOptions, ServerHandler}, McpServer, SessionId, }; - use std::sync::RwLock; + use std::sync::{Arc, RwLock}; use std::time::Duration; use tokio::time::timeout; - - use crate::common::sample_tools::SayHelloTool; + use tokio_stream::StreamExt; pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; + pub const INITIALIZE_RESPONSE: &str = r#"{"result":{"protocolVersion":"2025-06-18","capabilities":{"prompts":{},"resources":{"subscribe":true},"tools":{},"logging":{}},"serverInfo":{"name":"example-servers/everything","version":"1.0.0"}},"jsonrpc":"2.0","id":0}"#; pub struct LaunchedServer { pub hyper_runtime: HyperRuntime, @@ -71,16 +71,10 @@ pub mod test_server_common { #[async_trait] impl ServerHandler for TestServerHandler { - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } - async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; @@ -94,7 +88,7 @@ pub mod test_server_common { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) @@ -156,14 +150,17 @@ pub mod test_server_common { } } - impl IdGenerator for TestIdGenerator { - fn generate(&self) -> SessionId { + impl IdGenerator for TestIdGenerator + where + T: From, + { + fn generate(&self) -> T { let mut lock = self.generated.write().unwrap(); *lock += 1; if *lock > self.constant_ids.len() { *lock = 1; } - self.constant_ids[*lock - 1].to_owned() + T::from(self.constant_ids[*lock - 1].to_owned()) } } diff --git a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs index 5c184cf..9f2fd95 100644 --- a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs +++ b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs @@ -30,7 +30,7 @@ mod protocol_compatibility_on_server { ); handler - .handle_initialize_request(InitializeRequest::new(initialize_request), &runtime) + .handle_initialize_request(InitializeRequest::new(initialize_request), runtime) .await } diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs new file mode 100644 index 0000000..a0a2804 --- /dev/null +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -0,0 +1,823 @@ +#[path = "common/common.rs"] +pub mod common; + +use common::test_client_common::create_client; +use hyper::{Method, StatusCode}; +use rust_mcp_schema::{ + schema_utils::{ + ClientJsonrpcRequest, ClientMessage, MessageFromServer, RequestFromClient, + RequestFromServer, ResultFromServer, RpcMessage, ServerMessage, + }, + RequestId, ServerRequest, ServerResult, +}; +use rust_mcp_sdk::{ + error::McpSdkError, mcp_server::HyperServerOptions, McpClient, TransportError, + MCP_LAST_EVENT_ID_HEADER, +}; +use serde_json::{json, Value}; +use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration}; +use wiremock::{ + http::{HeaderName, HeaderValue}, + matchers::{body_json_string, header, method, path}, + Mock, MockServer, ResponseTemplate, +}; + +use crate::common::{ + create_sse_response, debug_wiremock, random_port, + test_client_common::{ + initialize_client, InitializedClient, INITIALIZE_REQUEST, TEST_SESSION_ID, + }, + test_server_common::{ + create_start_server, LaunchedServer, TestIdGenerator, INITIALIZE_RESPONSE, + }, + wait_for_n_requests, wiremock_request, MockBuilder, SimpleMockServer, SseEvent, +}; + +// should send JSON-RPC messages via POST +#[tokio::test] +async fn should_send_json_rpc_messages_via_post() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let received_request = wiremock_request(&mock_server, 0).await; + let header_values = received_request + .headers + .get(&HeaderName::from_str("accept").unwrap()) + .unwrap(); + + assert!(header_values.contains(&HeaderValue::from_str("application/json").unwrap())); + assert!(header_values.contains(&HeaderValue::from_str("text/event-stream").unwrap())); + + wait_for_n_requests(&mock_server, 2, None).await; +} + +// should send batch messages +#[tokio::test] +async fn should_send_batch_messages() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + let response = create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + // .expect(1) + .mount(&mock_server) + .await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + let result = client + .send_batch(vec![message_1, message_2], None) + .await + .unwrap() + .unwrap(); + + // two results for two requests + assert_eq!(result.len(), 2); + assert!(result.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); + + // not an Error + assert!(result + .iter() + .all(|r| matches!(r, ServerMessage::Response(_)))); + + // debug_wiremock(&mock_server).await; +} + +// should store session ID received during initialization +#[tokio::test] +async fn should_store_session_id_received_during_initialization() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = + create_sse_response(INITIALIZE_RESPONSE).append_header("mcp-session-id", "test-session-id"); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let received_request = wiremock_request(&mock_server, 0).await; + let header_values = received_request + .headers + .get(&HeaderName::from_str("accept").unwrap()) + .unwrap(); + + assert!(header_values.contains(&HeaderValue::from_str("application/json").unwrap())); + assert!(header_values.contains(&HeaderValue::from_str("text/event-stream").unwrap())); + + wait_for_n_requests(&mock_server, 2, None).await; +} + +// should terminate session with DELETE request +#[tokio::test] +async fn should_terminate_session_with_delete_request() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + client.terminate_session().await; +} + +// should handle 405 response when server doesn't support session termination +#[tokio::test] +async fn should_handle_405_unsupported_session_termination() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("DELETE")) + .and(path("/mcp")) + .and(header("mcp-session-id", "test-session-id")) + .respond_with(ResponseTemplate::new(405)) + .expect(1) + .mount(&mock_server) + .await; + + client.terminate_session().await; +} + +// should handle 404 response when session expires +#[tokio::test] +async fn should_handle_404_response_when_session_expires() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(404)) + .expect(1) + .mount(&mock_server) + .await; + + let result = client.ping(None).await; + + matches!( + result, + Err(McpSdkError::Transport(TransportError::SessionExpired)) + ); +} + +// should handle non-streaming JSON response +#[tokio::test] +async fn should_handle_non_streaming_json_response() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(Some(TEST_SESSION_ID.to_string()), None).await; + + let response = ResponseTemplate::new(200) + .set_body_json(json!({ + "id":1,"jsonrpc":"2.0", "result":{"something":"good"} + })) + .insert_header("Content-Type", "application/json"); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + let request = RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})); + + let result = client.request(request, None).await.unwrap(); + + let ResultFromServer::ServerResult(ServerResult::Result(result)) = result else { + panic!("Wrong result variant!") + }; + + let extra = result.extra.unwrap(); + assert_eq!(extra.get("something").unwrap(), "good"); +} + +// should handle successful initial GET connection for SSE +#[tokio::test] +async fn should_handle_successful_initial_get_connection_for_sse() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; + // + let mut body = String::new(); + body.push_str(&"data: Connection established\n\n".to_string()); + + let response = ResponseTemplate::new(200) + .set_body_raw(body.into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let requests = mock_server.received_requests().await.unwrap(); + let get_request = requests + .iter() + .find(|r| r.method == wiremock::http::Method::Get); + + assert!(get_request.is_some()) +} + +#[tokio::test] +async fn should_receive_server_initiated_messaged() { + let server_options = HyperServerOptions { + port: random_port(), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + enable_json_response: Some(false), + ..Default::default() + }; + let LaunchedServer { + hyper_runtime, + streamable_url, + sse_url, + sse_message_url, + } = create_start_server(server_options).await; + + let (client, message_history) = create_client(&streamable_url, None).await; + + client.clone().start().await.unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let result = hyper_runtime + .ping(&"AAA-BBB-CCC".to_string(), None) + .await + .unwrap(); + + let lock = message_history.read().await; + let ping_request = lock + .iter() + .find(|m| { + matches!( + m, + MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( + ServerRequest::PingRequest(_) + )) + ) + }) + .unwrap(); + let MessageFromServer::RequestFromServer(RequestFromServer::ServerRequest( + ServerRequest::PingRequest(_), + )) = ping_request + else { + panic!("Request is not a match!") + }; + assert!(result.meta.is_some()); + + let v = result.meta.unwrap().get("meta_number").unwrap().clone(); + + assert!(matches!(v, Value::Number(value) if value.as_i64().unwrap()==1515)) //1515 is passed from TestClientHandler +} + +// should attempt initial GET connection and handle 405 gracefully +#[tokio::test] +async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(405)) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + // let payload = r#"{"jsonrpc": "2.0", "method": "serverNotification", "params": {}}"#; + // + let mut body = String::new(); + body.push_str(&"data: Connection established\n\n".to_string()); + + let response = ResponseTemplate::new(405) + .set_body_raw(body.into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); + + let requests = mock_server.received_requests().await.unwrap(); + let get_request = requests + .iter() + .find(|r| r.method == wiremock::http::Method::Get); + + assert!(get_request.is_some()); + + // send a batch message, runtime should work as expected with no isse + + let response = create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(response) + // .expect(1) + .mount(&mock_server) + .await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + let result = client + .send_batch(vec![message_1, message_2], None) + .await + .unwrap() + .unwrap(); + + // two results for two requests + assert_eq!(result.len(), 2); + assert!(result.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); +} + +// should handle multiple concurrent SSE streams +#[tokio::test] +async fn should_handle_multiple_concurrent_sse_streams() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + let message_1: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id1".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test1", "params": {}})), + ) + .into(); + let message_2: ClientMessage = ClientJsonrpcRequest::new( + RequestId::String("id2".to_string()), + RequestFromClient::CustomRequest(json!({"method": "test2", "params": {}})), + ) + .into(); + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(|req: &wiremock::Request| { + let body_string = String::from_utf8(req.body.clone()).unwrap(); + if body_string.contains("test3") { + create_sse_response(r#"{"id":1,"jsonrpc":"2.0", "result":{}}"#) + } else { + create_sse_response( + r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, + ) + } + }) + .expect(2) + .mount(&mock_server) + .await; + + let message_3 = RequestFromClient::CustomRequest(json!({"method": "test3", "params": {}})); + let request1 = client.send_batch(vec![message_1, message_2], None); + let request2 = client.send(message_3.into(), None, None); + + // Run them concurrently and wait for both + let (res_batch, res_single) = tokio::join!(request1, request2); + + let res_batch = res_batch.unwrap().unwrap(); + // two results for two requests in the batch + assert_eq!(res_batch.len(), 2); + assert!(res_batch.iter().all(|r| { + let id = r.request_id().unwrap(); + id == RequestId::String("id1".to_string()) || id == RequestId::String("id2".to_string()) + })); + + // not an Error + assert!(res_batch + .iter() + .all(|r| matches!(r, ServerMessage::Response(_)))); + + let res_single = res_single.unwrap().unwrap(); + let ServerMessage::Response(res_single) = res_single else { + panic!("invalid respinse type, expected Result!") + }; + + assert!(matches!(res_single.id, RequestId::Integer(id) if id==1)); +} + +// should throw error when invalid content-type is received +#[tokio::test] +async fn should_throw_error_when_invalid_content_type_is_received() { + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, None).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + r#"{"id":0,"jsonrpc":"2.0", "result":{}}"#.to_string().into_bytes(), + "text/plain", + )) + .expect(1) + .mount(&mock_server) + .await; + + let result = client.ping(None).await; + + let Err(McpSdkError::Transport(TransportError::UnexpectedContentType(content_type))) = result + else { + panic!("Expected a TransportError::UnexpectedContentType error!"); + }; + + assert_eq!(content_type, "text/plain"); +} + +// should always send specified custom headers +#[tokio::test] +async fn should_always_send_specified_custom_headers() { + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let InitializedClient { + client, + mcp_url, + mock_server, + } = initialize_client(None, Some(headers)).await; + + Mock::given(method("POST")) + .and(path("/mcp")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + r#"{"id":1,"jsonrpc":"2.0", "result":{}}"#.to_string().into_bytes(), + "application/json", + )) + .expect(1) + .mount(&mock_server) + .await; + + let _result = client.ping(None).await; + + let requests = mock_server.received_requests().await.unwrap(); + + assert_eq!(requests.len(), 4); + assert!(requests + .iter() + .all(|r| r.headers.get(&"X-Custom-Header".into()).unwrap().as_str() == "CustomValue")); + + debug_wiremock(&mock_server).await +} + +// should reconnect a GET-initiated notification stream that fails + +#[tokio::test] +async fn should_reconnect_a_get_initiated_notification_stream_that_fails() { + // Start a mock server + let mock_server = MockServer::start().await; + + // intialize response + let response = create_sse_response(INITIALIZE_RESPONSE); + + // initialize request and response + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string(INITIALIZE_REQUEST)) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // two GET Mock, each expects one call , first time it fails, second retry it succeeds + let response = ResponseTemplate::new(502) + .set_body_raw("".to_string().into_bytes(), "text/event-stream") + .append_header("Connection", "keep-alive"); + + // Mount the mock for a GET request + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .up_to_n_times(1) + .mount(&mock_server) + .await; + + let response = ResponseTemplate::new(200) + .set_body_raw( + "data: Connection established\n\n".to_string().into_bytes(), + "text/event-stream", + ) + .append_header("Connection", "keep-alive"); + Mock::given(method("GET")) + .and(path("/mcp")) + .respond_with(response) + .expect(1) + .mount(&mock_server) + .await; + + // receive initialized notification + Mock::given(method("POST")) + .and(path("/mcp")) + .and(body_json_string( + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + )) + .respond_with(ResponseTemplate::new(202)) + .expect(1) + .mount(&mock_server) + .await; + + let mcp_url = format!("{}/mcp", mock_server.uri()); + let (client, _) = create_client(&mcp_url, None).await; + + client.clone().start().await.unwrap(); +} + +//****************** Resumability ****************** +// should pass lastEventId when reconnecting +#[tokio::test] +async fn should_pass_last_event_id_when_reconnecting() { + let msg = r#"{"jsonrpc":"2.0","method":"notifications/message","params":{"data":{},"level":"debug"}}"#; + + let mocks = vec![ + MockBuilder::new_sse(Method::POST, "/mcp".to_string(), INITIALIZE_RESPONSE).build(), + MockBuilder::new_breakable_sse( + Method::GET, + "/mcp".to_string(), + SseEvent { + data: Some(msg.into()), + event: Some("message".to_string()), + id: None, + }, + Duration::from_millis(100), + 5, + ) + .expect(2) + .build(), + MockBuilder::new_sse( + Method::POST, + "/mcp".to_string(), + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + ) + .build(), + ]; + + let (url, handle) = SimpleMockServer::start_with_mocks(mocks).await; + let mcp_url = format!("{url}/mcp"); + + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let (client, _) = create_client(&mcp_url, Some(headers)).await; + + client.clone().start().await.unwrap(); + + assert!(client.is_initialized()); + + // give it time for re-connection + tokio::time::sleep(Duration::from_secs(2)).await; + + let request_history = handle.get_history().await; + + let get_requests: Vec<_> = request_history + .iter() + .filter(|r| r.0.method == Method::GET) + .collect(); + + // there should be more than one GET reueat, indicating reconnection + assert!(get_requests.len() > 1); + + let Some(last_get_request) = get_requests.last() else { + panic!("Unable to find last GET reuest!"); + }; + + let last_event_id = last_get_request + .0 + .headers + .get(axum::http::HeaderName::from_static( + MCP_LAST_EVENT_ID_HEADER, + )); + + // last-event-id should be sent + assert!( + matches!(last_event_id, Some(last_event_id) if last_event_id.to_str().unwrap().starts_with("msg-id")) + ); + + // custom headers should be passed for all GET requests + assert!(get_requests.iter().all(|r| r + .0 + .headers + .get(axum::http::HeaderName::from_str("X-Custom-Header").unwrap()) + .unwrap() + .to_str() + .unwrap() + == "CustomValue")); + + println!("last_event_id {:?} ", last_event_id.unwrap()); +} + +// should NOT reconnect a POST-initiated stream that fails +#[tokio::test] +async fn should_not_reconnect_a_post_initiated_stream_that_fails() { + let mocks = vec![ + MockBuilder::new_sse(Method::POST, "/mcp".to_string(), INITIALIZE_RESPONSE) + .expect(1) + .build(), + MockBuilder::new_sse(Method::GET, "/mcp".to_string(), "".to_string()) + .with_status(StatusCode::METHOD_NOT_ALLOWED) + .build(), + MockBuilder::new_sse( + Method::POST, + "/mcp".to_string(), + r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#, + ) + .expect(1) + .build(), + MockBuilder::new_breakable_sse( + Method::POST, + "/mcp".to_string(), + SseEvent { + data: Some("msg".to_string()), + event: None, + id: None, + }, + Duration::ZERO, + 0, + ) + .build(), + ]; + + let (url, handle) = SimpleMockServer::start_with_mocks(mocks).await; + let mcp_url = format!("{url}/mcp"); + + let mut headers = HashMap::new(); + headers.insert("X-Custom-Header".to_string(), "CustomValue".to_string()); + let (client, _) = create_client(&mcp_url, Some(headers)).await; + + client.clone().start().await.unwrap(); + + assert!(client.is_initialized()); + + let result = client.send_roots_list_changed(None).await; + + assert!(result.is_err()); + + tokio::time::sleep(Duration::from_secs(2)).await; + + let request_history = handle.get_history().await; + let post_requests: Vec<_> = request_history + .iter() + .filter(|r| r.0.method == Method::POST) + .collect(); + assert_eq!(post_requests.len(), 3); // initialize, initialized, root_list_changed +} + +//****************** Auth ****************** +// attempts auth flow on 401 during POST request +// invalidates all credentials on InvalidClientError during auth +// invalidates all credentials on UnauthorizedClientError during auth +//invalidates tokens on InvalidGrantError during auth + +//****************** Others ****************** +// custom fetch in auth code paths +// should support custom reconnection options +// uses custom fetch implementation if provided +// should have exponential backoff with configurable maxRetries diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs similarity index 99% rename from crates/rust-mcp-sdk/tests/test_streamable_http.rs rename to crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 23ca27f..4809d6d 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -8,13 +8,12 @@ use rust_mcp_schema::{ SdkErrorCodes, ServerJsonrpcNotification, ServerJsonrpcRequest, ServerJsonrpcResponse, ServerMessages, }, - CallToolRequest, CallToolRequestParams, ListPromptsRequestParams, ListRootsRequestParams, - ListRootsResult, ListToolsRequest, LoggingLevel, LoggingMessageNotificationParams, RequestId, - RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult, + CallToolRequest, CallToolRequestParams, ListRootsResult, ListToolsRequest, LoggingLevel, + LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, + ServerRequest, ServerResult, }; use rust_mcp_sdk::mcp_server::HyperServerOptions; use serde_json::{json, Map, Value}; -use tokio_stream::StreamExt; use crate::common::{ random_port, read_sse_event, read_sse_event_from_stream, send_delete_request, send_get_request, diff --git a/crates/rust-mcp-transport/Cargo.toml b/crates/rust-mcp-transport/Cargo.toml index ec061bb..2f03580 100644 --- a/crates/rust-mcp-transport/Cargo.toml +++ b/crates/rust-mcp-transport/Cargo.toml @@ -42,10 +42,12 @@ workspace = true ### FEATURES ################################################################# [features] -default = ["stdio", "sse", "2025_06_18"] # Default features +default = ["stdio", "sse", "streamable-http", "2025_06_18"] # Default features stdio = [] sse = ["reqwest"] +streamable-http = ["reqwest"] + # enabled mcp protocol version 2025_06_18 2025_06_18 = ["rust-mcp-schema/2025_06_18", "rust-mcp-schema/schema_utils"] diff --git a/crates/rust-mcp-transport/README.md b/crates/rust-mcp-transport/README.md index 23b78bf..30cad83 100644 --- a/crates/rust-mcp-transport/README.md +++ b/crates/rust-mcp-transport/README.md @@ -14,7 +14,7 @@ let transport = StdioTransport::new(TransportOptions { timeout: 60_000 })?; ``` -Refer to the [Hello World MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) example for a complete demonstration. +Refer to the [Hello World MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-stdio) example for a complete demonstration. ### For MCP Client @@ -51,7 +51,7 @@ let transport = StdioTransport::create_with_server_launch( )?; ``` -Refer to the [Simple MCP Client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) example for a complete demonstration. +Refer to the [Simple MCP Client](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-stdio) example for a complete demonstration. --- diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index f201aa0..8d55bd0 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -5,7 +5,7 @@ use crate::transport::Transport; use crate::utils::{ extract_origin, http_post, CancellationTokenSource, ReadableChannel, SseStream, WritableChannel, }; -use crate::{IoStream, McpDispatch, TransportOptions}; +use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions}; use async_trait::async_trait; use bytes::Bytes; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; @@ -13,8 +13,13 @@ use reqwest::Client; use tokio::sync::oneshot::Sender; use tokio::task::JoinHandle; -use crate::schema::schema_utils::McpMessage; -use crate::schema::RequestId; +use crate::schema::{ + schema_utils::{ + ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage, + ServerMessages, + }, + RequestId, +}; use std::cmp::Ordering; use std::collections::HashMap; use std::pin::Pin; @@ -25,7 +30,7 @@ use tokio::sync::{mpsc, oneshot, Mutex}; const DEFAULT_CHANNEL_CAPACITY: usize = 64; const DEFAULT_MAX_RETRY: usize = 5; -const DEFAULT_RETRY_TIME_SECONDS: u64 = 3; +const DEFAULT_RETRY_TIME_SECONDS: u64 = 1; const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5; /// Configuration options for the Client SSE Transport @@ -102,10 +107,9 @@ where let base_url = match extract_origin(server_url) { Some(url) => url, None => { - let error_message = - format!("Failed to extract origin from server URL: {server_url}"); - tracing::error!(error_message); - return Err(TransportError::InvalidOptions(error_message)); + let message = format!("Failed to extract origin from server URL: {server_url}"); + tracing::error!(message); + return Err(TransportError::Configuration { message }); } }; @@ -145,12 +149,15 @@ where let mut header_map = HeaderMap::new(); for (key, value) in headers { - let header_name = key - .parse::() - .map_err(|e| TransportError::InvalidOptions(format!("Invalid header name: {e}")))?; - let header_value = HeaderValue::from_str(value).map_err(|e| { - TransportError::InvalidOptions(format!("Invalid header value: {e}")) - })?; + let header_name = + key.parse::() + .map_err(|e| TransportError::Configuration { + message: format!("Invalid header name: {e}"), + })?; + let header_value = + HeaderValue::from_str(value).map_err(|e| TransportError::Configuration { + message: format!("Invalid header value: {e}"), + })?; header_map.insert(header_name, header_value); } @@ -172,10 +179,12 @@ where } if let Some(endpoint_origin) = extract_origin(&endpoint) { if endpoint_origin.cmp(&self.base_url) != Ordering::Equal { - return Err(TransportError::InvalidOptions(format!( + return Err(TransportError::Configuration { + message: format!( "Endpoint origin does not match connection origin. expected: {} , received: {}", self.base_url, endpoint_origin - ))); + ), + }); } return Ok(endpoint); } @@ -284,8 +293,8 @@ where Some(data) => { // trim the trailing \n before making a request let body = String::from_utf8_lossy(&data).trim().to_string(); - if let Err(e) = http_post(&client_clone, &post_url, body, &custom_headers).await { - tracing::error!("Failed to POST message: {e:?}"); + if let Err(e) = http_post(&client_clone, &post_url, body,None, custom_headers.as_ref()).await { + tracing::error!("Failed to POST message: {e}"); } }, None => break, // Exit if channel is closed @@ -335,7 +344,7 @@ where } async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of consume_string_payload() function for ClientSseTransport" .to_string(), )) @@ -346,7 +355,7 @@ where _: Duration, _: oneshot::Sender<()>, ) -> TransportResult> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of keep_alive() function for ClientSseTransport".to_string(), )) } @@ -413,3 +422,55 @@ where pending_requests.remove(request_id) } } + +#[async_trait] +impl McpDispatch + for ClientSseTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for ClientSseTransport +{ +} diff --git a/crates/rust-mcp-transport/src/client_streamable_http.rs b/crates/rust-mcp-transport/src/client_streamable_http.rs new file mode 100644 index 0000000..c318649 --- /dev/null +++ b/crates/rust-mcp-transport/src/client_streamable_http.rs @@ -0,0 +1,515 @@ +use crate::error::TransportError; +use crate::mcp_stream::MCPStream; + +use crate::schema::{ + schema_utils::{ + ClientMessage, ClientMessages, McpMessage, MessageFromClient, SdkError, ServerMessage, + ServerMessages, + }, + RequestId, +}; +use crate::utils::{ + http_delete, http_post, CancellationTokenSource, ReadableChannel, StreamableHttpStream, + WritableChannel, +}; +use crate::{error::TransportResult, IoStream, McpDispatch, MessageDispatcher, Transport}; +use crate::{SessionId, TransportDispatcher, TransportOptions}; +use async_trait::async_trait; +use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +use reqwest::Client; +use std::collections::HashMap; +use std::pin::Pin; +use std::{sync::Arc, time::Duration}; +use tokio::io::{BufReader, BufWriter}; +use tokio::sync::oneshot::Sender; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::task::JoinHandle; + +const DEFAULT_CHANNEL_CAPACITY: usize = 64; +const DEFAULT_MAX_RETRY: usize = 5; +const DEFAULT_RETRY_TIME_SECONDS: u64 = 1; +const SHUTDOWN_TIMEOUT_SECONDS: u64 = 5; + +pub struct StreamableTransportOptions { + pub mcp_url: String, + pub request_options: RequestOptions, +} + +impl StreamableTransportOptions { + pub async fn terminate_session(&self, session_id: Option<&SessionId>) { + let client = Client::new(); + match http_delete(&client, &self.mcp_url, session_id, None).await { + Ok(_) => {} + Err(TransportError::Http(status_code)) => { + tracing::info!("Session termination failed with status code {status_code}",); + } + Err(error) => { + tracing::info!("Session termination failed with error :{error}"); + } + }; + } +} + +pub struct RequestOptions { + pub request_timeout: Duration, + pub retry_delay: Option, + pub max_retries: Option, + pub custom_headers: Option>, +} + +impl Default for RequestOptions { + fn default() -> Self { + Self { + request_timeout: TransportOptions::default().timeout, + retry_delay: None, + max_retries: None, + custom_headers: None, + } + } +} + +pub struct ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + /// Optional cancellation token source for shutting down the transport + shutdown_source: tokio::sync::RwLock>, + /// Flag indicating if the transport is shut down + is_shut_down: Mutex, + /// Timeout duration for MCP messages + request_timeout: Duration, + /// HTTP client for making requests + client: Client, + /// URL for the SSE endpoint + mcp_server_url: String, + /// Delay between retry attempts + retry_delay: Duration, + /// Maximum number of retry attempts + max_retries: usize, + /// Optional custom HTTP headers + custom_headers: Option, + sse_task: tokio::sync::RwLock>>, + post_task: tokio::sync::RwLock>>, + message_sender: Arc>>>, + error_stream: tokio::sync::RwLock>, + pending_requests: Arc>>>, + session_id: Arc>>, + standalone: bool, +} + +impl ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + pub fn new( + options: &StreamableTransportOptions, + session_id: Option, + standalone: bool, + ) -> TransportResult { + let client = Client::new(); + + let headers = match &options.request_options.custom_headers { + Some(h) => Some(Self::validate_headers(h)?), + None => None, + }; + + let mcp_server_url = options.mcp_url.to_owned(); + Ok(Self { + shutdown_source: tokio::sync::RwLock::new(None), + is_shut_down: Mutex::new(false), + request_timeout: options.request_options.request_timeout, + client, + mcp_server_url, + retry_delay: options + .request_options + .retry_delay + .unwrap_or(Duration::from_secs(DEFAULT_RETRY_TIME_SECONDS)), + max_retries: options + .request_options + .max_retries + .unwrap_or(DEFAULT_MAX_RETRY), + sse_task: tokio::sync::RwLock::new(None), + post_task: tokio::sync::RwLock::new(None), + custom_headers: headers, + message_sender: Arc::new(tokio::sync::RwLock::new(None)), + error_stream: tokio::sync::RwLock::new(None), + pending_requests: Arc::new(Mutex::new(HashMap::new())), + session_id: Arc::new(tokio::sync::RwLock::new(session_id)), + standalone, + }) + } + + fn validate_headers(headers: &HashMap) -> TransportResult { + let mut header_map = HeaderMap::new(); + for (key, value) in headers { + let header_name = + key.parse::() + .map_err(|e| TransportError::Configuration { + message: format!("Invalid header name: {e}"), + })?; + let header_value = + HeaderValue::from_str(value).map_err(|e| TransportError::Configuration { + message: format!("Invalid header value: {e}"), + })?; + header_map.insert(header_name, header_value); + } + Ok(header_map) + } + + pub(crate) async fn set_message_sender(&self, sender: MessageDispatcher) { + let mut lock = self.message_sender.write().await; + *lock = Some(sender); + } + + pub(crate) async fn set_error_stream( + &self, + error_stream: Pin>, + ) { + let mut lock = self.error_stream.write().await; + *lock = Some(IoStream::Readable(error_stream)); + } +} + +#[async_trait] +impl Transport for ClientStreamableTransport +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, + M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + OR: Clone + Send + Sync + serde::Serialize + 'static, + OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ + async fn start(&self) -> TransportResult> + where + MessageDispatcher: McpDispatch, + { + if self.standalone { + // Create CancellationTokenSource and token + let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); + let mut lock = self.shutdown_source.write().await; + *lock = Some(cancellation_source); + + let (write_tx, mut write_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + let (read_tx, read_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + + let max_retries = self.max_retries; + let retry_delay = self.retry_delay; + + let post_url = self.mcp_server_url.clone(); + let custom_headers = self.custom_headers.clone(); + let cancellation_token_post = cancellation_token.clone(); + let cancellation_token_sse = cancellation_token.clone(); + + let session_id_clone = self.session_id.clone(); + + let mut streamable_http = StreamableHttpStream { + client: self.client.clone(), + mcp_url: post_url, + max_retries, + retry_delay, + read_tx, + session_id: session_id_clone, //Arc>> + }; + + let session_id = self.session_id.read().await.to_owned(); + + let sse_response = streamable_http + .make_standalone_stream_connection(&cancellation_token_sse, &custom_headers, None) + .await?; + + let sse_task_handle = tokio::spawn(async move { + if let Err(error) = streamable_http + .run_standalone(&cancellation_token_sse, &custom_headers, sse_response) + .await + { + if !matches!(error, TransportError::Cancelled(_)) { + tracing::warn!("{error}"); + } + } + }); + + let mut sse_task_lock = self.sse_task.write().await; + *sse_task_lock = Some(sse_task_handle); + + let post_url = self.mcp_server_url.clone(); + let client = self.client.clone(); + let custom_headers = self.custom_headers.clone(); + + // Initiate a task to process POST requests from messages received via the writable stream. + let post_task_handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = cancellation_token_post.cancelled() => + { + break; + }, + data = write_rx.recv() => { + match data{ + Some(data) => { + // trim the trailing \n before making a request + let payload = String::from_utf8_lossy(&data).trim().to_string(); + + if let Err(e) = http_post( + &client, + &post_url, + payload.to_string(), + session_id.as_ref(), + custom_headers.as_ref(), + ) + .await{ + tracing::error!("Failed to POST message: {e}") + } + }, + None => break, // Exit if channel is closed + } + } + } + } + }); + let mut post_task_lock = self.post_task.write().await; + *post_task_lock = Some(post_task_handle); + + // Create writable stream + let writable: Mutex>> = + Mutex::new(Box::pin(BufWriter::new(WritableChannel { write_tx }))); + + // Create readable stream + let readable: Pin> = + Box::pin(BufReader::new(ReadableChannel { + read_rx, + buffer: Bytes::new(), + })); + + let (stream, sender, error_stream) = MCPStream::create( + readable, + writable, + IoStream::Writable(Box::pin(tokio::io::stderr())), + self.pending_requests.clone(), + self.request_timeout, + cancellation_token, + ); + + self.set_message_sender(sender).await; + + if let IoStream::Readable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + Ok(stream) + } else { + // Create CancellationTokenSource and token + let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); + let mut lock = self.shutdown_source.write().await; + *lock = Some(cancellation_source); + + // let (write_tx, mut write_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + let (write_tx, mut write_rx): ( + tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + tokio::sync::mpsc::Receiver<( + String, + tokio::sync::oneshot::Sender>, + )>, + ) = tokio::sync::mpsc::channel(DEFAULT_CHANNEL_CAPACITY); // Buffer size as needed + let (read_tx, read_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); + + let max_retries = self.max_retries; + let retry_delay = self.retry_delay; + + let post_url = self.mcp_server_url.clone(); + let custom_headers = self.custom_headers.clone(); + let cancellation_token_post = cancellation_token.clone(); + let cancellation_token_sse = cancellation_token.clone(); + + let session_id_clone = self.session_id.clone(); + + let mut streamable_http = StreamableHttpStream { + client: self.client.clone(), + mcp_url: post_url, + max_retries, + retry_delay, + read_tx, + session_id: session_id_clone, //Arc>> + }; + + // Initiate a task to process POST requests from messages received via the writable stream. + let post_task_handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = cancellation_token_post.cancelled() => + { + break; + }, + data = write_rx.recv() => { + match data{ + Some((data, ack_tx)) => { + // trim the trailing \n before making a request + let payload = data.trim().to_string(); + let result = streamable_http.run(payload, &cancellation_token_sse, &custom_headers).await; + let _ = ack_tx.send(result);// Ignore error if receiver dropped + }, + None => break, // Exit if channel is closed + } + } + } + } + }); + let mut post_task_lock = self.post_task.write().await; + *post_task_lock = Some(post_task_handle); + + // Create readable stream + let readable: Pin> = + Box::pin(BufReader::new(ReadableChannel { + read_rx, + buffer: Bytes::new(), + })); + + let (stream, sender, error_stream) = MCPStream::create_with_ack( + readable, + write_tx, + IoStream::Writable(Box::pin(tokio::io::stderr())), + self.pending_requests.clone(), + self.request_timeout, + cancellation_token, + ); + + self.set_message_sender(sender).await; + + if let IoStream::Readable(error_stream) = error_stream { + self.set_error_stream(error_stream).await; + } + + Ok(stream) + } + } + + fn message_sender(&self) -> Arc>>> { + self.message_sender.clone() as _ + } + + fn error_stream(&self) -> &tokio::sync::RwLock> { + &self.error_stream as _ + } + async fn shut_down(&self) -> TransportResult<()> { + // Trigger cancellation + let mut cancellation_lock = self.shutdown_source.write().await; + if let Some(source) = cancellation_lock.as_ref() { + source.cancel()?; + } + *cancellation_lock = None; // Clear cancellation_source + + // Mark as shut down + let mut is_shut_down_lock = self.is_shut_down.lock().await; + *is_shut_down_lock = true; + + // Get task handle + let post_task = self.post_task.write().await.take(); + + // // Wait for tasks to complete with a timeout + let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECONDS); + let shutdown_future = async { + if let Some(post_handle) = post_task { + let _ = post_handle.await; + } + Ok::<(), TransportError>(()) + }; + + tokio::select! { + result = shutdown_future => { + result // result of task completion + } + _ = tokio::time::sleep(timeout) => { + tracing::warn!("Shutdown timed out after {:?}", timeout); + Err(TransportError::ShutdownTimeout) + } + } + } + async fn is_shut_down(&self) -> bool { + let result = self.is_shut_down.lock().await; + *result + } + async fn consume_string_payload(&self, _: &str) -> TransportResult<()> { + Err(TransportError::Internal( + "Invalid invocation of consume_string_payload() function for ClientStreamableTransport" + .to_string(), + )) + } + + async fn pending_request_tx(&self, request_id: &RequestId) -> Option> { + let mut pending_requests = self.pending_requests.lock().await; + pending_requests.remove(request_id) + } + + async fn keep_alive( + &self, + _: Duration, + _: oneshot::Sender<()>, + ) -> TransportResult> { + Err(TransportError::Internal( + "Invalid invocation of keep_alive() function for ClientStreamableTransport".to_string(), + )) + } + + async fn session_id(&self) -> Option { + let guard = self.session_id.read().await; + guard.clone() + } +} + +#[async_trait] +impl McpDispatch + for ClientStreamableTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for ClientStreamableTransport +{ +} diff --git a/crates/rust-mcp-transport/src/constants.rs b/crates/rust-mcp-transport/src/constants.rs new file mode 100644 index 0000000..6ae0342 --- /dev/null +++ b/crates/rust-mcp-transport/src/constants.rs @@ -0,0 +1,3 @@ +pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; +pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; +pub const MCP_LAST_EVENT_ID_HEADER: &str = "last-event-id"; diff --git a/crates/rust-mcp-transport/src/error.rs b/crates/rust-mcp-transport/src/error.rs index 8f8b62f..a244456 100644 --- a/crates/rust-mcp-transport/src/error.rs +++ b/crates/rust-mcp-transport/src/error.rs @@ -1,11 +1,14 @@ use crate::schema::{schema_utils::SdkError, RpcError}; -use thiserror::Error; - use crate::utils::CancellationError; use core::fmt; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use reqwest::Error as ReqwestError; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use reqwest::StatusCode; use std::any::Any; +use std::io::Error as IoError; +use thiserror::Error; use tokio::sync::{broadcast, mpsc}; - /// A wrapper around a broadcast send error. This structure allows for generic error handling /// by boxing the underlying error into a type-erased form. #[derive(Debug)] @@ -80,31 +83,53 @@ pub type TransportResult = core::result::Result; #[derive(Debug, Error)] pub enum TransportError { - #[error("{0}")] - InvalidOptions(String), + #[error("Session expired or not found")] + SessionExpired, + + #[error("Failed to open SSE stream: {0}")] + FailedToOpenSSEStream(String), + + #[error("Unexpected content type: '{0}'")] + UnexpectedContentType(String), + + #[error("Failed to send message: {0}")] + SendFailure(String), + + #[error("I/O error: {0}")] + Io(#[from] IoError), + + #[cfg(any(feature = "sse", feature = "streamable-http"))] + #[error("HTTP connection error: {0}")] + HttpConnection(#[from] ReqwestError), + + #[cfg(any(feature = "sse", feature = "streamable-http"))] + #[error("HTTP error: {0}")] + Http(StatusCode), + + #[error("SDK error: {0}")] + Sdk(#[from] SdkError), + + #[error("Operation cancelled: {0}")] + Cancelled(#[from] CancellationError), + + #[error("Channel closed: {0}")] + ChannelClosed(#[from] tokio::sync::oneshot::error::RecvError), + + #[error("Configuration error: {message}")] + Configuration { message: String }, + #[error("{0}")] SendError(#[from] GenericSendError), - #[error("{0}")] - WatchSendError(#[from] GenericWatchSendError), - #[error("Send Error: {0}")] - StdioError(#[from] std::io::Error), + #[error("{0}")] JsonrpcError(#[from] RpcError), - #[error("{0}")] - SdkError(#[from] SdkError), - #[error("Process error{0}")] + + #[error("Process error: {0}")] ProcessError(String), - #[error("{0}")] - FromString(String), - #[error("{0}")] - OneshotRecvError(#[from] tokio::sync::oneshot::error::RecvError), - #[cfg(feature = "sse")] - #[error("{0}")] - SendMessageError(#[from] reqwest::Error), - #[error("Http Error: {0}")] - HttpError(u16), + + #[error("Internal error: {0}")] + Internal(String), + #[error("Shutdown timed out")] ShutdownTimeout, - #[error("Cancellation error : {0}")] - CancellationError(#[from] CancellationError), } diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 1634922..4a918db 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -1,25 +1,38 @@ // Copyright (c) 2025 mcp-rust-stack // Licensed under the MIT License. See LICENSE file for details. // Modifications to this file must be documented with a description of the changes made. + #[cfg(feature = "sse")] mod client_sse; +#[cfg(feature = "streamable-http")] +mod client_streamable_http; +mod constants; pub mod error; mod mcp_stream; mod message_dispatcher; mod schema; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod sse; +#[cfg(feature = "stdio")] mod stdio; mod transport; mod utils; #[cfg(feature = "sse")] pub use client_sse::*; +#[cfg(feature = "streamable-http")] +pub use client_streamable_http::*; +pub use constants::*; pub use message_dispatcher::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub use sse::*; +#[cfg(feature = "stdio")] pub use stdio::*; pub use transport::*; // Type alias for session identifier, represented as a String pub type SessionId = String; +// Type alias for stream identifier (that will be used at the transport scope), represented as a String +pub type StreamId = String; +// Type alias for event (MCP message) identifier, represented as a String +pub type EventId = String; diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index 08bdc21..0b10918 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -57,6 +57,43 @@ impl MCPStream { (stream, sender, error_io) } + pub fn create_with_ack( + readable: Pin>, + writable: tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + error_io: IoStream, + pending_requests: Arc>>>, + request_timeout: Duration, + cancellation_token: CancellationToken, + ) -> ( + tokio_stream::wrappers::ReceiverStream, + MessageDispatcher, + IoStream, + ) + where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + X: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + { + let (tx, rx) = tokio::sync::mpsc::channel::(CHANNEL_CAPACITY); + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + // Clone cancellation_token for reader + let reader_token = cancellation_token.clone(); + + #[allow(clippy::let_underscore_future)] + let _ = Self::spawn_reader(readable, tx, reader_token); + + let sender = MessageDispatcher::new_with_acknowledgement( + pending_requests, + writable, + request_timeout, + ); + + (stream, sender, error_io) + } + /// Creates a new task that continuously reads from the readable stream. /// The received data is deserialized into a JsonrpcMessage. If the deserialization is successful, /// the object is transmitted. If the object is a response or error corresponding to a pending request, diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index ea1eb04..7c7c93e 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -29,7 +29,13 @@ use crate::McpDispatch; /// a configurable timeout mechanism for asynchronous responses. pub struct MessageDispatcher { pending_requests: Arc>>>, - writable_std: Mutex>>, + writable_std: Option>>>, + writable_tx: Option< + tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + >, request_timeout: Duration, } @@ -51,7 +57,24 @@ impl MessageDispatcher { ) -> Self { Self { pending_requests, - writable_std, + writable_std: Some(writable_std), + writable_tx: None, + request_timeout, + } + } + + pub fn new_with_acknowledgement( + pending_requests: Arc>>>, + writable_tx: tokio::sync::mpsc::Sender<( + String, + tokio::sync::oneshot::Sender>, + )>, + request_timeout: Duration, + ) -> Self { + Self { + pending_requests, + writable_tx: Some(writable_tx), + writable_std: None, request_timeout, } } @@ -125,7 +148,7 @@ impl McpDispatch match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { Ok(response) => Ok(Some(ServerMessages::Single(response))), Err(error) => match error { - TransportError::OneshotRecvError(_) => { + TransportError::ChannelClosed(_) => { Err(schema_utils::SdkError::connection_closed().into()) } _ => Err(error), @@ -147,6 +170,9 @@ impl McpDispatch }) .unzip(); + // Ensure all request IDs are stored before sending the request + let tasks = join_all(pending_tasks).await; + // send the batch messages to the server let message_payload = serde_json::to_string(&client_messages).map_err(|_| { crate::error::TransportError::JsonrpcError(RpcError::parse_error()) @@ -154,12 +180,10 @@ impl McpDispatch self.write_str(message_payload.as_str()).await?; // no request in the batch, no need to wait for the result - if pending_tasks.is_empty() { + if request_ids.is_empty() { return Ok(None); } - let tasks = join_all(pending_tasks).await; - let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| { rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout))) }); @@ -210,11 +234,24 @@ impl McpDispatch /// appending a newline character and flushing the stream afterward. /// async fn write_str(&self, payload: &str) -> TransportResult<()> { - let mut writable_std = self.writable_std.lock().await; - writable_std.write_all(payload.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; - Ok(()) + if let Some(writable_std) = self.writable_std.as_ref() { + let mut writable_std = writable_std.lock().await; + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + return Ok(()); + }; + + if let Some(writable_tx) = self.writable_tx.as_ref() { + let (resp_tx, resp_rx) = oneshot::channel(); + writable_tx + .send((payload.to_string(), resp_tx)) + .await + .map_err(|err| TransportError::Internal(format!("{err}")))?; // Send fails if channel closed + return resp_rx.await?; // Await the POST result; propagates the error if POST failed + } + + Err(TransportError::Internal("Invalid dispatcher!".to_string())) } } @@ -339,10 +376,23 @@ impl McpDispatch /// appending a newline character and flushing the stream afterward. /// async fn write_str(&self, payload: &str) -> TransportResult<()> { - let mut writable_std = self.writable_std.lock().await; - writable_std.write_all(payload.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; - Ok(()) + if let Some(writable_std) = self.writable_std.as_ref() { + let mut writable_std = writable_std.lock().await; + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + return Ok(()); + }; + + if let Some(writable_tx) = self.writable_tx.as_ref() { + let (resp_tx, resp_rx) = oneshot::channel(); + writable_tx + .send((payload.to_string(), resp_tx)) + .await + .map_err(|err| TransportError::Internal(err.to_string()))?; // Send fails if channel closed + return resp_rx.await?; // Await the POST result; propagates the error if POST failed + } + + Err(TransportError::Internal("Invalid dispatcher!".to_string())) } } diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index 50dbb32..09809e4 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -156,7 +156,7 @@ impl Transport {} - Err(TransportError::StdioError(error)) => { + Err(TransportError::Io(error)) => { if error.kind() == std::io::ErrorKind::BrokenPipe { let _ = disconnect_tx.send(()); break; diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 582af5d..11bd0a6 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -1,5 +1,6 @@ use crate::schema::schema_utils::{ - ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, + ClientMessage, ClientMessages, MessageFromClient, MessageFromServer, SdkError, ServerMessage, + ServerMessages, }; use crate::schema::RequestId; use async_trait::async_trait; @@ -193,30 +194,29 @@ where #[cfg(unix)] command.process_group(0); - let mut process = command.spawn().map_err(TransportError::StdioError)?; + let mut process = command.spawn().map_err(TransportError::Io)?; let stdin = process .stdin .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stdin.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stdin.".into()))?; let stdout = process .stdout .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stdout.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stdout.".into()))?; let stderr = process .stderr .take() - .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?; + .ok_or_else(|| TransportError::Internal("Unable to retrieve stderr.".into()))?; - let pending_requests_clone1 = self.pending_requests.clone(); - let pending_requests_clone2 = self.pending_requests.clone(); + let pending_requests_clone = self.pending_requests.clone(); tokio::spawn(async move { let _ = process.wait().await; // clean up pending requests to cancel waiting tasks - let mut pending_requests = pending_requests_clone1.lock().await; + let mut pending_requests = pending_requests_clone.lock().await; pending_requests.clear(); }); @@ -224,7 +224,7 @@ where Box::pin(stdout), Mutex::new(Box::pin(stdin)), IoStream::Readable(Box::pin(stderr)), - pending_requests_clone2, + self.pending_requests.clone(), self.options.timeout, cancellation_token, ); @@ -275,7 +275,7 @@ where } async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(), )) } @@ -285,7 +285,7 @@ where _interval: Duration, _disconnect_tx: oneshot::Sender<()>, ) -> TransportResult> { - Err(TransportError::FromString( + Err(TransportError::Internal( "Invalid invocation of keep_alive() function for StdioTransport".to_string(), )) } @@ -365,3 +365,55 @@ impl > for StdioTransport { } + +#[async_trait] +impl McpDispatch + for StdioTransport +{ + async fn send_message( + &self, + message: ClientMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + > for StdioTransport +{ +} diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index 3d17ebd..b8e3ddc 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -1,15 +1,12 @@ -use std::{pin::Pin, sync::Arc, time::Duration}; - -use crate::schema::RequestId; +use crate::{error::TransportResult, message_dispatcher::MessageDispatcher}; +use crate::{schema::RequestId, SessionId}; use async_trait::async_trait; - +use std::{pin::Pin, sync::Arc, time::Duration}; use tokio::{ sync::oneshot::{self, Sender}, task::JoinHandle, }; -use crate::{error::TransportResult, message_dispatcher::MessageDispatcher}; - /// Default Timeout in milliseconds const DEFAULT_TIMEOUT_MSEC: u64 = 60_000; @@ -125,6 +122,9 @@ where interval: Duration, disconnect_tx: oneshot::Sender<()>, ) -> TransportResult>; + async fn session_id(&self) -> Option { + None + } } /// A composite trait that combines both transport and dispatch capabilities for the MCP protocol. @@ -160,3 +160,26 @@ where OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { } + +// pub trait IntoClientTransport { +// type TransportType: Transport< +// ServerMessages, +// MessageFromClient, +// ServerMessage, +// ClientMessages, +// ClientMessage, +// >; + +// fn into_transport(self, session_id: Option) -> TransportResult; +// } + +// impl IntoClientTransport for T +// where +// T: Transport, +// { +// type TransportType = T; + +// fn into_transport(self, _: Option) -> TransportResult { +// Ok(self) +// } +// } diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 218d517..82d7326 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -1,21 +1,29 @@ mod cancellation_token; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod http_utils; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod readable_channel; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +mod sse_parser; #[cfg(feature = "sse")] mod sse_stream; -#[cfg(feature = "sse")] +#[cfg(feature = "streamable-http")] +mod streamable_http_stream; +#[cfg(any(feature = "sse", feature = "streamable-http"))] mod writable_channel; pub(crate) use cancellation_token::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use http_utils::*; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use readable_channel::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +pub(crate) use sse_parser::*; #[cfg(feature = "sse")] pub(crate) use sse_stream::*; -#[cfg(feature = "sse")] +#[cfg(feature = "streamable-http")] +pub(crate) use streamable_http_stream::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use writable_channel::*; use crate::schema::schema_utils::SdkError; @@ -23,16 +31,16 @@ use tokio::time::{timeout, Duration}; use crate::error::{TransportError, TransportResult}; -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] use crate::SessionId; pub async fn await_timeout(operation: F, timeout_duration: Duration) -> TransportResult where F: std::future::Future>, // The operation returns a Result - E: Into, // The error type must be convertible to TransportError + E: Into, { match timeout(timeout_duration, operation).await { - Ok(result) => result.map_err(|err| err.into()), // Convert the error type into TransportError + Ok(result) => result.map_err(|err| err.into()), Err(_) => Err(SdkError::request_timeout(timeout_duration.as_millis()).into()), // Timeout error } } @@ -46,7 +54,7 @@ where /// # Returns /// A String containing the endpoint with the session ID added as a query parameter /// -#[cfg(feature = "sse")] +#[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) fn endpoint_with_session_id(endpoint: &str, session_id: &SessionId) -> String { // Handle empty endpoint let base = if endpoint.is_empty() { "/" } else { endpoint }; diff --git a/crates/rust-mcp-transport/src/utils/http_utils.rs b/crates/rust-mcp-transport/src/utils/http_utils.rs index 701dcb0..84b62dd 100644 --- a/crates/rust-mcp-transport/src/utils/http_utils.rs +++ b/crates/rust-mcp-transport/src/utils/http_utils.rs @@ -1,7 +1,35 @@ use crate::error::{TransportError, TransportResult}; +use crate::{SessionId, MCP_SESSION_ID_HEADER}; -use reqwest::header::{HeaderMap, CONTENT_TYPE}; -use reqwest::Client; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE}; +use reqwest::{Client, Response}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ResponseType { + EventStream, + Json, +} + +/// Determines the response type based on the `Content-Type` header. +pub async fn validate_response_type(response: &Response) -> TransportResult { + match response.headers().get(reqwest::header::CONTENT_TYPE) { + Some(content_type) => { + let content_type_str = content_type.to_str().map_err(|_| { + TransportError::UnexpectedContentType("".to_string()) + })?; + + // Normalize to lowercase for case-insensitive comparison + let content_type_normalized = content_type_str.to_ascii_lowercase(); + + match content_type_normalized.as_str() { + "text/event-stream" => Ok(ResponseType::EventStream), + "application/json" => Ok(ResponseType::Json), + other => Err(TransportError::UnexpectedContentType(other.to_string())), + } + } + None => Err(TransportError::UnexpectedContentType("".to_string())), + } +} /// Sends an HTTP POST request with the given body and headers /// @@ -17,21 +45,96 @@ pub async fn http_post( client: &Client, post_url: &str, body: String, - headers: &Option, -) -> TransportResult<()> { + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { let mut request = client .post(post_url) .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") .body(body); if let Some(map) = headers { request = request.headers(map.clone()); } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + let response = request.send().await?; if !response.status().is_success() { - return Err(TransportError::HttpError(response.status().as_u16())); + return Err(TransportError::Http(response.status())); } - Ok(()) + Ok(response) +} + +pub async fn http_get( + client: &Client, + url: &str, + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { + let mut request = client + .get(url) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream"); + + if let Some(map) = headers { + request = request.headers(map.clone()); + } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + + let response = request.send().await?; + if !response.status().is_success() { + return Err(TransportError::Http(response.status())); + } + Ok(response) +} + +pub async fn http_delete( + client: &Client, + post_url: &str, + session_id: Option<&SessionId>, + headers: Option<&HeaderMap>, +) -> TransportResult { + let mut request = client + .delete(post_url) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream"); + + if let Some(map) = headers { + request = request.headers(map.clone()); + } + + if let Some(session_id) = session_id { + request = request.header( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(session_id).unwrap(), + ); + } + + let response = request.send().await?; + if !response.status().is_success() { + let status_code = response.status(); + return Err(TransportError::Http(status_code)); + } + Ok(response) +} + +#[allow(unused)] +pub fn get_header_value(response: &Response, header_name: HeaderName) -> Option { + let content_type = response.headers().get(header_name)?.to_str().ok()?; + Some(content_type.to_string()) } pub fn extract_origin(url: &str) -> Option { @@ -88,7 +191,7 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is Ok assert!(result.is_ok()); @@ -113,11 +216,11 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is an HttpError with status 400 match result { - Err(TransportError::HttpError(status)) => assert_eq!(status, 400), + Err(TransportError::Http(status)) => assert_eq!(status, 400), _ => panic!("Expected HttpError with status 400"), } } @@ -142,7 +245,7 @@ mod tests { let headers = Some(create_test_headers()); // Perform the POST request - let result = http_post(&client, &url, body, &headers).await; + let result = http_post(&client, &url, body, None, headers.as_ref()).await; // Assert the result is Ok assert!(result.is_ok()); @@ -157,7 +260,7 @@ mod tests { let headers = None; // Perform the POST request - let result = http_post(&client, url, body, &headers).await; + let result = http_post(&client, url, body, None, headers.as_ref()).await; // Assert the result is an error (likely a connection error) assert!(result.is_err()); diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs new file mode 100644 index 0000000..064d3c3 --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -0,0 +1,320 @@ +use core::fmt; +use std::collections::HashMap; + +use bytes::{Bytes, BytesMut}; +const BUFFER_CAPACITY: usize = 1024; + +/// Represents a single Server-Sent Event (SSE) as defined in the SSE protocol. +/// +/// Contains the event type, data payload, and optional event ID. +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, +} + +impl std::fmt::Display for SseEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(id) = &self.id { + writeln!(f, "id: {id}")?; + } + + if let Some(event) = &self.event { + writeln!(f, "event: {event}")?; + } + + if let Some(data) = &self.data { + match std::str::from_utf8(data) { + Ok(text) => { + for line in text.lines() { + writeln!(f, "data: {line}")?; + } + } + Err(_) => { + writeln!(f, "data: [binary data]")?; + } + } + } + + writeln!(f)?; // Trailing newline for SSE message end + Ok(()) + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self + .data + .as_ref() + .map(|b| String::from_utf8_lossy(b).to_string()); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .finish() + } +} + +/// A parser for Server-Sent Events (SSE) that processes incoming byte chunks into `SseEvent`s. +/// This Parser is specificly designed for MCP messages and with no multi-line data support +/// +/// This struct maintains a buffer to accumulate incoming data and parses it into SSE events +/// based on the SSE protocol. It handles fields like `event`, `data`, and `id` as defined +/// in the SSE specification. +#[derive(Debug)] +pub struct SseParser { + pub buffer: BytesMut, +} + +impl SseParser { + /// Creates a new `SseParser` with an empty buffer pre-allocated to a default capacity. + /// + /// The buffer is initialized with a capacity of `BUFFER_CAPACITY` to + /// optimize for typical SSE message sizes. + /// + /// # Returns + /// A new `SseParser` instance with an empty buffer. + pub fn new() -> Self { + Self { + buffer: BytesMut::with_capacity(BUFFER_CAPACITY), + } + } + + /// Processes a new chunk of bytes and parses it into a vector of `SseEvent`s. + /// + /// This method appends the incoming `bytes` to the internal buffer, splits it into + /// complete lines (delimited by `\n`), and parses each line according to the SSE + /// protocol. It supports `event`, `id`, and `data` fields, as well as comments + /// (lines starting with `:`). Empty lines are skipped, and incomplete lines remain + /// in the buffer for future processing. + /// + /// # Parameters + /// - `bytes`: The incoming chunk of bytes to parse. + /// + /// # Returns + /// A vector of `SseEvent`s parsed from the complete lines in the buffer. If no + /// complete events are found, an empty vector is returned. + pub fn process_new_chunk(&mut self, bytes: Bytes) -> Vec { + self.buffer.extend_from_slice(&bytes); + + // Collect complete lines (ending in \n)—keep ALL lines, including empty ones for \n\n detection + let mut lines = Vec::new(); + while let Some(pos) = self.buffer.iter().position(|&b| b == b'\n') { + let line = self.buffer.split_to(pos + 1).freeze(); + lines.push(line); + } + + let mut events = Vec::new(); + let mut current_message_lines: Vec = Vec::new(); + + for line in lines { + current_message_lines.push(line); + + // Check if we've hit a double newline (end of message) + if current_message_lines.len() >= 2 + && current_message_lines + .last() + .is_some_and(|b| b.as_ref() == b"\n") + { + // Process the complete message (exclude the last empty lines for parsing) + let message_lines: Vec<_> = current_message_lines + .drain(..current_message_lines.len() - 1) + .filter(|l| l.as_ref() != b"\n") // Filter internal empties + .collect(); + + if let Some(event) = self.parse_sse_message(&message_lines) { + events.push(event); + } + } + } + + // Put back any incomplete message + if !current_message_lines.is_empty() { + self.buffer.clear(); + for line in current_message_lines { + self.buffer.extend_from_slice(&line); + } + } + + events + } + + fn parse_sse_message(&self, lines: &[Bytes]) -> Option { + let mut fields: HashMap = HashMap::new(); + let mut data_parts: Vec = Vec::new(); + + for line_bytes in lines { + let line_str = String::from_utf8_lossy(line_bytes); + + // Skip comments and empty lines + if line_str.is_empty() || line_str.starts_with(':') { + continue; + } + + let (key, value) = if let Some(value) = line_str.strip_prefix("data: ") { + ("data", value.trim_start().to_string()) + } else if let Some(value) = line_str.strip_prefix("event: ") { + ("event", value.trim().to_string()) + } else if let Some(value) = line_str.strip_prefix("id: ") { + ("id", value.trim().to_string()) + } else if let Some(value) = line_str.strip_prefix("retry: ") { + ("retry", value.trim().to_string()) + } else { + // Invalid line; skip + continue; + }; + + if key == "data" { + if !value.is_empty() { + data_parts.push(value); + } + } else { + fields.insert(key.to_string(), value); + } + } + + // Build data (concat multi-line data with \n) , should not occur in MCP tho + let data = if data_parts.is_empty() { + None + } else { + let full_data = data_parts.join("\n"); + Some(Bytes::copy_from_slice(full_data.as_bytes())) // Use copy_from_slice for efficiency + }; + + // Skip invalid message with no data + let data = data?; + + // Get event (default to None) + let event = fields.get("event").cloned(); + let id = fields.get("id").cloned(); + + Some(SseEvent { + event, + data: Some(data), + id, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + #[test] + fn test_single_data_event() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: hello\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + assert!(events[0].event.is_none()); + assert!(events[0].id.is_none()); + } + + #[test] + fn test_event_with_id_and_data() { + let mut parser = SseParser::new(); + let input = Bytes::from("event: message\nid: 123\ndata: hello\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!(events[0].event.as_deref(), Some("message")); + assert_eq!(events[0].id.as_deref(), Some("123")); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + } + + #[test] + fn test_event_chunks_in_different_orders() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: hello\nevent: message\nid: 123\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!(events[0].event.as_deref(), Some("message")); + assert_eq!(events[0].id.as_deref(), Some("123")); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("hello\n").as_ref()) + ); + } + + #[test] + fn test_comment_line_ignored() { + let mut parser = SseParser::new(); + let input = Bytes::from(": this is a comment\n\n"); + let events = parser.process_new_chunk(input); + assert_eq!(events.len(), 0); + } + + #[test] + fn test_event_with_empty_data() { + let mut parser = SseParser::new(); + let input = Bytes::from("data:\n\n"); + let events = parser.process_new_chunk(input); + // Your parser skips data lines with empty content + assert_eq!(events.len(), 0); + } + + #[test] + fn test_partial_chunks() { + let mut parser = SseParser::new(); + + let part1 = Bytes::from("data: hello"); + let part2 = Bytes::from(" world\n\n"); + + let events1 = parser.process_new_chunk(part1); + assert_eq!(events1.len(), 0); // incomplete + + let events2 = parser.process_new_chunk(part2); + assert_eq!(events2.len(), 1); + assert_eq!( + events2[0].data.as_deref(), + Some(Bytes::from("hello world\n").as_ref()) + ); + } + + #[test] + fn test_malformed_lines() { + let mut parser = SseParser::new(); + let input = Bytes::from("something invalid\ndata: ok\n\n"); + + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("ok\n").as_ref()) + ); + } + + #[test] + fn test_multiple_events_in_one_chunk() { + let mut parser = SseParser::new(); + let input = Bytes::from("data: first\n\ndata: second\n\n"); + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 2); + assert_eq!( + events[0].data.as_deref(), + Some(Bytes::from("first\n").as_ref()) + ); + assert_eq!( + events[1].data.as_deref(), + Some(Bytes::from("second\n").as_ref()) + ); + } +} diff --git a/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs new file mode 100644 index 0000000..ae9c69c --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs @@ -0,0 +1,374 @@ +use super::CancellationToken; +use crate::error::{TransportError, TransportResult}; +use crate::utils::SseParser; +use crate::utils::{http_get, validate_response_type, ResponseType}; +use crate::{utils::http_post, MCP_SESSION_ID_HEADER}; +use crate::{EventId, MCP_LAST_EVENT_ID_HEADER}; +use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest::{Client, Response, StatusCode}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{mpsc, RwLock}; +use tokio::time; +use tokio_stream::StreamExt; + +//-----------------------------------------------------------------------------------// +pub(crate) struct StreamableHttpStream { + /// HTTP client for making SSE requests + pub client: Client, + /// URL of the SSE endpoint + pub mcp_url: String, + /// Maximum number of retry attempts for failed connections + pub max_retries: usize, + /// Delay between retry attempts + pub retry_delay: Duration, + /// Sender for transmitting received data to the readable channel + pub read_tx: mpsc::Sender, + /// Session id will be received from the server in the http + pub session_id: Arc>>, +} + +impl StreamableHttpStream { + pub(crate) async fn run( + &mut self, + payload: String, + cancellation_token: &CancellationToken, + custom_headers: &Option, + ) -> TransportResult<()> { + let mut stream_parser = SseParser::new(); + let mut _last_event_id: Option = None; + + let session_id = self.session_id.read().await.clone(); + + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::info!( + "StreamableHttp cancelled before connection attempt {}", + payload + ); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + //TODO: simplify + let response = match http_post( + &self.client, + &self.mcp_url, + payload.to_string(), + session_id.as_ref(), + custom_headers.as_ref(), + ) + .await + { + Ok(response) => { + // if session_id_clone.read().await.is_none() { + let session_id = response + .headers() + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + let mut guard = self.session_id.write().await; + *guard = session_id; + response + } + + Err(error) => { + tracing::error!("Failed to connect to MCP endpoint: {error}"); + return Err(error); + } + }; + + // return if status code != 200 and no result is expected + if response.status() != StatusCode::OK { + return Ok(()); + } + + let response_type = validate_response_type(&response).await?; + + // Handle non-streaming JSON response + if response_type == ResponseType::Json { + return match response.bytes().await { + Ok(bytes) => { + // Send the message + self.read_tx.send(bytes).await.map_err(|_| { + tracing::error!("Readable stream closed, shutting down MCP task"); + TransportError::SendFailure( + "Failed to send message: channel closed or full".to_string(), + ) + })?; + + // Send the newline + self.read_tx + .send(Bytes::from_static(b"\n")) + .await + .map_err(|_| { + tracing::error!( + "Failed to send newline, channel may be closed or full" + ); + TransportError::SendFailure( + "Failed to send newline: channel closed or full".to_string(), + ) + })?; + + Ok(()) + } + Err(error) => Err(error.into()), + }; + } + + // Create a stream from the response bytes + let mut stream = response.bytes_stream(); + + // Inner loop for processing stream chunks + loop { + let next_chunk = tokio::select! { + // Wait for the next stream chunk + chunk = stream.next() => { + match chunk { + Some(chunk) => chunk, + None => { + // stream ended, unline SSE, so no retry attempt here needed to reconnect + return Err(TransportError::Internal("Stream has ended.".to_string())); + } + } + } + // Wait for cancellation + _ = cancellation_token.cancelled() => { + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + }; + + match next_chunk { + Ok(bytes) => { + let events = stream_parser.process_new_chunk(bytes); + + if !events.is_empty() { + for event in events { + if let Some(bytes) = event.data { + if event.id.is_some() { + _last_event_id = event.id.clone(); + } + + if self.read_tx.send(bytes).await.is_err() { + tracing::error!( + "Readable stream closed, shutting down MCP task" + ); + return Err(TransportError::SendFailure( + "Failed to send message: stream closed".to_string(), + )); + } + } + } + // break after receiving the message(s) + return Ok(()); + } + } + Err(error) => { + tracing::error!("Error reading stream: {error}"); + return Err(error.into()); + } + } + } + } + + pub(crate) async fn make_standalone_stream_connection( + &self, + cancellation_token: &CancellationToken, + custom_headers: &Option, + last_event_id: Option, + ) -> TransportResult { + let mut retry_count = 0; + let session_id = self.session_id.read().await.clone(); + + let headers = if let Some(event_id) = last_event_id.as_ref() { + let mut headers = HeaderMap::new(); + if let Some(custom) = custom_headers { + headers.extend(custom.iter().map(|(k, v)| (k.clone(), v.clone()))); + } + if let Ok(event_id_value) = HeaderValue::from_str(event_id) { + headers.insert(MCP_LAST_EVENT_ID_HEADER, event_id_value); + } + &Some(headers) + } else { + custom_headers + }; + + loop { + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::info!("Standalone StreamableHttp cancelled."); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + match http_get( + &self.client, + &self.mcp_url, + session_id.as_ref(), + headers.as_ref(), + ) + .await + { + Ok(response) => { + let is_event_stream = validate_response_type(&response) + .await + .is_ok_and(|response_type| response_type == ResponseType::EventStream); + + if !is_event_stream { + let message = + "SSE stream response returned an unexpected Content-Type.".to_string(); + tracing::warn!("{message}"); + return Err(TransportError::FailedToOpenSSEStream(message)); + } + + return Ok(response); + } + + Err(error) => { + match error { + crate::error::TransportError::HttpConnection(_) => { + // A reqwest::Error happened, we do not return ans instead retry the operation + } + crate::error::TransportError::Http(status_code) => match status_code { + StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED => { + return Err(crate::error::TransportError::FailedToOpenSSEStream( + format!("Not supported (code: {status_code})"), + )); + } + other => { + tracing::warn!( + "Failed to open SSE stream: {error} (code: {other})" + ); + } + }, + error => { + return Err(error); // return the error where the retry wont help + } + } + + if retry_count >= self.max_retries { + tracing::warn!("Max retries ({}) reached, giving up", self.max_retries); + return Err(error); + } + retry_count += 1; + time::sleep(self.retry_delay).await; + continue; + } + }; + } + } + + pub(crate) async fn run_standalone( + &mut self, + cancellation_token: &CancellationToken, + custom_headers: &Option, + response: Response, + ) -> TransportResult<()> { + let mut retry_count = 0; + let mut stream_parser = SseParser::new(); + let mut _last_event_id: Option = None; + + let mut response = Some(response); + + // Main loop for reconnection attempts + loop { + // Check for cancellation before attempting connection + if cancellation_token.is_cancelled() { + tracing::debug!("Standalone StreamableHttp cancelled."); + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + + // use initially passed response, otherwise try to make a new sse connection + let response = match response.take() { + Some(response) => response, + None => { + tracing::debug!( + "Reconnecting to SSE stream... (try {} of {})", + retry_count, + self.max_retries + ); + self.make_standalone_stream_connection( + cancellation_token, + custom_headers, + _last_event_id.clone(), + ) + .await? + } + }; + + // Create a stream from the response bytes + let mut stream = response.bytes_stream(); + + // Inner loop for processing stream chunks + loop { + let next_chunk = tokio::select! { + // Wait for the next stream chunk + chunk = stream.next() => { + match chunk { + Some(chunk) => chunk, + None => { + // stream ended, unline SSE, so no retry attempt here needed to reconnect + return Err(TransportError::Internal("Stream has ended.".to_string())); + } + } + } + // Wait for cancellation + _ = cancellation_token.cancelled() => { + return Err(TransportError::Cancelled( + crate::utils::CancellationError::ChannelClosed, + )); + } + }; + + match next_chunk { + Ok(bytes) => { + let events = stream_parser.process_new_chunk(bytes); + + if !events.is_empty() { + for event in events { + if let Some(bytes) = event.data { + if event.id.is_some() { + _last_event_id = event.id.clone(); + } + + if self.read_tx.send(bytes).await.is_err() { + tracing::error!( + "Readable stream closed, shutting down MCP task" + ); + return Err(TransportError::SendFailure( + "Failed to send message: stream closed".to_string(), + )); + } + } + } + } + retry_count = 0; // Reset retry count on successful chunk + } + Err(error) => { + if retry_count >= self.max_retries { + tracing::error!("Error reading stream: {error}"); + tracing::warn!("Max retries ({}) reached, giving up", self.max_retries); + return Err(error.into()); + } + + tracing::debug!( + "The standalone SSE stream encountered an error: '{}'", + error + ); + retry_count += 1; + time::sleep(self.retry_delay).await; + break; // Break inner loop to reconnect + } + } + } + } + } +} diff --git a/crates/rust-mcp-transport/tests/check_imports.rs b/crates/rust-mcp-transport/tests/check_imports.rs index cda7d0c..207644e 100644 --- a/crates/rust-mcp-transport/tests/check_imports.rs +++ b/crates/rust-mcp-transport/tests/check_imports.rs @@ -37,13 +37,12 @@ mod tests { // Check for `use rust_mcp_schema` if content.contains("use rust_mcp_schema") { errors.push(format!( - "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", - abs_path + "File {abs_path} contains `use rust_mcp_schema`. Use `use crate::schema` instead." )); } } Err(e) => { - errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + errors.push(format!("Failed to read file `{path_str}`: {e}")); } } } diff --git a/development.md b/development.md index e3673cc..e17dd17 100644 --- a/development.md +++ b/development.md @@ -33,14 +33,14 @@ Build and run instructions are available in their respective README.md files. You can run examples by passing the example project name to Cargo using the `-p` argument, like this: ```sh -cargo run -p simple-mcp-client +cargo run -p simple-mcp-client-stdio ``` -You can build the examples in a similar way. The following command builds the project and generates the binary at `target/release/hello-world-mcp-server`: +You can build the examples in a similar way. The following command builds the project and generates the binary at `target/release/hello-world-mcp-server-stdio`: ```sh -cargo build -p hello-world-mcp-server --release +cargo build -p hello-world-mcp-server-stdio --release ``` ## Code Formatting diff --git a/doc/getting-started-mcp-server.md b/doc/getting-started-mcp-server.md index 358b1b4..6fac258 100644 --- a/doc/getting-started-mcp-server.md +++ b/doc/getting-started-mcp-server.md @@ -160,7 +160,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, _request: ListToolsRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -173,7 +173,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = diff --git a/examples/hello-world-mcp-server-core/.gitignore b/examples/hello-world-mcp-server-stdio-core/.gitignore similarity index 100% rename from examples/hello-world-mcp-server-core/.gitignore rename to examples/hello-world-mcp-server-stdio-core/.gitignore diff --git a/examples/hello-world-mcp-server-core/Cargo.toml b/examples/hello-world-mcp-server-stdio-core/Cargo.toml similarity index 83% rename from examples/hello-world-mcp-server-core/Cargo.toml rename to examples/hello-world-mcp-server-stdio-core/Cargo.toml index bbab301..14eb904 100644 --- a/examples/hello-world-mcp-server-core/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-mcp-server-core" -version = "0.1.22" +name = "hello-world-mcp-server-stdio-core" +version = "0.1.19" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/hello-world-mcp-server-core/README.md b/examples/hello-world-mcp-server-stdio-core/README.md similarity index 81% rename from examples/hello-world-mcp-server-core/README.md rename to examples/hello-world-mcp-server-stdio-core/README.md index af9d703..cf57884 100644 --- a/examples/hello-world-mcp-server-core/README.md +++ b/examples/hello-world-mcp-server-stdio-core/README.md @@ -23,14 +23,14 @@ cd rust-mcp-sdk 2. Build the project: ```bash -cargo build -p hello-world-mcp-server-core --release +cargo build -p hello-world-mcp-server-stdio-core --release ``` -3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-core` +3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-stdio-core` You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. ```bash -npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-core +npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-stdio-core ``` ``` @@ -41,4 +41,4 @@ Starting MCP inspector... Here you can see it in action : -![hello-world-mcp-server-core]![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) +![hello-world-mcp-server-stdio-core]![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-stdio-core/src/handler.rs similarity index 97% rename from examples/hello-world-mcp-server-core/src/handler.rs rename to examples/hello-world-mcp-server-stdio-core/src/handler.rs index f0bdefe..acf55ea 100644 --- a/examples/hello-world-mcp-server-core/src/handler.rs +++ b/examples/hello-world-mcp-server-stdio-core/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -90,7 +92,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -99,7 +101,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_error( &self, error: &RpcError, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-mcp-server-core/src/main.rs b/examples/hello-world-mcp-server-stdio-core/src/main.rs similarity index 100% rename from examples/hello-world-mcp-server-core/src/main.rs rename to examples/hello-world-mcp-server-stdio-core/src/main.rs diff --git a/examples/hello-world-mcp-server-core/src/tools.rs b/examples/hello-world-mcp-server-stdio-core/src/tools.rs similarity index 100% rename from examples/hello-world-mcp-server-core/src/tools.rs rename to examples/hello-world-mcp-server-stdio-core/src/tools.rs diff --git a/examples/hello-world-mcp-server/Cargo.toml b/examples/hello-world-mcp-server-stdio/Cargo.toml similarity index 85% rename from examples/hello-world-mcp-server/Cargo.toml rename to examples/hello-world-mcp-server-stdio/Cargo.toml index 63a54af..9d15be3 100644 --- a/examples/hello-world-mcp-server/Cargo.toml +++ b/examples/hello-world-mcp-server-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-mcp-server" -version = "0.1.31" +name = "hello-world-mcp-server-stdio" +version = "0.1.28" edition = "2021" publish = false license = "MIT" @@ -10,8 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", - "hyper-server", - "ssl", + "stdio", "2025_06_18", ] } diff --git a/examples/hello-world-mcp-server/README.md b/examples/hello-world-mcp-server-stdio/README.md similarity index 84% rename from examples/hello-world-mcp-server/README.md rename to examples/hello-world-mcp-server-stdio/README.md index 33a62af..9e0bdda 100644 --- a/examples/hello-world-mcp-server/README.md +++ b/examples/hello-world-mcp-server-stdio/README.md @@ -22,14 +22,14 @@ cd rust-mcp-sdk 2. Build the project: ```bash -cargo build -p hello-world-mcp-server --release +cargo build -p hello-world-mcp-server-stdio --release ``` -3. After building the project, the binary will be located at `target/release/hello-world-mcp-server` +3. After building the project, the binary will be located at `target/release/hello-world-mcp-server-stdio` You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. ```bash -npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server +npx -y @modelcontextprotocol/inspector ./target/release/hello-world-mcp-server-stdio ``` ``` @@ -40,4 +40,4 @@ Starting MCP inspector... Here you can see it in action : -![hello-world-mcp-server](../../assets/examples/hello-world-mcp-server.gif) +![hello-world-mcp-server-stdio](../../assets/examples/hello-world-mcp-server.gif) diff --git a/examples/hello-world-mcp-server/src/handler.rs b/examples/hello-world-mcp-server-stdio/src/handler.rs similarity index 94% rename from examples/hello-world-mcp-server/src/handler.rs rename to examples/hello-world-mcp-server-stdio/src/handler.rs index d9741a0..47925a0 100644 --- a/examples/hello-world-mcp-server/src/handler.rs +++ b/examples/hello-world-mcp-server-stdio/src/handler.rs @@ -4,6 +4,7 @@ use rust_mcp_sdk::schema::{ ListToolsResult, RpcError, }; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; +use std::sync::Arc; use crate::tools::GreetingTools; @@ -20,7 +21,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +34,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = diff --git a/examples/hello-world-mcp-server/src/main.rs b/examples/hello-world-mcp-server-stdio/src/main.rs similarity index 92% rename from examples/hello-world-mcp-server/src/main.rs rename to examples/hello-world-mcp-server-stdio/src/main.rs index 00ca6a7..98ff6f0 100644 --- a/examples/hello-world-mcp-server/src/main.rs +++ b/examples/hello-world-mcp-server-stdio/src/main.rs @@ -1,6 +1,8 @@ mod handler; mod tools; +use std::sync::Arc; + use handler::MyServerHandler; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, @@ -40,7 +42,8 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; // STEP 4: create a MCP server - let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); + let server: Arc = + server_runtime::create_server(server_details, transport, handler); // STEP 5: Start the server if let Err(start_error) = server.start().await { diff --git a/examples/hello-world-mcp-server/src/tools.rs b/examples/hello-world-mcp-server-stdio/src/tools.rs similarity index 100% rename from examples/hello-world-mcp-server/src/tools.rs rename to examples/hello-world-mcp-server-stdio/src/tools.rs diff --git a/examples/hello-world-server-core-streamable-http/.gitignore b/examples/hello-world-server-streamable-http-core/.gitignore similarity index 100% rename from examples/hello-world-server-core-streamable-http/.gitignore rename to examples/hello-world-server-streamable-http-core/.gitignore diff --git a/examples/hello-world-server-core-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http-core/Cargo.toml similarity index 84% rename from examples/hello-world-server-core-streamable-http/Cargo.toml rename to examples/hello-world-server-streamable-http-core/Cargo.toml index 99d1011..a762058 100644 --- a/examples/hello-world-server-core-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "hello-world-server-core-streamable-http" -version = "0.1.22" +name = "hello-world-server-streamable-http-core" +version = "0.1.19" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "streamable-http", "hyper-server", "2025_06_18", ] } diff --git a/examples/hello-world-server-core-streamable-http/README.md b/examples/hello-world-server-streamable-http-core/README.md similarity index 95% rename from examples/hello-world-server-core-streamable-http/README.md rename to examples/hello-world-server-streamable-http-core/README.md index cd37623..49af2c2 100644 --- a/examples/hello-world-server-core-streamable-http/README.md +++ b/examples/hello-world-server-streamable-http-core/README.md @@ -37,7 +37,7 @@ cd rust-mcp-sdk 2. Build and start the server: ```bash -cargo run -p hello-world-server-core-streamable-http --release +cargo run -p hello-world-server-streamable-http-core --release ``` By default, both the Streamable HTTP and SSE endpoints are displayed in the terminal: @@ -65,4 +65,4 @@ Then , to test the server, visit one of the following URLs based on the desired Here you can see it in action : -![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-streamable-http-core.gif) diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http-core/src/handler.rs similarity index 97% rename from examples/hello-world-server-core-streamable-http/src/handler.rs rename to examples/hello-world-server-streamable-http-core/src/handler.rs index 1c69e8c..7941075 100644 --- a/examples/hello-world-server-core-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http-core/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -95,7 +97,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -104,7 +106,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_error( &self, error: &RpcError, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-server-core-streamable-http/src/main.rs b/examples/hello-world-server-streamable-http-core/src/main.rs similarity index 100% rename from examples/hello-world-server-core-streamable-http/src/main.rs rename to examples/hello-world-server-streamable-http-core/src/main.rs diff --git a/examples/hello-world-server-core-streamable-http/src/tools.rs b/examples/hello-world-server-streamable-http-core/src/tools.rs similarity index 100% rename from examples/hello-world-server-core-streamable-http/src/tools.rs rename to examples/hello-world-server-streamable-http-core/src/tools.rs diff --git a/examples/hello-world-server-streamable-http/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml index df4296d..17a87c8 100644 --- a/examples/hello-world-server-streamable-http/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "server", "macros", + "streamable-http", "hyper-server", "2025_06_18", ] } diff --git a/examples/hello-world-server-streamable-http/README.md b/examples/hello-world-server-streamable-http/README.md index ac56a86..7e3f3b6 100644 --- a/examples/hello-world-server-streamable-http/README.md +++ b/examples/hello-world-server-streamable-http/README.md @@ -66,4 +66,4 @@ Then , to test the server, visit one of the following URLs based on the desired Here you can see it in action : -![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-streamable-http-core.gif) diff --git a/examples/hello-world-server-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http/src/handler.rs index b8ce355..3939d86 100644 --- a/examples/hello-world-server-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http/src/handler.rs @@ -1,12 +1,11 @@ +use crate::tools::GreetingTools; use async_trait::async_trait; use rust_mcp_sdk::schema::{ schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, RpcError, }; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; - -use crate::tools::GreetingTools; - +use std::sync::Arc; // Custom Handler to handle MCP Messages pub struct MyServerHandler; @@ -20,7 +19,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +32,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = @@ -45,6 +44,4 @@ impl ServerHandler for MyServerHandler { GreetingTools::SayGoodbyeTool(say_goodbye_tool) => say_goodbye_tool.call_tool(), } } - - async fn on_server_started(&self, runtime: &dyn McpServer) {} } diff --git a/examples/simple-mcp-client-core-sse/Cargo.toml b/examples/simple-mcp-client-sse-core/Cargo.toml similarity index 88% rename from examples/simple-mcp-client-core-sse/Cargo.toml rename to examples/simple-mcp-client-sse-core/Cargo.toml index 0e32790..25dcd7d 100644 --- a/examples/simple-mcp-client-core-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client-core-sse" -version = "0.1.22" +name = "simple-mcp-client-sse-core" +version = "0.1.19" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "sse", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-core-sse/README.md b/examples/simple-mcp-client-sse-core/README.md similarity index 97% rename from examples/simple-mcp-client-core-sse/README.md rename to examples/simple-mcp-client-sse-core/README.md index e7e10d2..a0852fb 100644 --- a/examples/simple-mcp-client-core-sse/README.md +++ b/examples/simple-mcp-client-sse-core/README.md @@ -32,7 +32,7 @@ npx @modelcontextprotocol/server-everything sse 2. Open a new terminal and run the project with: ```bash -cargo run -p simple-mcp-client-core-sse +cargo run -p simple-mcp-client-sse-core ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client-core-sse/src/handler.rs b/examples/simple-mcp-client-sse-core/src/handler.rs similarity index 100% rename from examples/simple-mcp-client-core-sse/src/handler.rs rename to examples/simple-mcp-client-sse-core/src/handler.rs diff --git a/examples/simple-mcp-client-core-sse/src/inquiry_utils.rs b/examples/simple-mcp-client-sse-core/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client-core-sse/src/inquiry_utils.rs rename to examples/simple-mcp-client-sse-core/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client-core-sse/src/main.rs b/examples/simple-mcp-client-sse-core/src/main.rs similarity index 99% rename from examples/simple-mcp-client-core-sse/src/main.rs rename to examples/simple-mcp-client-sse-core/src/main.rs index 459f9ba..be8279b 100644 --- a/examples/simple-mcp-client-core-sse/src/main.rs +++ b/examples/simple-mcp-client-sse-core/src/main.rs @@ -44,6 +44,7 @@ async fn main() -> SdkResult<()> { // STEP 3: instantiate our custom handler that is responsible for handling MCP messages let handler = MyClientHandler {}; + // STEP 4: create the client let client = client_runtime_core::create_client(client_details, transport, handler); // STEP 5: start the MCP client diff --git a/examples/simple-mcp-client-sse/Cargo.toml b/examples/simple-mcp-client-sse/Cargo.toml index 14fd96b..bf7174d 100644 --- a/examples/simple-mcp-client-sse/Cargo.toml +++ b/examples/simple-mcp-client-sse/Cargo.toml @@ -9,6 +9,8 @@ license = "MIT" [dependencies] rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", + "sse", + "streamable-http", "macros", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-sse/src/main.rs b/examples/simple-mcp-client-sse/src/main.rs index ce8850a..0a76caa 100644 --- a/examples/simple-mcp-client-sse/src/main.rs +++ b/examples/simple-mcp-client-sse/src/main.rs @@ -15,7 +15,9 @@ use std::sync::Arc; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; -const MCP_SERVER_URL: &str = "http://localhost:3001/sse"; +// Connect to a server started with the following command: +// npx @modelcontextprotocol/server-everything sse +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/sse"; #[tokio::main] async fn main() -> SdkResult<()> { @@ -44,6 +46,7 @@ async fn main() -> SdkResult<()> { // STEP 3: instantiate our custom handler that is responsible for handling MCP messages let handler = MyClientHandler {}; + // STEP 4: create the client let client = client_runtime::create_client(client_details, transport, handler); // STEP 5: start the MCP client @@ -57,6 +60,7 @@ async fn main() -> SdkResult<()> { let utils = InquiryUtils { client: Arc::clone(&client), }; + // Display server information (name and version) utils.print_server_info(); @@ -78,8 +82,11 @@ async fn main() -> SdkResult<()> { // Call add tool, and print the result utils.call_add_tool(100, 25).await?; - // Set the log level - utils.client.set_logging_level(LoggingLevel::Debug).await?; + // // Set the log level + match utils.client.set_logging_level(LoggingLevel::Debug).await { + Ok(_) => println!("Log level is set to \"Debug\""), + Err(err) => eprintln!("Error setting the Log level : {err}"), + } // Send 3 pings to the server, with a 2-second interval between each ping. utils.ping_n_times(3).await; diff --git a/examples/simple-mcp-client-core/Cargo.toml b/examples/simple-mcp-client-stdio-core/Cargo.toml similarity index 86% rename from examples/simple-mcp-client-core/Cargo.toml rename to examples/simple-mcp-client-stdio-core/Cargo.toml index 0dacc2d..6d95cf6 100644 --- a/examples/simple-mcp-client-core/Cargo.toml +++ b/examples/simple-mcp-client-stdio-core/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client-core" -version = "0.1.31" +name = "simple-mcp-client-stdio-core" +version = "0.1.28" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/simple-mcp-client-core/README.md b/examples/simple-mcp-client-stdio-core/README.md similarity index 97% rename from examples/simple-mcp-client-core/README.md rename to examples/simple-mcp-client-stdio-core/README.md index 52d8074..f3258aa 100644 --- a/examples/simple-mcp-client-core/README.md +++ b/examples/simple-mcp-client-stdio-core/README.md @@ -24,7 +24,7 @@ cd rust-mcp-sdk 2. RUn the project: ```bash -cargo run -p simple-mcp-client-core +cargo run -p simple-mcp-client-stdio-core ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client-core/src/handler.rs b/examples/simple-mcp-client-stdio-core/src/handler.rs similarity index 100% rename from examples/simple-mcp-client-core/src/handler.rs rename to examples/simple-mcp-client-stdio-core/src/handler.rs diff --git a/examples/simple-mcp-client-core/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client-core/src/inquiry_utils.rs rename to examples/simple-mcp-client-stdio-core/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client-core/src/main.rs b/examples/simple-mcp-client-stdio-core/src/main.rs similarity index 100% rename from examples/simple-mcp-client-core/src/main.rs rename to examples/simple-mcp-client-stdio-core/src/main.rs diff --git a/examples/simple-mcp-client/Cargo.toml b/examples/simple-mcp-client-stdio/Cargo.toml similarity index 87% rename from examples/simple-mcp-client/Cargo.toml rename to examples/simple-mcp-client-stdio/Cargo.toml index 9599c46..3597105 100644 --- a/examples/simple-mcp-client/Cargo.toml +++ b/examples/simple-mcp-client-stdio/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "simple-mcp-client" -version = "0.1.31" +name = "simple-mcp-client-stdio" +version = "0.1.28" edition = "2021" publish = false license = "MIT" @@ -10,6 +10,7 @@ license = "MIT" rust-mcp-sdk = { workspace = true, default-features = false, features = [ "client", "macros", + "stdio", "2025_06_18", ] } diff --git a/examples/simple-mcp-client/README.md b/examples/simple-mcp-client-stdio/README.md similarity index 97% rename from examples/simple-mcp-client/README.md rename to examples/simple-mcp-client-stdio/README.md index c56a933..be17f02 100644 --- a/examples/simple-mcp-client/README.md +++ b/examples/simple-mcp-client-stdio/README.md @@ -24,7 +24,7 @@ cd rust-mcp-sdk 2. RUn the project: ```bash -cargo run -p simple-mcp-client +cargo run -p simple-mcp-client-stdio ``` You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. diff --git a/examples/simple-mcp-client/src/handler.rs b/examples/simple-mcp-client-stdio/src/handler.rs similarity index 100% rename from examples/simple-mcp-client/src/handler.rs rename to examples/simple-mcp-client-stdio/src/handler.rs diff --git a/examples/simple-mcp-client/src/inquiry_utils.rs b/examples/simple-mcp-client-stdio/src/inquiry_utils.rs similarity index 100% rename from examples/simple-mcp-client/src/inquiry_utils.rs rename to examples/simple-mcp-client-stdio/src/inquiry_utils.rs diff --git a/examples/simple-mcp-client/src/main.rs b/examples/simple-mcp-client-stdio/src/main.rs similarity index 100% rename from examples/simple-mcp-client/src/main.rs rename to examples/simple-mcp-client-stdio/src/main.rs diff --git a/examples/simple-mcp-client-streamable-http-core/Cargo.toml b/examples/simple-mcp-client-streamable-http-core/Cargo.toml new file mode 100644 index 0000000..68356e1 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "simple-mcp-client-streamable-http-core" +version = "0.1.0" +edition = "2021" +publish = false +license = "MIT" + + +[dependencies] +rust-mcp-sdk = { workspace = true, default-features = false, features = [ + "client", + "macros", + "streamable-http", + "2025_06_18", +] } + +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +colored = "3.0.0" +tracing-subscriber = { workspace = true } +tracing = { workspace = true } + + +[lints] +workspace = true diff --git a/examples/simple-mcp-client-streamable-http-core/README.md b/examples/simple-mcp-client-streamable-http-core/README.md new file mode 100644 index 0000000..a0852fb --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/README.md @@ -0,0 +1,40 @@ +# Simple MCP Client Core (SSE) + +This is a simple MCP (Model Context Protocol) client implemented with the rust-mcp-sdk, dmeonstrating SSE transport, showcasing fundamental MCP client operations like fetching the MCP server's capabilities and executing a tool call. + +## Overview + +This project demonstrates a basic MCP client implementation, showcasing the features of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). + +This example connects to a running instance of the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, which has already been started with the sse flag. + +It displays the server name and version, outlines the server's capabilities, and provides a list of available tools, prompts, templates, resources, and more offered by the server. Additionally, it will execute a tool call by utilizing the add tool from the server-everything package to sum two numbers and output the result. + +> Note that @modelcontextprotocol/server-everything is an npm package, so you must have Node.js and npm installed on your system, as this example attempts to start it. + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2- Start `@modelcontextprotocol/server-everything` with SSE argument: + +```bash +npx @modelcontextprotocol/server-everything sse +``` + +> It launches the server, making everything accessible via the SSE transport at http://localhost:3001/sse. + +2. Open a new terminal and run the project with: + +```bash +cargo run -p simple-mcp-client-sse-core +``` + +You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. + + diff --git a/examples/simple-mcp-client-streamable-http-core/src/handler.rs b/examples/simple-mcp-client-streamable-http-core/src/handler.rs new file mode 100644 index 0000000..ab86e9e --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/handler.rs @@ -0,0 +1,72 @@ +use async_trait::async_trait; +use rust_mcp_sdk::schema::{ + self, + schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, + RpcError, ServerRequest, +}; +use rust_mcp_sdk::{mcp_client::ClientHandlerCore, McpClient}; +pub struct MyClientHandler; + +// To check out a list of all the methods in the trait that you can override, take a look at +// https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs + +#[async_trait] +impl ClientHandlerCore for MyClientHandler { + async fn handle_request( + &self, + request: RequestFromServer, + _runtime: &dyn McpClient, + ) -> std::result::Result { + match request { + RequestFromServer::ServerRequest(server_request) => match server_request { + ServerRequest::PingRequest(_) => { + return Ok(schema::Result::default().into()); + } + ServerRequest::CreateMessageRequest(_create_message_request) => { + Err(RpcError::internal_error().with_message( + "CreateMessageRequest handler is not implemented".to_string(), + )) + } + ServerRequest::ListRootsRequest(_list_roots_request) => { + Err(RpcError::internal_error() + .with_message("ListRootsRequest handler is not implemented".to_string())) + } + ServerRequest::ElicitRequest(_elicit_request) => Err(RpcError::internal_error() + .with_message("ElicitRequest handler is not implemented".to_string())), + }, + RequestFromServer::CustomRequest(_value) => Err(RpcError::internal_error() + .with_message("CustomRequest handler is not implemented".to_string())), + } + } + + async fn handle_notification( + &self, + notification: NotificationFromServer, + _runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + if let NotificationFromServer::ServerNotification( + schema::ServerNotification::LoggingMessageNotification(logging_message_notification), + ) = notification + { + println!( + "Notification from server: {}", + logging_message_notification.params.data + ); + } else { + println!( + "A {} notification received from the server", + notification.method() + ); + }; + + Ok(()) + } + + async fn handle_error( + &self, + _error: &RpcError, + _runtime: &dyn McpClient, + ) -> std::result::Result<(), RpcError> { + Err(RpcError::internal_error().with_message("handle_error() Not implemented".to_string())) + } +} diff --git a/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs new file mode 100644 index 0000000..a8e7c9c --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/inquiry_utils.rs @@ -0,0 +1,222 @@ +//! This module contains utility functions for querying and displaying server capabilities. + +use colored::Colorize; +use rust_mcp_sdk::schema::CallToolRequestParams; +use rust_mcp_sdk::McpClient; +use rust_mcp_sdk::{error::SdkResult, mcp_client::ClientRuntime}; +use serde_json::json; +use std::io::Write; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; + +const GREY_COLOR: (u8, u8, u8) = (90, 90, 90); +const HEADER_SIZE: usize = 31; + +pub struct InquiryUtils { + pub client: Arc, +} + +impl InquiryUtils { + fn print_header(&self, title: &str) { + let pad = ((HEADER_SIZE as f32 / 2.0) + (title.len() as f32 / 2.0)).floor() as usize; + println!("\n{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + println!("{:>pad$}", title.custom_color(GREY_COLOR)); + println!("{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + } + + fn print_list(&self, list_items: Vec<(String, String)>) { + list_items.iter().enumerate().for_each(|(index, item)| { + println!("{}. {}: {}", index + 1, item.0.yellow(), item.1.cyan(),); + }); + } + + pub fn print_server_info(&self) { + self.print_header("Server info"); + let server_version = self.client.server_version().unwrap(); + println!("{} {}", "Server name:".bold(), server_version.name.cyan()); + println!( + "{} {}", + "Server version:".bold(), + server_version.version.cyan() + ); + } + + pub fn print_server_capabilities(&self) { + self.print_header("Capabilities"); + let capability_vec = [ + ("tools", self.client.server_has_tools()), + ("prompts", self.client.server_has_prompts()), + ("resources", self.client.server_has_resources()), + ("logging", self.client.server_supports_logging()), + ("experimental", self.client.server_has_experimental()), + ]; + + capability_vec.iter().for_each(|(tool_name, opt)| { + println!( + "{}: {}", + tool_name.bold(), + opt.map(|b| if b { "Yes" } else { "No" }) + .unwrap_or("Unknown") + .cyan() + ); + }); + } + + pub async fn print_tool_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support tools + if !self.client.server_has_tools().unwrap_or(false) { + return Ok(()); + } + + let tools = self.client.list_tools(None).await?; + self.print_header("Tools"); + self.print_list( + tools + .tools + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_prompts_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support prompts + if !self.client.server_has_prompts().unwrap_or(false) { + return Ok(()); + } + + let prompts = self.client.list_prompts(None).await?; + + self.print_header("Prompts"); + self.print_list( + prompts + .prompts + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn print_resource_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let resources = self.client.list_resources(None).await?; + + self.print_header("Resources"); + + self.print_list( + resources + .resources + .iter() + .map(|item| { + ( + item.name.clone(), + format!( + "( uri: {} , mime: {}", + item.uri, + item.mime_type.as_ref().unwrap_or(&"?".to_string()), + ), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_resource_templates(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let templates = self.client.list_resource_templates(None).await?; + + self.print_header("Resource Templates"); + + self.print_list( + templates + .resource_templates + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn call_add_tool(&self, a: i64, b: i64) -> SdkResult<()> { + // Invoke the "add" tool with 100 and 25 as arguments, and display the result + println!( + "{}", + format!("\nCalling the \"add\" tool with {a} and {b} ...").magenta() + ); + + // Create a `Map` to represent the tool parameters + let params = json!({ + "a": a, + "b": b + }) + .as_object() + .unwrap() + .clone(); + + // invoke the tool + let result = self + .client + .call_tool(CallToolRequestParams { + name: "add".to_string(), + arguments: Some(params), + }) + .await?; + + // Retrieve the result content and print it to the stdout + let result_content = result.content.first().unwrap().as_text_content()?; + println!("{}", result_content.text.green()); + + Ok(()) + } + + pub async fn ping_n_times(&self, n: i32) { + let max_pings = n; + println!(); + for ping_index in 1..=max_pings { + print!("Ping the server ({ping_index} out of {max_pings})..."); + std::io::stdout().flush().unwrap(); + let ping_result = self.client.ping(None).await; + print!( + "\rPing the server ({} out of {}) : {}", + ping_index, + max_pings, + if ping_result.is_ok() { + "success".bright_green() + } else { + "failed".bright_red() + } + ); + println!(); + sleep(Duration::from_secs(2)).await; + } + } +} diff --git a/examples/simple-mcp-client-streamable-http-core/src/main.rs b/examples/simple-mcp-client-streamable-http-core/src/main.rs new file mode 100644 index 0000000..e1a5849 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http-core/src/main.rs @@ -0,0 +1,95 @@ +mod handler; +mod inquiry_utils; + +use handler::MyClientHandler; + +use inquiry_utils::InquiryUtils; +use rust_mcp_sdk::error::SdkResult; +use rust_mcp_sdk::mcp_client::client_runtime_core; +use rust_mcp_sdk::schema::{ + ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, + LATEST_PROTOCOL_VERSION, +}; +use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +// Assuming @modelcontextprotocol/server-everything is launched with streamableHttp argument and listening on port 3001 +const MCP_SERVER_URL: &str = "http://127.0.0.1:3001/mcp"; + +#[tokio::main] +async fn main() -> SdkResult<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Step1 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-core-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (Core,SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + // STEP 3: instantiate our custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 4: create the client + let client = + client_runtime_core::with_transport_options(client_details, transport_options, handler); + + // STEP 5: start the MCP client + client.clone().start().await?; + + // You can utilize the client and its methods to interact with the MCP Server. + // The following demonstrates how to use client methods to retrieve server information, + // and print them in the terminal, set the log level, invoke a tool, and more. + + // Create a struct with utility functions for demonstration purpose, to utilize different client methods and display the information. + let utils = InquiryUtils { + client: Arc::clone(&client), + }; + // Display server information (name and version) + utils.print_server_info(); + + // Display server capabilities + utils.print_server_capabilities(); + + // Display the list of tools available on the server + utils.print_tool_list().await?; + + // Display the list of prompts available on the server + utils.print_prompts_list().await?; + + // Display the list of resources available on the server + utils.print_resource_list().await?; + + // Display the list of resource templates available on the server + utils.print_resource_templates().await?; + + // Call add tool, and print the result + utils.call_add_tool(100, 25).await?; + + // Set the log level + utils.client.set_logging_level(LoggingLevel::Debug).await?; + + // Send 3 pings to the server, with a 2-second interval between each ping. + utils.ping_n_times(3).await; + client.shut_down().await?; + + Ok(()) +} diff --git a/examples/simple-mcp-client-streamable-http/Cargo.toml b/examples/simple-mcp-client-streamable-http/Cargo.toml new file mode 100644 index 0000000..0638aab --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "simple-mcp-client-streamable-http" +version = "0.1.0" +edition = "2021" +publish = false +license = "MIT" + + +[dependencies] +rust-mcp-sdk = { workspace = true, default-features = false, features = [ + "client", + "streamable-http", + "macros", + "2025_06_18", +] } + +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +thiserror = { workspace = true } +colored = "3.0.0" +tracing-subscriber = { workspace = true } +tracing = { workspace = true } + + +[lints] +workspace = true diff --git a/examples/simple-mcp-client-streamable-http/README.md b/examples/simple-mcp-client-streamable-http/README.md new file mode 100644 index 0000000..5b4488e --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/README.md @@ -0,0 +1,40 @@ +# Simple MCP Client (SSE) + +This is a simple MCP (Model Context Protocol) client implemented with the rust-mcp-sdk, dmeonstrating SSE transport, showcasing fundamental MCP client operations like fetching the MCP server's capabilities and executing a tool call. + +## Overview + +This project demonstrates a basic MCP client implementation, showcasing the features of the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). + +This example connects to a running instance of the [@modelcontextprotocol/server-everything](https://www.npmjs.com/package/@modelcontextprotocol/server-everything) server, which has already been started with the sse flag. + +It displays the server name and version, outlines the server's capabilities, and provides a list of available tools, prompts, templates, resources, and more offered by the server. Additionally, it will execute a tool call by utilizing the add tool from the server-everything package to sum two numbers and output the result. + +> Note that @modelcontextprotocol/server-everything is an npm package, so you must have Node.js and npm installed on your system, as this example attempts to start it. + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2- Start `@modelcontextprotocol/server-everything` with SSE argument: + +```bash +npx @modelcontextprotocol/server-everything sse +``` + +> It launches the server, making everything accessible via the SSE transport at http://localhost:3001/sse. + +2. Open a new terminal and run the project with: + +```bash +cargo run -p simple-mcp-client-sse +``` + +You can observe a sample output of the project; however, your results may vary slightly depending on the version of the MCP Server in use when you run it. + + diff --git a/examples/simple-mcp-client-streamable-http/src/handler.rs b/examples/simple-mcp-client-streamable-http/src/handler.rs new file mode 100644 index 0000000..19360f6 --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/handler.rs @@ -0,0 +1,10 @@ +use async_trait::async_trait; +use rust_mcp_sdk::mcp_client::ClientHandler; + +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at + // https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} diff --git a/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs new file mode 100644 index 0000000..a8e7c9c --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/inquiry_utils.rs @@ -0,0 +1,222 @@ +//! This module contains utility functions for querying and displaying server capabilities. + +use colored::Colorize; +use rust_mcp_sdk::schema::CallToolRequestParams; +use rust_mcp_sdk::McpClient; +use rust_mcp_sdk::{error::SdkResult, mcp_client::ClientRuntime}; +use serde_json::json; +use std::io::Write; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; + +const GREY_COLOR: (u8, u8, u8) = (90, 90, 90); +const HEADER_SIZE: usize = 31; + +pub struct InquiryUtils { + pub client: Arc, +} + +impl InquiryUtils { + fn print_header(&self, title: &str) { + let pad = ((HEADER_SIZE as f32 / 2.0) + (title.len() as f32 / 2.0)).floor() as usize; + println!("\n{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + println!("{:>pad$}", title.custom_color(GREY_COLOR)); + println!("{}", "=".repeat(HEADER_SIZE).custom_color(GREY_COLOR)); + } + + fn print_list(&self, list_items: Vec<(String, String)>) { + list_items.iter().enumerate().for_each(|(index, item)| { + println!("{}. {}: {}", index + 1, item.0.yellow(), item.1.cyan(),); + }); + } + + pub fn print_server_info(&self) { + self.print_header("Server info"); + let server_version = self.client.server_version().unwrap(); + println!("{} {}", "Server name:".bold(), server_version.name.cyan()); + println!( + "{} {}", + "Server version:".bold(), + server_version.version.cyan() + ); + } + + pub fn print_server_capabilities(&self) { + self.print_header("Capabilities"); + let capability_vec = [ + ("tools", self.client.server_has_tools()), + ("prompts", self.client.server_has_prompts()), + ("resources", self.client.server_has_resources()), + ("logging", self.client.server_supports_logging()), + ("experimental", self.client.server_has_experimental()), + ]; + + capability_vec.iter().for_each(|(tool_name, opt)| { + println!( + "{}: {}", + tool_name.bold(), + opt.map(|b| if b { "Yes" } else { "No" }) + .unwrap_or("Unknown") + .cyan() + ); + }); + } + + pub async fn print_tool_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support tools + if !self.client.server_has_tools().unwrap_or(false) { + return Ok(()); + } + + let tools = self.client.list_tools(None).await?; + self.print_header("Tools"); + self.print_list( + tools + .tools + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_prompts_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support prompts + if !self.client.server_has_prompts().unwrap_or(false) { + return Ok(()); + } + + let prompts = self.client.list_prompts(None).await?; + + self.print_header("Prompts"); + self.print_list( + prompts + .prompts + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn print_resource_list(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let resources = self.client.list_resources(None).await?; + + self.print_header("Resources"); + + self.print_list( + resources + .resources + .iter() + .map(|item| { + ( + item.name.clone(), + format!( + "( uri: {} , mime: {}", + item.uri, + item.mime_type.as_ref().unwrap_or(&"?".to_string()), + ), + ) + }) + .collect(), + ); + + Ok(()) + } + + pub async fn print_resource_templates(&self) -> SdkResult<()> { + // Return if the MCP server does not support resources + if !self.client.server_has_resources().unwrap_or(false) { + return Ok(()); + } + + let templates = self.client.list_resource_templates(None).await?; + + self.print_header("Resource Templates"); + + self.print_list( + templates + .resource_templates + .iter() + .map(|item| { + ( + item.name.clone(), + item.description.clone().unwrap_or_default(), + ) + }) + .collect(), + ); + Ok(()) + } + + pub async fn call_add_tool(&self, a: i64, b: i64) -> SdkResult<()> { + // Invoke the "add" tool with 100 and 25 as arguments, and display the result + println!( + "{}", + format!("\nCalling the \"add\" tool with {a} and {b} ...").magenta() + ); + + // Create a `Map` to represent the tool parameters + let params = json!({ + "a": a, + "b": b + }) + .as_object() + .unwrap() + .clone(); + + // invoke the tool + let result = self + .client + .call_tool(CallToolRequestParams { + name: "add".to_string(), + arguments: Some(params), + }) + .await?; + + // Retrieve the result content and print it to the stdout + let result_content = result.content.first().unwrap().as_text_content()?; + println!("{}", result_content.text.green()); + + Ok(()) + } + + pub async fn ping_n_times(&self, n: i32) { + let max_pings = n; + println!(); + for ping_index in 1..=max_pings { + print!("Ping the server ({ping_index} out of {max_pings})..."); + std::io::stdout().flush().unwrap(); + let ping_result = self.client.ping(None).await; + print!( + "\rPing the server ({} out of {}) : {}", + ping_index, + max_pings, + if ping_result.is_ok() { + "success".bright_green() + } else { + "failed".bright_red() + } + ); + println!(); + sleep(Duration::from_secs(2)).await; + } + } +} diff --git a/examples/simple-mcp-client-streamable-http/src/main.rs b/examples/simple-mcp-client-streamable-http/src/main.rs new file mode 100644 index 0000000..ab580db --- /dev/null +++ b/examples/simple-mcp-client-streamable-http/src/main.rs @@ -0,0 +1,99 @@ +mod handler; +mod inquiry_utils; + +use handler::MyClientHandler; + +use rust_mcp_sdk::error::SdkResult; +use rust_mcp_sdk::mcp_client::client_runtime; +use rust_mcp_sdk::schema::{ + ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel, + LATEST_PROTOCOL_VERSION, +}; +use rust_mcp_sdk::{McpClient, RequestOptions, StreamableTransportOptions}; +use std::sync::Arc; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +use crate::inquiry_utils::InquiryUtils; + +const MCP_SERVER_URL: &str = "http://127.0.0.1:8080/mcp"; + +#[tokio::main] +async fn main() -> SdkResult<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Step1 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 2: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 3: instantiate our custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 4: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 5: start the MCP client + client.clone().start().await?; + + // You can utilize the client and its methods to interact with the MCP Server. + // The following demonstrates how to use client methods to retrieve server information, + // and print them in the terminal, set the log level, invoke a tool, and more. + + // Create a struct with utility functions for demonstration purpose, to utilize different client methods and display the information. + let utils = InquiryUtils { + client: Arc::clone(&client), + }; + + // Display server information (name and version) + utils.print_server_info(); + + // Display server capabilities + utils.print_server_capabilities(); + + // Display the list of tools available on the server + utils.print_tool_list().await?; + + // Display the list of prompts available on the server + utils.print_prompts_list().await?; + + // Display the list of resources available on the server + utils.print_resource_list().await?; + + // Display the list of resource templates available on the server + utils.print_resource_templates().await?; + + // Call add tool, and print the result + utils.call_add_tool(100, 25).await?; + + // Set the log level + match utils.client.set_logging_level(LoggingLevel::Debug).await { + Ok(_) => println!("Log level is set to \"Debug\""), + Err(err) => eprintln!("Error setting the Log level : {err}"), + } + + // Send 3 pings to the server, with a 2-second interval between each ping. + utils.ping_n_times(3).await; + client.shut_down().await?; + + Ok(()) +} From 39be611055bbe1dd95ecb2b25eb3b4878dab1cb4 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 13 Sep 2025 12:28:41 -0300 Subject: [PATCH 02/17] chore: typos --- .../src/mcp_runtimes/client_runtime.rs | 6 +++--- crates/rust-mcp-sdk/tests/common/test_client.rs | 2 +- .../tests/test_streamable_http_client.rs | 14 +++++++------- crates/rust-mcp-transport/src/utils/sse_parser.rs | 2 +- .../src/utils/streamable_http_stream.rs | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 9961b84..2093dc3 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -454,9 +454,9 @@ impl ClientRuntime { let result = transport.send_message(messages, timeout).await?; if no_session_id { - if let Some(resquest_id) = transport.session_id().await.clone() { + if let Some(request_id) = transport.session_id().await.clone() { let mut guard = self.session_id.write().await; - *guard = Some(resquest_id) + *guard = Some(request_id) } } @@ -515,7 +515,7 @@ impl ClientRuntime { // Run both tasks with cancellation logic let (send_res, _) = tokio::select! { res = &mut send_task => { - // cancel the receive_task task, to cover the case where sned_task returns with error + // cancel the receive_task task, to cover the case where send_task returns with error abort_recv_handle.abort(); (res, receive_task.await) // Wait for receive_task to finish (it should exit due to cancellation) } diff --git a/crates/rust-mcp-sdk/tests/common/test_client.rs b/crates/rust-mcp-sdk/tests/common/test_client.rs index 21678c7..46a8525 100644 --- a/crates/rust-mcp-sdk/tests/common/test_client.rs +++ b/crates/rust-mcp-sdk/tests/common/test_client.rs @@ -89,7 +89,7 @@ pub mod test_client_common { ) -> InitializedClient { let mock_server = MockServer::start().await; - // intialize response + // initialize response let mut response = create_sse_response(INITIALIZE_RESPONSE); if let Some(session_id) = session_id { diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs index a0a2804..cb82ff5 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -39,7 +39,7 @@ async fn should_send_json_rpc_messages_via_post() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE); // initialize request and response @@ -137,7 +137,7 @@ async fn should_store_session_id_received_during_initialization() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE).append_header("mcp-session-id", "test-session-id"); @@ -283,7 +283,7 @@ async fn should_handle_successful_initial_get_connection_for_sse() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE); // initialize request and response @@ -394,7 +394,7 @@ async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE); // initialize request and response @@ -445,7 +445,7 @@ async fn should_attempt_initial_get_connection_and_handle_405_gracefully() { assert!(get_request.is_some()); - // send a batch message, runtime should work as expected with no isse + // send a batch message, runtime should work as expected with no issue let response = create_sse_response( r#"[{"id":"id1","jsonrpc":"2.0", "result":{}},{"id":"id2","jsonrpc":"2.0", "result":{}}]"#, @@ -616,7 +616,7 @@ async fn should_reconnect_a_get_initiated_notification_stream_that_fails() { // Start a mock server let mock_server = MockServer::start().await; - // intialize response + // initialize response let response = create_sse_response(INITIALIZE_RESPONSE); // initialize request and response @@ -726,7 +726,7 @@ async fn should_pass_last_event_id_when_reconnecting() { assert!(get_requests.len() > 1); let Some(last_get_request) = get_requests.last() else { - panic!("Unable to find last GET reuest!"); + panic!("Unable to find last GET request!"); }; let last_event_id = last_get_request diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs index 064d3c3..5933726 100644 --- a/crates/rust-mcp-transport/src/utils/sse_parser.rs +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -62,7 +62,7 @@ impl fmt::Debug for SseEvent { } /// A parser for Server-Sent Events (SSE) that processes incoming byte chunks into `SseEvent`s. -/// This Parser is specificly designed for MCP messages and with no multi-line data support +/// This Parser is specifically designed for MCP messages and with no multi-line data support /// /// This struct maintains a buffer to accumulate incoming data and parses it into SSE events /// based on the SSE protocol. It handles fields like `event`, `data`, and `id` as defined diff --git a/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs index ae9c69c..3362c71 100644 --- a/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs +++ b/crates/rust-mcp-transport/src/utils/streamable_http_stream.rs @@ -130,7 +130,7 @@ impl StreamableHttpStream { match chunk { Some(chunk) => chunk, None => { - // stream ended, unline SSE, so no retry attempt here needed to reconnect + // stream ended, unlike SSE, so no retry attempt here needed to reconnect return Err(TransportError::Internal("Stream has ended.".to_string())); } } @@ -315,7 +315,7 @@ impl StreamableHttpStream { match chunk { Some(chunk) => chunk, None => { - // stream ended, unline SSE, so no retry attempt here needed to reconnect + // stream ended, unlike SSE, so no retry attempt here needed to reconnect return Err(TransportError::Internal("Stream has ended.".to_string())); } } From 6204d79232ad876a9050e42028b2b5b38ded7279 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sat, 13 Sep 2025 14:35:32 -0300 Subject: [PATCH 03/17] chore: update readme --- README.md | 145 +++++++++++++++++++++++++--------- crates/rust-mcp-sdk/README.md | 142 +++++++++++++++++++++++++-------- 2 files changed, 215 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index b1af670..ced3672 100644 --- a/README.md +++ b/README.md @@ -32,27 +32,15 @@ This project supports following transports: 🚀 The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. - **MCP Streamable HTTP Support** - ✅ Streamable HTTP Support for MCP Servers - ✅ DNS Rebinding Protection - ✅ Batch Messages - ✅ Streaming & non-streaming JSON response -- ⬜ Streamable HTTP Support for MCP Clients +- ✅ Streamable HTTP Support for MCP Clients - ⬜ Resumability - ⬜ Authentication / Oauth - - -**MCP Streamable HTTP Support** -- [x] Streamable HTTP Support for MCP Servers -- [x] DNS Rebinding Protection -- [x] Batch Messages -- [x] Streaming & non-streaming JSON response -- [ ] Streamable HTTP Support for MCP Clients -- [ ] Resumability -- [ ] Authentication / Oauth - **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents @@ -60,6 +48,7 @@ This project supports following transports: - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) + - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) @@ -202,7 +191,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) @@ -294,6 +283,8 @@ async fn main() -> SdkResult<()> { println!("{}",result.content.first().unwrap().as_text_content()?.text); + client.shut_down().await?; + Ok(()) } @@ -305,8 +296,82 @@ Here is the output : > your results may vary slightly depending on the version of the MCP Server in use when you run it. +### MCP Client (Streamable HTTP) +```rs + +// STEP 1: Custom Handler to handle incoming MCP Messages +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + + // Step2 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 4: instantiate the custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 5: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 6: start the MCP client + client.clone().start().await?; + + // STEP 7: use client methods to communicate with the MCP Server as you wish + + // Retrieve and display the list of tools available on the server + let server_version = client.server_version().unwrap(); + let tools = client.list_tools(None).await?.tools; + println!("List of tools for {}@{}", server_version.name, server_version.version); + + tools.iter().enumerate().for_each(|(tool_index, tool)| { + println!(" {}. {} : {}", + tool_index + 1, + tool.name, + tool.description.clone().unwrap_or_default() + ); + }); + + println!("Call \"add\" tool with 100 and 28 ..."); + // Create a `Map` to represent the tool parameters + let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); + let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; + + // invoke the tool + let result = client.call_tool(request).await?; + + println!("{}",result.content.first().unwrap().as_text_content()?.text); + + client.shut_down().await?; + + Ok(()) +``` +👉 see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. + + ### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical, with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: +Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: ```diff - let transport = StdioTransport::create_with_server_launch( @@ -317,6 +382,8 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost + let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; ``` +👉 see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. + ## Getting Started @@ -355,9 +422,15 @@ pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>>, + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, + /// Shared transport configuration used by the server + pub transport_options: Arc, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -367,12 +440,6 @@ pub struct HyperServerOptions { /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, - /// Enables SSL/TLS if set to `true` pub enable_ssl: bool, @@ -384,17 +451,6 @@ pub struct HyperServerOptions { /// Required if `enable_ssl` is `true`. pub ssl_key_path: Option, - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - /// List of allowed host header values for DNS rebinding protection. /// If not specified, host validation is disabled. pub allowed_hosts: Option>, @@ -406,6 +462,17 @@ pub struct HyperServerOptions { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, } ``` @@ -427,9 +494,13 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `server`: Activates MCP server capabilities in `rust-mcp-sdk`, providing modules and APIs for building and managing MCP servers. - `client`: Activates MCP client capabilities, offering modules and APIs for client development and communicating with MCP servers. -- `hyper-server`: This feature enables the **sse** transport for MCP servers, supporting multiple simultaneous client connections out of the box. -- `ssl`: This feature enables TLS/SSL support for the **sse** transport when used with the `hyper-server`. +- `hyper-server`: This feature is necessary to enable `Streamable HTTP` or `Server-Sent Events (SSE)` transports for MCP servers. It must be used alongside the server feature to support the required server functionalities. +- `ssl`: This feature enables TLS/SSL support for the `Streamable HTTP` or `Server-Sent Events (SSE)` transport when used with the `hyper-server`. - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. +- `sse`: Enables support for the `Server-Sent Events (SSE)` transport. +- `streamable-http`: Enables support for the `Streamable HTTP` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport.. + #### MCP Protocol Versions with Corresponding Features @@ -460,9 +531,9 @@ If you only need the MCP Server functionality, you can disable the default featu ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros","stdio"] } ``` -Optionally add `hyper-server` for **sse** transport, and `ssl` feature for tls/ssl support of the `hyper-server` +Optionally add `hyper-server` and `streamable-http` for **Streamable HTTP** transport, and `ssl` feature for tls/ssl support of the `hyper-server` @@ -475,7 +546,7 @@ Add the following to your Cargo.toml: ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05","stdio"] } ``` diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 9df027d..ced3672 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -37,21 +37,10 @@ This project supports following transports: - ✅ DNS Rebinding Protection - ✅ Batch Messages - ✅ Streaming & non-streaming JSON response -- ⬜ Streamable HTTP Support for MCP Clients +- ✅ Streamable HTTP Support for MCP Clients - ⬜ Resumability - ⬜ Authentication / Oauth - - -**MCP Streamable HTTP Support** -- [x] Streamable HTTP Support for MCP Servers -- [x] DNS Rebinding Protection -- [x] Batch Messages -- [x] Streaming & non-streaming JSON response -- [ ] Streamable HTTP Support for MCP Clients -- [ ] Resumability -- [ ] Authentication / Oauth - **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents @@ -59,6 +48,7 @@ This project supports following transports: - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) + - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) - [MCP Client (sse)](#mcp-client-sse) - [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) @@ -293,6 +283,8 @@ async fn main() -> SdkResult<()> { println!("{}",result.content.first().unwrap().as_text_content()?.text); + client.shut_down().await?; + Ok(()) } @@ -304,8 +296,82 @@ Here is the output : > your results may vary slightly depending on the version of the MCP Server in use when you run it. +### MCP Client (Streamable HTTP) +```rs + +// STEP 1: Custom Handler to handle incoming MCP Messages +pub struct MyClientHandler; + +#[async_trait] +impl ClientHandler for MyClientHandler { + // To check out a list of all the methods in the trait that you can override, take a look at https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + + // Step2 : Define client details and capabilities + let client_details: InitializeRequestParams = InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "simple-rust-mcp-client-sse".to_string(), + version: "0.1.0".to_string(), + title: Some("Simple Rust MCP Client (SSE)".to_string()), + }, + protocol_version: LATEST_PROTOCOL_VERSION.into(), + }; + + // Step 3: Create transport options to connect to an MCP server via Streamable HTTP. + let transport_options = StreamableTransportOptions { + mcp_url: MCP_SERVER_URL.to_string(), + request_options: RequestOptions { + ..RequestOptions::default() + }, + }; + + // STEP 4: instantiate the custom handler that is responsible for handling MCP messages + let handler = MyClientHandler {}; + + // STEP 5: create the client with transport options and the handler + let client = client_runtime::with_transport_options(client_details, transport_options, handler); + + // STEP 6: start the MCP client + client.clone().start().await?; + + // STEP 7: use client methods to communicate with the MCP Server as you wish + + // Retrieve and display the list of tools available on the server + let server_version = client.server_version().unwrap(); + let tools = client.list_tools(None).await?.tools; + println!("List of tools for {}@{}", server_version.name, server_version.version); + + tools.iter().enumerate().for_each(|(tool_index, tool)| { + println!(" {}. {} : {}", + tool_index + 1, + tool.name, + tool.description.clone().unwrap_or_default() + ); + }); + + println!("Call \"add\" tool with 100 and 28 ..."); + // Create a `Map` to represent the tool parameters + let params = json!({"a": 100,"b": 28}).as_object().unwrap().clone(); + let request = CallToolRequestParams { name: "add".to_string(),arguments: Some(params)}; + + // invoke the tool + let result = client.call_tool(request).await?; + + println!("{}",result.content.first().unwrap().as_text_content()?.text); + + client.shut_down().await?; + + Ok(()) +``` +👉 see [examples/simple-mcp-client-streamable-http](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-streamable-http) for a complete working example. + + ### MCP Client (sse) -Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical, with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: +Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost identical to the [stdio example](#mcp-client-stdio) , with one exception at `step 3`. Instead of creating a `StdioTransport`, you simply create a `ClientSseTransport`. The rest of the code remains the same: ```diff - let transport = StdioTransport::create_with_server_launch( @@ -316,6 +382,8 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost + let transport = ClientSseTransport::new(MCP_SERVER_URL, ClientSseTransportOptions::default())?; ``` +👉 see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. + ## Getting Started @@ -354,9 +422,15 @@ pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>>, + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) pub custom_streamable_http_endpoint: Option, + /// Shared transport configuration used by the server + pub transport_options: Arc, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -366,12 +440,6 @@ pub struct HyperServerOptions { /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, - /// Shared transport configuration used by the server - pub transport_options: Arc, - - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, - /// Enables SSL/TLS if set to `true` pub enable_ssl: bool, @@ -383,17 +451,6 @@ pub struct HyperServerOptions { /// Required if `enable_ssl` is `true`. pub ssl_key_path: Option, - /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) - pub sse_support: bool, - - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - /// Applicable only if sse_support is true - pub custom_sse_endpoint: Option, - - /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) - /// Applicable only if sse_support is true - pub custom_messages_endpoint: Option, - /// List of allowed host header values for DNS rebinding protection. /// If not specified, host validation is disabled. pub allowed_hosts: Option>, @@ -405,6 +462,17 @@ pub struct HyperServerOptions { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, } ``` @@ -426,9 +494,13 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `server`: Activates MCP server capabilities in `rust-mcp-sdk`, providing modules and APIs for building and managing MCP servers. - `client`: Activates MCP client capabilities, offering modules and APIs for client development and communicating with MCP servers. -- `hyper-server`: This feature enables the **sse** transport for MCP servers, supporting multiple simultaneous client connections out of the box. -- `ssl`: This feature enables TLS/SSL support for the **sse** transport when used with the `hyper-server`. +- `hyper-server`: This feature is necessary to enable `Streamable HTTP` or `Server-Sent Events (SSE)` transports for MCP servers. It must be used alongside the server feature to support the required server functionalities. +- `ssl`: This feature enables TLS/SSL support for the `Streamable HTTP` or `Server-Sent Events (SSE)` transport when used with the `hyper-server`. - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. +- `sse`: Enables support for the `Server-Sent Events (SSE)` transport. +- `streamable-http`: Enables support for the `Streamable HTTP` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport.. + #### MCP Protocol Versions with Corresponding Features @@ -459,9 +531,9 @@ If you only need the MCP Server functionality, you can disable the default featu ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["server","macros","stdio"] } ``` -Optionally add `hyper-server` for **sse** transport, and `ssl` feature for tls/ssl support of the `hyper-server` +Optionally add `hyper-server` and `streamable-http` for **Streamable HTTP** transport, and `ssl` feature for tls/ssl support of the `hyper-server` @@ -474,7 +546,7 @@ Add the following to your Cargo.toml: ```toml [dependencies] -rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05"] } +rust-mcp-sdk = { version = "0.2.0", default-features = false, features = ["client","2024_11_05","stdio"] } ``` From 4baadee6714f2bbb4a3ed7b572950079a3affd84 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 14 Sep 2025 13:47:08 -0300 Subject: [PATCH 04/17] feat: introduce event-store --- crates/rust-mcp-transport/src/event_store.rs | 27 ++ .../src/event_store/in_memory_event_store.rs | 250 ++++++++++++++++++ .../src/utils/time_utils.rs | 8 + 3 files changed, 285 insertions(+) create mode 100644 crates/rust-mcp-transport/src/event_store.rs create mode 100644 crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs create mode 100644 crates/rust-mcp-transport/src/utils/time_utils.rs diff --git a/crates/rust-mcp-transport/src/event_store.rs b/crates/rust-mcp-transport/src/event_store.rs new file mode 100644 index 0000000..2af76c1 --- /dev/null +++ b/crates/rust-mcp-transport/src/event_store.rs @@ -0,0 +1,27 @@ +mod in_memory_event_store; +use async_trait::async_trait; +pub use in_memory_event_store::*; +use rust_mcp_schema::schema_utils::ServerMessages; + +use crate::{EventId, SessionId, StreamId}; + +#[derive(Debug, Clone)] +pub struct EventStoreMessages { + pub session_id: SessionId, + pub stream_id: StreamId, + pub messages: Vec, +} + +#[async_trait] +pub trait EventStore: Send + Sync { + async fn store_event( + &self, + session_id: SessionId, + stream_id: StreamId, + time_stamp: u128, + message: ServerMessages, + ) -> EventId; + async fn remove_by_session_id(&self, session_id: SessionId); + async fn remove_stream_in_session(&self, session_id: SessionId, stream_id: StreamId); + async fn events_after(&self, last_event_id: EventId) -> Option; +} diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs new file mode 100644 index 0000000..a1c07dd --- /dev/null +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -0,0 +1,250 @@ +use async_trait::async_trait; +use rust_mcp_schema::schema_utils::ServerMessages; +use std::collections::HashMap; +use std::collections::VecDeque; +use tokio::sync::RwLock; + +use crate::{ + event_store::{EventStore, EventStoreMessages}, + EventId, SessionId, StreamId, +}; + +const MAX_EVENTS_PER_SESSION: usize = 32; +const ID_SEPERATOR: &str = "-.-"; + +#[derive(Debug, Clone)] +struct EventEntry { + pub stream_id: StreamId, + pub time_stamp: u128, + pub message: ServerMessages, +} + +#[derive(Debug)] +pub struct InMemoryEventStore { + max_events_per_session: usize, + storage_map: RwLock>>, +} + +/// In-memory implementation of the `EventStore` trait for MCP's Streamable HTTP transport. +/// +/// Stores events in a `HashMap` of session IDs to `VecDeque`s of events, with a per-session limit. +/// Events are identified by `event_id` (format: `session-.-stream-.-timestamp`) and used for SSE resumption. +/// Thread-safe via `RwLock` for concurrent access. +impl InMemoryEventStore { + /// Creates a new `InMemoryEventStore` with an optional maximum events per session. + /// + /// # Arguments + /// - `max_events_per_session`: Maximum number of events per session. Defaults to `MAX_EVENTS_PER_SESSION` (32) if `None`. + /// + /// # Returns + /// A new `InMemoryEventStore` instance with an empty `HashMap` wrapped in a `RwLock`. + /// + /// # Example + /// ``` + /// let store = InMemoryEventStore::new(Some(10)); + /// assert_eq!(store.max_events_per_session, 10); + /// ``` + pub fn new(max_events_per_session: Option) -> Self { + Self { + max_events_per_session: max_events_per_session.unwrap_or(MAX_EVENTS_PER_SESSION), + storage_map: RwLock::new(HashMap::new()), + } + } + + /// Generates an `event_id` string from session, stream, and timestamp components. + /// + /// Format: `session-.-stream-.-timestamp`, used as a resumption cursor in SSE (`Last-Event-ID`). + /// + /// # Arguments + /// - `session_id`: The session identifier. + /// - `stream_id`: The stream identifier. + /// - `time_stamp`: The event timestamp (u128). + /// + /// # Returns + /// A `String` in the format `session-.-stream-.-timestamp`. + fn generate_event_id( + &self, + session_id: &SessionId, + stream_id: &StreamId, + time_stamp: u128, + ) -> String { + format!("{session_id}{ID_SEPERATOR}{stream_id}{ID_SEPERATOR}{time_stamp}") + } + + /// Parses an event ID into its session, stream, and timestamp components. + /// + /// The event ID must follow the format `session-.-stream-.-timestamp`. + /// Returns `None` if the format is invalid, empty, or contains invalid characters (e.g., NULL). + /// + /// # Arguments + /// - `event_id`: The event ID string to parse. + /// + /// # Returns + /// An `Option` containing a tuple of `(session_id, stream_id, time_stamp)` as string slices, + /// or `None` if the format is invalid. + /// + /// # Example + /// ``` + /// let store = InMemoryEventStore::new(None); + /// let event_id = "session1-.-stream1-.-12345"; + /// assert_eq!( + /// store.parse_event_id(event_id), + /// Some(("session1", "stream1", "12345")) + /// ); + /// assert_eq!(store.parse_event_id("invalid"), None); + /// ``` + pub fn parse_event_id<'a>(&self, event_id: &'a str) -> Option<(&'a str, &'a str, &'a str)> { + // Check for empty input or invalid characters (e.g., NULL) + if event_id.is_empty() || event_id.contains('\0') { + return None; + } + + // Split into exactly three parts + let parts: Vec<&'a str> = event_id.split('.').collect(); + if parts.len() != 3 { + return None; + } + + let session_id = parts[0]; + let stream_id = parts[1]; + let time_stamp = parts[2]; + + // Ensure no part is empty + if session_id.is_empty() || stream_id.is_empty() || time_stamp.is_empty() { + return None; + } + + Some((session_id, stream_id, time_stamp)) + } +} + +#[async_trait] +impl EventStore for InMemoryEventStore { + /// Stores an event for a given session and stream, returning its `event_id`. + /// + /// Adds the event to the session’s `VecDeque`, removing the oldest event if the session + /// reaches `max_events_per_session`. + /// + /// # Arguments + /// - `session_id`: The session identifier. + /// - `stream_id`: The stream identifier. + /// - `time_stamp`: The event timestamp (u128). + /// - `message`: The `ServerMessages` payload. + /// + /// # Returns + /// The generated `EventId` for the stored event. + async fn store_event( + &self, + session_id: SessionId, + stream_id: StreamId, + time_stamp: u128, + message: ServerMessages, + ) -> EventId { + let event_id = self.generate_event_id(&session_id, &stream_id, time_stamp); + + let mut storage_map = self.storage_map.write().await; + + let session_map = storage_map + .entry(session_id) + .or_insert_with(|| VecDeque::with_capacity(self.max_events_per_session)); + + if session_map.len() == self.max_events_per_session { + session_map.pop_front(); // remove the oldest if full + } + + session_map.push_back(EventEntry { + stream_id, + time_stamp, + message, + }); + + event_id + } + + /// Removes all events associated with a given stream ID within a specific session. + /// + /// Removes events matching `stream_id` from the specified `session_id`’s event queue. + /// If the session’s queue becomes empty, it is removed from the store. + /// Idempotent if `session_id` or `stream_id` doesn’t exist. + /// + /// # Arguments + /// - `session_id`: The session identifier to target. + /// - `stream_id`: The stream identifier to remove. + async fn remove_stream_in_session(&self, session_id: SessionId, stream_id: StreamId) { + let mut storage_map = self.storage_map.write().await; + + // Check if session exists + if let Some(events) = storage_map.get_mut(&session_id) { + // Remove events with the given stream_id + events.retain(|event| event.stream_id != stream_id); + // Remove session if empty + if events.is_empty() { + storage_map.remove(&session_id); + } + } + // No action if session_id doesn’t exist (idempotent) + } + + /// Removes all events associated with a given session ID. + /// + /// Removes the entire session from the store. Idempotent if `session_id` doesn’t exist. + /// + /// # Arguments + /// - `session_id`: The session identifier to remove. + async fn remove_by_session_id(&self, session_id: SessionId) { + let mut storage_map = self.storage_map.write().await; + storage_map.remove(&session_id); + } + + /// Retrieves events after a given `event_id` for a specific session and stream. + /// + /// Parses `last_event_id` to extract `session_id`, `stream_id`, and `time_stamp`. + /// Returns events after the matching event in the session’s stream, sorted by timestamp + /// in ascending order (earliest to latest). Returns `None` if the `event_id` is invalid, + /// the session doesn’t exist, or the timestamp is non-numeric. + /// + /// # Arguments + /// - `last_event_id`: The event ID (format: `session-.-stream-.-timestamp`) to start after. + /// + /// # Returns + /// An `Option` containing `EventStoreMessages` with the session ID, stream ID, and sorted messages, + /// or `None` if no events are found or the input is invalid. + async fn events_after(&self, last_event_id: EventId) -> Option { + let Some((session_id, stream_id, time_stamp)) = self.parse_event_id(&last_event_id) else { + return None; + }; + + let storage_map = self.storage_map.read().await; + let Some(events) = storage_map.get(session_id) else { + return None; + }; + + let Ok(time_stamp) = time_stamp.parse::() else { + return None; + }; + + let events = match events + .iter() + .position(|e| e.stream_id == stream_id && e.time_stamp == time_stamp) + { + Some(index) if index + 1 < events.len() => { + // Collect subsequent events that match the stream_id + let mut subsequent: Vec<_> = events + .range(index + 1..) + .filter(|e| e.stream_id == stream_id) + .cloned() + .collect(); + + subsequent.sort_by(|a, b| a.time_stamp.cmp(&b.time_stamp)); + subsequent.iter().map(|e| e.message.clone()).collect() + } + _ => vec![], + }; + + Some(EventStoreMessages { + session_id: session_id.to_string(), + stream_id: stream_id.to_string(), + messages: events, + }) + } +} diff --git a/crates/rust-mcp-transport/src/utils/time_utils.rs b/crates/rust-mcp-transport/src/utils/time_utils.rs new file mode 100644 index 0000000..a94c305 --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/time_utils.rs @@ -0,0 +1,8 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +pub fn current_timestamp() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Invalid time") + .as_nanos() // or `.as_millis()` or `.as_nanos()` if you want higher precision +} From d8946dd6c57398060e8c0a89e9067f26aebecbc2 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 14 Sep 2025 14:18:09 -0300 Subject: [PATCH 05/17] chore: add event store to the app state --- crates/rust-mcp-sdk/src/hyper_servers/app_state.rs | 3 +++ crates/rust-mcp-sdk/src/hyper_servers/server.rs | 6 +++++- crates/rust-mcp-transport/src/lib.rs | 1 + crates/rust-mcp-transport/src/utils.rs | 3 +++ 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index ff6d5b2..a520294 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -3,6 +3,7 @@ use std::{sync::Arc, time::Duration}; use super::session_store::SessionStore; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::{id_generator::FastIdGenerator, mcp_traits::IdGenerator, schema::InitializeResult}; +use rust_mcp_transport::event_store::EventStore; use rust_mcp_transport::{SessionId, TransportOptions}; /// Application state struct for the Hyper server @@ -30,6 +31,8 @@ pub struct AppState { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + + pub event_store: Option>, } impl AppState { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 1c3b3cf..a1843b0 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -23,7 +23,7 @@ use super::{ }; use crate::schema::InitializeResult; use axum::Router; -use rust_mcp_transport::{SessionId, TransportOptions}; +use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions}; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); @@ -53,6 +53,8 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -225,6 +227,7 @@ impl Default for HyperServerOptions { allowed_hosts: None, allowed_origins: None, dns_rebinding_protection: false, + event_store: None, } } } @@ -271,6 +274,7 @@ impl HyperServer { allowed_hosts: server_options.allowed_hosts.take(), allowed_origins: server_options.allowed_origins.take(), dns_rebinding_protection: server_options.dns_rebinding_protection, + event_store: server_options.event_store.as_ref().map(Arc::clone), }); let app = app_routes(Arc::clone(&state), &server_options); Self { diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 4a918db..d21e5dd 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -8,6 +8,7 @@ mod client_sse; mod client_streamable_http; mod constants; pub mod error; +pub mod event_store; mod mcp_stream; mod message_dispatcher; mod schema; diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 82d7326..83abb5e 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -12,6 +12,9 @@ mod streamable_http_stream; #[cfg(any(feature = "sse", feature = "streamable-http"))] mod writable_channel; +mod time_utils; +pub(crate) use time_utils::*; + pub(crate) use cancellation_token::*; #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use http_utils::*; From 2504c9600248d6f31cecda8123b9954caeb3c495 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 14 Sep 2025 16:06:26 -0300 Subject: [PATCH 06/17] chore: refactor event store integration --- README.md | 5 ++ crates/rust-mcp-sdk/README.md | 5 ++ .../src/hyper_servers/app_state.rs | 3 +- .../src/hyper_servers/routes/hyper_utils.rs | 75 ++++++++++++++----- .../rust-mcp-sdk/src/hyper_servers/server.rs | 2 + crates/rust-mcp-sdk/src/utils.rs | 8 ++ crates/rust-mcp-transport/src/event_store.rs | 6 +- .../src/event_store/in_memory_event_store.rs | 9 ++- crates/rust-mcp-transport/src/utils.rs | 3 - .../src/utils/time_utils.rs | 8 -- 10 files changed, 87 insertions(+), 37 deletions(-) delete mode 100644 crates/rust-mcp-transport/src/utils/time_utils.rs diff --git a/README.md b/README.md index ced3672..44b584e 100644 --- a/README.md +++ b/README.md @@ -415,6 +415,7 @@ server.start().await?; Here is a list of available options with descriptions for configuring the HyperServer: ```rs + pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "127.0.0.1") pub host: String, @@ -431,6 +432,10 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index ced3672..44b584e 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -415,6 +415,7 @@ server.start().await?; Here is a list of available options with descriptions for configuring the HyperServer: ```rs + pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "127.0.0.1") pub host: String, @@ -431,6 +432,10 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index a520294..e7f8793 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -31,7 +31,8 @@ pub struct AppState { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, - + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages pub event_store: Option>, } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index da69c67..ee1b533 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -7,7 +7,7 @@ use crate::{ mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, - utils::validate_mcp_protocol_version, + utils::{current_timestamp, validate_mcp_protocol_version}, }; use crate::schema::schema_utils::{ClientMessage, SdkError}; @@ -23,13 +23,23 @@ use axum::{ use futures::stream; use hyper::{header, HeaderMap, StatusCode}; use rust_mcp_transport::{ - SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, + event_store::EventStore, EventId, SessionId, SseTransport, StreamId, + MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, }; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; const DUPLEX_BUFFER_SIZE: usize = 8192; +async fn store_event( + event_store: Arc, + session_id: SessionId, + stream_id: StreamId, + message_payload: &str, +) -> Option { + None +} + async fn create_sse_stream( runtime: Arc, session_id: SessionId, @@ -63,40 +73,65 @@ async fn create_sse_stream( .map_err(|err| TransportServerError::TransportError(err.to_string()))?, ); - let stream_id: StreamId = if standalone { - DEFAULT_STREAM_ID.to_string() + let session_id = Arc::new(session_id); + let stream_id: Arc = if standalone { + Arc::new(DEFAULT_STREAM_ID.to_string()) } else { - state.stream_id_gen.generate() + Arc::new(state.stream_id_gen.generate()) }; + let ping_interval = state.ping_interval; let runtime_clone = Arc::clone(&runtime); + let stream_id_clone = stream_id.clone(); //Start the server runtime tokio::spawn(async move { match runtime_clone - .start_stream(transport, &stream_id, ping_interval, payload_string) + .start_stream(transport, &stream_id_clone, ping_interval, payload_string) .await { - Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id), - Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), + Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err), } - let _ = runtime.remove_transport(&stream_id).await; + let _ = runtime.remove_transport(&stream_id_clone).await; }); + // let event_store = state.event_store.; + // Construct SSE stream let reader = BufReader::new(write_rx); - - // outgoing messages from server to the client - let message_stream = stream::unfold(reader, |mut reader| async move { - let mut line = String::new(); - - match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - Some((Ok(Event::default().data(trimmed_line)), reader)) + let session_id_clone = session_id.clone(); + let event_store = state.event_store.as_ref().map(Arc::clone); + + // send outgoing messages from server to the client over the sse stream + let message_stream = stream::unfold(reader, move |mut reader| { + let session_id = session_id_clone.clone(); + let stream_id = stream_id.clone(); + let event_store = event_store.clone(); + async move { + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + + // store the event for resumption if it is supported + if let Some(event_store) = event_store { + event_store + .store_event( + (*session_id).clone(), + (*stream_id).clone(), + current_timestamp(), + trimmed_line.clone(), + ) + .await; + } + + Some((Ok(Event::default().data(trimmed_line)), reader)) + } + Err(e) => Some((Err(e), reader)), } - Err(e) => Some((Err(e), reader)), } }); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index a1843b0..71bccee 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -53,6 +53,8 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages pub event_store: Option>, /// This setting only applies to streamable HTTP. diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index 16fe7c7..cfdce16 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -3,6 +3,7 @@ use crate::schema::schema_utils::{ClientMessages, SdkError}; use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; +use std::time::{SystemTime, UNIX_EPOCH}; /// A guard type that automatically aborts a Tokio task when dropped. /// @@ -234,6 +235,13 @@ pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> { Ok(()) } +pub fn current_timestamp() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Invalid time") + .as_nanos() // or `.as_millis()` or `.as_nanos()` if you want higher precision +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/rust-mcp-transport/src/event_store.rs b/crates/rust-mcp-transport/src/event_store.rs index 2af76c1..fdc0734 100644 --- a/crates/rust-mcp-transport/src/event_store.rs +++ b/crates/rust-mcp-transport/src/event_store.rs @@ -1,7 +1,6 @@ mod in_memory_event_store; use async_trait::async_trait; pub use in_memory_event_store::*; -use rust_mcp_schema::schema_utils::ServerMessages; use crate::{EventId, SessionId, StreamId}; @@ -9,7 +8,7 @@ use crate::{EventId, SessionId, StreamId}; pub struct EventStoreMessages { pub session_id: SessionId, pub stream_id: StreamId, - pub messages: Vec, + pub messages: Vec, } #[async_trait] @@ -19,9 +18,10 @@ pub trait EventStore: Send + Sync { session_id: SessionId, stream_id: StreamId, time_stamp: u128, - message: ServerMessages, + message: String, ) -> EventId; async fn remove_by_session_id(&self, session_id: SessionId); async fn remove_stream_in_session(&self, session_id: SessionId, stream_id: StreamId); + async fn clear(&self); async fn events_after(&self, last_event_id: EventId) -> Option; } diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs index a1c07dd..8568536 100644 --- a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -16,7 +16,7 @@ const ID_SEPERATOR: &str = "-.-"; struct EventEntry { pub stream_id: StreamId, pub time_stamp: u128, - pub message: ServerMessages, + pub message: String, } #[derive(Debug)] @@ -138,7 +138,7 @@ impl EventStore for InMemoryEventStore { session_id: SessionId, stream_id: StreamId, time_stamp: u128, - message: ServerMessages, + message: String, ) -> EventId { let event_id = self.generate_event_id(&session_id, &stream_id, time_stamp); @@ -247,4 +247,9 @@ impl EventStore for InMemoryEventStore { messages: events, }) } + + async fn clear(&self) { + let mut storage_map = self.storage_map.write().await; + storage_map.clear(); + } } diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 83abb5e..82d7326 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -12,9 +12,6 @@ mod streamable_http_stream; #[cfg(any(feature = "sse", feature = "streamable-http"))] mod writable_channel; -mod time_utils; -pub(crate) use time_utils::*; - pub(crate) use cancellation_token::*; #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use http_utils::*; diff --git a/crates/rust-mcp-transport/src/utils/time_utils.rs b/crates/rust-mcp-transport/src/utils/time_utils.rs deleted file mode 100644 index a94c305..0000000 --- a/crates/rust-mcp-transport/src/utils/time_utils.rs +++ /dev/null @@ -1,8 +0,0 @@ -use std::time::{SystemTime, UNIX_EPOCH}; - -pub fn current_timestamp() -> u128 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Invalid time") - .as_nanos() // or `.as_millis()` or `.as_nanos()` if you want higher precision -} From c9d9f982cda2210fa2084c6e4ee0b17e15b8217b Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 14 Sep 2025 16:16:52 -0300 Subject: [PATCH 07/17] chore: add tracing to inmemory store --- .../src/event_store/in_memory_event_store.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs index 8568536..850bab8 100644 --- a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use reqwest::header::Entry; use rust_mcp_schema::schema_utils::ServerMessages; use std::collections::HashMap; use std::collections::VecDeque; @@ -144,6 +145,10 @@ impl EventStore for InMemoryEventStore { let mut storage_map = self.storage_map.write().await; + tracing::trace!( + "Storing event for session: {session_id}\nstream_id: {stream_id}\nmessage: {message} ", + ); + let session_map = storage_map .entry(session_id) .or_insert_with(|| VecDeque::with_capacity(self.max_events_per_session)); @@ -152,11 +157,13 @@ impl EventStore for InMemoryEventStore { session_map.pop_front(); // remove the oldest if full } - session_map.push_back(EventEntry { + let entry = EventEntry { stream_id, time_stamp, message, - }); + }; + + session_map.push_back(entry); event_id } @@ -241,6 +248,8 @@ impl EventStore for InMemoryEventStore { _ => vec![], }; + tracing::trace!("{} messages after '{last_event_id}'", events.len()); + Some(EventStoreMessages { session_id: session_id.to_string(), stream_id: stream_id.to_string(), From 55a5fd65fa580240b2d1b15cb517e92b848f59f1 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 14 Sep 2025 16:52:20 -0300 Subject: [PATCH 08/17] chore: update examples to use event store --- README.md | 1 + crates/rust-mcp-sdk/README.md | 1 + .../src/hyper_servers/routes/hyper_utils.rs | 31 +++++++++---------- .../src/event_store/in_memory_event_store.rs | 13 ++++++-- .../src/main.rs | 4 +++ .../src/main.rs | 3 ++ 6 files changed, 34 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 44b584e..adf1a58 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,7 @@ let server = hyper_server::create_server( HyperServerOptions { host: "127.0.0.1".to_string(), sse_support: false, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 44b584e..adf1a58 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -153,6 +153,7 @@ let server = hyper_server::create_server( HyperServerOptions { host: "127.0.0.1".to_string(), sse_support: false, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index ee1b533..6724c9b 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -31,15 +31,6 @@ use tokio::io::{duplex, AsyncBufReadExt, BufReader}; const DUPLEX_BUFFER_SIZE: usize = 8192; -async fn store_event( - event_store: Arc, - session_id: SessionId, - stream_id: StreamId, - message_payload: &str, -) -> Option { - None -} - async fn create_sse_stream( runtime: Arc, session_id: SessionId, @@ -118,14 +109,16 @@ async fn create_sse_stream( // store the event for resumption if it is supported if let Some(event_store) = event_store { - event_store - .store_event( - (*session_id).clone(), - (*stream_id).clone(), - current_timestamp(), - trimmed_line.clone(), - ) - .await; + if !is_empty_sse_message(&trimmed_line) { + event_store + .store_event( + (*session_id).clone(), + (*stream_id).clone(), + current_timestamp(), + trimmed_line.clone(), + ) + .await; + } } Some((Ok(Event::default().data(trimmed_line)), reader)) @@ -400,6 +393,10 @@ pub async fn process_incoming_message( } } +pub fn is_empty_sse_message(sse_payload: &str) -> bool { + sse_payload.is_empty() || sse_payload.trim() == ":" +} + pub async fn delete_session( session_id: SessionId, state: Arc, diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs index 850bab8..db8dbf1 100644 --- a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -10,7 +10,7 @@ use crate::{ EventId, SessionId, StreamId, }; -const MAX_EVENTS_PER_SESSION: usize = 32; +const MAX_EVENTS_PER_SESSION: usize = 64; const ID_SEPERATOR: &str = "-.-"; #[derive(Debug, Clone)] @@ -26,6 +26,15 @@ pub struct InMemoryEventStore { storage_map: RwLock>>, } +impl Default for InMemoryEventStore { + fn default() -> Self { + Self { + max_events_per_session: MAX_EVENTS_PER_SESSION, + storage_map: Default::default(), + } + } +} + /// In-memory implementation of the `EventStore` trait for MCP's Streamable HTTP transport. /// /// Stores events in a `HashMap` of session IDs to `VecDeque`s of events, with a per-session limit. @@ -146,7 +155,7 @@ impl EventStore for InMemoryEventStore { let mut storage_map = self.storage_map.write().await; tracing::trace!( - "Storing event for session: {session_id}\nstream_id: {stream_id}\nmessage: {message} ", + "Storing event for session: {session_id}, stream_id: {stream_id}, message: {message} ", ); let session_map = storage_map diff --git a/examples/hello-world-server-streamable-http-core/src/main.rs b/examples/hello-world-server-streamable-http-core/src/main.rs index 7b41c70..81a6ae5 100644 --- a/examples/hello-world-server-streamable-http-core/src/main.rs +++ b/examples/hello-world-server-streamable-http-core/src/main.rs @@ -1,7 +1,10 @@ mod handler; mod tools; +use std::sync::Arc; + use handler::MyServerHandler; +use rust_mcp_sdk::event_store::InMemoryEventStore; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, LATEST_PROTOCOL_VERSION, @@ -48,6 +51,7 @@ async fn main() -> SdkResult<()> { handler, HyperServerOptions { sse_support: true, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); diff --git a/examples/hello-world-server-streamable-http/src/main.rs b/examples/hello-world-server-streamable-http/src/main.rs index cd8c658..3923a6d 100644 --- a/examples/hello-world-server-streamable-http/src/main.rs +++ b/examples/hello-world-server-streamable-http/src/main.rs @@ -1,8 +1,10 @@ mod handler; mod tools; +use std::sync::Arc; use std::time::Duration; +use rust_mcp_sdk::event_store::InMemoryEventStore; use rust_mcp_sdk::mcp_server::{hyper_server, HyperServerOptions}; use handler::MyServerHandler; @@ -57,6 +59,7 @@ async fn main() -> SdkResult<()> { HyperServerOptions { host: "127.0.0.1".to_string(), ping_interval: Duration::from_secs(5), + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); From 805e64f5ec4acdc9c19be63682a8448e89de4b98 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 14 Sep 2025 17:31:30 -0300 Subject: [PATCH 09/17] chore: improve flow --- .../src/hyper_servers/routes/hyper_utils.rs | 29 ++++++++++++++----- .../routes/streamable_http_routes.rs | 9 ++++-- crates/rust-mcp-transport/src/sse.rs | 2 +- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index 6724c9b..bc7d48b 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -23,8 +23,7 @@ use axum::{ use futures::stream; use hyper::{header, HeaderMap, StatusCode}; use rust_mcp_transport::{ - event_store::EventStore, EventId, SessionId, SseTransport, StreamId, - MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, + EventId, SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, }; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; @@ -37,6 +36,7 @@ async fn create_sse_stream( state: Arc, payload: Option<&str>, standalone: bool, + last_event_id: Option, ) -> TransportServerResult> { let payload_string = payload.map(|p| p.to_string()); @@ -87,8 +87,6 @@ async fn create_sse_stream( let _ = runtime.remove_transport(&stream_id_clone).await; }); - // let event_store = state.event_store.; - // Construct SSE stream let reader = BufReader::new(write_rx); let session_id_clone = session_id.clone(); @@ -107,9 +105,15 @@ async fn create_sse_stream( Ok(_) => { let trimmed_line = line.trim_end_matches('\n').to_owned(); + // empty sse comment to keep-alive + if is_empty_sse_message(&trimmed_line) { + return Some((Ok(Event::default()), reader)); + } + + let mut event_id: Option = None; // store the event for resumption if it is supported if let Some(event_store) = event_store { - if !is_empty_sse_message(&trimmed_line) { + event_id = Some( event_store .store_event( (*session_id).clone(), @@ -117,11 +121,16 @@ async fn create_sse_stream( current_timestamp(), trimmed_line.clone(), ) - .await; - } + .await, + ); } - Some((Ok(Event::default().data(trimmed_line)), reader)) + let event = match event_id { + Some(id) => Event::default().data(trimmed_line).id(id), + None => Event::default().data(trimmed_line), + }; + + Some((Ok(event), reader)) } Err(e) => Some((Err(e), reader)), } @@ -176,6 +185,7 @@ fn is_result(json_str: &str) -> Result { pub async fn create_standalone_stream( session_id: SessionId, + last_event_id: Option, state: Arc, ) -> TransportServerResult> { let runtime = state.session_store.get(&session_id).await.ok_or( @@ -195,6 +205,7 @@ pub async fn create_standalone_stream( state.clone(), None, true, + last_event_id, ) .await?; *response.status_mut() = StatusCode::OK; @@ -223,6 +234,7 @@ pub async fn start_new_session( state.clone(), Some(payload), false, + None, ) .await; @@ -382,6 +394,7 @@ pub async fn process_incoming_message( state.clone(), Some(payload), false, + None, ) .await } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 00d46c0..67f8679 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -23,7 +23,7 @@ use axum::{ Json, Router, }; use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::{SessionId, MCP_SESSION_ID_HEADER}; +use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; use std::{collections::HashMap, sync::Arc}; pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { @@ -60,9 +60,14 @@ pub async fn handle_streamable_http_get( .and_then(|value| value.to_str().ok()) .map(|s| s.to_string()); + let last_event_id: Option = headers + .get(MCP_LAST_EVENT_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + match session_id { Some(session_id) => { - let res = create_standalone_stream(session_id, state).await?; + let res = create_standalone_stream(session_id, last_event_id, state).await?; Ok(res.into_response()) } None => { diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index 09809e4..f019c60 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -239,7 +239,7 @@ impl Transport {} Err(TransportError::Io(error)) => { if error.kind() == std::io::ErrorKind::BrokenPipe { From cf156740912792f78d9412705184df6b8c440303 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Mon, 15 Sep 2025 08:40:30 -0300 Subject: [PATCH 10/17] chore: replay mechanism --- .../src/hyper_servers/routes/hyper_utils.rs | 32 ++++++++++++- .../src/mcp_runtimes/server_runtime.rs | 1 + crates/rust-mcp-sdk/tests/common/common.rs | 46 ++++++++++++++++++- .../src/event_store/in_memory_event_store.rs | 7 ++- 4 files changed, 81 insertions(+), 5 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index bc7d48b..dce0c66 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -23,7 +23,8 @@ use axum::{ use futures::stream; use hyper::{header, HeaderMap, StatusCode}; use rust_mcp_transport::{ - EventId, SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, + EventId, McpDispatch, SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, + MCP_SESSION_ID_HEADER, }; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; @@ -74,11 +75,17 @@ async fn create_sse_stream( let ping_interval = state.ping_interval; let runtime_clone = Arc::clone(&runtime); let stream_id_clone = stream_id.clone(); + let transport_clone = transport.clone(); //Start the server runtime tokio::spawn(async move { match runtime_clone - .start_stream(transport, &stream_id_clone, ping_interval, payload_string) + .start_stream( + transport_clone, + &stream_id_clone, + ping_interval, + payload_string, + ) .await { Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone), @@ -148,6 +155,20 @@ async fn create_sse_stream( HeaderValue::from_str(&session_id).unwrap(), ); + // if last_event_id exists we replay messages from the event-store + tokio::spawn(async move { + if let Some(last_event_id) = last_event_id { + if let Some(event_store) = state.event_store.as_ref() { + if let Some(events) = event_store.events_after(last_event_id).await { + for message_payload in events.messages { + let err = transport.write_str(&message_payload).await; + tracing::trace!("Error replaying message...") + } + } + } + } + }); + if !payload_contains_request { *response.status_mut() = StatusCode::ACCEPTED; } @@ -199,6 +220,13 @@ pub async fn create_standalone_stream( return Ok((StatusCode::CONFLICT, Json(error)).into_response()); } + if let Some(last_event_id) = last_event_id.as_ref() { + tracing::trace!( + "SSE stream rec-connected with last-event-id: {}", + last_event_id + ); + } + let mut response = create_sse_stream( runtime.clone(), session_id.clone(), diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 1b24b57..10ad896 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -435,6 +435,7 @@ impl ServerRuntime { }; // in case there is a payload, we consume it by transport to get processed + // payload would be message payload coming from the client if let Some(payload) = payload { if let Err(err) = transport.consume_string_payload(&payload).await { let _ = self.remove_transport(stream_id).await; diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index f330dda..31f9e9f 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -280,9 +280,16 @@ pub fn random_port_old() -> u16 { } pub mod sample_tools { + use std::{sync::Arc, time::Duration}; + + use rust_mcp_schema::{LoggingMessageNotificationParams, TextContent}; #[cfg(feature = "2025_06_18")] use rust_mcp_sdk::macros::{mcp_tool, JsonSchema}; - use rust_mcp_sdk::schema::{schema_utils::CallToolError, CallToolResult}; + use rust_mcp_sdk::{ + schema::{schema_utils::CallToolError, CallToolResult}, + McpServer, + }; + use serde_json::json; //****************// // SayHelloTool // @@ -342,6 +349,43 @@ pub mod sample_tools { return Ok(CallToolResult::text_content(goodbye_message, None)); } } + + //****************************// + // StartNotificationStream // + //****************************// + #[mcp_tool( + name = "start-notification-stream", + description = "Accepts a person's name and says a personalized \"Goodbye\" to that person." + )] + #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] + pub struct StartNotificationStream { + /// Interval in milliseconds between notifications + interval: u64, + /// Number of notifications to send (0 for 100) + count: u32, + } + impl StartNotificationStream { + pub async fn call_tool( + &self, + runtime: Arc, + ) -> Result { + for i in 0..self.count { + let _ = runtime + .send_logging_message(LoggingMessageNotificationParams { + data: json!({"id":format!("message {} of {}",i,self.count)}), + level: rust_mcp_sdk::schema::LoggingLevel::Emergency, + logger: None, + }) + .await; + tokio::time::sleep(Duration::from_millis(self.interval)).await; + } + + let message = format!("so many messages sent"); + Ok(CallToolResult::text_content(vec![TextContent::from( + message, + )])) + } + } } pub async fn wiremock_request(mock_server: &MockServer, index: usize) -> Request { diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs index db8dbf1..272fed0 100644 --- a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -110,7 +110,7 @@ impl InMemoryEventStore { } // Split into exactly three parts - let parts: Vec<&'a str> = event_id.split('.').collect(); + let parts: Vec<&'a str> = event_id.split(ID_SEPERATOR).collect(); if parts.len() != 3 { return None; } @@ -155,7 +155,7 @@ impl EventStore for InMemoryEventStore { let mut storage_map = self.storage_map.write().await; tracing::trace!( - "Storing event for session: {session_id}, stream_id: {stream_id}, message: {message} ", + "Storing event for session: {session_id}, stream_id: {stream_id}, message: {message}, {time_stamp} ", ); let session_map = storage_map @@ -227,15 +227,18 @@ impl EventStore for InMemoryEventStore { /// or `None` if no events are found or the input is invalid. async fn events_after(&self, last_event_id: EventId) -> Option { let Some((session_id, stream_id, time_stamp)) = self.parse_event_id(&last_event_id) else { + tracing::warn!("error parsing last event id: '{last_event_id}'"); return None; }; let storage_map = self.storage_map.read().await; let Some(events) = storage_map.get(session_id) else { + tracing::warn!("could not find the session_id in the store : '{session_id}'"); return None; }; let Ok(time_stamp) = time_stamp.parse::() else { + tracing::warn!("could not parse the timestamp: '{time_stamp}'"); return None; }; From fc4d1a58bd34b2ca5e9d262cefa57b1ed76a5c0c Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Mon, 15 Sep 2025 08:44:52 -0300 Subject: [PATCH 11/17] cleanup --- crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs | 6 ++++-- .../src/event_store/in_memory_event_store.rs | 2 -- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index dce0c66..f7f0874 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -161,8 +161,10 @@ async fn create_sse_stream( if let Some(event_store) = state.event_store.as_ref() { if let Some(events) = event_store.events_after(last_event_id).await { for message_payload in events.messages { - let err = transport.write_str(&message_payload).await; - tracing::trace!("Error replaying message...") + let error = transport.write_str(&message_payload).await; + if let Err(error) = error { + tracing::trace!("Error replaying message: {error}") + } } } } diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs index 272fed0..c31902d 100644 --- a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -1,6 +1,4 @@ use async_trait::async_trait; -use reqwest::header::Entry; -use rust_mcp_schema::schema_utils::ServerMessages; use std::collections::HashMap; use std::collections::VecDeque; use tokio::sync::RwLock; From 4ac720f78187538db5512d701a8acfaf1e79178c Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Mon, 15 Sep 2025 19:12:19 -0300 Subject: [PATCH 12/17] test: add new test for event-store --- crates/rust-mcp-sdk/tests/common/common.rs | 51 ++++--- .../rust-mcp-sdk/tests/common/test_server.rs | 4 + .../tests/test_streamable_http_client.rs | 1 + .../tests/test_streamable_http_server.rs | 130 ++++++++++++++++-- 4 files changed, 152 insertions(+), 34 deletions(-) diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 31f9e9f..fabfb9e 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -132,10 +132,11 @@ pub async fn send_get_request( use futures::stream::Stream; // stream: &mut impl Stream>, +/// reads sse events and return them as (id, event, data) tuple pub async fn read_sse_event_from_stream( stream: &mut (impl Stream> + Unpin), event_count: usize, -) -> Option> { +) -> Option, Option, String)>> { let mut buffer = String::new(); let mut events = vec![]; @@ -146,27 +147,28 @@ pub async fn read_sse_event_from_stream( buffer.push_str(chunk_str); while let Some(pos) = buffer.find("\n\n") { - let data = { - // Scope to limit borrows - let (event_str, rest) = buffer.split_at(pos); - let mut data = None; - - // Process the event string - for line in event_str.lines() { - if line.starts_with("data:") { - data = Some(line.trim_start_matches("data:").trim().to_string()); - break; // Exit loop after finding data - } + let (event_str, rest) = buffer.split_at(pos); + let mut id = None; + let mut event = None; + let mut data = None; + + // Process the event string + for line in event_str.lines() { + if line.starts_with("id:") { + id = Some(line.trim_start_matches("id:").trim().to_string()); + } else if line.starts_with("event:") { + event = Some(line.trim_start_matches("event:").trim().to_string()); + } else if line.starts_with("data:") { + data = Some(line.trim_start_matches("data:").trim().to_string()); } + } - // Update buffer after processing - buffer = rest[2..].to_string(); // Skip "\n\n" - data - }; + // Update buffer after processing + buffer = rest[2..].to_string(); // Skip "\n\n" - // Return if data was found + // Only include events with data if let Some(data) = data { - events.push(data); + events.push((id, event, data)); if events.len().eq(&event_count) { return Some(events); } @@ -174,15 +176,22 @@ pub async fn read_sse_event_from_stream( } } Err(_e) => { - // return Err(TransportServerError::HyperError(e)); return None; } } } - None + if !events.is_empty() { + Some(events) + } else { + None + } } -pub async fn read_sse_event(response: Response, event_count: usize) -> Option> { +// return sse event as (id, event, data) tuple +pub async fn read_sse_event( + response: Response, + event_count: usize, +) -> Option, Option, String)>> { let mut stream = response.bytes_stream(); read_sse_event_from_stream(&mut stream, event_count).await } diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index 769f8c6..d64244b 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -7,6 +7,7 @@ pub mod test_server_common { CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, ProtocolVersion, RpcError, }; + use rust_mcp_sdk::event_store::EventStore; use rust_mcp_sdk::id_generator::IdGenerator; use rust_mcp_sdk::mcp_server::hyper_runtime::HyperRuntime; use rust_mcp_sdk::schema::{ @@ -31,6 +32,7 @@ pub mod test_server_common { pub streamable_url: String, pub sse_url: String, pub sse_message_url: String, + pub event_store: Option>, } pub fn initialize_request() -> InitializeRequest { @@ -120,6 +122,7 @@ pub mod test_server_common { let sse_url = options.sse_url(); let sse_message_url = options.sse_message_url(); + let event_store_clone = options.event_store.clone(); let server = hyper_server::create_server(test_server_details(), TestServerHandler {}, options); @@ -132,6 +135,7 @@ pub mod test_server_common { streamable_url, sse_url, sse_message_url, + event_store: event_store_clone, } } diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs index cb82ff5..1d273e5 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -350,6 +350,7 @@ async fn should_receive_server_initiated_messaged() { streamable_url, sse_url, sse_message_url, + event_store, } = create_start_server(server_options).await; let (client, message_history) = create_client(&streamable_url, None).await; diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 4809d6d..edee3f6 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -12,7 +12,7 @@ use rust_mcp_schema::{ LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult, }; -use rust_mcp_sdk::mcp_server::HyperServerOptions; +use rust_mcp_sdk::{event_store::InMemoryEventStore, mcp_server::HyperServerOptions}; use serde_json::{json, Map, Value}; use crate::common::{ @@ -40,6 +40,7 @@ async fn initialize_server( "AAA-BBB-CCC".to_string() ]))), enable_json_response, + event_store: Some(Arc::new(InMemoryEventStore::default())), ..Default::default() }; @@ -169,7 +170,7 @@ async fn should_handle_post_requests_via_sse_response_correctly() { assert_eq!(response.status(), StatusCode::OK); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -220,7 +221,7 @@ async fn should_call_a_tool_and_return_the_result() { assert_eq!(response.status(), StatusCode::OK); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -345,7 +346,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { .unwrap(); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -429,14 +430,14 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { // read two events from the sse stream let events = read_sse_event(response, 2).await.unwrap(); - let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0]).unwrap(); + let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0].2).unwrap(); let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request else { panic!("invalid message received!"); }; - let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1]).unwrap(); + let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1].2).unwrap(); let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request else { @@ -472,7 +473,7 @@ async fn should_not_close_get_sse_stream() { let mut stream = response.bytes_stream(); let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -501,7 +502,7 @@ async fn should_not_close_get_sse_stream() { .unwrap(); let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification_2, @@ -713,7 +714,7 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert_eq!(response_2.status(), StatusCode::OK); let events = read_sse_event(response_2, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -729,7 +730,7 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() ); let events = read_sse_event(response_1, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -1080,7 +1081,7 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { ); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerMessages = serde_json::from_str(&events[0]).unwrap(); + let message: ServerMessages = serde_json::from_str(&events[0].2).unwrap(); let ServerMessages::Batch(mut messages) = message else { panic!("Invalid message type"); @@ -1358,5 +1359,108 @@ async fn should_skip_all_validations_when_false() { server.hyper_runtime.await_server().await.unwrap() } -//TODO: -// should return 400 error for invalid JSON-RPC messages +#[tokio::test] +async fn should_store_and_include_event_ids_in_server_sse_messages() { + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id).await; + + assert_eq!(response.status(), StatusCode::OK); + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification1"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification2"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + // read two events + let events = read_sse_event(response, 2).await.unwrap(); + assert_eq!(events.len(), 2); + // verify we got the notification with an event ID + let (first_id, _, data) = events[0].clone(); + let (second_id, _, _) = events[0].clone(); + + let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification1"); + + let first_id = first_id.unwrap(); + let second_id = second_id.unwrap(); + + //messages should be stored and accessible + let events = server + .event_store + .unwrap() + .events_after(first_id) + .await + .unwrap(); + assert_eq!(events.messages.len(), 1); + + // deserialize the message returned by event_store + let message: ServerJsonrpcNotification = serde_json::from_str(&events.messages[0]).unwrap(); + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification2, + )) = message.notification + else { + panic!("invalid message in store!"); + }; + assert_eq!(notification2.params.data.as_str().unwrap(), "notification2"); +} + +// TODO: should return 400 error for invalid JSON-RPC messages + +// should store and include event IDs in server SSE messages +// should store and replay MCP server tool notifications +// +// should keep stream open after sending server notifications + +// NA: should reject second initialization request +// NA: should pass request info to tool callback +// NA: should reject second SSE stream even in stateless mode + +// should reject requests to uninitialized server +// should accept requests with matching protocol version +// should accept when protocol version differs from negotiated version +// should call a tool with authInfo +// should calls tool without authInfo when it is optional +// should accept pre-parsed request body +// should handle pre-parsed batch messages +// should prefer pre-parsed body over request body + +// should operate without session ID validation +// should handle POST requests with various session IDs in stateless mode +// should call onsessionclosed callback when session is closed via DELETE +// should not call onsessionclosed callback when not provided +// should not call onsessionclosed callback for invalid session DELETE +// should call onsessionclosed callback with correct session ID when multiple sessions exist +// should support async onsessioninitialized callback +// should support sync onsessioninitialized callback (backwards compatibility) +// should support async onsessionclosed callback +// should propagate errors from async onsessioninitialized callback +// should propagate errors from async onsessionclosed callback +// should handle both async callbacks together +// should validate both host and origin when both are configured From c446a757c9885f8f905944b94f36489207150c96 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 16 Sep 2025 05:58:58 -0300 Subject: [PATCH 13/17] chore: add tracing to tests --- crates/rust-mcp-sdk/tests/common/common.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index fabfb9e..6b78895 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -11,9 +11,11 @@ use rust_mcp_sdk::mcp_client::ClientHandler; use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; use std::collections::HashMap; use std::process; +use std::sync::Once; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::time::timeout; use tokio_stream::StreamExt; +use tracing_subscriber::EnvFilter; use wiremock::{MockServer, Request, ResponseTemplate}; pub use test_client::*; @@ -23,7 +25,17 @@ pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything #[cfg(unix)] pub const UVX_SERVER_GIT: &str = "mcp-server-git"; +static INIT: Once = Once::new(); +pub fn init_tracing() { + INIT.call_once(|| { + let filter = EnvFilter::try_from_default_env() + .or_else(|_| EnvFilter::try_new("tracing")) + .unwrap(); + + tracing_subscriber::fmt().with_env_filter(filter).init(); + }); +} #[mcp_tool( name = "say_hello", description = "Accepts a person's name and says a personalized \"Hello\" to that person", @@ -126,6 +138,7 @@ pub async fn send_get_request( ); } } + client.get(url).headers(headers).send().await } @@ -193,7 +206,9 @@ pub async fn read_sse_event( event_count: usize, ) -> Option, Option, String)>> { let mut stream = response.bytes_stream(); - read_sse_event_from_stream(&mut stream, event_count).await + let events = read_sse_event_from_stream(&mut stream, event_count).await; + // drop(stream); + events } pub fn test_client_info() -> InitializeRequestParams { From cc91b0f1a6edcad68edfb28e908b04ace0945eaf Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 16 Sep 2025 05:59:34 -0300 Subject: [PATCH 14/17] chore: add test --- .../tests/test_streamable_http_server.rs | 98 ++++++++++++++++--- 1 file changed, 87 insertions(+), 11 deletions(-) diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index edee3f6..26c0bb8 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -40,6 +40,7 @@ async fn initialize_server( "AAA-BBB-CCC".to_string() ]))), enable_json_response, + ping_interval: Duration::from_secs(1), event_store: Some(Arc::new(InMemoryEventStore::default())), ..Default::default() }; @@ -291,12 +292,20 @@ async fn should_reject_invalid_session_id() { server.hyper_runtime.await_server().await.unwrap() } -async fn get_standalone_stream(streamable_url: &str, session_id: &str) -> reqwest::Response { +async fn get_standalone_stream( + streamable_url: &str, + session_id: &str, + last_event_id: Option<&str>, +) -> reqwest::Response { let mut headers = HashMap::new(); headers.insert("Accept", "text/event-stream , application/json"); headers.insert("mcp-session-id", session_id); headers.insert("mcp-protocol-version", "2025-03-26"); + if let Some(last_event_id) = last_event_id.clone() { + headers.insert("last-event-id", last_event_id); + } + let response = send_get_request(streamable_url, Some(headers)) .await .unwrap(); @@ -307,7 +316,7 @@ async fn get_standalone_stream(streamable_url: &str, session_id: &str) -> reqwes #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_messages() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -369,7 +378,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_requests() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -454,7 +463,7 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { #[tokio::test] async fn should_not_close_get_sse_stream() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -525,10 +534,10 @@ async fn should_not_close_get_sse_stream() { #[tokio::test] async fn should_reject_second_sse_stream_for_the_same_session() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); - let second_response = get_standalone_stream(&server.streamable_url, &session_id).await; + let second_response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(second_response.status(), StatusCode::CONFLICT); let error_data: SdkError = second_response.json().await.unwrap(); @@ -1359,10 +1368,11 @@ async fn should_skip_all_validations_when_false() { server.hyper_runtime.await_server().await.unwrap() } +// should store and include event IDs in server SSE messages #[tokio::test] async fn should_store_and_include_event_ids_in_server_sse_messages() { let (server, session_id) = initialize_server(Some(true)).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -1409,7 +1419,7 @@ async fn should_store_and_include_event_ids_in_server_sse_messages() { assert_eq!(notification1.params.data.as_str().unwrap(), "notification1"); let first_id = first_id.unwrap(); - let second_id = second_id.unwrap(); + assert!(second_id.is_some()); //messages should be stored and accessible let events = server @@ -1431,10 +1441,76 @@ async fn should_store_and_include_event_ids_in_server_sse_messages() { assert_eq!(notification2.params.data.as_str().unwrap(), "notification2"); } -// TODO: should return 400 error for invalid JSON-RPC messages - -// should store and include event IDs in server SSE messages // should store and replay MCP server tool notifications +#[tokio::test] +async fn should_store_and_replay_mcp_server_tool_notifications() { + common::init_tracing(); + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; + assert_eq!(response.status(), StatusCode::OK); + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification1"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + let events = read_sse_event(response, 1).await.unwrap(); + assert_eq!(events.len(), 1); + // verify we got the notification with an event ID + let (first_id, _, data) = events[0].clone(); + + let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification1"); + + let first_id = first_id.unwrap(); + + // sse connection is closed in read_sse_event() + // wait 4 seconds so server detect the disconnect and simulate a network error + tokio::time::sleep(Duration::from_secs(4)).await; + // we send another notification while SSE is disconnected + let result = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification1"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + println!(">>> result {:?} ", result); + + // make a new standalone SSE connection to simulate a re-connection + let response = + get_standalone_stream(&server.streamable_url, &session_id, Some(&first_id)).await; + assert_eq!(response.status(), StatusCode::OK); + println!(">>> 90 {:?} ", 90); + + let events = read_sse_event(response, 1).await.unwrap(); + println!(">>> 100 {:?} ", 100); + + // assert_eq!(events.len(), 1); + // println!(">>> {:?} ", events); +} + +// TODO: should return 400 error for invalid JSON-RPC messages // // should keep stream open after sending server notifications From 14b072e24c032370d4012295444d489f67b05c06 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 16 Sep 2025 09:43:57 -0300 Subject: [PATCH 15/17] chore: refactor replaying logic --- .../src/hyper_servers/routes/hyper_utils.rs | 66 +++++++-------- .../src/mcp_runtimes/server_runtime.rs | 5 +- crates/rust-mcp-sdk/src/utils.rs | 8 -- .../tests/test_streamable_http_server.rs | 34 ++++---- crates/rust-mcp-transport/src/client_sse.rs | 4 +- .../src/client_streamable_http.rs | 4 +- .../src/event_store/in_memory_event_store.rs | 2 +- .../src/message_dispatcher.rs | 83 +++++++++++++++---- crates/rust-mcp-transport/src/sse.rs | 44 ++++++++-- crates/rust-mcp-transport/src/stdio.rs | 8 +- crates/rust-mcp-transport/src/transport.rs | 2 +- crates/rust-mcp-transport/src/utils.rs | 2 + .../src/utils/time_utils.rs | 8 ++ 13 files changed, 178 insertions(+), 92 deletions(-) create mode 100644 crates/rust-mcp-transport/src/utils/time_utils.rs diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index f7f0874..01464bb 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -23,10 +23,10 @@ use axum::{ use futures::stream; use hyper::{header, HeaderMap, StatusCode}; use rust_mcp_transport::{ - EventId, McpDispatch, SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, - MCP_SESSION_ID_HEADER, + EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, + MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, }; -use std::{sync::Arc, time::Duration}; +use std::{clone, sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; const DUPLEX_BUFFER_SIZE: usize = 8192; @@ -55,16 +55,6 @@ async fn create_sse_stream( // writable stream to deliver message to the client let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - let transport = Arc::new( - SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?, - ); - let session_id = Arc::new(session_id); let stream_id: Arc = if standalone { Arc::new(DEFAULT_STREAM_ID.to_string()) @@ -72,6 +62,21 @@ async fn create_sse_stream( Arc::new(state.stream_id_gen.generate()) }; + let event_store = state.event_store.as_ref().map(Arc::clone); + let resumability_enabled = event_store.is_some(); + + let mut transport = SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + if let Some(event_store) = event_store.clone() { + transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store); + } + let transport = Arc::new(transport); + let ping_interval = state.ping_interval; let runtime_clone = Arc::clone(&runtime); let stream_id_clone = stream_id.clone(); @@ -96,14 +101,9 @@ async fn create_sse_stream( // Construct SSE stream let reader = BufReader::new(write_rx); - let session_id_clone = session_id.clone(); - let event_store = state.event_store.as_ref().map(Arc::clone); // send outgoing messages from server to the client over the sse stream let message_stream = stream::unfold(reader, move |mut reader| { - let session_id = session_id_clone.clone(); - let stream_id = stream_id.clone(); - let event_store = event_store.clone(); async move { let mut line = String::new(); @@ -117,24 +117,17 @@ async fn create_sse_stream( return Some((Ok(Event::default()), reader)); } - let mut event_id: Option = None; - // store the event for resumption if it is supported - if let Some(event_store) = event_store { - event_id = Some( - event_store - .store_event( - (*session_id).clone(), - (*stream_id).clone(), - current_timestamp(), - trimmed_line.clone(), - ) - .await, - ); - } + let (event_id, message) = match ( + resumability_enabled, + trimmed_line.split_once(char::from(ID_SEPARATOR)), + ) { + (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()), + _ => (None, trimmed_line), + }; let event = match event_id { - Some(id) => Event::default().data(trimmed_line).id(id), - None => Event::default().data(trimmed_line), + Some(id) => Event::default().data(message).id(id), + None => Event::default().data(message), }; Some((Ok(event), reader)) @@ -161,7 +154,8 @@ async fn create_sse_stream( if let Some(event_store) = state.event_store.as_ref() { if let Some(events) = event_store.events_after(last_event_id).await { for message_payload in events.messages { - let error = transport.write_str(&message_payload).await; + // skip storing replay messages + let error = transport.write_str(&message_payload, true).await; if let Err(error) = error { tracing::trace!("Error replaying message: {error}") } @@ -224,7 +218,7 @@ pub async fn create_standalone_stream( if let Some(last_event_id) = last_event_id.as_ref() { tracing::trace!( - "SSE stream rec-connected with last-event-id: {}", + "SSE stream re-connected with last-event-id: {}", last_event_id ); } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 10ad896..f29805d 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -368,16 +368,17 @@ impl ServerRuntime { Ok(()) } + //TODO: re-visit and simlify unnecesarry hashmap pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> { if stream_id != DEFAULT_STREAM_ID { return Ok(()); } - let mut transport_map = self.transport_map.write().await; + let transport_map = self.transport_map.read().await; tracing::trace!("removing transport for stream id : {}", stream_id); if let Some(transport) = transport_map.get(stream_id) { transport.shut_down().await?; } - transport_map.remove(stream_id); + // transport_map.remove(stream_id); Ok(()) } diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index cfdce16..16fe7c7 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -3,7 +3,6 @@ use crate::schema::schema_utils::{ClientMessages, SdkError}; use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; -use std::time::{SystemTime, UNIX_EPOCH}; /// A guard type that automatically aborts a Tokio task when dropped. /// @@ -235,13 +234,6 @@ pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> { Ok(()) } -pub fn current_timestamp() -> u128 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Invalid time") - .as_nanos() // or `.as_millis()` or `.as_nanos()` if you want higher precision -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 26c0bb8..af2dce6 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -1371,6 +1371,7 @@ async fn should_skip_all_validations_when_false() { // should store and include event IDs in server SSE messages #[tokio::test] async fn should_store_and_include_event_ids_in_server_sse_messages() { + common::init_tracing(); let (server, session_id) = initialize_server(Some(true)).await.unwrap(); let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; @@ -1480,44 +1481,46 @@ async fn should_store_and_replay_mcp_server_tool_notifications() { let first_id = first_id.unwrap(); // sse connection is closed in read_sse_event() - // wait 4 seconds so server detect the disconnect and simulate a network error - tokio::time::sleep(Duration::from_secs(4)).await; + // wait so server detect the disconnect and simulate a network error + tokio::time::sleep(Duration::from_secs(3)).await; + tokio::task::yield_now().await; // we send another notification while SSE is disconnected - let result = server + let _result = server .hyper_runtime .send_logging_message( &session_id, LoggingMessageNotificationParams { - data: json!("notification1"), + data: json!("notification2"), level: LoggingLevel::Info, logger: None, }, ) .await; - println!(">>> result {:?} ", result); - // make a new standalone SSE connection to simulate a re-connection let response = get_standalone_stream(&server.streamable_url, &session_id, Some(&first_id)).await; assert_eq!(response.status(), StatusCode::OK); - println!(">>> 90 {:?} ", 90); - let events = read_sse_event(response, 1).await.unwrap(); - println!(">>> 100 {:?} ", 100); - // assert_eq!(events.len(), 1); - // println!(">>> {:?} ", events); + assert_eq!(events.len(), 1); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification2"); } -// TODO: should return 400 error for invalid JSON-RPC messages -// +// should return 400 error for invalid JSON-RPC messages // should keep stream open after sending server notifications - // NA: should reject second initialization request // NA: should pass request info to tool callback // NA: should reject second SSE stream even in stateless mode - // should reject requests to uninitialized server // should accept requests with matching protocol version // should accept when protocol version differs from negotiated version @@ -1526,7 +1529,6 @@ async fn should_store_and_replay_mcp_server_tool_notifications() { // should accept pre-parsed request body // should handle pre-parsed batch messages // should prefer pre-parsed body over request body - // should operate without session ID validation // should handle POST requests with various session IDs in stateless mode // should call onsessionclosed callback when session is closed via DELETE diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index 8d55bd0..0a1e8f3 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -457,10 +457,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } diff --git a/crates/rust-mcp-transport/src/client_streamable_http.rs b/crates/rust-mcp-transport/src/client_streamable_http.rs index c318649..edda062 100644 --- a/crates/rust-mcp-transport/src/client_streamable_http.rs +++ b/crates/rust-mcp-transport/src/client_streamable_http.rs @@ -496,10 +496,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs index c31902d..f258567 100644 --- a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -153,7 +153,7 @@ impl EventStore for InMemoryEventStore { let mut storage_map = self.storage_map.write().await; tracing::trace!( - "Storing event for session: {session_id}, stream_id: {stream_id}, message: {message}, {time_stamp} ", + "Storing event for session: {session_id}, stream_id: {stream_id}, message: '{message}', {time_stamp} ", ); let session_map = storage_map diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index 7c7c93e..cd9727c 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -1,13 +1,20 @@ -use crate::schema::{ - schema_utils::{ - self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, ServerMessages, +use crate::error::{TransportError, TransportResult}; +use crate::schema::{RequestId, RpcError}; +use crate::utils::{await_timeout, current_timestamp}; +use crate::McpDispatch; +use crate::{ + event_store::EventStore, + schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, + ServerMessages, + }, + JsonrpcError, }, - JsonrpcError, + SessionId, StreamId, }; -use crate::schema::{RequestId, RpcError}; use async_trait::async_trait; use futures::future::join_all; - use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; @@ -16,9 +23,7 @@ use tokio::io::AsyncWriteExt; use tokio::sync::oneshot::{self}; use tokio::sync::Mutex; -use crate::error::{TransportError, TransportResult}; -use crate::utils::await_timeout; -use crate::McpDispatch; +pub const ID_SEPARATOR: u8 = b'|'; /// Provides a dispatcher for sending MCP messages and handling responses. /// @@ -37,6 +42,10 @@ pub struct MessageDispatcher { )>, >, request_timeout: Duration, + // resumability support + session_id: Option, + stream_id: Option, + event_store: Option>, } impl MessageDispatcher { @@ -60,6 +69,9 @@ impl MessageDispatcher { writable_std: Some(writable_std), writable_tx: None, request_timeout, + session_id: None, + stream_id: None, + event_store: None, } } @@ -76,9 +88,25 @@ impl MessageDispatcher { writable_tx: Some(writable_tx), writable_std: None, request_timeout, + session_id: None, + stream_id: None, + event_store: None, } } + /// Supports resumability for streamable HTTP transports by setting the session ID, + /// stream ID, and event store. + pub fn make_resumable( + &mut self, + session_id: SessionId, + stream_id: StreamId, + event_store: Arc, + ) { + self.session_id = Some(session_id); + self.stream_id = Some(stream_id); + self.event_store = Some(event_store); + } + async fn store_pending_request( &self, request_id: RequestId, @@ -141,7 +169,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), true).await?; if let Some(rx) = rx_response { // Wait for the response with timeout @@ -177,7 +205,7 @@ impl McpDispatch let message_payload = serde_json::to_string(&client_messages).map_err(|_| { crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), true).await?; // no request in the batch, no need to wait for the result if request_ids.is_empty() { @@ -233,7 +261,7 @@ impl McpDispatch /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, _skip_store: bool) -> TransportResult<()> { if let Some(writable_std) = self.writable_std.as_ref() { let mut writable_std = writable_std.lock().await; writable_std.write_all(payload.as_bytes()).await?; @@ -289,7 +317,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), false).await?; if let Some(rx) = rx_response { match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { @@ -317,7 +345,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), false).await?; // no request in the batch, no need to wait for the result if pending_tasks.is_empty() { @@ -375,9 +403,34 @@ impl McpDispatch /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { + let mut event_id = None; + + if !skip_store && !payload.trim().is_empty() { + if let (Some(session_id), Some(stream_id), Some(event_store)) = ( + self.session_id.as_ref(), + self.stream_id.as_ref(), + self.event_store.as_ref(), + ) { + event_id = Some( + event_store + .store_event( + session_id.clone(), + stream_id.clone(), + current_timestamp(), + payload.to_owned(), + ) + .await, + ) + }; + } + if let Some(writable_std) = self.writable_std.as_ref() { let mut writable_std = writable_std.lock().await; + if let Some(id) = event_id { + writable_std.write_all(id.as_bytes()).await?; + writable_std.write_all(&[ID_SEPARATOR]).await?; // separate id from message + } writable_std.write_all(payload.as_bytes()).await?; writable_std.write_all(b"\n").await?; // new line writable_std.flush().await?; diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index f019c60..daaba80 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -1,8 +1,10 @@ +use crate::event_store::EventStore; use crate::schema::schema_utils::{ ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, }; use crate::schema::RequestId; use async_trait::async_trait; +use rust_mcp_schema::BooleanSchema; use serde::de::DeserializeOwned; use std::collections::HashMap; use std::pin::Pin; @@ -19,7 +21,7 @@ use crate::mcp_stream::MCPStream; use crate::message_dispatcher::MessageDispatcher; use crate::transport::Transport; use crate::utils::{endpoint_with_session_id, CancellationTokenSource}; -use crate::{IoStream, McpDispatch, SessionId, TransportDispatcher, TransportOptions}; +use crate::{IoStream, McpDispatch, SessionId, StreamId, TransportDispatcher, TransportOptions}; pub struct SseTransport where @@ -33,6 +35,10 @@ where message_sender: Arc>>>, error_stream: tokio::sync::RwLock>, pending_requests: Arc>>>, + // resumability support + session_id: Option, + stream_id: Option, + event_store: Option>, } /// Server-Sent Events (SSE) transport implementation @@ -67,6 +73,9 @@ where message_sender: Arc::new(tokio::sync::RwLock::new(None)), error_stream: tokio::sync::RwLock::new(None), pending_requests: Arc::new(Mutex::new(HashMap::new())), + session_id: None, + stream_id: None, + event_store: None, }) } @@ -86,6 +95,19 @@ where let mut lock = self.error_stream.write().await; *lock = Some(IoStream::Writable(error_stream)); } + + /// Supports resumability for streamable HTTP transports by setting the session ID, + /// stream ID, and event store. + pub fn make_resumable( + &mut self, + session_id: SessionId, + stream_id: StreamId, + event_store: Arc, + ) { + self.session_id = Some(session_id); + self.stream_id = Some(stream_id); + self.event_store = Some(event_store); + } } #[async_trait] @@ -123,10 +145,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } @@ -161,7 +183,7 @@ impl Transport( + let (stream, mut sender, error_stream) = MCPStream::create::( Box::pin(read_rx), Mutex::new(Box::pin(write_tx)), IoStream::Writable(Box::pin(tokio::io::stderr())), @@ -170,6 +192,18 @@ impl Transport {} Err(TransportError::Io(error)) => { if error.kind() == std::io::ErrorKind::BrokenPipe { diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 11bd0a6..7678c65 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -348,10 +348,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } @@ -400,10 +400,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index b8e3ddc..a9e7190 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -82,7 +82,7 @@ where /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()>; + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()>; } /// A trait representing the transport layer for the MCP (Message Communication Protocol). diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 82d7326..034f062 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -25,6 +25,8 @@ pub(crate) use sse_stream::*; pub(crate) use streamable_http_stream::*; #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use writable_channel::*; +mod time_utils; +pub use time_utils::*; use crate::schema::schema_utils::SdkError; use tokio::time::{timeout, Duration}; diff --git a/crates/rust-mcp-transport/src/utils/time_utils.rs b/crates/rust-mcp-transport/src/utils/time_utils.rs new file mode 100644 index 0000000..25c4f5d --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/time_utils.rs @@ -0,0 +1,8 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +pub fn current_timestamp() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Invalid time") + .as_nanos() +} From feab8d2191d7bc99272a226b5d89b8bedb2494bb Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 16 Sep 2025 09:47:37 -0300 Subject: [PATCH 16/17] chore: cleanup --- crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs | 4 ++-- crates/rust-mcp-transport/src/sse.rs | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index 01464bb..7101a73 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -7,7 +7,7 @@ use crate::{ mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, - utils::{current_timestamp, validate_mcp_protocol_version}, + utils::validate_mcp_protocol_version, }; use crate::schema::schema_utils::{ClientMessage, SdkError}; @@ -26,7 +26,7 @@ use rust_mcp_transport::{ EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, }; -use std::{clone, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; const DUPLEX_BUFFER_SIZE: usize = 8192; diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index daaba80..89ca67f 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -4,7 +4,6 @@ use crate::schema::schema_utils::{ }; use crate::schema::RequestId; use async_trait::async_trait; -use rust_mcp_schema::BooleanSchema; use serde::de::DeserializeOwned; use std::collections::HashMap; use std::pin::Pin; From 46b29615b689eda77775d339afd2d820c9f79be0 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Thu, 18 Sep 2025 18:48:39 -0300 Subject: [PATCH 17/17] typo --- crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs | 2 +- .../src/event_store/in_memory_event_store.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index f29805d..5502cee 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -368,7 +368,7 @@ impl ServerRuntime { Ok(()) } - //TODO: re-visit and simlify unnecesarry hashmap + //TODO: re-visit and simplify unnecessary hashmap pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> { if stream_id != DEFAULT_STREAM_ID { return Ok(()); diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs index f258567..66e738c 100644 --- a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -9,7 +9,7 @@ use crate::{ }; const MAX_EVENTS_PER_SESSION: usize = 64; -const ID_SEPERATOR: &str = "-.-"; +const ID_SEPARATOR: &str = "-.-"; #[derive(Debug, Clone)] struct EventEntry { @@ -76,7 +76,7 @@ impl InMemoryEventStore { stream_id: &StreamId, time_stamp: u128, ) -> String { - format!("{session_id}{ID_SEPERATOR}{stream_id}{ID_SEPERATOR}{time_stamp}") + format!("{session_id}{ID_SEPARATOR}{stream_id}{ID_SEPARATOR}{time_stamp}") } /// Parses an event ID into its session, stream, and timestamp components. @@ -108,7 +108,7 @@ impl InMemoryEventStore { } // Split into exactly three parts - let parts: Vec<&'a str> = event_id.split(ID_SEPERATOR).collect(); + let parts: Vec<&'a str> = event_id.split(ID_SEPARATOR).collect(); if parts.len() != 3 { return None; }