From f96dacc79ef131c09b5d79c8b08d66fe8820e937 Mon Sep 17 00:00:00 2001 From: flc1125 Date: Wed, 2 Apr 2025 13:42:56 +0800 Subject: [PATCH 1/2] refactor(server): replace sync.Map with Sessionizer interface for session management --- server/server.go | 24 +++++++++++++++--------- server/server_test.go | 4 ++-- server/sessionizer.go | 38 ++++++++++++++++++++++++++++++++++++++ server/sse.go | 2 +- server/sse_test.go | 6 +++--- 5 files changed, 59 insertions(+), 15 deletions(-) create mode 100644 server/sessionizer.go diff --git a/server/server.go b/server/server.go index 800628cf0..9f02f71a1 100644 --- a/server/server.go +++ b/server/server.go @@ -147,7 +147,7 @@ type MCPServer struct { tools map[string]ServerTool notificationHandlers map[string]NotificationHandlerFunc capabilities serverCapabilities - sessions sync.Map + sessionizer Sessionizer hooks *Hooks } @@ -175,7 +175,7 @@ func (s *MCPServer) RegisterSession( session ClientSession, ) error { sessionID := session.SessionID() - if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + if _, exists := s.sessionizer.LoadOrStore(sessionID, session); exists { return fmt.Errorf("session %s is already registered", sessionID) } return nil @@ -185,7 +185,7 @@ func (s *MCPServer) RegisterSession( func (s *MCPServer) UnregisterSession( sessionID string, ) { - s.sessions.Delete(sessionID) + s.sessionizer.Delete(sessionID) } // sendNotificationToAllClients sends a notification to all the currently active clients. @@ -203,16 +203,15 @@ func (s *MCPServer) sendNotificationToAllClients( }, } - s.sessions.Range(func(k, v any) bool { - if session, ok := v.(ClientSession); ok && session.Initialized() { + for _, session := range s.sessionizer.All() { + if session.Initialized() { select { case session.NotificationChannel() <- notification: default: // TODO: log blocked channel in the future versions } } - return true - }) + } } // SendNotificationToClient sends a notification to the current client @@ -322,6 +321,12 @@ func WithInstructions(instructions string) ServerOption { } } +func WithSessionizer(sessionizer Sessionizer) ServerOption { + return func(s *MCPServer) { + s.sessionizer = sessionizer + } +} + // NewMCPServer creates a new MCP server instance with the given name, version and options func NewMCPServer( name, version string, @@ -342,6 +347,7 @@ func NewMCPServer( prompts: nil, logging: false, }, + sessionizer: &SyncMapSessionizer{}, } for _, opt := range opts { @@ -410,7 +416,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { } s.mu.Unlock() - // Send notification to all initialized sessions + // Send notification to all initialized sessionizer s.sendNotificationToAllClients("notifications/tools/list_changed", nil) } @@ -430,7 +436,7 @@ func (s *MCPServer) DeleteTools(names ...string) { } s.mu.Unlock() - // Send notification to all initialized sessions + // Send notification to all initialized sessionizer s.sendNotificationToAllClients("notifications/tools/list_changed", nil) } diff --git a/server/server_test.go b/server/server_test.go index 7367cbe69..806d803af 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -153,7 +153,7 @@ func TestMCPServer_Tools(t *testing.T) { validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage) }{ { - name: "SetTools sends no notifications/tools/list_changed without active sessions", + name: "SetTools sends no notifications/tools/list_changed without active sessionizer", action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), @@ -216,7 +216,7 @@ func TestMCPServer_Tools(t *testing.T) { }) require.NoError(t, err) } - // also let's register inactive sessions + // also let's register inactive sessionizer for i := range 5 { err := server.RegisterSession(&fakeSession{ sessionID: fmt.Sprintf("test%d", i+5), diff --git a/server/sessionizer.go b/server/sessionizer.go new file mode 100644 index 000000000..b845ece36 --- /dev/null +++ b/server/sessionizer.go @@ -0,0 +1,38 @@ +package server + +import "sync" + +type Sessionizer interface { + LoadOrStore(sessionID string, session ClientSession) (ClientSession, bool) + + Delete(sessionID string) + + All() []ClientSession +} + +type SyncMapSessionizer struct { + sessions sync.Map +} + +var _ Sessionizer = (*SyncMapSessionizer)(nil) + +func (s *SyncMapSessionizer) LoadOrStore(sessionID string, session ClientSession) (ClientSession, bool) { + actual, ok := s.sessions.LoadOrStore(sessionID, session) + if ok { + return actual.(ClientSession), true + } + return session, false +} + +func (s *SyncMapSessionizer) Delete(sessionID string) { + s.sessions.Delete(sessionID) +} + +func (s *SyncMapSessionizer) All() []ClientSession { + var sessions []ClientSession + s.sessions.Range(func(key, value any) bool { + sessions = append(sessions, value.(ClientSession)) + return true + }) + return sessions +} diff --git a/server/sse.go b/server/sse.go index a869ad1aa..140b439fa 100644 --- a/server/sse.go +++ b/server/sse.go @@ -167,7 +167,7 @@ func (s *SSEServer) Start(addr string) error { return s.srv.ListenAndServe() } -// Shutdown gracefully stops the SSE server, closing all active sessions +// Shutdown gracefully stops the SSE server, closing all active sessionizer // and shutting down the HTTP server. func (s *SSEServer) Shutdown(ctx context.Context) error { if s.srv != nil { diff --git a/server/sse_test.go b/server/sse_test.go index ced76e4ee..d6d4a4723 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -122,7 +122,7 @@ func TestSSEServer(t *testing.T) { } }) - t.Run("Can handle multiple sessions", func(t *testing.T) { + t.Run("Can handle multiple sessionizer", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0", WithResourceCapabilities(true, true), ) @@ -238,9 +238,9 @@ func TestSSEServer(t *testing.T) { select { case <-done: - // All sessions completed successfully + // All sessionizer completed successfully case <-time.After(5 * time.Second): - t.Fatal("Timeout waiting for sessions to complete") + t.Fatal("Timeout waiting for sessionizer to complete") } }) From 7f764d72d6babf5dea79941c9945be2a03548c22 Mon Sep 17 00:00:00 2001 From: flc1125 Date: Wed, 2 Apr 2025 13:45:38 +0800 Subject: [PATCH 2/2] refactor(server): update terminology from sessionizer to sessions for clarity --- server/server.go | 4 ++-- server/server_test.go | 4 ++-- server/sse.go | 2 +- server/sse_test.go | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/server/server.go b/server/server.go index 9f02f71a1..799d1f84b 100644 --- a/server/server.go +++ b/server/server.go @@ -416,7 +416,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { } s.mu.Unlock() - // Send notification to all initialized sessionizer + // Send notification to all initialized sessions s.sendNotificationToAllClients("notifications/tools/list_changed", nil) } @@ -436,7 +436,7 @@ func (s *MCPServer) DeleteTools(names ...string) { } s.mu.Unlock() - // Send notification to all initialized sessionizer + // Send notification to all initialized sessions s.sendNotificationToAllClients("notifications/tools/list_changed", nil) } diff --git a/server/server_test.go b/server/server_test.go index 806d803af..7367cbe69 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -153,7 +153,7 @@ func TestMCPServer_Tools(t *testing.T) { validate func(*testing.T, []mcp.JSONRPCNotification, mcp.JSONRPCMessage) }{ { - name: "SetTools sends no notifications/tools/list_changed without active sessionizer", + name: "SetTools sends no notifications/tools/list_changed without active sessions", action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), @@ -216,7 +216,7 @@ func TestMCPServer_Tools(t *testing.T) { }) require.NoError(t, err) } - // also let's register inactive sessionizer + // also let's register inactive sessions for i := range 5 { err := server.RegisterSession(&fakeSession{ sessionID: fmt.Sprintf("test%d", i+5), diff --git a/server/sse.go b/server/sse.go index 140b439fa..a869ad1aa 100644 --- a/server/sse.go +++ b/server/sse.go @@ -167,7 +167,7 @@ func (s *SSEServer) Start(addr string) error { return s.srv.ListenAndServe() } -// Shutdown gracefully stops the SSE server, closing all active sessionizer +// Shutdown gracefully stops the SSE server, closing all active sessions // and shutting down the HTTP server. func (s *SSEServer) Shutdown(ctx context.Context) error { if s.srv != nil { diff --git a/server/sse_test.go b/server/sse_test.go index d6d4a4723..ced76e4ee 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -122,7 +122,7 @@ func TestSSEServer(t *testing.T) { } }) - t.Run("Can handle multiple sessionizer", func(t *testing.T) { + t.Run("Can handle multiple sessions", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0", WithResourceCapabilities(true, true), ) @@ -238,9 +238,9 @@ func TestSSEServer(t *testing.T) { select { case <-done: - // All sessionizer completed successfully + // All sessions completed successfully case <-time.After(5 * time.Second): - t.Fatal("Timeout waiting for sessionizer to complete") + t.Fatal("Timeout waiting for sessions to complete") } })