diff --git a/server/server.go b/server/server.go index 800628cf0..799d1f84b 100644 --- a/server/server.go +++ b/server/server.go @@ -147,7 +147,7 @@ type MCPServer struct { tools map[string]ServerTool notificationHandlers map[string]NotificationHandlerFunc capabilities serverCapabilities - sessions sync.Map + sessionizer Sessionizer hooks *Hooks } @@ -175,7 +175,7 @@ func (s *MCPServer) RegisterSession( session ClientSession, ) error { sessionID := session.SessionID() - if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + if _, exists := s.sessionizer.LoadOrStore(sessionID, session); exists { return fmt.Errorf("session %s is already registered", sessionID) } return nil @@ -185,7 +185,7 @@ func (s *MCPServer) RegisterSession( func (s *MCPServer) UnregisterSession( sessionID string, ) { - s.sessions.Delete(sessionID) + s.sessionizer.Delete(sessionID) } // sendNotificationToAllClients sends a notification to all the currently active clients. @@ -203,16 +203,15 @@ func (s *MCPServer) sendNotificationToAllClients( }, } - s.sessions.Range(func(k, v any) bool { - if session, ok := v.(ClientSession); ok && session.Initialized() { + for _, session := range s.sessionizer.All() { + if session.Initialized() { select { case session.NotificationChannel() <- notification: default: // TODO: log blocked channel in the future versions } } - return true - }) + } } // SendNotificationToClient sends a notification to the current client @@ -322,6 +321,12 @@ func WithInstructions(instructions string) ServerOption { } } +func WithSessionizer(sessionizer Sessionizer) ServerOption { + return func(s *MCPServer) { + s.sessionizer = sessionizer + } +} + // NewMCPServer creates a new MCP server instance with the given name, version and options func NewMCPServer( name, version string, @@ -342,6 +347,7 @@ func NewMCPServer( prompts: nil, logging: false, }, + sessionizer: &SyncMapSessionizer{}, } for _, opt := range opts { diff --git a/server/sessionizer.go b/server/sessionizer.go new file mode 100644 index 000000000..b845ece36 --- /dev/null +++ b/server/sessionizer.go @@ -0,0 +1,38 @@ +package server + +import "sync" + +type Sessionizer interface { + LoadOrStore(sessionID string, session ClientSession) (ClientSession, bool) + + Delete(sessionID string) + + All() []ClientSession +} + +type SyncMapSessionizer struct { + sessions sync.Map +} + +var _ Sessionizer = (*SyncMapSessionizer)(nil) + +func (s *SyncMapSessionizer) LoadOrStore(sessionID string, session ClientSession) (ClientSession, bool) { + actual, ok := s.sessions.LoadOrStore(sessionID, session) + if ok { + return actual.(ClientSession), true + } + return session, false +} + +func (s *SyncMapSessionizer) Delete(sessionID string) { + s.sessions.Delete(sessionID) +} + +func (s *SyncMapSessionizer) All() []ClientSession { + var sessions []ClientSession + s.sessions.Range(func(key, value any) bool { + sessions = append(sessions, value.(ClientSession)) + return true + }) + return sessions +}