diff --git a/server/server.go b/server/server.go index 46e6d9c57..d40ba4b06 100644 --- a/server/server.go +++ b/server/server.go @@ -11,6 +11,7 @@ import ( "sync" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" ) // resourceEntry holds both a resource and its handler @@ -160,6 +161,7 @@ type MCPServer struct { paginationLimit *int sessions sync.Map hooks *Hooks + logger util.Logger } // WithPaginationLimit sets the pagination limit for the server. @@ -281,6 +283,12 @@ func WithLogging() ServerOption { } } +func WithMCPLogger(logger util.Logger) ServerOption { + return func(s *MCPServer) { + s.logger = logger + } +} + // WithInstructions sets the server instructions for the client returned in the initialize response func WithInstructions(instructions string) ServerOption { return func(s *MCPServer) { @@ -308,6 +316,7 @@ func NewMCPServer( prompts: nil, logging: nil, }, + logger: util.DefaultLogger(), } for _, opt := range opts { @@ -474,6 +483,9 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { s.toolsMu.Lock() for _, entry := range tools { + if _, exists := s.tools[entry.Tool.Name]; exists { + s.logger.Infof("Tool with name %q already exists. Overwriting.", entry.Tool.Name) + } s.tools[entry.Tool.Name] = entry } s.toolsMu.Unlock() diff --git a/server/server_race_test.go b/server/server_race_test.go index 4e0be43a8..35e334894 100644 --- a/server/server_race_test.go +++ b/server/server_race_test.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "math/rand" "sync" "testing" "time" @@ -56,7 +57,7 @@ func TestRaceConditions(t *testing.T) { }) runConcurrentOperation(&wg, testDuration, "add-tools", func() { - name := fmt.Sprintf("tool-%d", time.Now().UnixNano()) + name := fmt.Sprintf("tool-%d-%d", time.Now().UnixNano(), rand.Int()) srv.AddTool(mcp.Tool{ Name: name, Description: "Test tool", diff --git a/server/server_test.go b/server/server_test.go index 1c81d18dd..87c8fe827 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "context" "encoding/base64" "encoding/json" @@ -12,6 +13,7 @@ import ( "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -342,11 +344,50 @@ func TestMCPServer_Tools(t *testing.T) { assert.Equal(t, "test-tool-2", tools[1].Name) }, }, + { + name: "AddTools overwrites tool with same name", + action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { + err := server.RegisterSession(context.TODO(), &fakeSession{ + sessionID: "test", + notificationChannel: notificationChannel, + initialized: true, + }) + require.NoError(t, err) + server.AddTool( + mcp.NewTool("test-tool-dup"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }, + ) + // Add same tool name with different handler or data to test overwrite + server.AddTool( + mcp.NewTool("test-tool-dup"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{}, nil + }, + ) + }, + expectedNotifications: 2, // one per AddTool with active session + validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { + // Both adds must have triggered notifications + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) + assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[1].Method) + + tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools + assert.Len(t, tools, 1, "Expected only one tool after overwrite") + assert.Equal(t, "test-tool-dup", tools[0].Name) + }, + }, + } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + + var buf bytes.Buffer + logger := &testutil.TestLogger{Buf: &buf} + + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true), WithMCPLogger(logger)) _ = server.HandleMessage(ctx, []byte(`{ "jsonrpc": "2.0", "id": 1, @@ -373,6 +414,11 @@ func TestMCPServer_Tools(t *testing.T) { "method": "tools/list" }`)) tt.validate(t, notifications, toolsList) + + if tt.name == "AddTools overwrites tool with same name" { + logOutput := buf.String() + assert.Contains(t, logOutput, "already exists") + } }) } } diff --git a/testutil/testlogger.go b/testutil/testlogger.go new file mode 100644 index 000000000..b9a62b720 --- /dev/null +++ b/testutil/testlogger.go @@ -0,0 +1,18 @@ +package testutil + +import ( + "bytes" + "fmt" +) + +type TestLogger struct { + Buf *bytes.Buffer +} + +func (l *TestLogger) Infof(format string, v ...any) { + fmt.Fprintf(l.Buf, "INFO: "+format, v...) +} + +func (l *TestLogger) Errorf(format string, v ...any) { + fmt.Fprintf(l.Buf, "ERROR: "+format, v...) +}