-
Notifications
You must be signed in to change notification settings - Fork 702
fix: panic when streamable HTTP server sends notification #348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
package client | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"github.com/mark3labs/mcp-go/mcp" | ||
"github.com/mark3labs/mcp-go/server" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestHTTPClient(t *testing.T) { | ||
hooks := &server.Hooks{} | ||
hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { | ||
clientSession := server.ClientSessionFromContext(ctx) | ||
// wait until all the notifications are handled | ||
for len(clientSession.NotificationChannel()) > 0 { | ||
} | ||
time.Sleep(time.Millisecond * 50) | ||
}) | ||
|
||
// Create MCP server with capabilities | ||
mcpServer := server.NewMCPServer( | ||
"test-server", | ||
"1.0.0", | ||
server.WithToolCapabilities(true), | ||
server.WithHooks(hooks), | ||
) | ||
|
||
mcpServer.AddTool( | ||
mcp.NewTool("notify"), | ||
func( | ||
ctx context.Context, | ||
request mcp.CallToolRequest, | ||
) (*mcp.CallToolResult, error) { | ||
server := server.ServerFromContext(ctx) | ||
err := server.SendNotificationToClient( | ||
ctx, | ||
"notifications/progress", | ||
map[string]any{ | ||
"progress": 10, | ||
"total": 10, | ||
"progressToken": 0, | ||
}, | ||
) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to send notification: %w", err) | ||
} | ||
|
||
return &mcp.CallToolResult{ | ||
Content: []mcp.Content{ | ||
mcp.TextContent{ | ||
Type: "text", | ||
Text: "notification sent successfully", | ||
}, | ||
}, | ||
}, nil | ||
}, | ||
) | ||
|
||
testServer := server.NewTestStreamableHTTPServer(mcpServer) | ||
defer testServer.Close() | ||
|
||
t.Run("Can receive notification from server", func(t *testing.T) { | ||
client, err := NewStreamableHttpClient(testServer.URL) | ||
if err != nil { | ||
t.Fatalf("create client failed %v", err) | ||
return | ||
} | ||
|
||
notificationNum := 0 | ||
client.OnNotification(func(notification mcp.JSONRPCNotification) { | ||
notificationNum += 1 | ||
}) | ||
|
||
ctx := context.Background() | ||
|
||
if err := client.Start(ctx); err != nil { | ||
t.Fatalf("Failed to start client: %v", err) | ||
return | ||
} | ||
|
||
// Initialize | ||
initRequest := mcp.InitializeRequest{} | ||
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION | ||
initRequest.Params.ClientInfo = mcp.Implementation{ | ||
Name: "test-client", | ||
Version: "1.0.0", | ||
} | ||
|
||
_, err = client.Initialize(ctx, initRequest) | ||
if err != nil { | ||
t.Fatalf("Failed to initialize: %v\n", err) | ||
} | ||
|
||
request := mcp.CallToolRequest{} | ||
request.Params.Name = "notify" | ||
result, err := client.CallTool(ctx, request) | ||
if err != nil { | ||
t.Fatalf("CallTool failed: %v", err) | ||
} | ||
|
||
if len(result.Content) != 1 { | ||
t.Errorf("Expected 1 content item, got %d", len(result.Content)) | ||
} | ||
|
||
if notificationNum != 1 { | ||
t.Errorf("Expected 1 notification item, got %d", notificationNum) | ||
} | ||
}) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,22 @@ type SessionWithClientInfo interface { | |
SetClientInfo(clientInfo mcp.Implementation) | ||
} | ||
|
||
// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations | ||
type SessionWithStreamableHTTPConfig interface { | ||
ClientSession | ||
// UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server | ||
// sends notifications to the client | ||
// | ||
// The protocol specification: | ||
// - If the server response contains any JSON-RPC notifications, it MUST either: | ||
// - Return Content-Type: text/event-stream to initiate an SSE stream, OR | ||
// - Return Content-Type: application/json for a single JSON object | ||
// - The client MUST support both response types. | ||
// | ||
// Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server | ||
UpgradeToSSEWhenReceiveNotification() | ||
} | ||
Comment on lines
+51
to
+65
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the documentation here. |
||
|
||
// clientSessionKey is the context key for storing current client notification channel. | ||
type clientSessionKey struct{} | ||
|
||
|
@@ -146,6 +162,11 @@ func (s *MCPServer) SendNotificationToClient( | |
return ErrNotificationNotInitialized | ||
} | ||
|
||
// upgrades the client-server communication to SSE stream when the server sends notifications to the client | ||
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { | ||
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() | ||
} | ||
|
||
notification := mcp.JSONRPCNotification{ | ||
JSONRPC: mcp.JSONRPC_VERSION, | ||
Notification: mcp.Notification{ | ||
|
@@ -193,6 +214,11 @@ func (s *MCPServer) SendNotificationToSpecificClient( | |
return ErrSessionNotInitialized | ||
} | ||
|
||
// upgrades the client-server communication to SSE stream when the server sends notifications to the client | ||
if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { | ||
sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() | ||
} | ||
|
||
notification := mcp.JSONRPCNotification{ | ||
JSONRPC: mcp.JSONRPC_VERSION, | ||
Notification: mcp.Notification{ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ import ( | |
"net/http/httptest" | ||
"strings" | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
|
||
"github.com/google/uuid" | ||
|
@@ -243,9 +244,8 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request | |
|
||
// handle potential notifications | ||
mu := sync.Mutex{} | ||
upgraded := false | ||
upgradedHeader := false | ||
done := make(chan struct{}) | ||
defer close(done) | ||
|
||
go func() { | ||
for { | ||
|
@@ -254,20 +254,26 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request | |
func() { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
// if the done chan is closed, as the request is terminated, just return | ||
select { | ||
case <-done: | ||
return | ||
default: | ||
} | ||
defer func() { | ||
flusher, ok := w.(http.Flusher) | ||
if ok { | ||
flusher.Flush() | ||
} | ||
}() | ||
|
||
// if there's notifications, upgrade to SSE response | ||
if !upgraded { | ||
upgraded = true | ||
// if there's notifications, upgradedHeader to SSE response | ||
if !upgradedHeader { | ||
w.Header().Set("Content-Type", "text/event-stream") | ||
w.Header().Set("Connection", "keep-alive") | ||
w.Header().Set("Cache-Control", "no-cache") | ||
w.WriteHeader(http.StatusAccepted) | ||
upgradedHeader = true | ||
} | ||
err := writeSSEEvent(w, nt) | ||
if err != nil { | ||
|
@@ -294,10 +300,20 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request | |
// Write response | ||
mu.Lock() | ||
defer mu.Unlock() | ||
// close the done chan before unlock | ||
defer close(done) | ||
if ctx.Err() != nil { | ||
return | ||
} | ||
if upgraded { | ||
// If client-server communication already upgraded to SSE stream | ||
if session.upgradeToSSE.Load() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Maybe add some inline comments to the entire There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your review. comments are added as suggested. |
||
if !upgradedHeader { | ||
w.Header().Set("Content-Type", "text/event-stream") | ||
w.Header().Set("Connection", "keep-alive") | ||
w.Header().Set("Cache-Control", "no-cache") | ||
w.WriteHeader(http.StatusAccepted) | ||
upgradedHeader = true | ||
} | ||
if err := writeSSEEvent(w, response); err != nil { | ||
s.logger.Errorf("Failed to write final SSE response event: %v", err) | ||
} | ||
|
@@ -494,6 +510,7 @@ type streamableHttpSession struct { | |
sessionID string | ||
notificationChannel chan mcp.JSONRPCNotification // server -> client notifications | ||
tools *sessionToolsStore | ||
upgradeToSSE atomic.Bool | ||
} | ||
|
||
func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { | ||
|
@@ -534,6 +551,12 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { | |
|
||
var _ SessionWithTools = (*streamableHttpSession)(nil) | ||
|
||
func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { | ||
s.upgradeToSSE.Store(true) | ||
} | ||
|
||
var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) | ||
|
||
// --- session id manager --- | ||
|
||
type SessionIdManager interface { | ||
|
Uh oh!
There was an error while loading. Please reload this page.