diff --git a/server/sse.go b/server/sse.go index 0e526be0c..5a28aaa2a 100644 --- a/server/sse.go +++ b/server/sse.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "sync" "github.com/google/uuid" @@ -25,6 +26,7 @@ type sseSession struct { type SSEServer struct { server *MCPServer baseURL string + basePath string messageEndpoint string sseEndpoint string sessions sync.Map @@ -41,6 +43,18 @@ func WithBaseURL(baseURL string) Option { } } +// Add a new option for setting base path +func WithBasePath(basePath string) Option { + return func(s *SSEServer) { + // Ensure the path starts with / and doesn't end with / + if !strings.HasPrefix(basePath, "/") { + basePath = "/" + basePath + } + s.basePath = strings.TrimSuffix(basePath, "/") + s.baseURL = s.baseURL + s.basePath + } +} + // WithMessageEndpoint sets the message endpoint path func WithMessageEndpoint(endpoint string) Option { return func(s *SSEServer) { @@ -68,6 +82,7 @@ func NewSSEServer(server *MCPServer, opts ...Option) *SSEServer { server: server, sseEndpoint: "/sse", messageEndpoint: "/message", + basePath: "", } // Apply all options @@ -299,12 +314,22 @@ func (s *SSEServer) SendEventToSession( // ServeHTTP implements the http.Handler interface. func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case s.sseEndpoint: + path := r.URL.Path + + // Construct the full SSE and message paths + ssePath := s.basePath + s.sseEndpoint + messagePath := s.basePath + s.messageEndpoint + + // Use exact path matching rather than Contains + if path == ssePath { s.handleSSE(w, r) - case s.messageEndpoint: + return + } + + if path == messagePath { s.handleMessage(w, r) - default: - http.NotFound(w, r) + return } + + http.NotFound(w, r) } diff --git a/server/sse_test.go b/server/sse_test.go index 8ea49e26e..bd88cacf3 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -16,7 +16,10 @@ import ( func TestSSEServer(t *testing.T) { t.Run("Can instantiate", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") - sseServer := NewSSEServer(mcpServer, WithBaseURL("http://localhost:8080")) + sseServer := NewSSEServer(mcpServer, + WithBaseURL("http://localhost:8080"), + WithBasePath("/mcp"), + ) if sseServer == nil { t.Error("SSEServer should not be nil") @@ -24,12 +27,18 @@ func TestSSEServer(t *testing.T) { if sseServer.server == nil { t.Error("MCPServer should not be nil") } - if sseServer.baseURL != "http://localhost:8080" { + if sseServer.baseURL != "http://localhost:8080/mcp" { t.Errorf( - "Expected baseURL http://localhost:8080, got %s", + "Expected baseURL http://localhost:8080/mcp, got %s", sseServer.baseURL, ) } + if sseServer.basePath != "/mcp" { + t.Errorf( + "Expected basePath /mcp, got %s", + sseServer.basePath, + ) + } }) t.Run("Can send and receive messages", func(t *testing.T) { @@ -405,4 +414,58 @@ func TestSSEServer(t *testing.T) { // Clean up SSE connection cancel() }) + + t.Run("works as http.Handler with custom basePath", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + sseServer := NewSSEServer(mcpServer, WithBasePath("/mcp")) + + ts := httptest.NewServer(sseServer) + defer ts.Close() + + // Test 404 for unknown path first (simpler case) + resp, err := http.Get(fmt.Sprintf("%s/sse", ts.URL)) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected status 404, got %d", resp.StatusCode) + } + + // Test SSE endpoint with proper cleanup + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sseURL := fmt.Sprintf("%s/sse", ts.URL+sseServer.basePath) + req, err := http.NewRequestWithContext(ctx, "GET", sseURL, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Read initial message in goroutine + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 1024) + _, err := resp.Body.Read(buf) + if err != nil && err.Error() != "context canceled" { + t.Errorf("Failed to read from SSE stream: %v", err) + } + }() + + // Wait briefly for initial response then cancel + time.Sleep(100 * time.Millisecond) + cancel() + <-done + }) }