From 6361a2eb75531b364c80684d3dfdec858b7bcab5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Wed, 28 May 2025 22:07:03 +0800 Subject: [PATCH 1/5] fix for streamable server to send notification --- server/session.go | 26 ++++++++++++++++++++++++++ server/streamable_http.go | 16 ++++++++++++---- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/server/session.go b/server/session.go index 7b1e12fee..a79da22ca 100644 --- a/server/session.go +++ b/server/session.go @@ -48,6 +48,22 @@ type SessionWithClientInfo interface { SetClientInfo(clientInfo mcp.Implementation) } +// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations +type SessionWithStreamableHTTPConfig interface { + ClientSession + // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server + // sends notifications to the client + // + // The protocol specification: + // - If the server response contains any JSON-RPC notifications, it MUST either: + // - Return Content-Type: text/event-stream to initiate an SSE stream, OR + // - Return Content-Type: application/json for a single JSON object + // - The client MUST support both response types. + // + // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server + UpgradeToSSEWhenReceiveNotification() +} + // clientSessionKey is the context key for storing current client notification channel. type clientSessionKey struct{} @@ -146,6 +162,11 @@ func (s *MCPServer) SendNotificationToClient( return ErrNotificationNotInitialized } + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ @@ -193,6 +214,11 @@ func (s *MCPServer) SendNotificationToSpecificClient( return ErrSessionNotInitialized } + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ diff --git a/server/streamable_http.go b/server/streamable_http.go index b13577a8c..5f6822316 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -243,11 +244,11 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request // handle potential notifications mu := sync.Mutex{} - upgraded := false done := make(chan struct{}) defer close(done) go func() { + alreadyUpgraded := false for { select { case nt := <-session.notificationChannel: @@ -262,12 +263,12 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request }() // if there's notifications, upgrade to SSE response - if !upgraded { - upgraded = true + if (!alreadyUpgraded) && !session.upgradeToSSE.Load() { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Connection", "keep-alive") w.Header().Set("Cache-Control", "no-cache") w.WriteHeader(http.StatusAccepted) + alreadyUpgraded = true } err := writeSSEEvent(w, nt) if err != nil { @@ -297,7 +298,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request if ctx.Err() != nil { return } - if upgraded { + if session.upgradeToSSE.Load() { if err := writeSSEEvent(w, response); err != nil { s.logger.Errorf("Failed to write final SSE response event: %v", err) } @@ -494,6 +495,7 @@ type streamableHttpSession struct { sessionID string notificationChannel chan mcp.JSONRPCNotification // server -> client notifications tools *sessionToolsStore + upgradeToSSE atomic.Bool } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { @@ -534,6 +536,12 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { var _ SessionWithTools = (*streamableHttpSession)(nil) +func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { + s.upgradeToSSE.Store(true) +} + +var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) + // --- session id manager --- type SessionIdManager interface { From fb64448309eb0b1e953d4e009a53649393274d9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Wed, 28 May 2025 22:25:36 +0800 Subject: [PATCH 2/5] update --- server/streamable_http.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index 5f6822316..f174e2a39 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -248,7 +248,6 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request defer close(done) go func() { - alreadyUpgraded := false for { select { case nt := <-session.notificationChannel: @@ -263,13 +262,10 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request }() // if there's notifications, upgrade to SSE response - if (!alreadyUpgraded) && !session.upgradeToSSE.Load() { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Cache-Control", "no-cache") - w.WriteHeader(http.StatusAccepted) - alreadyUpgraded = true - } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusAccepted) err := writeSSEEvent(w, nt) if err != nil { s.logger.Errorf("Failed to write SSE event: %v", err) @@ -299,6 +295,10 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request return } if session.upgradeToSSE.Load() { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusAccepted) if err := writeSSEEvent(w, response); err != nil { s.logger.Errorf("Failed to write final SSE response event: %v", err) } From 6c97a1b516c3b086110c73595a88c4c5e246a417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Wed, 28 May 2025 22:53:05 +0800 Subject: [PATCH 3/5] update --- client/http_test.go | 112 ++++++++++++++++++++++++++++++++++++++ server/streamable_http.go | 25 ++++++--- 2 files changed, 128 insertions(+), 9 deletions(-) create mode 100644 client/http_test.go diff --git a/client/http_test.go b/client/http_test.go new file mode 100644 index 000000000..b4595d237 --- /dev/null +++ b/client/http_test.go @@ -0,0 +1,112 @@ +package client + +import ( + "context" + "fmt" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "testing" + "time" +) + +func TestHTTPClient(t *testing.T) { + hooks := &server.Hooks{} + hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { + clientSession := server.ClientSessionFromContext(ctx) + // return until all the notifications are handled + for len(clientSession.NotificationChannel()) > 0 { + } + time.Sleep(time.Millisecond * 200) + }) + + // Create MCP server with capabilities + mcpServer := server.NewMCPServer( + "test-server", + "1.0.0", + server.WithToolCapabilities(true), + server.WithHooks(hooks), + ) + + mcpServer.AddTool( + mcp.NewTool("notify"), + func( + ctx context.Context, + request mcp.CallToolRequest, + ) (*mcp.CallToolResult, error) { + server := server.ServerFromContext(ctx) + err := server.SendNotificationToClient( + ctx, + "notifications/progress", + map[string]any{ + "progress": 10, + "total": 10, + "progressToken": 0, + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to send notification: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "notification sent successfully", + }, + }, + }, nil + }, + ) + + // Initialize + testServer := server.NewTestStreamableHTTPServer(mcpServer) + defer testServer.Close() + + t.Run("Can create client", func(t *testing.T) { + client, err := NewStreamableHttpClient(testServer.URL) + if err != nil { + t.Fatalf("create client failed %v", err) + return + } + + notificationNum := 0 + client.OnNotification(func(notification mcp.JSONRPCNotification) { + notificationNum += 1 + }) + + ctx := context.Background() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + return + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v\n", err) + } + + request := mcp.CallToolRequest{} + request.Params.Name = "notify" + result, err := client.CallTool(ctx, request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + + if len(result.Content) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Content)) + } + + if notificationNum != 1 { + t.Errorf("Expected 1 notification item, got %d", notificationNum) + } + }) +} diff --git a/server/streamable_http.go b/server/streamable_http.go index f174e2a39..838cc16e5 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -244,6 +244,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request // handle potential notifications mu := sync.Mutex{} + upgradedHeader := false done := make(chan struct{}) defer close(done) @@ -261,11 +262,14 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } }() - // if there's notifications, upgrade to SSE response - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Cache-Control", "no-cache") - w.WriteHeader(http.StatusAccepted) + // if there's notifications, upgradedHeader to SSE response + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusAccepted) + upgradedHeader = true + } err := writeSSEEvent(w, nt) if err != nil { s.logger.Errorf("Failed to write SSE event: %v", err) @@ -295,10 +299,13 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request return } if session.upgradeToSSE.Load() { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Cache-Control", "no-cache") - w.WriteHeader(http.StatusAccepted) + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusAccepted) + upgradedHeader = true + } if err := writeSSEEvent(w, response); err != nil { s.logger.Errorf("Failed to write final SSE response event: %v", err) } From dff933e6e572608fc5ab5b2a0205d638e216eaef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Thu, 29 May 2025 00:11:23 +0800 Subject: [PATCH 4/5] update --- client/http_test.go | 7 +++---- server/streamable_http.go | 9 ++++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/client/http_test.go b/client/http_test.go index b4595d237..3c2e6a3b7 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -13,10 +13,10 @@ func TestHTTPClient(t *testing.T) { hooks := &server.Hooks{} hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { clientSession := server.ClientSessionFromContext(ctx) - // return until all the notifications are handled + // wait until all the notifications are handled for len(clientSession.NotificationChannel()) > 0 { } - time.Sleep(time.Millisecond * 200) + time.Sleep(time.Millisecond * 50) }) // Create MCP server with capabilities @@ -58,11 +58,10 @@ func TestHTTPClient(t *testing.T) { }, ) - // Initialize testServer := server.NewTestStreamableHTTPServer(mcpServer) defer testServer.Close() - t.Run("Can create client", func(t *testing.T) { + t.Run("Can receive notification from server", func(t *testing.T) { client, err := NewStreamableHttpClient(testServer.URL) if err != nil { t.Fatalf("create client failed %v", err) diff --git a/server/streamable_http.go b/server/streamable_http.go index 838cc16e5..b0b5c5720 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -246,7 +246,6 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request mu := sync.Mutex{} upgradedHeader := false done := make(chan struct{}) - defer close(done) go func() { for { @@ -255,6 +254,12 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request func() { mu.Lock() defer mu.Unlock() + // if the done chan is closed, as the request is terminated, just return + select { + case <-done: + return + default: + } defer func() { flusher, ok := w.(http.Flusher) if ok { @@ -295,6 +300,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request // Write response mu.Lock() defer mu.Unlock() + // close the done chan before unlock + defer close(done) if ctx.Err() != nil { return } From f9bbaf50f19843a34f2196844c6e7e236408cca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Thu, 29 May 2025 20:23:57 +0800 Subject: [PATCH 5/5] add comments for changes --- server/streamable_http.go | 1 + 1 file changed, 1 insertion(+) diff --git a/server/streamable_http.go b/server/streamable_http.go index b0b5c5720..a92e00129 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -305,6 +305,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request if ctx.Err() != nil { return } + // If client-server communication already upgraded to SSE stream if session.upgradeToSSE.Load() { if !upgradedHeader { w.Header().Set("Content-Type", "text/event-stream")