Skip to content

fix java mcp message endpoint #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ var _ ClientSession = (*sseSession)(nil)
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
server *MCPServer
baseURL string
basePath string
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc
server *MCPServer
baseURL string
basePath string
messageEndpoint string
useFullURLForMessageEndpoint bool
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc
}

// SSEOption defines a function type for configuring SSEServer
Expand Down Expand Up @@ -106,6 +107,15 @@ func WithMessageEndpoint(endpoint string) SSEOption {
}
}

// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL)
// or just the path portion for the message endpoint. Set to false when clients will concatenate
// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message".
func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption {
return func(s *SSEServer) {
s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint
}
}

// WithSSEEndpoint sets the SSE endpoint path
func WithSSEEndpoint(endpoint string) SSEOption {
return func(s *SSEServer) {
Expand All @@ -131,9 +141,10 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
s := &SSEServer{
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
}

// Apply all options
Expand Down Expand Up @@ -244,10 +255,8 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}
}()

messageEndpoint := fmt.Sprintf("%s?sessionId=%s", s.CompleteMessageEndpoint(), sessionID)

// Send the initial endpoint event
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", messageEndpoint)
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID))
flusher.Flush()

// Main event loop - this runs in the HTTP handler goroutine
Expand All @@ -264,6 +273,16 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}
}

// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID
// based on the useFullURLForMessageEndpoint configuration.
func (s *SSEServer) GetMessageEndpointForClient(sessionID string) string {
messageEndpoint := s.messageEndpoint
if s.useFullURLForMessageEndpoint {
messageEndpoint = s.CompleteMessageEndpoint()
}
return fmt.Sprintf("%s?sessionId=%s", messageEndpoint, sessionID)
}

// handleMessage processes incoming JSON-RPC messages from clients and sends responses
// back through both the SSE connection and HTTP response.
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
Expand Down
80 changes: 80 additions & 0 deletions server/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,81 @@ func TestSSEServer(t *testing.T) {
cancel()
})

t.Run("test useFullURLForMessageEndpoint", func(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")
sseServer := NewSSEServer(mcpServer)

mux := http.NewServeMux()
mux.Handle("/mcp/", sseServer)

ts := httptest.NewServer(mux)
defer ts.Close()

sseServer.baseURL = ts.URL + "/mcp"
sseServer.useFullURLForMessageEndpoint = false
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/sse", sseServer.baseURL), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}

// Read the endpoint event
buf := make([]byte, 1024)
n, err := resp.Body.Read(buf)
if err != nil {
t.Fatalf("Failed to read SSE response: %v", err)
}

endpointEvent := string(buf[:n])
messageURL := strings.TrimSpace(
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
)
if !strings.HasPrefix(messageURL, sseServer.messageEndpoint) {
t.Errorf("Expected messageURL to be %s, got %s", sseServer.messageEndpoint, messageURL)
}

// The messageURL should already be correct since we set the baseURL correctly
// Test message endpoint
initRequest := map[string]interface{}{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": map[string]interface{}{
"protocolVersion": "2024-11-05",
"clientInfo": map[string]interface{}{
"name": "test-client",
"version": "1.0.0",
},
},
}
requestBody, _ := json.Marshal(initRequest)

resp, err = http.Post(sseServer.baseURL+messageURL, "application/json", bytes.NewBuffer(requestBody))
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusAccepted {
t.Errorf("Expected status 202, got %d", resp.StatusCode)
}

// Clean up SSE connection
cancel()
})

t.Run("works as http.Handler with custom basePath", func(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")
sseServer := NewSSEServer(mcpServer, WithBasePath("/mcp"))
Expand Down Expand Up @@ -621,11 +696,13 @@ func TestSSEServer(t *testing.T) {
baseURL := "http://localhost:8080/test"
messageEndpoint := "/message-test"
sseEndpoint := "/sse-test"
useFullURLForMessageEndpoint := false
srv := &http.Server{}
rands := []SSEOption{
WithBasePath(basePath),
WithBaseURL(baseURL),
WithMessageEndpoint(messageEndpoint),
WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint),
WithSSEEndpoint(sseEndpoint),
WithHTTPServer(srv),
}
Expand All @@ -641,6 +718,9 @@ func TestSSEServer(t *testing.T) {
if sseServer.basePath != basePath {
t.Fatalf("basePath %v, got: %v", basePath, sseServer.basePath)
}
if sseServer.useFullURLForMessageEndpoint != useFullURLForMessageEndpoint {
t.Fatalf("useFullURLForMessageEndpoint %v, got: %v", useFullURLForMessageEndpoint, sseServer.useFullURLForMessageEndpoint)
}

if sseServer.baseURL != baseURL {
t.Fatalf("baseURL %v, got: %v", baseURL, sseServer.baseURL)
Expand Down