diff --git a/server/server.go b/server/server.go index 64b516234..800628cf0 100644 --- a/server/server.go +++ b/server/server.go @@ -8,7 +8,6 @@ import ( "fmt" "sort" "sync" - "sync/atomic" "github.com/mark3labs/mcp-go/mcp" ) @@ -48,6 +47,10 @@ type ServerTool struct { // ClientSession represents an active session that can be used by MCPServer to interact with client. type ClientSession interface { + // Initialize marks session as fully initialized and ready for notifications + Initialize() + // Initialized returns if session is ready to accept notifications + Initialized() bool // NotificationChannel provides a channel suitable for sending notifications to client. NotificationChannel() chan<- mcp.JSONRPCNotification // SessionID is a unique identifier used to track user session. @@ -145,7 +148,6 @@ type MCPServer struct { notificationHandlers map[string]NotificationHandlerFunc capabilities serverCapabilities sessions sync.Map - initialized atomic.Bool // Use atomic for the initialized flag hooks *Hooks } @@ -202,7 +204,7 @@ func (s *MCPServer) sendNotificationToAllClients( } s.sessions.Range(func(k, v any) bool { - if session, ok := v.(ClientSession); ok { + if session, ok := v.(ClientSession); ok && session.Initialized() { select { case session.NotificationChannel() <- notification: default: @@ -220,7 +222,7 @@ func (s *MCPServer) SendNotificationToClient( params map[string]any, ) error { session := ClientSessionFromContext(ctx) - if session == nil { + if session == nil || !session.Initialized() { return fmt.Errorf("notification channel not initialized") } @@ -406,13 +408,10 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { for _, entry := range tools { s.tools[entry.Tool.Name] = entry } - initialized := s.initialized.Load() s.mu.Unlock() - // Send notification if server is already initialized - if initialized { - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) - } + // Send notification to all initialized sessions + s.sendNotificationToAllClients("notifications/tools/list_changed", nil) } // SetTools replaces all existing tools with the provided list @@ -429,13 +428,10 @@ func (s *MCPServer) DeleteTools(names ...string) { for _, name := range names { delete(s.tools, name) } - initialized := s.initialized.Load() s.mu.Unlock() - // Send notification if server is already initialized - if initialized { - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) - } + // Send notification to all initialized sessions + s.sendNotificationToAllClients("notifications/tools/list_changed", nil) } // AddNotificationHandler registers a new handler for incoming notifications @@ -498,7 +494,9 @@ func (s *MCPServer) handleInitialize( Instructions: s.instructions, } - s.initialized.Store(true) + if session := ClientSessionFromContext(ctx); session != nil { + session.Initialize() + } return &result, nil } diff --git a/server/server_test.go b/server/server_test.go index d2bbaf040..7367cbe69 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -181,6 +181,7 @@ func TestMCPServer_Tools(t *testing.T) { err := server.RegisterSession(&fakeSession{ sessionID: "test", notificationChannel: notificationChannel, + initialized: true, }) require.NoError(t, err) server.SetTools(ServerTool{ @@ -211,6 +212,16 @@ func TestMCPServer_Tools(t *testing.T) { err := server.RegisterSession(&fakeSession{ sessionID: fmt.Sprintf("test%d", i), notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + } + // also let's register inactive sessions + for i := range 5 { + err := server.RegisterSession(&fakeSession{ + sessionID: fmt.Sprintf("test%d", i+5), + notificationChannel: notificationChannel, + initialized: false, }) require.NoError(t, err) } @@ -243,6 +254,7 @@ func TestMCPServer_Tools(t *testing.T) { err := server.RegisterSession(&fakeSession{ sessionID: "test", notificationChannel: notificationChannel, + initialized: true, }) require.NoError(t, err) server.AddTool(mcp.NewTool("test-tool-1"), @@ -270,6 +282,7 @@ func TestMCPServer_Tools(t *testing.T) { err := server.RegisterSession(&fakeSession{ sessionID: "test", notificationChannel: notificationChannel, + initialized: true, }) require.NoError(t, err) server.SetTools( @@ -489,12 +502,28 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) { require.Error(t, srv.SendNotificationToClient(ctx, "method", nil)) }, }, + { + name: "uninit session", + contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context { + return srv.WithContext(ctx, fakeSession{ + sessionID: "test", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, + }) + }, + validate: func(t *testing.T, ctx context.Context, srv *MCPServer) { + require.Error(t, srv.SendNotificationToClient(ctx, "method", nil)) + _, ok := ClientSessionFromContext(ctx).(fakeSession) + require.True(t, ok, "session not found or of incorrect type") + }, + }, { name: "active session", contextPrepare: func(ctx context.Context, srv *MCPServer) context.Context { return srv.WithContext(ctx, fakeSession{ sessionID: "test", notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, }) }, validate: func(t *testing.T, ctx context.Context, srv *MCPServer) { @@ -519,6 +548,7 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) { return srv.WithContext(ctx, fakeSession{ sessionID: "test", notificationChannel: make(chan mcp.JSONRPCNotification, 1), + initialized: true, }) }, validate: func(t *testing.T, ctx context.Context, srv *MCPServer) { @@ -1136,6 +1166,7 @@ func createTestServer() *MCPServer { type fakeSession struct { sessionID string notificationChannel chan mcp.JSONRPCNotification + initialized bool } func (f fakeSession) SessionID() string { @@ -1146,6 +1177,13 @@ func (f fakeSession) NotificationChannel() chan<- mcp.JSONRPCNotification { return f.notificationChannel } +func (f fakeSession) Initialize() { +} + +func (f fakeSession) Initialized() bool { + return f.initialized +} + var _ ClientSession = fakeSession{} func TestMCPServer_WithHooks(t *testing.T) { diff --git a/server/sse.go b/server/sse.go index c3c070e6c..a869ad1aa 100644 --- a/server/sse.go +++ b/server/sse.go @@ -9,6 +9,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "github.com/google/uuid" "github.com/mark3labs/mcp-go/mcp" @@ -22,6 +23,7 @@ type sseSession struct { eventQueue chan string // Channel for queuing events sessionID string notificationChannel chan mcp.JSONRPCNotification + initialized atomic.Bool } // SSEContextFunc is a function that takes an existing context and the current @@ -37,6 +39,14 @@ func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { return s.notificationChannel } +func (s *sseSession) Initialize() { + s.initialized.Store(true) +} + +func (s *sseSession) Initialized() bool { + return s.initialized.Load() +} + var _ ClientSession = (*sseSession)(nil) // SSEServer implements a Server-Sent Events (SSE) based MCP server. diff --git a/server/stdio.go b/server/stdio.go index 441c50b99..2464a37fb 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -9,6 +9,7 @@ import ( "log" "os" "os/signal" + "sync/atomic" "syscall" "github.com/mark3labs/mcp-go/mcp" @@ -51,6 +52,7 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { notifications chan mcp.JSONRPCNotification + initialized atomic.Bool } func (s *stdioSession) SessionID() string { @@ -61,6 +63,14 @@ func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { return s.notifications } +func (s *stdioSession) Initialize() { + s.initialized.Store(true) +} + +func (s *stdioSession) Initialized() bool { + return s.initialized.Load() +} + var _ ClientSession = (*stdioSession)(nil) var stdioSessionInstance = stdioSession{