diff --git a/server/server.go b/server/server.go index 5b2d739dc..1a55ea6da 100644 --- a/server/server.go +++ b/server/server.go @@ -211,8 +211,8 @@ func (s *MCPServer) UnregisterSession( s.sessions.Delete(sessionID) } -// sendNotificationToAllClients sends a notification to all the currently active clients. -func (s *MCPServer) sendNotificationToAllClients( +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( method string, params map[string]any, ) { @@ -472,7 +472,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { s.toolsMu.Unlock() // Send notification to all initialized sessions - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) + s.SendNotificationToAllClients("notifications/tools/list_changed", nil) } // SetTools replaces all existing tools with the provided list @@ -492,7 +492,7 @@ func (s *MCPServer) DeleteTools(names ...string) { s.toolsMu.Unlock() // Send notification to all initialized sessions - s.sendNotificationToAllClients("notifications/tools/list_changed", nil) + s.SendNotificationToAllClients("notifications/tools/list_changed", nil) } // AddNotificationHandler registers a new handler for incoming notifications diff --git a/server/server_test.go b/server/server_test.go index e55008f1b..859d5c526 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -573,6 +573,75 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) { } } +func TestMCPServer_SendNotificationToAllClients(t *testing.T) { + + contextPrepare := func(ctx context.Context, srv *MCPServer) context.Context { + // Create 5 active sessions + for i := 0; i < 5; i++ { + err := srv.RegisterSession(ctx, &fakeSession{ + sessionID: fmt.Sprintf("test%d", i), + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + }) + require.NoError(t, err) + } + return ctx + } + + validate := func(t *testing.T, ctx context.Context, srv *MCPServer) { + // Send 10 notifications to all sessions + for i := 0; i < 10; i++ { + srv.SendNotificationToAllClients("method", map[string]any{ + "count": i, + }) + } + + // Verify each session received all 10 notifications + srv.sessions.Range(func(k, v any) bool { + session := v.(ClientSession) + fakeSess := session.(*fakeSession) + notificationCount := 0 + + // Read all notifications from the channel + for notificationCount < 10 { + select { + case notification := <-fakeSess.notificationChannel: + // Verify notification method + assert.Equal(t, "method", notification.Method) + // Verify count parameter + count, ok := notification.Params.AdditionalFields["count"] + assert.True(t, ok, "count parameter not found") + assert.Equal(t, notificationCount, count.(int), "count should match notification count") + notificationCount++ + case <-time.After(100 * time.Millisecond): + t.Errorf("timeout waiting for notification %d for session %s", notificationCount, session.SessionID()) + return false + } + } + + // Verify no more notifications + select { + case notification := <-fakeSess.notificationChannel: + t.Errorf("unexpected notification received: %v", notification) + default: + // Channel empty as expected + } + return true + }) + } + + t.Run("all sessions", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + ctx := contextPrepare(context.Background(), server) + _ = server.HandleMessage(ctx, []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize" + }`)) + validate(t, ctx, server) + }) +} + func TestMCPServer_PromptHandling(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true),