diff --git a/server/sse.go b/server/sse.go index 9a4191500..64618399c 100644 --- a/server/sse.go +++ b/server/sse.go @@ -51,6 +51,12 @@ func (s *sseSession) Initialized() bool { var _ ClientSession = (*sseSession)(nil) +// Server is a type that implements ListenAndServe and Shutdown +type Server interface { + ListenAndServe() error + Shutdown(ctx context.Context) error +} + // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { @@ -61,7 +67,7 @@ type SSEServer struct { messageEndpoint string sseEndpoint string sessions sync.Map - srv *http.Server + srv Server contextFunc SSEContextFunc keepAlive bool @@ -131,7 +137,7 @@ func WithSSEEndpoint(endpoint string) SSEOption { } // WithHTTPServer sets the HTTP server instance -func WithHTTPServer(srv *http.Server) SSEOption { +func WithHTTPServer(srv Server) SSEOption { return func(s *SSEServer) { s.srv = srv } @@ -165,6 +171,7 @@ func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { sseEndpoint: "/sse", messageEndpoint: "/message", useFullURLForMessageEndpoint: true, + srv: nil, // will be set by Start keepAlive: false, keepAliveInterval: 10 * time.Second, } @@ -190,9 +197,11 @@ func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { // It sets up HTTP handlers for SSE and message endpoints. func (s *SSEServer) Start(addr string) error { s.mu.Lock() - s.srv = &http.Server{ - Addr: addr, - Handler: s, + if s.srv == nil { + s.srv = &http.Server{ + Addr: addr, + Handler: s, + } } s.mu.Unlock()