diff --git a/client/http_test.go b/client/http_test.go new file mode 100644 index 000000000..3c2e6a3b7 --- /dev/null +++ b/client/http_test.go @@ -0,0 +1,111 @@ +package client + +import ( + "context" + "fmt" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "testing" + "time" +) + +func TestHTTPClient(t *testing.T) { + hooks := &server.Hooks{} + hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { + clientSession := server.ClientSessionFromContext(ctx) + // wait until all the notifications are handled + for len(clientSession.NotificationChannel()) > 0 { + } + time.Sleep(time.Millisecond * 50) + }) + + // Create MCP server with capabilities + mcpServer := server.NewMCPServer( + "test-server", + "1.0.0", + server.WithToolCapabilities(true), + server.WithHooks(hooks), + ) + + mcpServer.AddTool( + mcp.NewTool("notify"), + func( + ctx context.Context, + request mcp.CallToolRequest, + ) (*mcp.CallToolResult, error) { + server := server.ServerFromContext(ctx) + err := server.SendNotificationToClient( + ctx, + "notifications/progress", + map[string]any{ + "progress": 10, + "total": 10, + "progressToken": 0, + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to send notification: %w", err) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "notification sent successfully", + }, + }, + }, nil + }, + ) + + testServer := server.NewTestStreamableHTTPServer(mcpServer) + defer testServer.Close() + + t.Run("Can receive notification from server", func(t *testing.T) { + client, err := NewStreamableHttpClient(testServer.URL) + if err != nil { + t.Fatalf("create client failed %v", err) + return + } + + notificationNum := 0 + client.OnNotification(func(notification mcp.JSONRPCNotification) { + notificationNum += 1 + }) + + ctx := context.Background() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + return + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v\n", err) + } + + request := mcp.CallToolRequest{} + request.Params.Name = "notify" + result, err := client.CallTool(ctx, request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + + if len(result.Content) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Content)) + } + + if notificationNum != 1 { + t.Errorf("Expected 1 notification item, got %d", notificationNum) + } + }) +} diff --git a/server/session.go b/server/session.go index 7b1e12fee..a79da22ca 100644 --- a/server/session.go +++ b/server/session.go @@ -48,6 +48,22 @@ type SessionWithClientInfo interface { SetClientInfo(clientInfo mcp.Implementation) } +// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations +type SessionWithStreamableHTTPConfig interface { + ClientSession + // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server + // sends notifications to the client + // + // The protocol specification: + // - If the server response contains any JSON-RPC notifications, it MUST either: + // - Return Content-Type: text/event-stream to initiate an SSE stream, OR + // - Return Content-Type: application/json for a single JSON object + // - The client MUST support both response types. + // + // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server + UpgradeToSSEWhenReceiveNotification() +} + // clientSessionKey is the context key for storing current client notification channel. type clientSessionKey struct{} @@ -146,6 +162,11 @@ func (s *MCPServer) SendNotificationToClient( return ErrNotificationNotInitialized } + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ @@ -193,6 +214,11 @@ func (s *MCPServer) SendNotificationToSpecificClient( return ErrSessionNotInitialized } + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ diff --git a/server/streamable_http.go b/server/streamable_http.go index b13577a8c..a92e00129 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -243,9 +244,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request // handle potential notifications mu := sync.Mutex{} - upgraded := false + upgradedHeader := false done := make(chan struct{}) - defer close(done) go func() { for { @@ -254,6 +254,12 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request func() { mu.Lock() defer mu.Unlock() + // if the done chan is closed, as the request is terminated, just return + select { + case <-done: + return + default: + } defer func() { flusher, ok := w.(http.Flusher) if ok { @@ -261,13 +267,13 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } }() - // if there's notifications, upgrade to SSE response - if !upgraded { - upgraded = true + // if there's notifications, upgradedHeader to SSE response + if !upgradedHeader { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Connection", "keep-alive") w.Header().Set("Cache-Control", "no-cache") w.WriteHeader(http.StatusAccepted) + upgradedHeader = true } err := writeSSEEvent(w, nt) if err != nil { @@ -294,10 +300,20 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request // Write response mu.Lock() defer mu.Unlock() + // close the done chan before unlock + defer close(done) if ctx.Err() != nil { return } - if upgraded { + // If client-server communication already upgraded to SSE stream + if session.upgradeToSSE.Load() { + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusAccepted) + upgradedHeader = true + } if err := writeSSEEvent(w, response); err != nil { s.logger.Errorf("Failed to write final SSE response event: %v", err) } @@ -494,6 +510,7 @@ type streamableHttpSession struct { sessionID string notificationChannel chan mcp.JSONRPCNotification // server -> client notifications tools *sessionToolsStore + upgradeToSSE atomic.Bool } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { @@ -534,6 +551,12 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { var _ SessionWithTools = (*streamableHttpSession)(nil) +func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { + s.upgradeToSSE.Store(true) +} + +var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) + // --- session id manager --- type SessionIdManager interface {