Skip to content

Commit 67c327e

Browse files
committed
Add per-session notifications handling
1 parent 258aee9 commit 67c327e

File tree

5 files changed

+332
-103
lines changed

5 files changed

+332
-103
lines changed

examples/everything/main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ func handleSendNotification(
300300
server := server.ServerFromContext(ctx)
301301

302302
err := server.SendNotificationToClient(
303+
ctx,
303304
"notifications/progress",
304305
map[string]interface{}{
305306
"progress": 10,
@@ -336,6 +337,7 @@ func handleLongRunningOperationTool(
336337
time.Sleep(time.Duration(stepDuration * float64(time.Second)))
337338
if progressToken != nil {
338339
server.SendNotificationToClient(
340+
ctx,
339341
"notifications/progress",
340342
map[string]interface{}{
341343
"progress": i,

server/server.go

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,23 @@ type ServerTool struct {
4646
Handler ToolHandlerFunc
4747
}
4848

49-
// NotificationContext provides client identification for notifications
50-
type NotificationContext struct {
51-
ClientID string
52-
SessionID string
49+
// ClientSession represents an active session that can be used by MCPServer to interact with client.
50+
type ClientSession interface {
51+
// NotificationChannel provides a channel suitable for sending notifications to client.
52+
NotificationChannel() chan<- mcp.JSONRPCNotification
53+
// SessionID is a unique identifier used to track user session.
54+
SessionID() string
5355
}
5456

55-
// ServerNotification combines the notification with client context
56-
type ServerNotification struct {
57-
Context NotificationContext
58-
Notification mcp.JSONRPCNotification
57+
// clientSessionKey is the context key for storing current client notification channel.
58+
type clientSessionKey struct{}
59+
60+
// ClientSessionFromContext retrieves current client notification context from context.
61+
func ClientSessionFromContext(ctx context.Context) ClientSession {
62+
if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok {
63+
return session
64+
}
65+
return nil
5966
}
6067

6168
// NotificationHandlerFunc handles incoming notifications.
@@ -74,9 +81,7 @@ type MCPServer struct {
7481
tools map[string]ServerTool
7582
notificationHandlers map[string]NotificationHandlerFunc
7683
capabilities serverCapabilities
77-
notifications chan ServerNotification
78-
clientMu sync.Mutex // Separate mutex for client context
79-
currentClient NotificationContext
84+
sessions sync.Map
8085
initialized atomic.Bool // Use atomic for the initialized flag
8186
}
8287

@@ -91,30 +96,70 @@ func ServerFromContext(ctx context.Context) *MCPServer {
9196
return nil
9297
}
9398

94-
// WithContext sets the current client context and returns the provided context
99+
// WithContext sets the current client session and returns the provided context
95100
func (s *MCPServer) WithContext(
96101
ctx context.Context,
97-
notifCtx NotificationContext,
102+
session ClientSession,
98103
) context.Context {
99-
s.clientMu.Lock()
100-
s.currentClient = notifCtx
101-
s.clientMu.Unlock()
102-
return ctx
104+
return context.WithValue(ctx, clientSessionKey{}, session)
105+
}
106+
107+
// RegisterSession saves session that should be notified in case if some server attributes changed.
108+
func (s *MCPServer) RegisterSession(
109+
session ClientSession,
110+
) error {
111+
sessionID := session.SessionID()
112+
if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
113+
return fmt.Errorf("session %s is already registered", sessionID)
114+
}
115+
return nil
116+
}
117+
118+
// UnregisterSession removes from storage session that is shut down.
119+
func (s *MCPServer) UnregisterSession(
120+
sessionID string,
121+
) {
122+
s.sessions.Delete(sessionID)
123+
}
124+
125+
// sendNotificationToAllClients sends a notification to all the currently active clients.
126+
func (s *MCPServer) sendNotificationToAllClients(
127+
method string,
128+
params map[string]any,
129+
) {
130+
notification := mcp.JSONRPCNotification{
131+
JSONRPC: mcp.JSONRPC_VERSION,
132+
Notification: mcp.Notification{
133+
Method: method,
134+
Params: mcp.NotificationParams{
135+
AdditionalFields: params,
136+
},
137+
},
138+
}
139+
140+
s.sessions.Range(func(k, v any) bool {
141+
if session, ok := v.(ClientSession); ok {
142+
select {
143+
case session.NotificationChannel() <- notification:
144+
default:
145+
// TODO: log blocked channel in the future versions
146+
}
147+
}
148+
return true
149+
})
103150
}
104151

105152
// SendNotificationToClient sends a notification to the current client
106153
func (s *MCPServer) SendNotificationToClient(
154+
ctx context.Context,
107155
method string,
108-
params map[string]interface{},
156+
params map[string]any,
109157
) error {
110-
if s.notifications == nil {
158+
session := ClientSessionFromContext(ctx)
159+
if session == nil {
111160
return fmt.Errorf("notification channel not initialized")
112161
}
113162

114-
s.clientMu.Lock()
115-
clientContext := s.currentClient
116-
s.clientMu.Unlock()
117-
118163
notification := mcp.JSONRPCNotification{
119164
JSONRPC: mcp.JSONRPC_VERSION,
120165
Notification: mcp.Notification{
@@ -126,10 +171,7 @@ func (s *MCPServer) SendNotificationToClient(
126171
}
127172

128173
select {
129-
case s.notifications <- ServerNotification{
130-
Context: clientContext,
131-
Notification: notification,
132-
}:
174+
case session.NotificationChannel() <- notification:
133175
return nil
134176
default:
135177
return fmt.Errorf("notification channel full or blocked")
@@ -212,7 +254,6 @@ func NewMCPServer(
212254
name: name,
213255
version: version,
214256
notificationHandlers: make(map[string]NotificationHandlerFunc),
215-
notifications: make(chan ServerNotification, 100),
216257
capabilities: serverCapabilities{
217258
tools: nil,
218259
resources: nil,
@@ -483,9 +524,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) {
483524

484525
// Send notification if server is already initialized
485526
if initialized {
486-
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
487-
// We can't return the error, but in a future version we could log it
488-
}
527+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
489528
}
490529
}
491530

@@ -508,9 +547,7 @@ func (s *MCPServer) DeleteTools(names ...string) {
508547

509548
// Send notification if server is already initialized
510549
if initialized {
511-
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
512-
// We can't return the error, but in a future version we could log it
513-
}
550+
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
514551
}
515552
}
516553

0 commit comments

Comments
 (0)