From 6eb1b27dabeeadf91a712e3812d06f99f2ff47bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arda=20G=C3=BC=C3=A7l=C3=BC?= Date: Wed, 16 Jul 2025 14:21:59 +0300 Subject: [PATCH 1/4] Embed client capabilities into the Session Client sends its capabilities during the initialization step. This commit embeds the client capabilities into the client session in this step to enable subsequent executions are able to check the client capabilities to determine what actions they can perform. For instance, MCP Server checks the support of elicitation. If elicititation is supported by the client MCP Server can send elicitation request. --- server/inprocess_session.go | 28 +++++++++++++++++++++------- server/server.go | 2 ++ server/session.go | 4 ++++ server/session_test.go | 26 ++++++++++++++++++++++++-- server/sse.go | 14 ++++++++++++++ server/stdio.go | 32 +++++++++++++++++++++++--------- 6 files changed, 88 insertions(+), 18 deletions(-) diff --git a/server/inprocess_session.go b/server/inprocess_session.go index fecf0aad0..daaf28a5c 100644 --- a/server/inprocess_session.go +++ b/server/inprocess_session.go @@ -16,13 +16,14 @@ type SamplingHandler interface { } type InProcessSession struct { - sessionID string - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value - samplingHandler SamplingHandler - mu sync.RWMutex + sessionID string + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value + clientCapabilities atomic.Value + samplingHandler SamplingHandler + mu sync.RWMutex } func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InProcessSession { @@ -63,6 +64,19 @@ func (s *InProcessSession) SetClientInfo(clientInfo mcp.Implementation) { s.clientInfo.Store(clientInfo) } +func (s *InProcessSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *InProcessSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + func (s *InProcessSession) SetLogLevel(level mcp.LoggingLevel) { s.loggingLevel.Store(level) } diff --git a/server/server.go b/server/server.go index ee2eb9aa4..291f9cb19 100644 --- a/server/server.go +++ b/server/server.go @@ -583,8 +583,10 @@ func (s *MCPServer) handleInitialize( // Store client info if the session supports it if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok { sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo) + sessionWithClientInfo.SetClientCapabilities(request.Params.Capabilities) } } + return &result, nil } diff --git a/server/session.go b/server/session.go index 165ecaedd..11ee8a2f1 100644 --- a/server/session.go +++ b/server/session.go @@ -46,6 +46,10 @@ type SessionWithClientInfo interface { GetClientInfo() mcp.Implementation // SetClientInfo sets the client information for this session SetClientInfo(clientInfo mcp.Implementation) + // GetClientCapabilities returns the client capabilities for this session + GetClientCapabilities() mcp.ClientCapabilities + // SetClientCapabilities sets the client capabilities for this session + SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) } // SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations diff --git a/server/session_test.go b/server/session_test.go index 22da95714..35f9b8db2 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -106,6 +106,7 @@ type sessionTestClientWithClientInfo struct { notificationChannel chan mcp.JSONRPCNotification initialized bool clientInfo atomic.Value + clientCapabilities atomic.Value } func (f *sessionTestClientWithClientInfo) SessionID() string { @@ -137,6 +138,19 @@ func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implement f.clientInfo.Store(clientInfo) } +func (f *sessionTestClientWithClientInfo) GetClientCapabilities() mcp.ClientCapabilities { + if value := f.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (f *sessionTestClientWithClientInfo) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + f.clientCapabilities.Store(clientCapabilities) +} + // sessionTestClientWithTools implements the SessionWithLogging interface for testing type sessionTestClientWithLogging struct { sessionID string @@ -888,7 +902,7 @@ func TestMCPServer_SessionToolCapabilitiesBehavior(t *testing.T) { validateServer func(t *testing.T, s *MCPServer, session *sessionTestClientWithTools) }{ { - name: "no tool capabilities provided", + name: "no tool capabilities provided", serverOptions: []ServerOption{ // No WithToolCapabilities }, @@ -1099,10 +1113,14 @@ func TestSessionWithClientInfo_Integration(t *testing.T) { Version: "1.0.0", } + clientCapability := mcp.ClientCapabilities{ + Sampling: &struct{}{}, + } + initRequest := mcp.InitializeRequest{} initRequest.Params.ClientInfo = clientInfo initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.Capabilities = mcp.ClientCapabilities{} + initRequest.Params.Capabilities = clientCapability sessionCtx := server.WithContext(context.Background(), session) @@ -1125,6 +1143,10 @@ func TestSessionWithClientInfo_Integration(t *testing.T) { assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match") assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match") + + storedClientCapabilities := sessionWithClientInfo.GetClientCapabilities() + + assert.Equal(t, clientCapability, storedClientCapabilities, "Client capability should match") } // New test function to cover log notification functionality diff --git a/server/sse.go b/server/sse.go index c20e67820..ac4089ced 100644 --- a/server/sse.go +++ b/server/sse.go @@ -30,6 +30,7 @@ type sseSession struct { loggingLevel atomic.Value tools sync.Map // stores session-specific tools clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capaabilities } // SSEContextFunc is a function that takes an existing context and the current @@ -108,6 +109,19 @@ func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { s.clientInfo.Store(clientInfo) } +func (s *sseSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + var ( _ ClientSession = (*sseSession)(nil) _ SessionWithTools = (*sseSession)(nil) diff --git a/server/stdio.go b/server/stdio.go index 33ac9bb88..7646ae1ad 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -52,15 +52,16 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info - writer io.Writer // for sending requests to client - requestID atomic.Int64 // for generating unique request IDs - mu sync.RWMutex // protects writer - pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests - pendingMu sync.RWMutex // protects pendingRequests + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingMu sync.RWMutex // protects pendingRequests } // samplingResponse represents a response to a sampling request @@ -100,6 +101,19 @@ func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { s.clientInfo.Store(clientInfo) } +func (s *stdioSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientInfo, ok := value.(mcp.ClientCapabilities); ok { + return clientInfo + } + } + return mcp.ClientCapabilities{} +} + +func (s *stdioSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { s.loggingLevel.Store(level) } From bd225957d3b705b1376fb5cbadb45fd8f12d3cc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arda=20G=C3=BC=C3=A7l=C3=BC?= Date: Wed, 16 Jul 2025 14:31:32 +0300 Subject: [PATCH 2/4] fix --- server/sse.go | 2 +- server/stdio.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/sse.go b/server/sse.go index ac4089ced..9c9766cf3 100644 --- a/server/sse.go +++ b/server/sse.go @@ -30,7 +30,7 @@ type sseSession struct { loggingLevel atomic.Value tools sync.Map // stores session-specific tools clientInfo atomic.Value // stores session-specific client info - clientCapabilities atomic.Value // stores session-specific client capaabilities + clientCapabilities atomic.Value // stores session-specific client capabilities } // SSEContextFunc is a function that takes an existing context and the current diff --git a/server/stdio.go b/server/stdio.go index 7646ae1ad..4d567d8cb 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -103,8 +103,8 @@ func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { func (s *stdioSession) GetClientCapabilities() mcp.ClientCapabilities { if value := s.clientCapabilities.Load(); value != nil { - if clientInfo, ok := value.(mcp.ClientCapabilities); ok { - return clientInfo + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities } } return mcp.ClientCapabilities{} From 64b8730b2b391948c39307066d19eac638a978ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arda=20G=C3=BC=C3=A7l=C3=BC?= Date: Wed, 23 Jul 2025 09:06:28 +0300 Subject: [PATCH 3/4] Add document for client capability based filtering --- www/docs/pages/servers/advanced.mdx | 50 +++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/www/docs/pages/servers/advanced.mdx b/www/docs/pages/servers/advanced.mdx index 5751c8220..64a4d9353 100644 --- a/www/docs/pages/servers/advanced.mdx +++ b/www/docs/pages/servers/advanced.mdx @@ -821,6 +821,56 @@ func startWithGracefulShutdown(s *server.MCPServer) { } ``` +## Client Capability Based Filtering + +```go +package main + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + s := server.NewMCPServer("Typed Server", "1.0.0", + server.WithToolCapabilities(true), + ) + + s.AddTool( + mcp.NewTool("calculate", + mcp.WithDescription("Perform basic mathematical calculations"), + mcp.WithString("operation", + mcp.Required(), + mcp.Enum("add", "subtract", "multiply", "divide"), + mcp.Description("The operation to perform"), + ), + mcp.WithNumber("x", mcp.Required(), mcp.Description("First number")), + mcp.WithNumber("y", mcp.Required(), mcp.Description("Second number")), + ), + handleCalculate, + ) + + server.ServeStdio(s) +} + +func handleCalculate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + session := server.ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + if clientSession, ok := session.(server.SessionWithClientInfo); ok { + clientCapabilities := clientSession.GetClientCapabilities() + if clientCapabilities.Sampling == nil { + fmt.Println("sampling is not enabled in client") + } + } +} +``` + ## Sampling (Advanced) Sampling is an advanced feature that allows servers to request LLM completions from clients. This enables bidirectional communication where servers can leverage client-side LLM capabilities for content generation, reasoning, and question answering. From a3b88281d6cf1a84bac4d0d814713350f9379382 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arda=20G=C3=BC=C3=A7l=C3=BC?= Date: Wed, 23 Jul 2025 09:11:42 +0300 Subject: [PATCH 4/4] Add return statement into the new example --- www/docs/pages/servers/advanced.mdx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/www/docs/pages/servers/advanced.mdx b/www/docs/pages/servers/advanced.mdx index 64a4d9353..990599b05 100644 --- a/www/docs/pages/servers/advanced.mdx +++ b/www/docs/pages/servers/advanced.mdx @@ -868,6 +868,9 @@ func handleCalculate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo fmt.Println("sampling is not enabled in client") } } + + // TODO: implement calculation logic + return mcp.NewToolResultError("not implemented"), nil } ```