Skip to content

Add per-session notifications handling #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/everything/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ func handleSendNotification(
server := server.ServerFromContext(ctx)

err := server.SendNotificationToClient(
ctx,
"notifications/progress",
map[string]interface{}{
"progress": 10,
Expand Down Expand Up @@ -336,6 +337,7 @@ func handleLongRunningOperationTool(
time.Sleep(time.Duration(stepDuration * float64(time.Second)))
if progressToken != nil {
server.SendNotificationToClient(
ctx,
"notifications/progress",
map[string]interface{}{
"progress": i,
Expand Down
105 changes: 71 additions & 34 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,23 @@ type ServerTool struct {
Handler ToolHandlerFunc
}

// NotificationContext provides client identification for notifications
type NotificationContext struct {
ClientID string
SessionID string
// ClientSession represents an active session that can be used by MCPServer to interact with client.
type ClientSession interface {
// NotificationChannel provides a channel suitable for sending notifications to client.
NotificationChannel() chan<- mcp.JSONRPCNotification
// SessionID is a unique identifier used to track user session.
SessionID() string
}

// ServerNotification combines the notification with client context
type ServerNotification struct {
Context NotificationContext
Notification mcp.JSONRPCNotification
// clientSessionKey is the context key for storing current client notification channel.
type clientSessionKey struct{}

// ClientSessionFromContext retrieves current client notification context from context.
func ClientSessionFromContext(ctx context.Context) ClientSession {
if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok {
return session
}
return nil
}

// NotificationHandlerFunc handles incoming notifications.
Expand All @@ -75,9 +82,7 @@ type MCPServer struct {
tools map[string]ServerTool
notificationHandlers map[string]NotificationHandlerFunc
capabilities serverCapabilities
notifications chan ServerNotification
clientMu sync.Mutex // Separate mutex for client context
currentClient NotificationContext
sessions sync.Map
initialized atomic.Bool // Use atomic for the initialized flag
}

Expand All @@ -92,30 +97,70 @@ func ServerFromContext(ctx context.Context) *MCPServer {
return nil
}

// WithContext sets the current client context and returns the provided context
// WithContext sets the current client session and returns the provided context
func (s *MCPServer) WithContext(
ctx context.Context,
notifCtx NotificationContext,
session ClientSession,
) context.Context {
s.clientMu.Lock()
s.currentClient = notifCtx
s.clientMu.Unlock()
return ctx
return context.WithValue(ctx, clientSessionKey{}, session)
}

// RegisterSession saves session that should be notified in case if some server attributes changed.
func (s *MCPServer) RegisterSession(
session ClientSession,
) error {
sessionID := session.SessionID()
if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
return fmt.Errorf("session %s is already registered", sessionID)
}
return nil
}

// UnregisterSession removes from storage session that is shut down.
func (s *MCPServer) UnregisterSession(
sessionID string,
) {
s.sessions.Delete(sessionID)
}

// sendNotificationToAllClients sends a notification to all the currently active clients.
func (s *MCPServer) sendNotificationToAllClients(
method string,
params map[string]any,
) {
notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Method: method,
Params: mcp.NotificationParams{
AdditionalFields: params,
},
},
}

s.sessions.Range(func(k, v any) bool {
if session, ok := v.(ClientSession); ok {
select {
case session.NotificationChannel() <- notification:
default:
// TODO: log blocked channel in the future versions
}
}
return true
})
}

// SendNotificationToClient sends a notification to the current client
func (s *MCPServer) SendNotificationToClient(
ctx context.Context,
method string,
params map[string]interface{},
params map[string]any,
) error {
if s.notifications == nil {
session := ClientSessionFromContext(ctx)
if session == nil {
return fmt.Errorf("notification channel not initialized")
}

s.clientMu.Lock()
clientContext := s.currentClient
s.clientMu.Unlock()

notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Expand All @@ -127,10 +172,7 @@ func (s *MCPServer) SendNotificationToClient(
}

select {
case s.notifications <- ServerNotification{
Context: clientContext,
Notification: notification,
}:
case session.NotificationChannel() <- notification:
return nil
default:
return fmt.Errorf("notification channel full or blocked")
Expand Down Expand Up @@ -220,7 +262,6 @@ func NewMCPServer(
name: name,
version: version,
notificationHandlers: make(map[string]NotificationHandlerFunc),
notifications: make(chan ServerNotification, 100),
capabilities: serverCapabilities{
tools: nil,
resources: nil,
Expand Down Expand Up @@ -491,9 +532,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) {

// Send notification if server is already initialized
if initialized {
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
// We can't return the error, but in a future version we could log it
}
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
}
}

Expand All @@ -516,9 +555,7 @@ func (s *MCPServer) DeleteTools(names ...string) {

// Send notification if server is already initialized
if initialized {
if err := s.SendNotificationToClient("notifications/tools/list_changed", nil); err != nil {
// We can't return the error, but in a future version we could log it
}
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
}
}

Expand Down
Loading