@@ -232,8 +232,8 @@ func WithOriginAllowlist(allowlist []string) StreamableHTTPOption {
232
232
})
233
233
}
234
234
235
- // realStreamableHTTPServer is the concrete implementation of StreamableHTTPServer.
236
- // It provides HTTP transport capabilities following the MCP Streamable HTTP specification.
235
+ // StreamableHTTPServer is the concrete implementation of a server that supports
236
+ // the MCP Streamable HTTP transport specification.
237
237
type StreamableHTTPServer struct {
238
238
// Implement the httpTransportConfigurable interface
239
239
server * MCPServer
@@ -250,6 +250,13 @@ type StreamableHTTPServer struct {
250
250
requestToStreamMap sync.Map // Maps requestID to streamID
251
251
statelessMode bool
252
252
originAllowlist []string // List of allowed origins for CORS validation
253
+
254
+ // Fields for dynamic base path
255
+ dynamicBasePathFunc DynamicBasePathFunc
256
+
257
+ // Fields for keep-alive
258
+ keepAliveEnabled bool
259
+ keepAliveInterval time.Duration
253
260
}
254
261
255
262
// NewStreamableHTTPServer creates a new Streamable HTTP server instance with the given MCP server and options.
@@ -314,7 +321,21 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
314
321
// ServeHTTP implements the http.Handler interface.
315
322
func (s * StreamableHTTPServer ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
316
323
path := r .URL .Path
317
- endpoint := s .basePath + s .endpoint
324
+
325
+ // Determine the endpoint path
326
+ var endpoint string
327
+
328
+ // If dynamic base path function is set, use it to determine the base path
329
+ if s .dynamicBasePathFunc != nil {
330
+ // Get the session ID from the header if present
331
+ sessionID := r .Header .Get ("Mcp-Session-Id" )
332
+ // Use the dynamic base path function to determine the base path
333
+ dynamicBasePath := s .dynamicBasePathFunc (r , sessionID )
334
+ endpoint = dynamicBasePath + s .endpoint
335
+ } else {
336
+ // Use the static base path
337
+ endpoint = s .basePath + s .endpoint
338
+ }
318
339
319
340
if path != endpoint {
320
341
http .NotFound (w , r )
@@ -667,6 +688,13 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
667
688
ctx , cancel := context .WithTimeout (r .Context (), 5 * time .Second )
668
689
defer cancel ()
669
690
691
+ // Set up keep-alive if enabled
692
+ keepAliveTicker := time .NewTicker (24 * time .Hour ) // Default to a very long interval (effectively disabled)
693
+ if s .keepAliveEnabled && s .keepAliveInterval > 0 {
694
+ keepAliveTicker .Reset (s .keepAliveInterval )
695
+ }
696
+ defer keepAliveTicker .Stop ()
697
+
670
698
// Process notifications in the main handler goroutine
671
699
for {
672
700
select {
@@ -684,6 +712,24 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
684
712
if err := s .writeSSEEvent (streamID , "" , notification ); err != nil {
685
713
fmt .Printf ("Error writing notification: %v\n " , err )
686
714
}
715
+ case <- keepAliveTicker .C :
716
+ // Send a keep-alive message
717
+ if s .keepAliveEnabled {
718
+ keepAliveMsg := mcp.JSONRPCNotification {
719
+ JSONRPC : "2.0" ,
720
+ Notification : mcp.Notification {
721
+ Method : "connection/keepalive" ,
722
+ Params : mcp.NotificationParams {
723
+ AdditionalFields : map [string ]interface {}{
724
+ "timestamp" : time .Now ().UnixNano () / int64 (time .Millisecond ),
725
+ },
726
+ },
727
+ },
728
+ }
729
+ if err := s .writeSSEEvent (streamID , "keepalive" , keepAliveMsg ); err != nil {
730
+ fmt .Printf ("Error writing keep-alive: %v\n " , err )
731
+ }
732
+ }
687
733
case <- ctx .Done ():
688
734
// Request context is done or timeout reached, exit the loop
689
735
return
@@ -788,6 +834,13 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
788
834
// For standalone SSE streams, we'll keep the connection open until the client disconnects
789
835
ctx := r .Context ()
790
836
837
+ // Set up keep-alive if enabled
838
+ keepAliveTicker := time .NewTicker (24 * time .Hour ) // Default to a very long interval (effectively disabled)
839
+ if s .keepAliveEnabled && s .keepAliveInterval > 0 {
840
+ keepAliveTicker .Reset (s .keepAliveInterval )
841
+ }
842
+ defer keepAliveTicker .Stop ()
843
+
791
844
// Process notifications in the main handler goroutine
792
845
for {
793
846
select {
@@ -796,6 +849,24 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
796
849
if err := s .writeSSEEvent (streamID , "" , notification ); err != nil {
797
850
fmt .Printf ("Error writing notification: %v\n " , err )
798
851
}
852
+ case <- keepAliveTicker .C :
853
+ // Send a keep-alive message
854
+ if s .keepAliveEnabled {
855
+ keepAliveMsg := mcp.JSONRPCNotification {
856
+ JSONRPC : "2.0" ,
857
+ Notification : mcp.Notification {
858
+ Method : "connection/keepalive" ,
859
+ Params : mcp.NotificationParams {
860
+ AdditionalFields : map [string ]interface {}{
861
+ "timestamp" : time .Now ().UnixNano () / int64 (time .Millisecond ),
862
+ },
863
+ },
864
+ },
865
+ }
866
+ if err := s .writeSSEEvent (streamID , "keepalive" , keepAliveMsg ); err != nil {
867
+ fmt .Printf ("Error writing keep-alive: %v\n " , err )
868
+ }
869
+ }
799
870
case <- ctx .Done ():
800
871
// Request context is done, exit the loop
801
872
return
0 commit comments