From 14151a0d5d24594efb3ed088df89b989713f7304 Mon Sep 17 00:00:00 2001 From: Buf Generate Date: Wed, 26 Mar 2025 20:05:23 +0800 Subject: [PATCH 1/3] fix java mcp message endpoint --- server/sse.go | 37 ++++++++++++++------- server/sse_test.go | 80 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 12 deletions(-) diff --git a/server/sse.go b/server/sse.go index a869ad1aa..fa325ba27 100644 --- a/server/sse.go +++ b/server/sse.go @@ -52,14 +52,15 @@ var _ ClientSession = (*sseSession)(nil) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { - server *MCPServer - baseURL string - basePath string - messageEndpoint string - sseEndpoint string - sessions sync.Map - srv *http.Server - contextFunc SSEContextFunc + server *MCPServer + baseURL string + basePath string + messageEndpoint string + isCompleteMessageEndpoint bool + sseEndpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc } // SSEOption defines a function type for configuring SSEServer @@ -106,6 +107,13 @@ func WithMessageEndpoint(endpoint string) SSEOption { } } +// WithIsCompleteMessageEndpoint sets the flag for whether the endpoint is for complete messages or not +func WithIsCompleteMessageEndpoint(isCompleteMessageEndpoint bool) SSEOption { + return func(s *SSEServer) { + s.isCompleteMessageEndpoint = isCompleteMessageEndpoint + } +} + // WithSSEEndpoint sets the SSE endpoint path func WithSSEEndpoint(endpoint string) SSEOption { return func(s *SSEServer) { @@ -131,9 +139,10 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption { // NewSSEServer creates a new SSE server instance with the given MCP server and options. func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { s := &SSEServer{ - server: server, - sseEndpoint: "/sse", - messageEndpoint: "/message", + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + isCompleteMessageEndpoint: true, } // Apply all options @@ -244,7 +253,11 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } }() - messageEndpoint := fmt.Sprintf("%s?sessionId=%s", s.CompleteMessageEndpoint(), sessionID) + messageEndpoint := s.messageEndpoint + if s.isCompleteMessageEndpoint { + messageEndpoint = s.CompleteMessageEndpoint() + } + messageEndpoint = fmt.Sprintf("%s?sessionId=%s", messageEndpoint, sessionID) // Send the initial endpoint event fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", messageEndpoint) diff --git a/server/sse_test.go b/server/sse_test.go index ced76e4ee..5a35a0c23 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -418,6 +418,81 @@ func TestSSEServer(t *testing.T) { cancel() }) + t.Run("test isCompleteMessageEndpoint", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + sseServer := NewSSEServer(mcpServer) + + mux := http.NewServeMux() + mux.Handle("/mcp/", sseServer) + + ts := httptest.NewServer(mux) + defer ts.Close() + + sseServer.baseURL = ts.URL + "/mcp" + sseServer.isCompleteMessageEndpoint = false + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/sse", sseServer.baseURL), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Read the endpoint event + buf := make([]byte, 1024) + n, err := resp.Body.Read(buf) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + endpointEvent := string(buf[:n]) + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + if !strings.HasPrefix(messageURL, sseServer.messageEndpoint) { + t.Errorf("Expected messageURL to be %s, got %s", sseServer.messageEndpoint, messageURL) + } + + // The messageURL should already be correct since we set the baseURL correctly + // Test message endpoint + initRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]interface{}{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + requestBody, _ := json.Marshal(initRequest) + + resp, err = http.Post(sseServer.baseURL+messageURL, "application/json", bytes.NewBuffer(requestBody)) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp.StatusCode) + } + + // Clean up SSE connection + cancel() + }) + t.Run("works as http.Handler with custom basePath", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") sseServer := NewSSEServer(mcpServer, WithBasePath("/mcp")) @@ -621,11 +696,13 @@ func TestSSEServer(t *testing.T) { baseURL := "http://localhost:8080/test" messageEndpoint := "/message-test" sseEndpoint := "/sse-test" + isCompleteMessageEndpoint := false srv := &http.Server{} rands := []SSEOption{ WithBasePath(basePath), WithBaseURL(baseURL), WithMessageEndpoint(messageEndpoint), + WithIsCompleteMessageEndpoint(isCompleteMessageEndpoint), WithSSEEndpoint(sseEndpoint), WithHTTPServer(srv), } @@ -641,6 +718,9 @@ func TestSSEServer(t *testing.T) { if sseServer.basePath != basePath { t.Fatalf("basePath %v, got: %v", basePath, sseServer.basePath) } + if sseServer.isCompleteMessageEndpoint != isCompleteMessageEndpoint { + t.Fatalf("isCompleteMessageEndpoint %v, got: %v", isCompleteMessageEndpoint, sseServer.isCompleteMessageEndpoint) + } if sseServer.baseURL != baseURL { t.Fatalf("baseURL %v, got: %v", baseURL, sseServer.baseURL) From 38225de9967618624d2c81c89f4c0e2f53dbad8f Mon Sep 17 00:00:00 2001 From: Buf Generate Date: Wed, 26 Mar 2025 20:18:05 +0800 Subject: [PATCH 2/3] fix java mcp message endpoint --- server/sse.go | 39 +++++++++++++++++++++------------------ server/sse_test.go | 12 ++++++------ 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/server/sse.go b/server/sse.go index fa325ba27..2703d47d8 100644 --- a/server/sse.go +++ b/server/sse.go @@ -52,15 +52,15 @@ var _ ClientSession = (*sseSession)(nil) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { - server *MCPServer - baseURL string - basePath string - messageEndpoint string - isCompleteMessageEndpoint bool - sseEndpoint string - sessions sync.Map - srv *http.Server - contextFunc SSEContextFunc + server *MCPServer + baseURL string + basePath string + messageEndpoint string + useFullURLForMessageEndpoint bool + sseEndpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc } // SSEOption defines a function type for configuring SSEServer @@ -107,10 +107,12 @@ func WithMessageEndpoint(endpoint string) SSEOption { } } -// WithIsCompleteMessageEndpoint sets the flag for whether the endpoint is for complete messages or not -func WithIsCompleteMessageEndpoint(isCompleteMessageEndpoint bool) SSEOption { +// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL) +// or just the path portion for the message endpoint. Set to false when clients will concatenate +// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message". +func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption { return func(s *SSEServer) { - s.isCompleteMessageEndpoint = isCompleteMessageEndpoint + s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint } } @@ -139,10 +141,10 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption { // NewSSEServer creates a new SSE server instance with the given MCP server and options. func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { s := &SSEServer{ - server: server, - sseEndpoint: "/sse", - messageEndpoint: "/message", - isCompleteMessageEndpoint: true, + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + useFullURLForMessageEndpoint: true, } // Apply all options @@ -252,9 +254,10 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } } }() - + // Use either just the path or the complete URL based on configuration. + // This prevents issues with clients that concatenate the base URL themselves. messageEndpoint := s.messageEndpoint - if s.isCompleteMessageEndpoint { + if s.useFullURLForMessageEndpoint { messageEndpoint = s.CompleteMessageEndpoint() } messageEndpoint = fmt.Sprintf("%s?sessionId=%s", messageEndpoint, sessionID) diff --git a/server/sse_test.go b/server/sse_test.go index 5a35a0c23..111c58456 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -418,7 +418,7 @@ func TestSSEServer(t *testing.T) { cancel() }) - t.Run("test isCompleteMessageEndpoint", func(t *testing.T) { + t.Run("test useFullURLForMessageEndpoint", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") sseServer := NewSSEServer(mcpServer) @@ -429,7 +429,7 @@ func TestSSEServer(t *testing.T) { defer ts.Close() sseServer.baseURL = ts.URL + "/mcp" - sseServer.isCompleteMessageEndpoint = false + sseServer.useFullURLForMessageEndpoint = false ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -696,13 +696,13 @@ func TestSSEServer(t *testing.T) { baseURL := "http://localhost:8080/test" messageEndpoint := "/message-test" sseEndpoint := "/sse-test" - isCompleteMessageEndpoint := false + useFullURLForMessageEndpoint := false srv := &http.Server{} rands := []SSEOption{ WithBasePath(basePath), WithBaseURL(baseURL), WithMessageEndpoint(messageEndpoint), - WithIsCompleteMessageEndpoint(isCompleteMessageEndpoint), + WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint), WithSSEEndpoint(sseEndpoint), WithHTTPServer(srv), } @@ -718,8 +718,8 @@ func TestSSEServer(t *testing.T) { if sseServer.basePath != basePath { t.Fatalf("basePath %v, got: %v", basePath, sseServer.basePath) } - if sseServer.isCompleteMessageEndpoint != isCompleteMessageEndpoint { - t.Fatalf("isCompleteMessageEndpoint %v, got: %v", isCompleteMessageEndpoint, sseServer.isCompleteMessageEndpoint) + if sseServer.useFullURLForMessageEndpoint != useFullURLForMessageEndpoint { + t.Fatalf("useFullURLForMessageEndpoint %v, got: %v", useFullURLForMessageEndpoint, sseServer.useFullURLForMessageEndpoint) } if sseServer.baseURL != baseURL { From c6f4e767f82ae9d33e0965777ff5601442d6373e Mon Sep 17 00:00:00 2001 From: Buf Generate Date: Wed, 26 Mar 2025 20:39:57 +0800 Subject: [PATCH 3/3] fix java mcp message endpoint --- server/sse.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/server/sse.go b/server/sse.go index 2703d47d8..ea0a03b8b 100644 --- a/server/sse.go +++ b/server/sse.go @@ -254,16 +254,9 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } } }() - // Use either just the path or the complete URL based on configuration. - // This prevents issues with clients that concatenate the base URL themselves. - messageEndpoint := s.messageEndpoint - if s.useFullURLForMessageEndpoint { - messageEndpoint = s.CompleteMessageEndpoint() - } - messageEndpoint = fmt.Sprintf("%s?sessionId=%s", messageEndpoint, sessionID) // Send the initial endpoint event - fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", messageEndpoint) + fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID)) flusher.Flush() // Main event loop - this runs in the HTTP handler goroutine @@ -280,6 +273,16 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } } +// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID +// based on the useFullURLForMessageEndpoint configuration. +func (s *SSEServer) GetMessageEndpointForClient(sessionID string) string { + messageEndpoint := s.messageEndpoint + if s.useFullURLForMessageEndpoint { + messageEndpoint = s.CompleteMessageEndpoint() + } + return fmt.Sprintf("%s?sessionId=%s", messageEndpoint, sessionID) +} + // handleMessage processes incoming JSON-RPC messages from clients and sends responses // back through both the SSE connection and HTTP response. func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {