Skip to content

Commit 6f9af6c

Browse files
committed
Update settings
1 parent a863825 commit 6f9af6c

File tree

2 files changed

+116
-12
lines changed

2 files changed

+116
-12
lines changed

server/http_transport_options.go

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,48 @@ func (o commonOption) isHTTPServerOption() {}
7272
func (o commonOption) applyToSSE(s *SSEServer) { o.apply(s) }
7373
func (o commonOption) applyToStreamableHTTP(s *StreamableHTTPServer) { o.apply(s) }
7474

75-
// Add stub methods to satisfy httpTransportConfigurable
76-
77-
func (s *StreamableHTTPServer) setBasePath(string) {}
78-
func (s *StreamableHTTPServer) setDynamicBasePath(DynamicBasePathFunc) {}
79-
func (s *StreamableHTTPServer) setKeepAliveInterval(time.Duration) {}
80-
func (s *StreamableHTTPServer) setKeepAlive(bool) {}
81-
func (s *StreamableHTTPServer) setContextFunc(HTTPContextFunc) {}
82-
func (s *StreamableHTTPServer) setHTTPServer(srv *http.Server) {}
83-
func (s *StreamableHTTPServer) setBaseURL(baseURL string) {}
75+
// Implement methods to satisfy httpTransportConfigurable interface
76+
77+
// setBasePath sets the base path for the server
78+
func (s *StreamableHTTPServer) setBasePath(path string) {
79+
s.basePath = path
80+
}
81+
82+
// setDynamicBasePath sets a function to dynamically determine the base path
83+
// for each request based on the request and session ID
84+
func (s *StreamableHTTPServer) setDynamicBasePath(fn DynamicBasePathFunc) {
85+
s.dynamicBasePathFunc = fn
86+
// Note: The ServeHTTP method would need to be updated to use this function
87+
// for determining the base path for each request
88+
}
89+
90+
// setKeepAliveInterval sets the interval for sending keep-alive messages
91+
func (s *StreamableHTTPServer) setKeepAliveInterval(interval time.Duration) {
92+
s.keepAliveInterval = interval
93+
// Note: Additional implementation would be needed to send keep-alive messages
94+
// at this interval in the SSE streams
95+
}
96+
97+
// setKeepAlive enables or disables keep-alive messages
98+
func (s *StreamableHTTPServer) setKeepAlive(enabled bool) {
99+
s.keepAliveEnabled = enabled
100+
// Note: This works in conjunction with setKeepAliveInterval
101+
}
102+
103+
// setContextFunc sets a function to customize the context for each request
104+
func (s *StreamableHTTPServer) setContextFunc(fn HTTPContextFunc) {
105+
s.contextFunc = fn
106+
}
107+
108+
// setHTTPServer sets the HTTP server instance
109+
func (s *StreamableHTTPServer) setHTTPServer(srv *http.Server) {
110+
s.srv = srv
111+
}
112+
113+
// setBaseURL sets the base URL for the server
114+
func (s *StreamableHTTPServer) setBaseURL(baseURL string) {
115+
s.baseURL = baseURL
116+
}
84117

85118
// Ensure the option types implement the correct interfaces
86119
var (

server/streamable_http.go

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ func WithOriginAllowlist(allowlist []string) StreamableHTTPOption {
232232
})
233233
}
234234

235-
// realStreamableHTTPServer is the concrete implementation of StreamableHTTPServer.
236-
// It provides HTTP transport capabilities following the MCP Streamable HTTP specification.
235+
// StreamableHTTPServer is the concrete implementation of a server that supports
236+
// the MCP Streamable HTTP transport specification.
237237
type StreamableHTTPServer struct {
238238
// Implement the httpTransportConfigurable interface
239239
server *MCPServer
@@ -250,6 +250,13 @@ type StreamableHTTPServer struct {
250250
requestToStreamMap sync.Map // Maps requestID to streamID
251251
statelessMode bool
252252
originAllowlist []string // List of allowed origins for CORS validation
253+
254+
// Fields for dynamic base path
255+
dynamicBasePathFunc DynamicBasePathFunc
256+
257+
// Fields for keep-alive
258+
keepAliveEnabled bool
259+
keepAliveInterval time.Duration
253260
}
254261

255262
// NewStreamableHTTPServer creates a new Streamable HTTP server instance with the given MCP server and options.
@@ -314,7 +321,21 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
314321
// ServeHTTP implements the http.Handler interface.
315322
func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
316323
path := r.URL.Path
317-
endpoint := s.basePath + s.endpoint
324+
325+
// Determine the endpoint path
326+
var endpoint string
327+
328+
// If dynamic base path function is set, use it to determine the base path
329+
if s.dynamicBasePathFunc != nil {
330+
// Get the session ID from the header if present
331+
sessionID := r.Header.Get("Mcp-Session-Id")
332+
// Use the dynamic base path function to determine the base path
333+
dynamicBasePath := s.dynamicBasePathFunc(r, sessionID)
334+
endpoint = dynamicBasePath + s.endpoint
335+
} else {
336+
// Use the static base path
337+
endpoint = s.basePath + s.endpoint
338+
}
318339

319340
if path != endpoint {
320341
http.NotFound(w, r)
@@ -667,6 +688,13 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
667688
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
668689
defer cancel()
669690

691+
// Set up keep-alive if enabled
692+
keepAliveTicker := time.NewTicker(24 * time.Hour) // Default to a very long interval (effectively disabled)
693+
if s.keepAliveEnabled && s.keepAliveInterval > 0 {
694+
keepAliveTicker.Reset(s.keepAliveInterval)
695+
}
696+
defer keepAliveTicker.Stop()
697+
670698
// Process notifications in the main handler goroutine
671699
for {
672700
select {
@@ -684,6 +712,24 @@ func (s *StreamableHTTPServer) handleSSEResponse(w http.ResponseWriter, r *http.
684712
if err := s.writeSSEEvent(streamID, "", notification); err != nil {
685713
fmt.Printf("Error writing notification: %v\n", err)
686714
}
715+
case <-keepAliveTicker.C:
716+
// Send a keep-alive message
717+
if s.keepAliveEnabled {
718+
keepAliveMsg := mcp.JSONRPCNotification{
719+
JSONRPC: "2.0",
720+
Notification: mcp.Notification{
721+
Method: "connection/keepalive",
722+
Params: mcp.NotificationParams{
723+
AdditionalFields: map[string]interface{}{
724+
"timestamp": time.Now().UnixNano() / int64(time.Millisecond),
725+
},
726+
},
727+
},
728+
}
729+
if err := s.writeSSEEvent(streamID, "keepalive", keepAliveMsg); err != nil {
730+
fmt.Printf("Error writing keep-alive: %v\n", err)
731+
}
732+
}
687733
case <-ctx.Done():
688734
// Request context is done or timeout reached, exit the loop
689735
return
@@ -788,6 +834,13 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
788834
// For standalone SSE streams, we'll keep the connection open until the client disconnects
789835
ctx := r.Context()
790836

837+
// Set up keep-alive if enabled
838+
keepAliveTicker := time.NewTicker(24 * time.Hour) // Default to a very long interval (effectively disabled)
839+
if s.keepAliveEnabled && s.keepAliveInterval > 0 {
840+
keepAliveTicker.Reset(s.keepAliveInterval)
841+
}
842+
defer keepAliveTicker.Stop()
843+
791844
// Process notifications in the main handler goroutine
792845
for {
793846
select {
@@ -796,6 +849,24 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
796849
if err := s.writeSSEEvent(streamID, "", notification); err != nil {
797850
fmt.Printf("Error writing notification: %v\n", err)
798851
}
852+
case <-keepAliveTicker.C:
853+
// Send a keep-alive message
854+
if s.keepAliveEnabled {
855+
keepAliveMsg := mcp.JSONRPCNotification{
856+
JSONRPC: "2.0",
857+
Notification: mcp.Notification{
858+
Method: "connection/keepalive",
859+
Params: mcp.NotificationParams{
860+
AdditionalFields: map[string]interface{}{
861+
"timestamp": time.Now().UnixNano() / int64(time.Millisecond),
862+
},
863+
},
864+
},
865+
}
866+
if err := s.writeSSEEvent(streamID, "keepalive", keepAliveMsg); err != nil {
867+
fmt.Printf("Error writing keep-alive: %v\n", err)
868+
}
869+
}
799870
case <-ctx.Done():
800871
// Request context is done, exit the loop
801872
return

0 commit comments

Comments
 (0)