diff --git a/server/inprocess_session.go b/server/inprocess_session.go index fecf0aad0..daaf28a5c 100644 --- a/server/inprocess_session.go +++ b/server/inprocess_session.go @@ -16,13 +16,14 @@ type SamplingHandler interface { } type InProcessSession struct { - sessionID string - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value - samplingHandler SamplingHandler - mu sync.RWMutex + sessionID string + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value + clientCapabilities atomic.Value + samplingHandler SamplingHandler + mu sync.RWMutex } func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InProcessSession { @@ -63,6 +64,19 @@ func (s *InProcessSession) SetClientInfo(clientInfo mcp.Implementation) { s.clientInfo.Store(clientInfo) } +func (s *InProcessSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *InProcessSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + func (s *InProcessSession) SetLogLevel(level mcp.LoggingLevel) { s.loggingLevel.Store(level) } diff --git a/server/server.go b/server/server.go index ee2eb9aa4..291f9cb19 100644 --- a/server/server.go +++ b/server/server.go @@ -583,8 +583,10 @@ func (s *MCPServer) handleInitialize( // Store client info if the session supports it if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok { sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo) + sessionWithClientInfo.SetClientCapabilities(request.Params.Capabilities) } } + return &result, nil } diff --git a/server/session.go b/server/session.go index 165ecaedd..11ee8a2f1 100644 --- a/server/session.go +++ b/server/session.go @@ -46,6 +46,10 @@ type SessionWithClientInfo interface { GetClientInfo() mcp.Implementation // SetClientInfo sets the client information for this session SetClientInfo(clientInfo mcp.Implementation) + // GetClientCapabilities returns the client capabilities for this session + GetClientCapabilities() mcp.ClientCapabilities + // SetClientCapabilities sets the client capabilities for this session + SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) } // SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations diff --git a/server/session_test.go b/server/session_test.go index 22da95714..35f9b8db2 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -106,6 +106,7 @@ type sessionTestClientWithClientInfo struct { notificationChannel chan mcp.JSONRPCNotification initialized bool clientInfo atomic.Value + clientCapabilities atomic.Value } func (f *sessionTestClientWithClientInfo) SessionID() string { @@ -137,6 +138,19 @@ func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implement f.clientInfo.Store(clientInfo) } +func (f *sessionTestClientWithClientInfo) GetClientCapabilities() mcp.ClientCapabilities { + if value := f.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (f *sessionTestClientWithClientInfo) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + f.clientCapabilities.Store(clientCapabilities) +} + // sessionTestClientWithTools implements the SessionWithLogging interface for testing type sessionTestClientWithLogging struct { sessionID string @@ -888,7 +902,7 @@ func TestMCPServer_SessionToolCapabilitiesBehavior(t *testing.T) { validateServer func(t *testing.T, s *MCPServer, session *sessionTestClientWithTools) }{ { - name: "no tool capabilities provided", + name: "no tool capabilities provided", serverOptions: []ServerOption{ // No WithToolCapabilities }, @@ -1099,10 +1113,14 @@ func TestSessionWithClientInfo_Integration(t *testing.T) { Version: "1.0.0", } + clientCapability := mcp.ClientCapabilities{ + Sampling: &struct{}{}, + } + initRequest := mcp.InitializeRequest{} initRequest.Params.ClientInfo = clientInfo initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.Capabilities = mcp.ClientCapabilities{} + initRequest.Params.Capabilities = clientCapability sessionCtx := server.WithContext(context.Background(), session) @@ -1125,6 +1143,10 @@ func TestSessionWithClientInfo_Integration(t *testing.T) { assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match") assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match") + + storedClientCapabilities := sessionWithClientInfo.GetClientCapabilities() + + assert.Equal(t, clientCapability, storedClientCapabilities, "Client capability should match") } // New test function to cover log notification functionality diff --git a/server/sse.go b/server/sse.go index c20e67820..9c9766cf3 100644 --- a/server/sse.go +++ b/server/sse.go @@ -30,6 +30,7 @@ type sseSession struct { loggingLevel atomic.Value tools sync.Map // stores session-specific tools clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities } // SSEContextFunc is a function that takes an existing context and the current @@ -108,6 +109,19 @@ func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { s.clientInfo.Store(clientInfo) } +func (s *sseSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + var ( _ ClientSession = (*sseSession)(nil) _ SessionWithTools = (*sseSession)(nil) diff --git a/server/stdio.go b/server/stdio.go index 33ac9bb88..4d567d8cb 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -52,15 +52,16 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info - writer io.Writer // for sending requests to client - requestID atomic.Int64 // for generating unique request IDs - mu sync.RWMutex // protects writer - pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests - pendingMu sync.RWMutex // protects pendingRequests + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingMu sync.RWMutex // protects pendingRequests } // samplingResponse represents a response to a sampling request @@ -100,6 +101,19 @@ func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { s.clientInfo.Store(clientInfo) } +func (s *stdioSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *stdioSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { s.loggingLevel.Store(level) } diff --git a/www/docs/pages/servers/advanced.mdx b/www/docs/pages/servers/advanced.mdx index 5751c8220..990599b05 100644 --- a/www/docs/pages/servers/advanced.mdx +++ b/www/docs/pages/servers/advanced.mdx @@ -821,6 +821,59 @@ func startWithGracefulShutdown(s *server.MCPServer) { } ``` +## Client Capability Based Filtering + +```go +package main + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + s := server.NewMCPServer("Typed Server", "1.0.0", + server.WithToolCapabilities(true), + ) + + s.AddTool( + mcp.NewTool("calculate", + mcp.WithDescription("Perform basic mathematical calculations"), + mcp.WithString("operation", + mcp.Required(), + mcp.Enum("add", "subtract", "multiply", "divide"), + mcp.Description("The operation to perform"), + ), + mcp.WithNumber("x", mcp.Required(), mcp.Description("First number")), + mcp.WithNumber("y", mcp.Required(), mcp.Description("Second number")), + ), + handleCalculate, + ) + + server.ServeStdio(s) +} + +func handleCalculate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + session := server.ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + if clientSession, ok := session.(server.SessionWithClientInfo); ok { + clientCapabilities := clientSession.GetClientCapabilities() + if clientCapabilities.Sampling == nil { + fmt.Println("sampling is not enabled in client") + } + } + + // TODO: implement calculation logic + return mcp.NewToolResultError("not implemented"), nil +} +``` + ## Sampling (Advanced) Sampling is an advanced feature that allows servers to request LLM completions from clients. This enables bidirectional communication where servers can leverage client-side LLM capabilities for content generation, reasoning, and question answering.