diff --git a/README.md b/README.md index a35a3ebe0..65cfa447b 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ func main() { } } -func helloHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func helloHandler(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { name, err := request.RequireString("name") if err != nil { return mcp.NewToolResultError(err.Error()), nil @@ -92,6 +92,7 @@ MCP Go handles all the complex protocol details and server management, so you ca - [Resources](#resources) - [Tools](#tools) - [Prompts](#prompts) + - [RequestSession](#requestSession) - [Examples](#examples) - [Extras](#extras) - [Transports](#transports) @@ -153,7 +154,7 @@ func main() { ) // Add the calculator handler - s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + s.AddTool(calculatorTool, func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Using helper functions for type-safe argument access op, err := request.RequireString("operation") if err != nil { @@ -251,7 +252,7 @@ resource := mcp.NewResource( ) // Add resource with its handler -s.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { +s.AddResource(resource, func(ctx context.Context, requestSession server.RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { content, err := os.ReadFile("README.md") if err != nil { return nil, err @@ -279,7 +280,7 @@ template := mcp.NewResourceTemplate( ) // Add template with its handler -s.AddResourceTemplate(template, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { +s.AddResourceTemplate(template, func(ctx context.Context, requestSession server.RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { // Extract ID from the URI using regex matching // The server automatically matches URIs to templates userID := extractIDFromURI(request.Params.URI) @@ -328,7 +329,7 @@ calculatorTool := mcp.NewTool("calculate", ), ) -s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +s.AddTool(calculatorTool, func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { args := request.GetArguments() op := args["operation"].(string) x := args["x"].(float64) @@ -372,7 +373,7 @@ httpTool := mcp.NewTool("http_request", ), ) -s.AddTool(httpTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +s.AddTool(httpTool, func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { args := request.GetArguments() method := args["method"].(string) url := args["url"].(string) @@ -440,7 +441,7 @@ s.AddPrompt(mcp.NewPrompt("greeting", mcp.WithArgument("name", mcp.ArgumentDescription("Name of the person to greet"), ), -), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { +), func(ctx context.Context, requestSession server.RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { name := request.Params.Arguments["name"] if name == "" { name = "friend" @@ -464,7 +465,7 @@ s.AddPrompt(mcp.NewPrompt("code_review", mcp.ArgumentDescription("Pull request number to review"), mcp.RequiredArgument(), ), -), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { +), func(ctx context.Context, requestSession server.RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { prNumber := request.Params.Arguments["pr_number"] if prNumber == "" { return nil, fmt.Errorf("pr_number is required") @@ -495,7 +496,7 @@ s.AddPrompt(mcp.NewPrompt("query_builder", mcp.ArgumentDescription("Name of the table to query"), mcp.RequiredArgument(), ), -), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { +), func(ctx context.Context, requestSession server.RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { tableName := request.Params.Arguments["table"] if tableName == "" { return nil, fmt.Errorf("table name is required") @@ -530,6 +531,41 @@ Prompts can include: +### requestSession +
+Show RequestSession Examples + +The RequestSession object provides capabilities to interact with the client, such as sending logging notification and progress notification. + +```go +// Example of using RequestSession to send logging notifications and progress notifications +mcpServer.AddTool(mcp.NewTool( + "test-RequestSession", + mcp.WithDescription("test RequestSession"), +), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // you could invoke `requestSession.IsLoggingNotificationSupported()` first the check if server supports logging notification + // if server does not support logging notification, this method will do nothing. + _ = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + "testLog": "test send log notification", + }) + + // server should send progress notification if request metadata includes a progressToken + total := float64(100) + progressMessage := "human readable progress information" + _ = requestSession.SendProgressNotification(ctx, float64(50), &total, &progressMessage) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "context from header: " + ctx.Value(testHeaderKey).(string) + ", " + ctx.Value(testHeaderFuncKey).(string), + }, + }, + }, nil +}) +``` +
+ ## Examples For examples, see the [`examples/`](examples/) directory. @@ -723,7 +759,7 @@ s := server.NewMCPServer( The session context is automatically passed to tool and resource handlers: ```go -s.AddTool(mcp.NewTool("session_aware"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { +s.AddTool(mcp.NewTool("session_aware"), func(ctx context.Context, requestSession server.RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get the current session from context session := server.ClientSessionFromContext(ctx) if session == nil { diff --git a/client/http_test.go b/client/http_test.go index 3c2e6a3b7..b64c8944a 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -31,6 +31,7 @@ func TestHTTPClient(t *testing.T) { mcp.NewTool("notify"), func( ctx context.Context, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { server := server.ServerFromContext(ctx) diff --git a/client/inprocess_test.go b/client/inprocess_test.go index 7b150e81e..ec7520353 100644 --- a/client/inprocess_test.go +++ b/client/inprocess_test.go @@ -27,7 +27,7 @@ func TestInProcessMCPClient(t *testing.T) { mcp.WithDestructiveHintAnnotation(false), mcp.WithIdempotentHintAnnotation(true), mcp.WithOpenWorldHintAnnotation(false), - ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -48,7 +48,7 @@ func TestInProcessMCPClient(t *testing.T) { URI: "resource://testresource", Name: "My Resource", }, - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession server.RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "resource://testresource", @@ -70,7 +70,7 @@ func TestInProcessMCPClient(t *testing.T) { }, }, }, - func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + func(ctx context.Context, requestSession server.RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Messages: []mcp.PromptMessage{ { diff --git a/client/sse_test.go b/client/sse_test.go index f38c31b17..08bb8ad9c 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -2,7 +2,9 @@ package client import ( "context" + "github.com/stretchr/testify/assert" "net/http" + "sync" "testing" "time" @@ -27,6 +29,7 @@ func TestSSEMCPClient(t *testing.T) { server.WithResourceCapabilities(true, true), server.WithPromptCapabilities(true), server.WithToolCapabilities(true), + server.WithLogging(), ) // Add a test tool @@ -39,7 +42,7 @@ func TestSSEMCPClient(t *testing.T) { mcp.WithDestructiveHintAnnotation(false), mcp.WithIdempotentHintAnnotation(true), mcp.WithOpenWorldHintAnnotation(false), - ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -52,7 +55,7 @@ func TestSSEMCPClient(t *testing.T) { mcpServer.AddTool(mcp.NewTool( "test-tool-for-http-header", mcp.WithDescription("Test tool for http header"), - ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // , X-Test-Header-Func return &mcp.CallToolResult{ Content: []mcp.Content{ @@ -64,6 +67,126 @@ func TestSSEMCPClient(t *testing.T) { }, nil }) + mcpServer.AddTool(mcp.NewTool( + "test-tool-for-sending-notification", + mcp.WithDescription("Test tool for sending log notification, and the log level is warn"), + ), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + totalProgressValue := float64(100) + startFuncMessage := "start func" + err := requestSession.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + "filtered_log_message": "will be filtered by log level", + }) + if err != nil { + return nil, err + } + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ + "log_message": "log message value", + }) + if err != nil { + return nil, err + } + + startFuncMessage = "end func" + err = requestSession.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "result", + }, + }, + }, nil + }) + mcpServer.AddPrompt(mcp.Prompt{ + Name: "prompt_get_for_server_notification", + Description: "Test prompt", + }, func(ctx context.Context, requestSession server.RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + totalProgressValue := float64(100) + startFuncMessage := "start get prompt" + err := requestSession.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + "filtered_log_message": "will be filtered by log level", + }) + if err != nil { + return nil, err + } + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ + "log_message": "log message value", + }) + if err != nil { + return nil, err + } + + startFuncMessage = "end get prompt" + err = requestSession.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + + return &mcp.GetPromptResult{ + Messages: []mcp.PromptMessage{ + { + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "prompt value", + }, + }, + }, + }, nil + }) + + mcpServer.AddResource(mcp.Resource{ + URI: "resource://testresource", + Name: "My Resource", + }, func(ctx context.Context, requestSession server.RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + totalProgressValue := float64(100) + startFuncMessage := "start read resource" + err := requestSession.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + "filtered_log_message": "will be filtered by log level", + }) + if err != nil { + return nil, err + } + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ + "log_message": "log message value", + }) + if err != nil { + return nil, err + } + + startFuncMessage = "end read resource" + err = requestSession.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "resource://testresource", + MIMEType: "text/plain", + Text: "test content", + }, + }, nil + }) + // Initialize testServer := server.NewTestServer(mcpServer, server.WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context { @@ -330,4 +453,331 @@ func TestSSEMCPClient(t *testing.T) { t.Errorf("Got %q, want %q", result.Content[0].(mcp.TextContent).Text, "context from header: test-header-value, test-header-func-value") } }) + + t.Run("CallTool for testing log and progress notification", func(t *testing.T) { + client, err := NewSSEMCPClient(testServer.URL + "/sse") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + mu := sync.Mutex{} + notificationNum := 0 + var messageNotification *mcp.JSONRPCNotification + progressNotifications := make([]*mcp.JSONRPCNotification, 0) + client.OnNotification(func(notification mcp.JSONRPCNotification) { + mu.Lock() + defer mu.Unlock() + if notification.Method == string(mcp.MethodNotificationMessage) { + messageNotification = ¬ification + } else if notification.Method == string(mcp.MethodNotificationProgress) { + progressNotifications = append(progressNotifications, ¬ification) + } + notificationNum += 1 + }) + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + setLevelRequest := mcp.SetLevelRequest{} + setLevelRequest.Params.Level = mcp.LoggingLevelWarning + err = client.SetLevel(ctx, setLevelRequest) + if err != nil { + t.Errorf("SetLevel failed: %v", err) + } + + request := mcp.CallToolRequest{} + request.Params.Name = "test-tool-for-sending-notification" + request.Params.Meta = &mcp.Meta{ + ProgressToken: "progress_token", + } + + result, err := client.CallTool(ctx, request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + if len(result.Content) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Content)) + } + + time.Sleep(time.Millisecond * 500) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, notificationNum, 3) + assert.NotNil(t, messageNotification) + assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) + assert.Equal(t, messageNotification.Params.AdditionalFields["level"], "error") + assert.Equal(t, messageNotification.Params.AdditionalFields["data"], map[string]any{ + "log_message": "log message value", + }) + + assert.Len(t, progressNotifications, 2) + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[0].Method) + assert.Equal(t, "start func", progressNotifications[0].Params.AdditionalFields["message"]) + assert.EqualValues(t, 0, progressNotifications[0].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[0].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[0].Params.AdditionalFields["total"]) + + // Assert second progress notification (end func) + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[1].Method) + assert.Equal(t, "end func", progressNotifications[1].Params.AdditionalFields["message"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"]) + + }) + + t.Run("Ensure the server does not send notifications", func(t *testing.T) { + client, err := NewSSEMCPClient(testServer.URL + "/sse") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + mu := sync.Mutex{} + notifications := make([]*mcp.JSONRPCNotification, 0) + client.OnNotification(func(notification mcp.JSONRPCNotification) { + mu.Lock() + defer mu.Unlock() + notifications = append(notifications, ¬ification) + }) + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + setLevelRequest := mcp.SetLevelRequest{} + setLevelRequest.Params.Level = mcp.LoggingLevelCritical + err = client.SetLevel(ctx, setLevelRequest) + if err != nil { + t.Errorf("SetLevel failed: %v", err) + } + + // request param without progressToken + request := mcp.CallToolRequest{} + request.Params.Name = "test-tool-for-sending-notification" + + _, _ = client.CallTool(ctx, request) + time.Sleep(time.Millisecond * 500) + + mu.Lock() + defer mu.Unlock() + assert.Len(t, notifications, 0) + }) + + t.Run("GetPrompt for testing log and progress notification", func(t *testing.T) { + client, err := NewSSEMCPClient(testServer.URL + "/sse") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + mu := sync.Mutex{} + var messageNotification *mcp.JSONRPCNotification + progressNotifications := make([]*mcp.JSONRPCNotification, 0) + notificationNum := 0 + client.OnNotification(func(notification mcp.JSONRPCNotification) { + mu.Lock() + defer mu.Unlock() + if notification.Method == string(mcp.MethodNotificationMessage) { + messageNotification = ¬ification + } else if notification.Method == string(mcp.MethodNotificationProgress) { + progressNotifications = append(progressNotifications, ¬ification) + } + notificationNum += 1 + }) + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + setLevelRequest := mcp.SetLevelRequest{} + setLevelRequest.Params.Level = mcp.LoggingLevelWarning + err = client.SetLevel(ctx, setLevelRequest) + if err != nil { + t.Errorf("SetLevel failed: %v", err) + } + + request := mcp.GetPromptRequest{} + request.Params.Name = "prompt_get_for_server_notification" + request.Params.Meta = &mcp.Meta{ + ProgressToken: "progress_token", + } + + result, err := client.GetPrompt(ctx, request) + if err != nil { + t.Fatalf("GetPrompt failed: %v", err) + } + assert.NotNil(t, result) + assert.Len(t, result.Messages, 1) + assert.Equal(t, result.Messages[0].Role, mcp.RoleAssistant) + assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Type, "text") + assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Text, "prompt value") + time.Sleep(time.Millisecond * 500) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, notificationNum, 3) + assert.NotNil(t, messageNotification) + assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) + assert.Equal(t, messageNotification.Params.AdditionalFields["level"], "error") + assert.Equal(t, messageNotification.Params.AdditionalFields["data"], map[string]any{ + "log_message": "log message value", + }) + + assert.Len(t, progressNotifications, 2) + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[0].Method) + assert.Equal(t, "start get prompt", progressNotifications[0].Params.AdditionalFields["message"]) + assert.EqualValues(t, 0, progressNotifications[0].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[0].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[0].Params.AdditionalFields["total"]) + + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[1].Method) + assert.Equal(t, "end get prompt", progressNotifications[1].Params.AdditionalFields["message"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"]) + }) + + t.Run("GetResource for testing log and progress notification", func(t *testing.T) { + client, err := NewSSEMCPClient(testServer.URL + "/sse") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + mu := sync.Mutex{} + var messageNotification *mcp.JSONRPCNotification + progressNotifications := make([]*mcp.JSONRPCNotification, 0) + notificationNum := 0 + client.OnNotification(func(notification mcp.JSONRPCNotification) { + mu.Lock() + defer mu.Unlock() + if notification.Method == string(mcp.MethodNotificationMessage) { + messageNotification = ¬ification + } else if notification.Method == string(mcp.MethodNotificationProgress) { + progressNotifications = append(progressNotifications, ¬ification) + } + notificationNum += 1 + }) + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + setLevelRequest := mcp.SetLevelRequest{} + setLevelRequest.Params.Level = mcp.LoggingLevelWarning + err = client.SetLevel(ctx, setLevelRequest) + if err != nil { + t.Errorf("SetLevel failed: %v", err) + } + + request := mcp.ReadResourceRequest{} + request.Params.URI = "resource://testresource" + request.Params.Meta = &mcp.Meta{ + ProgressToken: "progress_token", + } + + result, err := client.ReadResource(ctx, request) + if err != nil { + t.Fatalf("ReadResource failed: %v", err) + } + + assert.NotNil(t, result) + assert.Len(t, result.Contents, 1) + assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).URI, "resource://testresource") + assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).MIMEType, "text/plain") + assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).Text, "test content") + + time.Sleep(time.Millisecond * 500) + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, notificationNum, 3) + assert.NotNil(t, messageNotification) + assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) + assert.Equal(t, messageNotification.Params.AdditionalFields["level"], "error") + assert.Equal(t, messageNotification.Params.AdditionalFields["data"], map[string]any{ + "log_message": "log message value", + }) + + assert.Len(t, progressNotifications, 2) + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[0].Method) + assert.Equal(t, "start read resource", progressNotifications[0].Params.AdditionalFields["message"]) + assert.EqualValues(t, 0, progressNotifications[0].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[0].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[0].Params.AdditionalFields["total"]) + + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[1].Method) + assert.Equal(t, "end read resource", progressNotifications[1].Params.AdditionalFields["message"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"]) + }) } diff --git a/examples/custom_context/main.go b/examples/custom_context/main.go index e41ab8db7..88136a2c6 100644 --- a/examples/custom_context/main.go +++ b/examples/custom_context/main.go @@ -79,6 +79,7 @@ func makeRequest(ctx context.Context, message, token string) (*response, error) // using the token from the context. func handleMakeAuthenticatedRequestTool( ctx context.Context, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { message, ok := request.GetArguments()["message"].(string) diff --git a/examples/dynamic_path/main.go b/examples/dynamic_path/main.go index 80d96789a..a38175245 100644 --- a/examples/dynamic_path/main.go +++ b/examples/dynamic_path/main.go @@ -19,7 +19,7 @@ func main() { mcpServer := server.NewMCPServer("dynamic-path-example", "1.0.0") // Add a trivial tool for demonstration - mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, requestSession server.RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.GetArguments()["message"])), nil }) diff --git a/examples/everything/main.go b/examples/everything/main.go index 5489220c3..1c3ccd566 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -191,6 +191,7 @@ func generateResources() []mcp.Resource { func handleReadResource( ctx context.Context, + requestSession server.RequestSession, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ @@ -204,6 +205,7 @@ func handleReadResource( func handleResourceTemplate( ctx context.Context, + requestSession server.RequestSession, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ @@ -217,6 +219,7 @@ func handleResourceTemplate( func handleGeneratedResource( ctx context.Context, + requestSession server.RequestSession, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { uri := request.Params.URI @@ -254,6 +257,7 @@ func handleGeneratedResource( func handleSimplePrompt( ctx context.Context, + requestSession server.RequestSession, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ @@ -272,6 +276,7 @@ func handleSimplePrompt( func handleComplexPrompt( ctx context.Context, + requestSession server.RequestSession, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { arguments := request.Params.Arguments @@ -310,6 +315,7 @@ func handleComplexPrompt( func handleEchoTool( ctx context.Context, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -329,6 +335,7 @@ func handleEchoTool( func handleAddTool( ctx context.Context, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -350,6 +357,7 @@ func handleAddTool( func handleSendNotification( ctx context.Context, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { @@ -380,6 +388,7 @@ func handleSendNotification( func handleLongRunningOperationTool( ctx context.Context, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -445,6 +454,7 @@ func handleLongRunningOperationTool( func handleGetTinyImageTool( ctx context.Context, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ diff --git a/examples/typed_tools/main.go b/examples/typed_tools/main.go index f9bd3c21e..90348e425 100644 --- a/examples/typed_tools/main.go +++ b/examples/typed_tools/main.go @@ -64,7 +64,7 @@ func main() { ) // Add tool handler using the typed handler - s.AddTool(tool, mcp.NewTypedToolHandler(typedGreetingHandler)) + s.AddTool(tool, server.NewTypedToolHandler(typedGreetingHandler)) // Start the stdio server if err := server.ServeStdio(s); err != nil { @@ -73,7 +73,7 @@ func main() { } // Our typed handler function that receives strongly-typed arguments -func typedGreetingHandler(ctx context.Context, request mcp.CallToolRequest, args GreetingArgs) (*mcp.CallToolResult, error) { +func typedGreetingHandler(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest, args GreetingArgs) (*mcp.CallToolResult, error) { if args.Name == "" { return mcp.NewToolResultError("name is required"), nil } diff --git a/mcp/prompts.go b/mcp/prompts.go index a63a21450..38880d8ef 100644 --- a/mcp/prompts.go +++ b/mcp/prompts.go @@ -27,6 +27,9 @@ type GetPromptParams struct { Name string `json:"name"` // Arguments to use for templating the prompt. Arguments map[string]string `json:"arguments,omitempty"` + // Meta is metadata attached to a request's parameters. This can include fields + // formally defined by the protocol or other arbitrary data. + Meta *Meta `json:"_meta,omitempty"` } // GetPromptResult is the server's response to a prompts/get request from the diff --git a/mcp/typed_tools.go b/mcp/typed_tools.go deleted file mode 100644 index 68d8cdd1f..000000000 --- a/mcp/typed_tools.go +++ /dev/null @@ -1,20 +0,0 @@ -package mcp - -import ( - "context" - "fmt" -) - -// TypedToolHandlerFunc is a function that handles a tool call with typed arguments -type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error) - -// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct -func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { - return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { - var args T - if err := request.BindArguments(&args); err != nil { - return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil - } - return handler(ctx, request, args) - } -} diff --git a/mcp/types.go b/mcp/types.go index 0091d2e42..a83e9ebb9 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -56,17 +56,40 @@ const ( // MethodNotificationResourcesListChanged notifies when the list of available resources changes. // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification - MethodNotificationResourcesListChanged = "notifications/resources/list_changed" + MethodNotificationResourcesListChanged MCPMethod = "notifications/resources/list_changed" - MethodNotificationResourceUpdated = "notifications/resources/updated" + MethodNotificationResourceUpdated MCPMethod = "notifications/resources/updated" // MethodNotificationPromptsListChanged notifies when the list of available prompt templates changes. // https://modelcontextprotocol.io/specification/2025-03-26/server/prompts#list-changed-notification - MethodNotificationPromptsListChanged = "notifications/prompts/list_changed" + MethodNotificationPromptsListChanged MCPMethod = "notifications/prompts/list_changed" // MethodNotificationToolsListChanged notifies when the list of available tools changes. // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/list_changed/ - MethodNotificationToolsListChanged = "notifications/tools/list_changed" + MethodNotificationToolsListChanged MCPMethod = "notifications/tools/list_changed" + + // MethodNotificationMessage notifies when severs send log messages. + // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#log-message-notifications + MethodNotificationMessage MCPMethod = "notifications/message" + + // MethodNotificationProgress notifies progress updates for long-running operations + // https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/progress + MethodNotificationProgress MCPMethod = "notifications/progress" + + // MethodNotificationCancellation can be sent by either side to indicate that it is + // cancelling a previously-issued request. + // + // The request SHOULD still be in-flight, but due to communication latency, it + // is always possible that this notification MAY arrive after the request has + // already finished. + // + // This notification indicates that the result will be unused, so any + // associated processing SHOULD cease. + // + // A client MUST NOT attempt to cancel its `initialize` request. + // + // https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation + MethodNotificationCancellation MCPMethod = "notifications/cancelled" ) type URITemplate struct { @@ -576,6 +599,9 @@ type ReadResourceParams struct { URI string `json:"uri"` // Arguments to pass to the resource handler Arguments map[string]any `json:"arguments,omitempty"` + // Meta is metadata attached to a request's parameters. This can include fields + // formally defined by the protocol or other arbitrary data. + Meta *Meta `json:"_meta,omitempty"` } // ReadResourceResult is the server's response to a resources/read request @@ -761,6 +787,29 @@ const ( LoggingLevelEmergency LoggingLevel = "emergency" ) +var ( + levelToSeverity = func() map[LoggingLevel]int { + return map[LoggingLevel]int{ + LoggingLevelEmergency: 0, + LoggingLevelAlert: 1, + LoggingLevelCritical: 2, + LoggingLevelError: 3, + LoggingLevelWarning: 4, + LoggingLevelNotice: 5, + LoggingLevelInfo: 6, + LoggingLevelDebug: 7, + } + }() +) + +// Allows is a helper function that decides a message could be sent to client or not according to the logging level +func (subscribedLevel LoggingLevel) Allows(currentLevel LoggingLevel) (bool, error) { + if _, ok := levelToSeverity[currentLevel]; !ok { + return false, fmt.Errorf("illegal message logging level:%s", currentLevel) + } + return levelToSeverity[subscribedLevel] >= levelToSeverity[currentLevel], nil +} + /* Sampling */ // CreateMessageRequest is a request from the server to sample an LLM via the diff --git a/mcp/utils.go b/mcp/utils.go index 55bef7a99..dcdb8e4a7 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -22,8 +22,6 @@ var _ ClientRequest = &CallToolRequest{} var _ ClientRequest = &ListToolsRequest{} // ClientNotification types -var _ ClientNotification = &CancelledNotification{} -var _ ClientNotification = &ProgressNotification{} var _ ClientNotification = &InitializedNotification{} var _ ClientNotification = &RootsListChangedNotification{} @@ -38,9 +36,6 @@ var _ ServerRequest = &CreateMessageRequest{} var _ ServerRequest = &ListRootsRequest{} // ServerNotification types -var _ ServerNotification = &CancelledNotification{} -var _ ServerNotification = &ProgressNotification{} -var _ ServerNotification = &LoggingMessageNotification{} var _ ServerNotification = &ResourceUpdatedNotification{} var _ ServerNotification = &ResourceListChangedNotification{} var _ ServerNotification = &ToolListChangedNotification{} @@ -131,60 +126,6 @@ func NewJSONRPCError( } } -// NewProgressNotification -// Helper function for creating a progress notification -func NewProgressNotification( - token ProgressToken, - progress float64, - total *float64, - message *string, -) ProgressNotification { - notification := ProgressNotification{ - Notification: Notification{ - Method: "notifications/progress", - }, - Params: struct { - ProgressToken ProgressToken `json:"progressToken"` - Progress float64 `json:"progress"` - Total float64 `json:"total,omitempty"` - Message string `json:"message,omitempty"` - }{ - ProgressToken: token, - Progress: progress, - }, - } - if total != nil { - notification.Params.Total = *total - } - if message != nil { - notification.Params.Message = *message - } - return notification -} - -// NewLoggingMessageNotification -// Helper function for creating a logging message notification -func NewLoggingMessageNotification( - level LoggingLevel, - logger string, - data any, -) LoggingMessageNotification { - return LoggingMessageNotification{ - Notification: Notification{ - Method: "notifications/message", - }, - Params: struct { - Level LoggingLevel `json:"level"` - Logger string `json:"logger,omitempty"` - Data any `json:"data"` - }{ - Level: level, - Logger: logger, - Data: data, - }, - } -} - // NewPromptMessage // Helper function to create a new PromptMessage func NewPromptMessage(role Role, content Content) PromptMessage { diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go index 129ef39ff..75dd632d0 100644 --- a/mcptest/mcptest_test.go +++ b/mcptest/mcptest_test.go @@ -50,7 +50,7 @@ func TestServerWithTool(t *testing.T) { } } -func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func helloWorldHandler(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Extract name from request arguments name, ok := request.GetArguments()["name"].(string) if !ok { diff --git a/server/internal/gen/request_handler.go.tmpl b/server/internal/gen/request_handler.go.tmpl index 7e4a68a05..89e9d740e 100644 --- a/server/internal/gen/request_handler.go.tmpl +++ b/server/internal/gen/request_handler.go.tmpl @@ -55,7 +55,9 @@ func (s *MCPServer) HandleMessage( } s.handleNotification(ctx, notification) return nil // Return nil for notifications - } + } else { + ctx = context.WithValue(ctx, requestIDKey{}, mcp.NewRequestId(baseMessage.ID)) + } if baseMessage.Result != nil { // this is a response to a request sent by the server (e.g. from a ping diff --git a/server/request_handler.go b/server/request_handler.go index 25f6ef14f..519799b84 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -54,6 +54,8 @@ func (s *MCPServer) HandleMessage( } s.handleNotification(ctx, notification) return nil // Return nil for notifications + } else { + ctx = context.WithValue(ctx, requestIDKey{}, mcp.NewRequestId(baseMessage.ID)) } if baseMessage.Result != nil { diff --git a/server/request_session.go b/server/request_session.go new file mode 100644 index 000000000..80c3c0217 --- /dev/null +++ b/server/request_session.go @@ -0,0 +1,118 @@ +package server + +import ( + "context" + "fmt" + "github.com/mark3labs/mcp-go/mcp" +) + +type requestIDKey struct{} + +// RequestSession represents an exchange with MCP client, and provides +// methods to interact with the client and query its capabilities. +type RequestSession struct { + progressToken *mcp.ProgressToken +} + +func NewRequestSession(requestParamMeta *mcp.Meta) RequestSession { + requestSession := RequestSession{} + + // server should send progress notification if request metadata includes a progressToken + if requestParamMeta != nil && requestParamMeta.ProgressToken != nil { + requestSession.progressToken = &requestParamMeta.ProgressToken + } + + return requestSession +} + +// IsLoggingNotificationSupported returns true if server supports logging notification +func (exchange *RequestSession) IsLoggingNotificationSupported(ctx context.Context) bool { + mcpServer := ServerFromContext(ctx) + return mcpServer != nil && mcpServer.capabilities.logging != nil && *mcpServer.capabilities.logging +} + +// SendLoggingNotification send logging notification to client. +// If server does not support logging notification, this method will do nothing. +func (exchange *RequestSession) SendLoggingNotification(ctx context.Context, level mcp.LoggingLevel, message map[string]any) error { + if !exchange.IsLoggingNotificationSupported(ctx) { + return nil + } + + clientLogLevel := ClientSessionFromContext(ctx).GetLogLevel() + allowed, err := clientLogLevel.Allows(level) + if err != nil { + return err + } + if !allowed { + return nil + } + + params := map[string]any{ + "level": level, + "data": message, + } + if ClientSessionFromContext(ctx).GetLoggerName() != nil { + params["logger"] = *ClientSessionFromContext(ctx).GetLoggerName() + } + + mcpServer := ServerFromContext(ctx) + return mcpServer.SendNotificationToClient( + ctx, + string(mcp.MethodNotificationMessage), + params, + ) +} + +// SendProgressNotification send progress notification only if the client has requested progress +func (exchange *RequestSession) SendProgressNotification(ctx context.Context, progress float64, total *float64, message *string) error { + if exchange.progressToken == nil { + return nil + } + + params := map[string]any{ + "progress": progress, + "progressToken": *exchange.progressToken, + } + if total != nil { + params["total"] = *total + } + if message != nil { + params["message"] = *message + } + + mcpServer := ServerFromContext(ctx) + return mcpServer.SendNotificationToClient( + ctx, + string(mcp.MethodNotificationProgress), + params, + ) +} + +// SendCancellationNotification send cancellation notification to client +func (exchange *RequestSession) SendCancellationNotification(ctx context.Context, reason *string) error { + requestIDRawValue := ctx.Value(requestIDKey{}) + if requestIDRawValue == nil { + return fmt.Errorf("invalid requestID") + } + + requestID, ok := requestIDRawValue.(mcp.RequestId) + if !ok { + return fmt.Errorf("invalid requestID type") + } + + params := map[string]any{ + "requestId": requestID.Value(), + } + if reason != nil { + params["reason"] = reason + } + + mcpServer := ServerFromContext(ctx) + return mcpServer.SendNotificationToClient( + ctx, + string(mcp.MethodNotificationCancellation), + params, + ) +} + +// TODO should implement other methods like 'roots/list', this still could happen when server handle client request diff --git a/server/resource_test.go b/server/resource_test.go index 05a3b2793..9b32c29fe 100644 --- a/server/resource_test.go +++ b/server/resource_test.go @@ -28,7 +28,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", @@ -47,7 +47,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 2"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource2", @@ -84,7 +84,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { expectedNotifications: 1, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { // Check that we received a list_changed notification - assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationResourcesListChanged), notifications[0].Method) // Verify we now have only one resource resp, ok := resourcesList.(mcp.JSONRPCResponse) @@ -108,7 +108,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", @@ -164,7 +164,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", diff --git a/server/server.go b/server/server.go index 46e6d9c57..d196a441a 100644 --- a/server/server.go +++ b/server/server.go @@ -29,16 +29,16 @@ type resourceTemplateEntry struct { type ServerOption func(*MCPServer) // ResourceHandlerFunc is a function that returns resource contents. -type ResourceHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) +type ResourceHandlerFunc func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) // ResourceTemplateHandlerFunc is a function that returns a resource template. -type ResourceTemplateHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) +type ResourceTemplateHandlerFunc func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) // PromptHandlerFunc handles prompt requests with given arguments. -type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) +type PromptHandlerFunc func(ctx context.Context, requestSession RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) // ToolHandlerFunc handles tool calls with given arguments. -type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) +type ToolHandlerFunc func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) // ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc @@ -230,7 +230,7 @@ func WithToolFilter( // WithRecovery adds a middleware that recovers from panics in tool handlers. func WithRecovery() ServerOption { return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc { - return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { + return func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf( @@ -240,7 +240,7 @@ func WithRecovery() ServerOption { ) } }() - return next(ctx, request) + return next(ctx, requestSession, request) } }) } @@ -333,7 +333,7 @@ func (s *MCPServer) AddResources(resources ...ServerResource) { // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification if s.capabilities.resources.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationResourcesListChanged), nil) } } @@ -356,7 +356,7 @@ func (s *MCPServer) RemoveResource(uri string) { // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationResourcesListChanged), nil) } } @@ -377,7 +377,7 @@ func (s *MCPServer) AddResourceTemplate( // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification if s.capabilities.resources.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationResourcesListChanged), nil) } } @@ -395,7 +395,7 @@ func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) { // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification. if s.capabilities.prompts.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationPromptsListChanged), nil) } } @@ -420,7 +420,7 @@ func (s *MCPServer) DeletePrompts(names ...string) { // Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationPromptsListChanged), nil) } } @@ -481,7 +481,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. if s.capabilities.tools.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationToolsListChanged), nil) } } @@ -508,7 +508,7 @@ func (s *MCPServer) DeleteTools(names ...string) { // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. if exists && s.capabilities.tools != nil && s.capabilities.tools.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationToolsListChanged), nil) } } @@ -613,15 +613,6 @@ func (s *MCPServer) handleSetLevel( } } - sessionLogging, ok := clientSession.(SessionWithLogging) - if !ok { - return nil, &requestError{ - id: id, - code: mcp.INTERNAL_ERROR, - err: ErrSessionDoesNotSupportLogging, - } - } - level := request.Params.Level // Validate logging level switch level { @@ -637,7 +628,7 @@ func (s *MCPServer) handleSetLevel( } } - sessionLogging.SetLogLevel(level) + clientSession.SetLogLevel(level) return &mcp.EmptyResult{}, nil } @@ -757,12 +748,15 @@ func (s *MCPServer) handleReadResource( id any, request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, *requestError) { + + requestSession := NewRequestSession(request.Params.Meta) + s.resourcesMu.RLock() // First try direct resource handlers if entry, ok := s.resources[request.Params.URI]; ok { handler := entry.handler s.resourcesMu.RUnlock() - contents, err := handler(ctx, request) + contents, err := handler(ctx, requestSession, request) if err != nil { return nil, &requestError{ id: id, @@ -793,7 +787,7 @@ func (s *MCPServer) handleReadResource( s.resourcesMu.RUnlock() if matched { - contents, err := matchedHandler(ctx, request) + contents, err := matchedHandler(ctx, requestSession, request) if err != nil { return nil, &requestError{ id: id, @@ -875,7 +869,8 @@ func (s *MCPServer) handleGetPrompt( } } - result, err := handler(ctx, request) + requestSession := NewRequestSession(request.Params.Meta) + result, err := handler(ctx, requestSession, request) if err != nil { return nil, &requestError{ id: id, @@ -1025,7 +1020,8 @@ func (s *MCPServer) handleToolCall( finalHandler = mw[i](finalHandler) } - result, err := finalHandler(ctx, request) + requestSession := NewRequestSession(request.Params.Meta) + result, err := finalHandler(ctx, requestSession, request) if err != nil { return nil, &requestError{ id: id, diff --git a/server/server_race_test.go b/server/server_race_test.go index 4e0be43a8..26ef7686d 100644 --- a/server/server_race_test.go +++ b/server/server_race_test.go @@ -39,7 +39,7 @@ func TestRaceConditions(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: name, Description: "Test prompt", - }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) }) @@ -49,7 +49,7 @@ func TestRaceConditions(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: name, Description: "Temporary prompt", - }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) srv.DeletePrompts(name) @@ -60,7 +60,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: name, Description: "Test tool", - }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) }) @@ -71,7 +71,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: name, Description: "Temporary tool", - }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) srv.DeleteTools(name) @@ -79,8 +79,8 @@ func TestRaceConditions(t *testing.T) { runConcurrentOperation(&wg, testDuration, "add-middleware", func() { middleware := func(next ToolHandlerFunc) ToolHandlerFunc { - return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return next(ctx, req) + return func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return next(ctx, requestSession, req) } } WithToolHandlerMiddleware(middleware)(srv) @@ -102,7 +102,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: "persistent-tool", Description: "Test tool that always exists", - }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) @@ -121,7 +121,7 @@ func TestRaceConditions(t *testing.T) { URI: uri, Name: uri, Description: "Test resource", - }, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: uri, @@ -169,12 +169,12 @@ func TestConcurrentPromptAdd(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: "initial-prompt", Description: "Initial prompt", - }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { go func() { srv.AddPrompt(mcp.Prompt{ Name: fmt.Sprintf("new-prompt-%d", time.Now().UnixNano()), Description: "Added from handler", - }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) }() diff --git a/server/server_test.go b/server/server_test.go index 1c81d18dd..6af3a63fd 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -160,12 +160,12 @@ func TestMCPServer_Tools(t *testing.T) { action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -189,19 +189,19 @@ func TestMCPServer_Tools(t *testing.T) { require.NoError(t, err) server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) }, expectedNotifications: 1, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[0].Method) tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -230,12 +230,12 @@ func TestMCPServer_Tools(t *testing.T) { } server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -243,7 +243,7 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 5, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { for _, notification := range notifications { - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notification.Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notification.Method) } tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) @@ -262,21 +262,21 @@ func TestMCPServer_Tools(t *testing.T) { require.NoError(t, err) server.AddTool( mcp.NewTool("test-tool-1"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) server.AddTool( mcp.NewTool("test-tool-2"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) }, expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[1].Method) tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -300,9 +300,9 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { // One for SetTools - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[0].Method) // One for DeleteTools - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[1].Method) // Expect a successful response with an empty list of tools resp, ok := toolsList.(mcp.JSONRPCResponse) @@ -333,7 +333,7 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 1, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { // Only one notification expected for SetTools - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[0].Method) // Confirm the tool list does not change tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools @@ -706,7 +706,7 @@ func TestMCPServer_PromptHandling(t *testing.T) { server.AddPrompt( testPrompt, - func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Messages: []mcp.PromptMessage{ { @@ -843,9 +843,9 @@ func TestMCPServer_Prompts(t *testing.T) { expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { // One for AddPrompt - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[0].Method) // One for DeletePrompts - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[1].Method) // Expect a successful response with an empty list of prompts resp, ok := promptsList.(mcp.JSONRPCResponse) @@ -898,11 +898,11 @@ func TestMCPServer_Prompts(t *testing.T) { expectedNotifications: 3, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { // first notification expected for AddPrompt test-prompt-1 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[0].Method) // second notification expected for AddPrompt test-prompt-2 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[1].Method) // second notification expected for DeletePrompts test-prompt-1 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[2].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[2].Method) // Confirm the prompt list does not change prompts := promptsList.(mcp.JSONRPCResponse).Result.(mcp.ListPromptsResult).Prompts @@ -951,9 +951,9 @@ func TestMCPServer_Prompts(t *testing.T) { expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { // first notification expected for AddPrompt test-prompt-1 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[0].Method) // second notification expected for AddPrompt test-prompt-2 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[1].Method) // Confirm the prompt list does not change prompts := promptsList.(mcp.JSONRPCResponse).Result.(mcp.ListPromptsResult).Prompts @@ -1121,7 +1121,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { IdempotentHint: mcp.ToBoolPtr(false), OpenWorldHint: mcp.ToBoolPtr(false), }, - }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) @@ -1383,7 +1383,7 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { "test://{a}/test-resource{/b*}", "My Resource", ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { a := request.Params.Arguments["a"].([]string) b := request.Params.Arguments["b"].([]string) // Validate that the template arguments are passed correctly to the handler @@ -1470,7 +1470,7 @@ func createTestServer() *MCPServer { URI: "resource://testresource", Name: "My Resource", }, - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "resource://testresource", @@ -1486,7 +1486,7 @@ func createTestServer() *MCPServer { Name: "test-tool", Description: "Test tool", }, - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -1522,6 +1522,16 @@ func (f fakeSession) Initialized() bool { return f.initialized } +func (f fakeSession) SetLogLevel(level mcp.LoggingLevel) {} + +func (f fakeSession) GetLogLevel() mcp.LoggingLevel { + return mcp.LoggingLevelError +} + +func (f fakeSession) GetLoggerName() *string { + return nil +} + var _ ClientSession = fakeSession{} func TestMCPServer_WithHooks(t *testing.T) { @@ -1622,7 +1632,7 @@ func TestMCPServer_WithHooks(t *testing.T) { // Add a test tool server.AddTool( mcp.NewTool("test-tool"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) @@ -1793,7 +1803,7 @@ func TestMCPServer_SessionHooks_NilHooks(t *testing.T) { } func TestMCPServer_WithRecover(t *testing.T) { - panicToolHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + panicToolHandler := func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { panic("test panic") } diff --git a/server/session.go b/server/session.go index a79da22ca..340e07772 100644 --- a/server/session.go +++ b/server/session.go @@ -17,15 +17,12 @@ type ClientSession interface { NotificationChannel() chan<- mcp.JSONRPCNotification // SessionID is a unique identifier used to track user session. SessionID() string -} - -// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level -type SessionWithLogging interface { - ClientSession // SetLogLevel sets the minimum log level SetLogLevel(level mcp.LoggingLevel) // GetLogLevel retrieves the minimum log level GetLogLevel() mcp.LoggingLevel + // GetLoggerName returns the logger name + GetLoggerName() *string } // SessionWithTools is an extension of ClientSession that can store session-specific tool data diff --git a/server/session_test.go b/server/session_test.go index 3067f4e9c..1eb5e7ca9 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -42,6 +42,17 @@ func (f sessionTestClient) Initialized() bool { return f.initialized } +func (f sessionTestClient) SetLogLevel(level mcp.LoggingLevel) { +} + +func (f sessionTestClient) GetLogLevel() mcp.LoggingLevel { + return mcp.LoggingLevelError +} + +func (f sessionTestClient) GetLoggerName() *string { + return nil +} + // sessionTestClientWithTools implements the SessionWithTools interface for testing type sessionTestClientWithTools struct { sessionID string @@ -67,6 +78,17 @@ func (f *sessionTestClientWithTools) Initialized() bool { return f.initialized } +func (f *sessionTestClientWithTools) SetLogLevel(level mcp.LoggingLevel) { +} + +func (f *sessionTestClientWithTools) GetLogLevel() mcp.LoggingLevel { + return mcp.LoggingLevelError +} + +func (f *sessionTestClientWithTools) GetLoggerName() *string { + return nil +} + func (f *sessionTestClientWithTools) GetSessionTools() map[string]ServerTool { f.mu.RLock() defer f.mu.RUnlock() @@ -137,6 +159,17 @@ func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implement f.clientInfo.Store(clientInfo) } +func (f *sessionTestClientWithClientInfo) SetLogLevel(level mcp.LoggingLevel) { +} + +func (f *sessionTestClientWithClientInfo) GetLogLevel() mcp.LoggingLevel { + return mcp.LoggingLevelError +} + +func (f *sessionTestClientWithClientInfo) GetLoggerName() *string { + return nil +} + // sessionTestClientWithTools implements the SessionWithLogging interface for testing type sessionTestClientWithLogging struct { sessionID string @@ -172,11 +205,14 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (f *sessionTestClientWithLogging) GetLoggerName() *string { + return nil +} + // Verify that all implementations satisfy their respective interfaces var ( _ ClientSession = (*sessionTestClient)(nil) _ SessionWithTools = (*sessionTestClientWithTools)(nil) - _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) _ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil) ) @@ -186,7 +222,7 @@ func TestSessionWithTools_Integration(t *testing.T) { // Create session-specific tools sessionTool := ServerTool{ Tool: mcp.NewTool("session-tool"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("session-tool result"), nil }, } @@ -235,7 +271,7 @@ func TestSessionWithTools_Integration(t *testing.T) { require.NotNil(t, tool, "Session tool should not be nil") // Now test calling directly with the handler - result, err := tool.Handler(sessionCtx, testReq) + result, err := tool.Handler(sessionCtx, RequestSession{}, testReq) require.NoError(t, err, "No error calling session tool handler directly") require.NotNil(t, result, "Result should not be nil") require.Len(t, result.Content, 1, "Result should have one content item") @@ -355,7 +391,7 @@ func TestMCPServer_AddSessionTool(t *testing.T) { err = server.AddSessionTool( session.SessionID(), mcp.NewTool("session-tool-helper"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("helper result"), nil }, ) @@ -560,7 +596,7 @@ func TestMCPServer_CallSessionTool(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) // Add global tool - server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("global result"), nil }) @@ -580,7 +616,7 @@ func TestMCPServer_CallSessionTool(t *testing.T) { err = server.AddSessionTool( session.SessionID(), mcp.NewTool("test_tool"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("session result"), nil }, ) diff --git a/server/sse.go b/server/sse.go index 416995730..194de674e 100644 --- a/server/sse.go +++ b/server/sse.go @@ -28,8 +28,10 @@ type sseSession struct { notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool loggingLevel atomic.Value - tools sync.Map // stores session-specific tools - clientInfo atomic.Value // stores session-specific client info + // FIXME assign logger name in a proper way + loggerName *string + tools sync.Map // stores session-specific tools + clientInfo atomic.Value // stores session-specific client info } // SSEContextFunc is a function that takes an existing context and the current @@ -74,6 +76,10 @@ func (s *sseSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (s *sseSession) GetLoggerName() *string { + return s.loggerName +} + func (s *sseSession) GetSessionTools() map[string]ServerTool { tools := make(map[string]ServerTool) s.tools.Range(func(key, value any) bool { @@ -111,7 +117,6 @@ func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { var ( _ ClientSession = (*sseSession)(nil) _ SessionWithTools = (*sseSession)(nil) - _ SessionWithLogging = (*sseSession)(nil) _ SessionWithClientInfo = (*sseSession)(nil) ) @@ -516,7 +521,8 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { var message string if eventData, err := json.Marshal(response); err != nil { // If there is an error marshalling the response, send a generic error response - log.Printf("failed to marshal response: %v", err) + marshal, _ := json.Marshal(response) + log.Printf("failed to marshal response: %v, response %s", err, string(marshal)) message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n" } else { message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) diff --git a/server/sse_test.go b/server/sse_test.go index 96912be49..c9789f121 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -574,7 +574,7 @@ func TestSSEServer(t *testing.T) { WithResourceCapabilities(true, true), ) // Add a tool which uses the context function. - mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Note this is agnostic to the transport type i.e. doesn't know about request headers. testVal := testValFromContext(ctx) return mcp.NewToolResultText(testVal), nil @@ -1183,7 +1183,7 @@ func TestSSEServer(t *testing.T) { Title: "Test Tool", }, }, - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("test"), nil }, }, @@ -1335,7 +1335,7 @@ func TestSSEServer(t *testing.T) { processingCompleted := make(chan struct{}) processingStarted := make(chan struct{}) - mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { close(processingStarted) // signal for processing started select { diff --git a/server/stdio.go b/server/stdio.go index dabe9c160..6b18d4761 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -54,7 +54,9 @@ type stdioSession struct { notifications chan mcp.JSONRPCNotification initialized atomic.Bool loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info + // FIXME assign logger name in a proper way + loggerName *string + clientInfo atomic.Value // stores session-specific client info } func (s *stdioSession) SessionID() string { @@ -100,9 +102,12 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (f *stdioSession) GetLoggerName() *string { + return f.loggerName +} + var ( _ ClientSession = (*stdioSession)(nil) - _ SessionWithLogging = (*stdioSession)(nil) _ SessionWithClientInfo = (*stdioSession)(nil) ) diff --git a/server/stdio_test.go b/server/stdio_test.go index 8433fd0ac..477a799b9 100644 --- a/server/stdio_test.go +++ b/server/stdio_test.go @@ -140,7 +140,7 @@ func TestStdioServer(t *testing.T) { // Create server mcpServer := NewMCPServer("test", "1.0.0") // Add a tool which uses the context function. - mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Note this is agnostic to the transport type i.e. doesn't know about request headers. testVal := testValFromContext(ctx) return mcp.NewToolResultText(testVal), nil diff --git a/server/streamable_http.go b/server/streamable_http.go index fecc63d76..f92ccfb5a 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -540,7 +540,10 @@ type streamableHttpSession struct { sessionID string notificationChannel chan mcp.JSONRPCNotification // server -> client notifications tools *sessionToolsStore - upgradeToSSE atomic.Bool + loggingLevel atomic.Value + // FIXME assign logger name in a proper way + loggerName *string + upgradeToSSE atomic.Bool } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { @@ -569,6 +572,22 @@ func (s *streamableHttpSession) Initialized() bool { return true } +func (s *streamableHttpSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *streamableHttpSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +func (s *streamableHttpSession) GetLoggerName() *string { + return s.loggerName +} + var _ ClientSession = (*streamableHttpSession)(nil) func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool { diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 7474464f8..ad69e0b35 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -39,7 +39,7 @@ var initRequest = map[string]any{ func addSSETool(mcpServer *MCPServer) { mcpServer.AddTool(mcp.Tool{ Name: "sseTool", - }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Send notification to client server := ServerFromContext(ctx) for i := 0; i < 10; i++ { @@ -601,7 +601,7 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { Title: "Test Tool", }, }, - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("test"), nil }, }, diff --git a/server/typed_tools_handler_func.go b/server/typed_tools_handler_func.go new file mode 100644 index 000000000..c58c95775 --- /dev/null +++ b/server/typed_tools_handler_func.go @@ -0,0 +1,21 @@ +package server + +import ( + "context" + "fmt" + "github.com/mark3labs/mcp-go/mcp" +) + +// TypedToolHandlerFunc is a function that handles a tool call with typed arguments +type TypedToolHandlerFunc[T any] func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest, args T) (*mcp.CallToolResult, error) + +// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct +func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var args T + if err := request.BindArguments(&args); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil + } + return handler(ctx, requestSession, request, args) + } +} diff --git a/mcp/typed_tools_test.go b/server/typed_tools_handler_func_test.go similarity index 66% rename from mcp/typed_tools_test.go rename to server/typed_tools_handler_func_test.go index d78d47028..32873140e 100644 --- a/mcp/typed_tools_test.go +++ b/server/typed_tools_handler_func_test.go @@ -1,9 +1,10 @@ -package mcp +package server import ( "context" "encoding/json" "fmt" + "github.com/mark3labs/mcp-go/mcp" "testing" "github.com/stretchr/testify/assert" @@ -18,15 +19,15 @@ func TestTypedToolHandler(t *testing.T) { } // Create a typed handler function - typedHandler := func(ctx context.Context, request CallToolRequest, args HelloArgs) (*CallToolResult, error) { - return NewToolResultText(args.Name), nil + typedHandler := func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest, args HelloArgs) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText(args.Name), nil } // Create a wrapped handler wrappedHandler := NewTypedToolHandler(typedHandler) // Create a test request - req := CallToolRequest{} + req := mcp.CallToolRequest{} req.Params.Name = "test-tool" req.Params.Arguments = map[string]any{ "name": "John Doe", @@ -35,12 +36,12 @@ func TestTypedToolHandler(t *testing.T) { } // Call the wrapped handler - result, err := wrappedHandler(context.Background(), req) + result, err := wrappedHandler(context.Background(), RequestSession{}, req) // Verify results assert.NoError(t, err) assert.NotNil(t, result) - assert.Equal(t, "John Doe", result.Content[0].(TextContent).Text) + assert.Equal(t, "John Doe", result.Content[0].(mcp.TextContent).Text) // Test with invalid arguments req.Params.Arguments = map[string]any{ @@ -50,7 +51,7 @@ func TestTypedToolHandler(t *testing.T) { } // This should still work because of type conversion - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) @@ -62,14 +63,14 @@ func TestTypedToolHandler(t *testing.T) { } // This should still work but name will be empty - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) - assert.Equal(t, "", result.Content[0].(TextContent).Text) + assert.Equal(t, "", result.Content[0].(mcp.TextContent).Text) // Test with completely invalid arguments req.Params.Arguments = "not a map" - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) // Error is wrapped in the result assert.NotNil(t, result) assert.True(t, result.IsError) @@ -84,10 +85,10 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { } // Create a typed handler function with validation - typedHandler := func(ctx context.Context, request CallToolRequest, args CalculatorArgs) (*CallToolResult, error) { + typedHandler := func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest, args CalculatorArgs) (*mcp.CallToolResult, error) { // Validate operation if args.Operation == "" { - return NewToolResultError("operation is required"), nil + return mcp.NewToolResultError("operation is required"), nil } var result float64 @@ -100,21 +101,21 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { result = args.X * args.Y case "divide": if args.Y == 0 { - return NewToolResultError("division by zero"), nil + return mcp.NewToolResultError("division by zero"), nil } result = args.X / args.Y default: - return NewToolResultError("invalid operation"), nil + return mcp.NewToolResultError("invalid operation"), nil } - return NewToolResultText(fmt.Sprintf("%.0f", result)), nil + return mcp.NewToolResultText(fmt.Sprintf("%.0f", result)), nil } // Create a wrapped handler wrappedHandler := NewTypedToolHandler(typedHandler) // Create a test request - req := CallToolRequest{} + req := mcp.CallToolRequest{} req.Params.Name = "calculator" req.Params.Arguments = map[string]any{ "operation": "add", @@ -123,12 +124,12 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { } // Call the wrapped handler - result, err := wrappedHandler(context.Background(), req) + result, err := wrappedHandler(context.Background(), RequestSession{}, req) // Verify results assert.NoError(t, err) assert.NotNil(t, result) - assert.Equal(t, "16", result.Content[0].(TextContent).Text) + assert.Equal(t, "16", result.Content[0].(mcp.TextContent).Text) // Test division by zero req.Params.Arguments = map[string]any{ @@ -137,11 +138,11 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { "y": 0.0, } - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) - assert.Contains(t, result.Content[0].(TextContent).Text, "division by zero") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "division by zero") } func TestTypedToolHandlerWithComplexObjects(t *testing.T) { @@ -170,13 +171,13 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { } // Create a typed handler function - typedHandler := func(ctx context.Context, request CallToolRequest, profile UserProfile) (*CallToolResult, error) { + typedHandler := func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest, profile UserProfile) (*mcp.CallToolResult, error) { // Validate required fields if profile.Name == "" { - return NewToolResultError("name is required"), nil + return mcp.NewToolResultError("name is required"), nil } if profile.Email == "" { - return NewToolResultError("email is required"), nil + return mcp.NewToolResultError("email is required"), nil } // Build a response that includes nested object data @@ -210,14 +211,14 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { response += fmt.Sprintf(", Tags: %v", profile.Tags) } - return NewToolResultText(response), nil + return mcp.NewToolResultText(response), nil } // Create a wrapped handler wrappedHandler := NewTypedToolHandler(typedHandler) // Test with complete complex object - req := CallToolRequest{} + req := mcp.CallToolRequest{} req.Params.Name = "user_profile" req.Params.Arguments = map[string]any{ "name": "John Doe", @@ -239,16 +240,16 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { } // Call the wrapped handler - result, err := wrappedHandler(context.Background(), req) + result, err := wrappedHandler(context.Background(), RequestSession{}, req) // Verify results assert.NoError(t, err) assert.NotNil(t, result) - assert.Contains(t, result.Content[0].(TextContent).Text, "John Doe") - assert.Contains(t, result.Content[0].(TextContent).Text, "San Francisco, USA") - assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: dark") - assert.Contains(t, result.Content[0].(TextContent).Text, "Subscribed to 2 newsletters") - assert.Contains(t, result.Content[0].(TextContent).Text, "Tags: [premium early_adopter]") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "John Doe") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "San Francisco, USA") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Theme: dark") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Subscribed to 2 newsletters") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Tags: [premium early_adopter]") // Test with partial data (missing some nested fields) req.Params.Arguments = map[string]any{ @@ -265,13 +266,13 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { }, } - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) - assert.Contains(t, result.Content[0].(TextContent).Text, "Jane Smith") - assert.Contains(t, result.Content[0].(TextContent).Text, "London, UK") - assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: light") - assert.NotContains(t, result.Content[0].(TextContent).Text, "newsletters") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Jane Smith") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "London, UK") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Theme: light") + assert.NotContains(t, result.Content[0].(mcp.TextContent).Text, "newsletters") // Test with JSON string input (simulating raw JSON from client) jsonInput := `{ @@ -294,11 +295,11 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { }` req.Params.Arguments = json.RawMessage(jsonInput) - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) - assert.Contains(t, result.Content[0].(TextContent).Text, "Bob Johnson") - assert.Contains(t, result.Content[0].(TextContent).Text, "New York, USA") - assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: system") - assert.Contains(t, result.Content[0].(TextContent).Text, "Subscribed to 1 newsletters") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Bob Johnson") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "New York, USA") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Theme: system") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Subscribed to 1 newsletters") }