Skip to content

fix: panic when streamable HTTP server sends notification #348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions client/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package client

import (
"context"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"testing"
"time"
)

func TestHTTPClient(t *testing.T) {
hooks := &server.Hooks{}
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
clientSession := server.ClientSessionFromContext(ctx)
// wait until all the notifications are handled
for len(clientSession.NotificationChannel()) > 0 {
}
time.Sleep(time.Millisecond * 50)
})

// Create MCP server with capabilities
mcpServer := server.NewMCPServer(
"test-server",
"1.0.0",
server.WithToolCapabilities(true),
server.WithHooks(hooks),
)

mcpServer.AddTool(
mcp.NewTool("notify"),
func(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
server := server.ServerFromContext(ctx)
err := server.SendNotificationToClient(
ctx,
"notifications/progress",
map[string]any{
"progress": 10,
"total": 10,
"progressToken": 0,
},
)
if err != nil {
return nil, fmt.Errorf("failed to send notification: %w", err)
}

return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: "notification sent successfully",
},
},
}, nil
},
)

testServer := server.NewTestStreamableHTTPServer(mcpServer)
defer testServer.Close()

t.Run("Can receive notification from server", func(t *testing.T) {
client, err := NewStreamableHttpClient(testServer.URL)
if err != nil {
t.Fatalf("create client failed %v", err)
return
}

notificationNum := 0
client.OnNotification(func(notification mcp.JSONRPCNotification) {
notificationNum += 1
})

ctx := context.Background()

if err := client.Start(ctx); err != nil {
t.Fatalf("Failed to start client: %v", err)
return
}

// Initialize
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "test-client",
Version: "1.0.0",
}

_, err = client.Initialize(ctx, initRequest)
if err != nil {
t.Fatalf("Failed to initialize: %v\n", err)
}

request := mcp.CallToolRequest{}
request.Params.Name = "notify"
result, err := client.CallTool(ctx, request)
if err != nil {
t.Fatalf("CallTool failed: %v", err)
}

if len(result.Content) != 1 {
t.Errorf("Expected 1 content item, got %d", len(result.Content))
}

if notificationNum != 1 {
t.Errorf("Expected 1 notification item, got %d", notificationNum)
}
})
}
26 changes: 26 additions & 0 deletions server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ type SessionWithClientInfo interface {
SetClientInfo(clientInfo mcp.Implementation)
}

// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
type SessionWithStreamableHTTPConfig interface {
ClientSession
// UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server
// sends notifications to the client
//
// The protocol specification:
// - If the server response contains any JSON-RPC notifications, it MUST either:
// - Return Content-Type: text/event-stream to initiate an SSE stream, OR
// - Return Content-Type: application/json for a single JSON object
// - The client MUST support both response types.
//
// Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server
UpgradeToSSEWhenReceiveNotification()
}
Comment on lines +51 to +65
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the documentation here.


// clientSessionKey is the context key for storing current client notification channel.
type clientSessionKey struct{}

Expand Down Expand Up @@ -146,6 +162,11 @@ func (s *MCPServer) SendNotificationToClient(
return ErrNotificationNotInitialized
}

// upgrades the client-server communication to SSE stream when the server sends notifications to the client
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
}

notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Expand Down Expand Up @@ -193,6 +214,11 @@ func (s *MCPServer) SendNotificationToSpecificClient(
return ErrSessionNotInitialized
}

// upgrades the client-server communication to SSE stream when the server sends notifications to the client
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
}

notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Expand Down
35 changes: 29 additions & 6 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -243,9 +244,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request

// handle potential notifications
mu := sync.Mutex{}
upgraded := false
upgradedHeader := false
done := make(chan struct{})
defer close(done)

go func() {
for {
Expand All @@ -254,20 +254,26 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
func() {
mu.Lock()
defer mu.Unlock()
// if the done chan is closed, as the request is terminated, just return
select {
case <-done:
return
default:
}
defer func() {
flusher, ok := w.(http.Flusher)
if ok {
flusher.Flush()
}
}()

// if there's notifications, upgrade to SSE response
if !upgraded {
upgraded = true
// if there's notifications, upgradedHeader to SSE response
if !upgradedHeader {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(http.StatusAccepted)
upgradedHeader = true
}
err := writeSSEEvent(w, nt)
if err != nil {
Expand All @@ -294,10 +300,20 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
// Write response
mu.Lock()
defer mu.Unlock()
// close the done chan before unlock
defer close(done)
if ctx.Err() != nil {
return
}
if upgraded {
// If client-server communication already upgraded to SSE stream
if session.upgradeToSSE.Load() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe add some inline comments to the entire streamable_http.go file for the changes you made?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. comments are added as suggested.

if !upgradedHeader {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(http.StatusAccepted)
upgradedHeader = true
}
if err := writeSSEEvent(w, response); err != nil {
s.logger.Errorf("Failed to write final SSE response event: %v", err)
}
Expand Down Expand Up @@ -494,6 +510,7 @@ type streamableHttpSession struct {
sessionID string
notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
tools *sessionToolsStore
upgradeToSSE atomic.Bool
}

func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession {
Expand Down Expand Up @@ -534,6 +551,12 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {

var _ SessionWithTools = (*streamableHttpSession)(nil)

func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
s.upgradeToSSE.Store(true)
}

var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)

// --- session id manager ---

type SessionIdManager interface {
Expand Down