Skip to content

Commit 807a9ba

Browse files
a67793581Buf Generate
andauthored
fix SSEOption (#63)
* fix SSEOption * fix SSEOption * fix SSEOption * fix SSEOption * fix SSEOption --------- Co-authored-by: Buf Generate <[email protected]>
1 parent e183dd1 commit 807a9ba

File tree

2 files changed

+100
-20
lines changed

2 files changed

+100
-20
lines changed

server/sse.go

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"net/http"
88
"net/http/httptest"
9+
"net/url"
910
"strings"
1011
"sync"
1112

@@ -57,7 +58,23 @@ type SSEOption func(*SSEServer)
5758
// WithBaseURL sets the base URL for the SSE server
5859
func WithBaseURL(baseURL string) SSEOption {
5960
return func(s *SSEServer) {
60-
s.baseURL = baseURL
61+
if baseURL != "" {
62+
u, err := url.Parse(baseURL)
63+
if err != nil {
64+
return
65+
}
66+
if u.Scheme != "http" && u.Scheme != "https" {
67+
return
68+
}
69+
// Check if the host is empty or only contains a port
70+
if u.Host == "" || strings.HasPrefix(u.Host, ":") {
71+
return
72+
}
73+
if len(u.Query()) > 0 {
74+
return
75+
}
76+
}
77+
s.baseURL = strings.TrimSuffix(baseURL, "/")
6178
}
6279
}
6380

@@ -69,7 +86,6 @@ func WithBasePath(basePath string) SSEOption {
6986
basePath = "/" + basePath
7087
}
7188
s.basePath = strings.TrimSuffix(basePath, "/")
72-
s.baseURL = s.baseURL + s.basePath
7389
}
7490
}
7591

@@ -108,7 +124,6 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
108124
server: server,
109125
sseEndpoint: "/sse",
110126
messageEndpoint: "/message",
111-
basePath: "",
112127
}
113128

114129
// Apply all options
@@ -219,12 +234,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
219234
}
220235
}()
221236

222-
messageEndpoint := fmt.Sprintf(
223-
"%s%s?sessionId=%s",
224-
s.baseURL,
225-
s.messageEndpoint,
226-
sessionID,
227-
)
237+
messageEndpoint := fmt.Sprintf("%s?sessionId=%s", s.CompleteMessageEndpoint(), sessionID)
228238

229239
// Send the initial endpoint event
230240
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", messageEndpoint)
@@ -345,22 +355,47 @@ func (s *SSEServer) SendEventToSession(
345355
return fmt.Errorf("event queue full")
346356
}
347357
}
358+
func (s *SSEServer) GetUrlPath(input string) (string, error) {
359+
parse, err := url.Parse(input)
360+
if err != nil {
361+
return "", fmt.Errorf("failed to parse URL %s: %w", input, err)
362+
}
363+
return parse.Path, nil
364+
}
365+
366+
func (s *SSEServer) CompleteSseEndpoint() string {
367+
return s.baseURL + s.basePath + s.sseEndpoint
368+
}
369+
func (s *SSEServer) CompleteSsePath() string {
370+
path, err := s.GetUrlPath(s.CompleteSseEndpoint())
371+
if err != nil {
372+
return s.basePath + s.sseEndpoint
373+
}
374+
return path
375+
}
376+
377+
func (s *SSEServer) CompleteMessageEndpoint() string {
378+
return s.baseURL + s.basePath + s.messageEndpoint
379+
}
380+
func (s *SSEServer) CompleteMessagePath() string {
381+
path, err := s.GetUrlPath(s.CompleteMessageEndpoint())
382+
if err != nil {
383+
return s.basePath + s.messageEndpoint
384+
}
385+
return path
386+
}
348387

349388
// ServeHTTP implements the http.Handler interface.
350389
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
351390
path := r.URL.Path
352-
353-
// Construct the full SSE and message paths
354-
ssePath := s.basePath + s.sseEndpoint
355-
messagePath := s.basePath + s.messageEndpoint
356-
357391
// Use exact path matching rather than Contains
358-
if path == ssePath {
392+
ssePath := s.CompleteSsePath()
393+
if ssePath != "" && path == ssePath {
359394
s.handleSSE(w, r)
360395
return
361396
}
362-
363-
if path == messagePath {
397+
messagePath := s.CompleteMessagePath()
398+
if messagePath != "" && path == messagePath {
364399
s.handleMessage(w, r)
365400
return
366401
}

server/sse_test.go

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8+
"math/rand"
89
"net/http"
910
"net/http/httptest"
1011
"strings"
@@ -29,9 +30,9 @@ func TestSSEServer(t *testing.T) {
2930
if sseServer.server == nil {
3031
t.Error("MCPServer should not be nil")
3132
}
32-
if sseServer.baseURL != "http://localhost:8080/mcp" {
33+
if sseServer.baseURL != "http://localhost:8080" {
3334
t.Errorf(
34-
"Expected baseURL http://localhost:8080/mcp, got %s",
35+
"Expected baseURL http://localhost:8080, got %s",
3536
sseServer.baseURL,
3637
)
3738
}
@@ -350,7 +351,7 @@ func TestSSEServer(t *testing.T) {
350351
sseServer := NewSSEServer(mcpServer)
351352

352353
mux := http.NewServeMux()
353-
mux.Handle("/mcp/", http.StripPrefix("/mcp", sseServer))
354+
mux.Handle("/mcp/", sseServer)
354355

355356
ts := httptest.NewServer(mux)
356357
defer ts.Close()
@@ -614,4 +615,48 @@ func TestSSEServer(t *testing.T) {
614615
}
615616
})
616617

618+
t.Run("SSEOption should not have negative effects when used repeatedly but should always remain consistent.", func(t *testing.T) {
619+
mcpServer := NewMCPServer("test", "1.0.0")
620+
basePath := "/mcp-test"
621+
baseURL := "http://localhost:8080/test"
622+
messageEndpoint := "/message-test"
623+
sseEndpoint := "/sse-test"
624+
srv := &http.Server{}
625+
rands := []SSEOption{
626+
WithBasePath(basePath),
627+
WithBaseURL(baseURL),
628+
WithMessageEndpoint(messageEndpoint),
629+
WithSSEEndpoint(sseEndpoint),
630+
WithHTTPServer(srv),
631+
}
632+
for i := 0; i < 100; i++ {
633+
634+
var options []SSEOption
635+
for i2 := 0; i2 < 100; i2++ {
636+
index := rand.Intn(len(rands))
637+
options = append(options, rands[index])
638+
}
639+
sseServer := NewSSEServer(mcpServer, options...)
640+
641+
if sseServer.basePath != basePath {
642+
t.Fatalf("basePath %v, got: %v", basePath, sseServer.basePath)
643+
}
644+
645+
if sseServer.baseURL != baseURL {
646+
t.Fatalf("baseURL %v, got: %v", baseURL, sseServer.baseURL)
647+
}
648+
649+
if sseServer.sseEndpoint != sseEndpoint {
650+
t.Fatalf("sseEndpoint %v, got: %v", sseEndpoint, sseServer.sseEndpoint)
651+
}
652+
653+
if sseServer.messageEndpoint != messageEndpoint {
654+
t.Fatalf("messageEndpoint %v, got: %v", messageEndpoint, sseServer.messageEndpoint)
655+
}
656+
657+
if sseServer.srv != srv {
658+
t.Fatalf("srv %v, got: %v", srv, sseServer.srv)
659+
}
660+
}
661+
})
617662
}

0 commit comments

Comments
 (0)