6
6
"fmt"
7
7
"net/http"
8
8
"net/http/httptest"
9
+ "net/url"
9
10
"strings"
10
11
"sync"
11
12
@@ -57,7 +58,23 @@ type SSEOption func(*SSEServer)
57
58
// WithBaseURL sets the base URL for the SSE server
58
59
func WithBaseURL (baseURL string ) SSEOption {
59
60
return func (s * SSEServer ) {
60
- s .baseURL = baseURL
61
+ if baseURL != "" {
62
+ u , err := url .Parse (baseURL )
63
+ if err != nil {
64
+ return
65
+ }
66
+ if u .Scheme != "http" && u .Scheme != "https" {
67
+ return
68
+ }
69
+ // Check if the host is empty or only contains a port
70
+ if u .Host == "" || strings .HasPrefix (u .Host , ":" ) {
71
+ return
72
+ }
73
+ if len (u .Query ()) > 0 {
74
+ return
75
+ }
76
+ }
77
+ s .baseURL = strings .TrimSuffix (baseURL , "/" )
61
78
}
62
79
}
63
80
@@ -69,7 +86,6 @@ func WithBasePath(basePath string) SSEOption {
69
86
basePath = "/" + basePath
70
87
}
71
88
s .basePath = strings .TrimSuffix (basePath , "/" )
72
- s .baseURL = s .baseURL + s .basePath
73
89
}
74
90
}
75
91
@@ -108,7 +124,6 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
108
124
server : server ,
109
125
sseEndpoint : "/sse" ,
110
126
messageEndpoint : "/message" ,
111
- basePath : "" ,
112
127
}
113
128
114
129
// Apply all options
@@ -219,12 +234,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
219
234
}
220
235
}()
221
236
222
- messageEndpoint := fmt .Sprintf (
223
- "%s%s?sessionId=%s" ,
224
- s .baseURL ,
225
- s .messageEndpoint ,
226
- sessionID ,
227
- )
237
+ messageEndpoint := fmt .Sprintf ("%s?sessionId=%s" , s .CompleteMessageEndpoint (), sessionID )
228
238
229
239
// Send the initial endpoint event
230
240
fmt .Fprintf (w , "event: endpoint\n data: %s\r \n \r \n " , messageEndpoint )
@@ -345,22 +355,47 @@ func (s *SSEServer) SendEventToSession(
345
355
return fmt .Errorf ("event queue full" )
346
356
}
347
357
}
358
+ func (s * SSEServer ) GetUrlPath (input string ) (string , error ) {
359
+ parse , err := url .Parse (input )
360
+ if err != nil {
361
+ return "" , fmt .Errorf ("failed to parse URL %s: %w" , input , err )
362
+ }
363
+ return parse .Path , nil
364
+ }
365
+
366
+ func (s * SSEServer ) CompleteSseEndpoint () string {
367
+ return s .baseURL + s .basePath + s .sseEndpoint
368
+ }
369
+ func (s * SSEServer ) CompleteSsePath () string {
370
+ path , err := s .GetUrlPath (s .CompleteSseEndpoint ())
371
+ if err != nil {
372
+ return s .basePath + s .sseEndpoint
373
+ }
374
+ return path
375
+ }
376
+
377
+ func (s * SSEServer ) CompleteMessageEndpoint () string {
378
+ return s .baseURL + s .basePath + s .messageEndpoint
379
+ }
380
+ func (s * SSEServer ) CompleteMessagePath () string {
381
+ path , err := s .GetUrlPath (s .CompleteMessageEndpoint ())
382
+ if err != nil {
383
+ return s .basePath + s .messageEndpoint
384
+ }
385
+ return path
386
+ }
348
387
349
388
// ServeHTTP implements the http.Handler interface.
350
389
func (s * SSEServer ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
351
390
path := r .URL .Path
352
-
353
- // Construct the full SSE and message paths
354
- ssePath := s .basePath + s .sseEndpoint
355
- messagePath := s .basePath + s .messageEndpoint
356
-
357
391
// Use exact path matching rather than Contains
358
- if path == ssePath {
392
+ ssePath := s .CompleteSsePath ()
393
+ if ssePath != "" && path == ssePath {
359
394
s .handleSSE (w , r )
360
395
return
361
396
}
362
-
363
- if path == messagePath {
397
+ messagePath := s . CompleteMessagePath ()
398
+ if messagePath != "" && path == messagePath {
364
399
s .handleMessage (w , r )
365
400
return
366
401
}
0 commit comments