diff --git a/client/sse_test.go b/client/sse_test.go index 6bff34f3f..366fbc517 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -16,6 +16,7 @@ func TestSSEMCPClient(t *testing.T) { "1.0.0", server.WithResourceCapabilities(true, true), server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), ) // Add a test tool diff --git a/server/server.go b/server/server.go index e98b4a24d..64ed9529d 100644 --- a/server/server.go +++ b/server/server.go @@ -138,6 +138,7 @@ func (s *MCPServer) SendNotificationToClient( // serverCapabilities defines the supported features of the MCP server type serverCapabilities struct { + tools *toolCapabilities resources *resourceCapabilities prompts *promptCapabilities logging bool @@ -154,6 +155,11 @@ type promptCapabilities struct { listChanged bool } +// toolCapabilities defines the supported tool-related features +type toolCapabilities struct { + listChanged bool +} + // WithResourceCapabilities configures resource-related server capabilities func WithResourceCapabilities(subscribe, listChanged bool) ServerOption { return func(s *MCPServer) { @@ -173,6 +179,15 @@ func WithPromptCapabilities(listChanged bool) ServerOption { } } +// WithToolCapabilities configures tool-related server capabilities +func WithToolCapabilities(listChanged bool) ServerOption { + return func(s *MCPServer) { + s.capabilities.tools = &toolCapabilities{ + listChanged: listChanged, + } + } +} + // WithLogging enables logging capabilities for the server func WithLogging() ServerOption { return func(s *MCPServer) { @@ -195,6 +210,12 @@ func NewMCPServer( version: version, notificationHandlers: make(map[string]NotificationHandlerFunc), notifications: make(chan ServerNotification, 100), + capabilities: serverCapabilities{ + tools: &toolCapabilities{}, + resources: &resourceCapabilities{}, + prompts: &promptCapabilities{}, + logging: false, + }, } for _, opt := range opts { @@ -304,7 +325,7 @@ func (s *MCPServer) HandleMessage( } return s.handleListResourceTemplates(ctx, baseMessage.ID, request) case "resources/read": - if s.capabilities.resources == nil { + if !s.capabilities.resources.listChanged { return createErrorResponse( baseMessage.ID, mcp.METHOD_NOT_FOUND, @@ -338,7 +359,7 @@ func (s *MCPServer) HandleMessage( } return s.handleListPrompts(ctx, baseMessage.ID, request) case "prompts/get": - if s.capabilities.prompts == nil { + if !s.capabilities.prompts.listChanged { return createErrorResponse( baseMessage.ID, mcp.METHOD_NOT_FOUND, @@ -372,7 +393,7 @@ func (s *MCPServer) HandleMessage( } return s.handleListTools(ctx, baseMessage.ID, request) case "tools/call": - if len(s.tools) == 0 { + if !s.capabilities.tools.listChanged || len(s.tools) == 0 { return createErrorResponse( baseMessage.ID, mcp.METHOD_NOT_FOUND, @@ -508,20 +529,20 @@ func (s *MCPServer) handleInitialize( Subscribe bool `json:"subscribe,omitempty"` ListChanged bool `json:"listChanged,omitempty"` }{ - Subscribe: false, - ListChanged: true, + Subscribe: s.capabilities.resources.subscribe, + ListChanged: s.capabilities.resources.listChanged, } capabilities.Prompts = &struct { ListChanged bool `json:"listChanged,omitempty"` }{ - ListChanged: true, + ListChanged: s.capabilities.prompts.listChanged, } capabilities.Tools = &struct { ListChanged bool `json:"listChanged,omitempty"` }{ - ListChanged: true, + ListChanged: s.capabilities.tools.listChanged, } if s.capabilities.logging { diff --git a/server/server_test.go b/server/server_test.go index fb4d8d822..4574456ac 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -42,11 +42,11 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.Equal(t, "1.0.0", initResult.ServerInfo.Version) assert.NotNil(t, initResult.Capabilities.Resources) assert.False(t, initResult.Capabilities.Resources.Subscribe) - assert.True(t, initResult.Capabilities.Resources.ListChanged) + assert.False(t, initResult.Capabilities.Resources.ListChanged) assert.NotNil(t, initResult.Capabilities.Prompts) - assert.True(t, initResult.Capabilities.Prompts.ListChanged) + assert.False(t, initResult.Capabilities.Prompts.ListChanged) assert.NotNil(t, initResult.Capabilities.Tools) - assert.True(t, initResult.Capabilities.Tools.ListChanged) + assert.False(t, initResult.Capabilities.Tools.ListChanged) assert.Nil(t, initResult.Capabilities.Logging) }, }, @@ -55,6 +55,7 @@ func TestMCPServer_Capabilities(t *testing.T) { options: []ServerOption{ WithResourceCapabilities(true, true), WithPromptCapabilities(true), + WithToolCapabilities(true), WithLogging(), }, validate: func(t *testing.T, response mcp.JSONRPCMessage) { @@ -73,8 +74,8 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.Equal(t, "1.0.0", initResult.ServerInfo.Version) assert.NotNil(t, initResult.Capabilities.Resources) - // Resources capabilities are now always false for subscribe and true for listChanged - assert.False(t, initResult.Capabilities.Resources.Subscribe) + + assert.True(t, initResult.Capabilities.Resources.Subscribe) assert.True(t, initResult.Capabilities.Resources.ListChanged) assert.NotNil(t, initResult.Capabilities.Prompts) @@ -83,6 +84,43 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.NotNil(t, initResult.Capabilities.Tools) assert.True(t, initResult.Capabilities.Tools.ListChanged) + assert.NotNil(t, initResult.Capabilities.Logging) + }, + }, + { + name: "Specific capabilities", + options: []ServerOption{ + WithResourceCapabilities(true, false), + WithPromptCapabilities(true), + WithToolCapabilities(false), + WithLogging(), + }, + validate: func(t *testing.T, response mcp.JSONRPCMessage) { + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + + initResult, ok := resp.Result.(mcp.InitializeResult) + assert.True(t, ok) + + assert.Equal( + t, + mcp.LATEST_PROTOCOL_VERSION, + initResult.ProtocolVersion, + ) + assert.Equal(t, "test-server", initResult.ServerInfo.Name) + assert.Equal(t, "1.0.0", initResult.ServerInfo.Version) + + assert.NotNil(t, initResult.Capabilities.Resources) + + assert.True(t, initResult.Capabilities.Resources.Subscribe) + assert.False(t, initResult.Capabilities.Resources.ListChanged) + + assert.NotNil(t, initResult.Capabilities.Prompts) + assert.True(t, initResult.Capabilities.Prompts.ListChanged) + + assert.NotNil(t, initResult.Capabilities.Tools) + assert.False(t, initResult.Capabilities.Tools.ListChanged) + assert.NotNil(t, initResult.Capabilities.Logging) }, }, @@ -525,6 +563,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(true, true), WithPromptCapabilities(true), + WithToolCapabilities(true), ) // Add a test tool to enable tool capabilities @@ -600,14 +639,10 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { } func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) { - server := NewMCPServer( - "test-server", - "1.0.0", - ) - tests := []struct { name string message string + options []ServerOption expectedErr int }{ { @@ -620,6 +655,9 @@ func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) { "name": "test-tool" } }`, + options: []ServerOption{ + WithToolCapabilities(false), + }, expectedErr: mcp.METHOD_NOT_FOUND, }, { @@ -632,6 +670,9 @@ func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) { "name": "test-prompt" } }`, + options: []ServerOption{ + WithPromptCapabilities(false), + }, expectedErr: mcp.METHOD_NOT_FOUND, }, { @@ -644,12 +685,16 @@ func TestMCPServer_HandleMethodsWithoutCapabilities(t *testing.T) { "uri": "test-resource" } }`, + options: []ServerOption{ + WithResourceCapabilities(false, false), + }, expectedErr: mcp.METHOD_NOT_FOUND, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", tt.options...) response := server.HandleMessage( context.Background(), []byte(tt.message),