From 7e2dd543c02e67e7415151b089e2a7d668ded47e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Sun, 25 May 2025 02:46:32 +0800 Subject: [PATCH 01/15] add RequestContext to support logging, progress, cancellation notification --- client/inprocess_test.go | 6 +- client/sse_test.go | 4 +- examples/custom_context/main.go | 1 + examples/dynamic_path/main.go | 2 +- examples/everything/main.go | 10 ++ examples/typed_tools/main.go | 16 +-- mcp/typed_tools.go | 20 --- mcp/types.go | 45 ++++++- mcp/utils.go | 2 +- mcptest/mcptest_test.go | 2 +- server/request_context.go | 118 ++++++++++++++++++ server/request_handler.go | 2 + server/resource_test.go | 10 +- server/server.go | 50 ++++---- server/server_race_test.go | 20 +-- server/server_test.go | 68 +++++----- server/session.go | 7 +- server/session_test.go | 45 +++++-- server/sse.go | 13 +- server/sse_test.go | 6 +- server/stdio.go | 19 +-- server/stdio_test.go | 2 +- server/streamable_http.go | 20 +++ server/streamable_http_test.go | 4 +- server/typed_tools_handler_func.go | 21 ++++ .../typed_tools_handler_func_test.go | 111 ++++++++-------- 26 files changed, 425 insertions(+), 199 deletions(-) delete mode 100644 mcp/typed_tools.go create mode 100644 server/request_context.go create mode 100644 server/typed_tools_handler_func.go rename mcp/typed_tools_test.go => server/typed_tools_handler_func_test.go (63%) diff --git a/client/inprocess_test.go b/client/inprocess_test.go index 7b150e81e..cdf426ccb 100644 --- a/client/inprocess_test.go +++ b/client/inprocess_test.go @@ -27,7 +27,7 @@ func TestInProcessMCPClient(t *testing.T) { mcp.WithDestructiveHintAnnotation(false), mcp.WithIdempotentHintAnnotation(true), mcp.WithOpenWorldHintAnnotation(false), - ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -48,7 +48,7 @@ func TestInProcessMCPClient(t *testing.T) { URI: "resource://testresource", Name: "My Resource", }, - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, reqContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "resource://testresource", @@ -70,7 +70,7 @@ func TestInProcessMCPClient(t *testing.T) { }, }, }, - func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + func(ctx context.Context, reqContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Messages: []mcp.PromptMessage{ { diff --git a/client/sse_test.go b/client/sse_test.go index f38c31b17..3a61de638 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -39,7 +39,7 @@ func TestSSEMCPClient(t *testing.T) { mcp.WithDestructiveHintAnnotation(false), mcp.WithIdempotentHintAnnotation(true), mcp.WithOpenWorldHintAnnotation(false), - ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -52,7 +52,7 @@ func TestSSEMCPClient(t *testing.T) { mcpServer.AddTool(mcp.NewTool( "test-tool-for-http-header", mcp.WithDescription("Test tool for http header"), - ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // , X-Test-Header-Func return &mcp.CallToolResult{ Content: []mcp.Content{ diff --git a/examples/custom_context/main.go b/examples/custom_context/main.go index e41ab8db7..1232c6676 100644 --- a/examples/custom_context/main.go +++ b/examples/custom_context/main.go @@ -79,6 +79,7 @@ func makeRequest(ctx context.Context, message, token string) (*response, error) // using the token from the context. func handleMakeAuthenticatedRequestTool( ctx context.Context, + reqContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { message, ok := request.GetArguments()["message"].(string) diff --git a/examples/dynamic_path/main.go b/examples/dynamic_path/main.go index 80d96789a..52da36756 100644 --- a/examples/dynamic_path/main.go +++ b/examples/dynamic_path/main.go @@ -19,7 +19,7 @@ func main() { mcpServer := server.NewMCPServer("dynamic-path-example", "1.0.0") // Add a trivial tool for demonstration - mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, reqContext server.RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.GetArguments()["message"])), nil }) diff --git a/examples/everything/main.go b/examples/everything/main.go index 5489220c3..a4bcc3ffe 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -191,6 +191,7 @@ func generateResources() []mcp.Resource { func handleReadResource( ctx context.Context, + reqContext server.RequestContext, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ @@ -204,6 +205,7 @@ func handleReadResource( func handleResourceTemplate( ctx context.Context, + reqContext server.RequestContext, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ @@ -217,6 +219,7 @@ func handleResourceTemplate( func handleGeneratedResource( ctx context.Context, + reqContext server.RequestContext, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { uri := request.Params.URI @@ -254,6 +257,7 @@ func handleGeneratedResource( func handleSimplePrompt( ctx context.Context, + reqContext server.RequestContext, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ @@ -272,6 +276,7 @@ func handleSimplePrompt( func handleComplexPrompt( ctx context.Context, + reqContext server.RequestContext, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { arguments := request.Params.Arguments @@ -310,6 +315,7 @@ func handleComplexPrompt( func handleEchoTool( ctx context.Context, + reqContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -329,6 +335,7 @@ func handleEchoTool( func handleAddTool( ctx context.Context, + reqContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -350,6 +357,7 @@ func handleAddTool( func handleSendNotification( ctx context.Context, + reqContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { @@ -380,6 +388,7 @@ func handleSendNotification( func handleLongRunningOperationTool( ctx context.Context, + reqContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -445,6 +454,7 @@ func handleLongRunningOperationTool( func handleGetTinyImageTool( ctx context.Context, + reqContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ diff --git a/examples/typed_tools/main.go b/examples/typed_tools/main.go index 5c49fed85..da7f21e38 100644 --- a/examples/typed_tools/main.go +++ b/examples/typed_tools/main.go @@ -64,7 +64,7 @@ func main() { ) // Add tool handler using the typed handler - s.AddTool(tool, mcp.NewTypedToolHandler(typedGreetingHandler)) + s.AddTool(tool, server.NewTypedToolHandler(typedGreetingHandler)) // Start the stdio server if err := server.ServeStdio(s); err != nil { @@ -73,33 +73,33 @@ func main() { } // Our typed handler function that receives strongly-typed arguments -func typedGreetingHandler(ctx context.Context, request mcp.CallToolRequest, args GreetingArgs) (*mcp.CallToolResult, error) { +func typedGreetingHandler(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest, args GreetingArgs) (*mcp.CallToolResult, error) { if args.Name == "" { return mcp.NewToolResultError("name is required"), nil } // Build a personalized greeting based on the complex arguments greeting := fmt.Sprintf("Hello, %s!", args.Name) - + if args.Age > 0 { greeting += fmt.Sprintf(" You are %d years old.", args.Age) } - + if args.IsVIP { greeting += " Welcome back, valued VIP customer!" } - + if len(args.Languages) > 0 { greeting += fmt.Sprintf(" You speak %d languages: %v.", len(args.Languages), args.Languages) } - + if args.Metadata.Location != "" { greeting += fmt.Sprintf(" I see you're from %s.", args.Metadata.Location) - + if args.Metadata.Timezone != "" { greeting += fmt.Sprintf(" Your timezone is %s.", args.Metadata.Timezone) } } return mcp.NewToolResultText(greeting), nil -} \ No newline at end of file +} diff --git a/mcp/typed_tools.go b/mcp/typed_tools.go deleted file mode 100644 index 68d8cdd1f..000000000 --- a/mcp/typed_tools.go +++ /dev/null @@ -1,20 +0,0 @@ -package mcp - -import ( - "context" - "fmt" -) - -// TypedToolHandlerFunc is a function that handles a tool call with typed arguments -type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error) - -// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct -func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { - return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { - var args T - if err := request.BindArguments(&args); err != nil { - return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil - } - return handler(ctx, request, args) - } -} diff --git a/mcp/types.go b/mcp/types.go index 9bbf16f27..c6ca3ff6f 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -5,8 +5,8 @@ package mcp import ( "encoding/json" "fmt" - "strconv" "maps" + "strconv" "github.com/yosida95/uritemplate/v3" ) @@ -56,17 +56,29 @@ const ( // MethodNotificationResourcesListChanged notifies when the list of available resources changes. // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification - MethodNotificationResourcesListChanged = "notifications/resources/list_changed" + MethodNotificationResourcesListChanged MCPMethod = "notifications/resources/list_changed" - MethodNotificationResourceUpdated = "notifications/resources/updated" + MethodNotificationResourceUpdated MCPMethod = "notifications/resources/updated" // MethodNotificationPromptsListChanged notifies when the list of available prompt templates changes. // https://modelcontextprotocol.io/specification/2025-03-26/server/prompts#list-changed-notification - MethodNotificationPromptsListChanged = "notifications/prompts/list_changed" + MethodNotificationPromptsListChanged MCPMethod = "notifications/prompts/list_changed" // MethodNotificationToolsListChanged notifies when the list of available tools changes. // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/list_changed/ - MethodNotificationToolsListChanged = "notifications/tools/list_changed" + MethodNotificationToolsListChanged MCPMethod = "notifications/tools/list_changed" + + // MethodNotificationMessage notifies when severs send log messages. + // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#log-message-notifications + MethodNotificationMessage MCPMethod = "notifications/message" + + // MethodNotificationProgress notifies progress updates for long-running operations. + // https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/progress + MethodNotificationProgress MCPMethod = "notifications/progress" + + // MethodNotificationCancellation cancel a previously-issued request + // https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation + MethodNotificationCancellation MCPMethod = "notifications/cancelled" ) type URITemplate struct { @@ -734,6 +746,29 @@ const ( LoggingLevelEmergency LoggingLevel = "emergency" ) +var ( + levelToSeverity = func() map[LoggingLevel]int { + return map[LoggingLevel]int{ + LoggingLevelEmergency: 0, + LoggingLevelAlert: 1, + LoggingLevelCritical: 2, + LoggingLevelError: 3, + LoggingLevelWarning: 4, + LoggingLevelNotice: 5, + LoggingLevelInfo: 6, + LoggingLevelDebug: 7, + } + }() +) + +// Allows is a helper function that decides a message could be sent to client or not according to the logging level +func (subscribedLevel LoggingLevel) Allows(currentLevel LoggingLevel) (bool, error) { + if _, ok := levelToSeverity[currentLevel]; !ok { + return false, fmt.Errorf("illegal message logging level:%s", currentLevel) + } + return levelToSeverity[subscribedLevel] >= levelToSeverity[currentLevel], nil +} + /* Sampling */ // CreateMessageRequest is a request from the server to sample an LLM via the diff --git a/mcp/utils.go b/mcp/utils.go index d6d42b7e4..d16850d11 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -141,7 +141,7 @@ func NewProgressNotification( ) ProgressNotification { notification := ProgressNotification{ Notification: Notification{ - Method: "notifications/progress", + Method: string(MethodNotificationProgress), }, Params: struct { ProgressToken ProgressToken `json:"progressToken"` diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go index b71b93eaa..5e44f0f41 100644 --- a/mcptest/mcptest_test.go +++ b/mcptest/mcptest_test.go @@ -50,7 +50,7 @@ func TestServer(t *testing.T) { } } -func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func helloWorldHandler(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Extract name from request arguments name, ok := request.GetArguments()["name"].(string) if !ok { diff --git a/server/request_context.go b/server/request_context.go new file mode 100644 index 000000000..bc983d043 --- /dev/null +++ b/server/request_context.go @@ -0,0 +1,118 @@ +package server + +import ( + "context" + "fmt" + "github.com/mark3labs/mcp-go/mcp" +) + +type requestIDKey struct{} + +// RequestContext represents an exchange with MCP client. The exchange provides +// methods to interact with the client and query its capabilities. +type RequestContext struct { + mcpServer *MCPServer + + progressToken *mcp.ProgressToken +} + +func NewRequestContext(mcpServer *MCPServer, requestParamMeta *mcp.Meta) RequestContext { + requestContext := RequestContext{ + mcpServer: mcpServer, + } + + // server should send progress notification if request metadata includes a progressToken + if requestParamMeta != nil && requestParamMeta.ProgressToken != nil { + requestContext.progressToken = &requestParamMeta.ProgressToken + } + + return requestContext +} + +// IsLoggingNotificationSupported returns true if server supports logging notification +func (exchange *RequestContext) IsLoggingNotificationSupported() bool { + return exchange.mcpServer != nil && exchange.mcpServer.capabilities.logging != nil && *exchange.mcpServer.capabilities.logging +} + +// SendLoggingNotification send logging notification to client. +// If server does not support logging notification, this method will do nothing. +func (exchange *RequestContext) SendLoggingNotification(ctx context.Context, level mcp.LoggingLevel, message map[string]any) error { + if !exchange.IsLoggingNotificationSupported() { + return nil + } + + clientLogLevel := ClientSessionFromContext(ctx).GetLogLevel() + allowed, err := clientLogLevel.Allows(level) + if err != nil { + return err + } + if !allowed { + return nil + } + + params := map[string]any{ + "level": level, + "message": message, + } + if ClientSessionFromContext(ctx).GetLoggerName() != nil { + params["logger"] = *ClientSessionFromContext(ctx).GetLoggerName() + } + + return exchange.mcpServer.SendNotificationToClient( + ctx, + string(mcp.MethodNotificationMessage), + params, + ) +} + +// SendProgressNotification send progress notification only if the client has requested progress +func (exchange *RequestContext) SendProgressNotification(ctx context.Context, progress, total *float64, message *string) error { + if exchange.progressToken == nil { + return nil + } + + params := map[string]any{ + "progress": progress, + "progressToken": *exchange.progressToken, + } + if total != nil { + params["total"] = *total + } + if message != nil { + params["message"] = *message + } + + return exchange.mcpServer.SendNotificationToClient( + ctx, + string(mcp.MethodNotificationProgress), + params, + ) +} + +// SendCancellationNotification send cancellation notification to client +func (exchange *RequestContext) SendCancellationNotification(ctx context.Context, reason *string) error { + requestIDRawValue := ctx.Value(requestIDKey{}) + if requestIDRawValue == nil { + return fmt.Errorf("invalid requestID") + } + + requestID, ok := requestIDRawValue.(mcp.RequestId) + if !ok { + return fmt.Errorf("invalid requestID type") + } + + params := map[string]any{ + "requestId": requestID.Value(), + } + if reason != nil { + params["reason"] = reason + } + + return exchange.mcpServer.SendNotificationToClient( + ctx, + string(mcp.MethodNotificationCancellation), + params, + ) +} + +// TODO should implement other methods like 'roots/list', this still could happen when server handle client request diff --git a/server/request_handler.go b/server/request_handler.go index 25f6ef14f..519799b84 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -54,6 +54,8 @@ func (s *MCPServer) HandleMessage( } s.handleNotification(ctx, notification) return nil // Return nil for notifications + } else { + ctx = context.WithValue(ctx, requestIDKey{}, mcp.NewRequestId(baseMessage.ID)) } if baseMessage.Result != nil { diff --git a/server/resource_test.go b/server/resource_test.go index 05a3b2793..e32cb2675 100644 --- a/server/resource_test.go +++ b/server/resource_test.go @@ -28,7 +28,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", @@ -47,7 +47,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 2"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource2", @@ -84,7 +84,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { expectedNotifications: 1, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, resourcesList mcp.JSONRPCMessage) { // Check that we received a list_changed notification - assert.Equal(t, mcp.MethodNotificationResourcesListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationResourcesListChanged), notifications[0].Method) // Verify we now have only one resource resp, ok := resourcesList.(mcp.JSONRPCResponse) @@ -108,7 +108,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", @@ -164,7 +164,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", diff --git a/server/server.go b/server/server.go index b913b1f70..ac796502f 100644 --- a/server/server.go +++ b/server/server.go @@ -28,16 +28,16 @@ type resourceTemplateEntry struct { type ServerOption func(*MCPServer) // ResourceHandlerFunc is a function that returns resource contents. -type ResourceHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) +type ResourceHandlerFunc func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) // ResourceTemplateHandlerFunc is a function that returns a resource template. -type ResourceTemplateHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) +type ResourceTemplateHandlerFunc func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) // PromptHandlerFunc handles prompt requests with given arguments. -type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) +type PromptHandlerFunc func(ctx context.Context, requestContext RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) // ToolHandlerFunc handles tool calls with given arguments. -type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) +type ToolHandlerFunc func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) // ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc @@ -217,7 +217,7 @@ func WithToolFilter( // WithRecovery adds a middleware that recovers from panics in tool handlers. func WithRecovery() ServerOption { return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc { - return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { + return func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf( @@ -227,7 +227,7 @@ func WithRecovery() ServerOption { ) } }() - return next(ctx, request) + return next(ctx, requestContext, request) } }) } @@ -332,7 +332,7 @@ func (s *MCPServer) AddResource( // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification if s.capabilities.resources.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationResourcesListChanged), nil) } } @@ -347,7 +347,7 @@ func (s *MCPServer) RemoveResource(uri string) { // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationResourcesListChanged), nil) } } @@ -379,7 +379,7 @@ func (s *MCPServer) AddResourceTemplate( // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification if s.capabilities.resources.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationResourcesListChanged), nil) } } @@ -406,7 +406,7 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification. if s.capabilities.prompts.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationPromptsListChanged), nil) } } @@ -426,7 +426,7 @@ func (s *MCPServer) DeletePrompts(names ...string) { // Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationPromptsListChanged), nil) } } @@ -468,7 +468,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) { // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. if s.capabilities.tools.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationToolsListChanged), nil) } } @@ -495,7 +495,7 @@ func (s *MCPServer) DeleteTools(names ...string) { // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. if exists && s.capabilities.tools != nil && s.capabilities.tools.listChanged { // Send notification to all initialized sessions - s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + s.SendNotificationToAllClients(string(mcp.MethodNotificationToolsListChanged), nil) } } @@ -587,15 +587,6 @@ func (s *MCPServer) handleSetLevel( } } - sessionLogging, ok := clientSession.(SessionWithLogging) - if !ok { - return nil, &requestError{ - id: id, - code: mcp.INTERNAL_ERROR, - err: ErrSessionDoesNotSupportLogging, - } - } - level := request.Params.Level // Validate logging level switch level { @@ -611,7 +602,7 @@ func (s *MCPServer) handleSetLevel( } } - sessionLogging.SetLogLevel(level) + clientSession.SetLogLevel(level) return &mcp.EmptyResult{}, nil } @@ -731,12 +722,15 @@ func (s *MCPServer) handleReadResource( id any, request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, *requestError) { + + requestContext := NewRequestContext(s, request.Request.Params.Meta) + s.resourcesMu.RLock() // First try direct resource handlers if entry, ok := s.resources[request.Params.URI]; ok { handler := entry.handler s.resourcesMu.RUnlock() - contents, err := handler(ctx, request) + contents, err := handler(ctx, requestContext, request) if err != nil { return nil, &requestError{ id: id, @@ -767,7 +761,7 @@ func (s *MCPServer) handleReadResource( s.resourcesMu.RUnlock() if matched { - contents, err := matchedHandler(ctx, request) + contents, err := matchedHandler(ctx, requestContext, request) if err != nil { return nil, &requestError{ id: id, @@ -849,7 +843,8 @@ func (s *MCPServer) handleGetPrompt( } } - result, err := handler(ctx, request) + requestContext := NewRequestContext(s, request.Request.Params.Meta) + result, err := handler(ctx, requestContext, request) if err != nil { return nil, &requestError{ id: id, @@ -999,7 +994,8 @@ func (s *MCPServer) handleToolCall( finalHandler = mw[i](finalHandler) } - result, err := finalHandler(ctx, request) + requestContext := NewRequestContext(s, request.Request.Params.Meta) + result, err := finalHandler(ctx, requestContext, request) if err != nil { return nil, &requestError{ id: id, diff --git a/server/server_race_test.go b/server/server_race_test.go index 4e0be43a8..17354609c 100644 --- a/server/server_race_test.go +++ b/server/server_race_test.go @@ -39,7 +39,7 @@ func TestRaceConditions(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: name, Description: "Test prompt", - }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, reqContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) }) @@ -49,7 +49,7 @@ func TestRaceConditions(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: name, Description: "Temporary prompt", - }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, reqContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) srv.DeletePrompts(name) @@ -60,7 +60,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: name, Description: "Test tool", - }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) }) @@ -71,7 +71,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: name, Description: "Temporary tool", - }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) srv.DeleteTools(name) @@ -79,8 +79,8 @@ func TestRaceConditions(t *testing.T) { runConcurrentOperation(&wg, testDuration, "add-middleware", func() { middleware := func(next ToolHandlerFunc) ToolHandlerFunc { - return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return next(ctx, req) + return func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return next(ctx, reqContext, req) } } WithToolHandlerMiddleware(middleware)(srv) @@ -102,7 +102,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: "persistent-tool", Description: "Test tool that always exists", - }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) @@ -121,7 +121,7 @@ func TestRaceConditions(t *testing.T) { URI: uri, Name: uri, Description: "Test resource", - }, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + }, func(ctx context.Context, reqContext RequestContext, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: uri, @@ -169,12 +169,12 @@ func TestConcurrentPromptAdd(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: "initial-prompt", Description: "Initial prompt", - }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, reqContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { go func() { srv.AddPrompt(mcp.Prompt{ Name: fmt.Sprintf("new-prompt-%d", time.Now().UnixNano()), Description: "Added from handler", - }, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, reqContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) }() diff --git a/server/server_test.go b/server/server_test.go index c0ececc92..439d01a5c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -160,12 +160,12 @@ func TestMCPServer_Tools(t *testing.T) { action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -189,19 +189,19 @@ func TestMCPServer_Tools(t *testing.T) { require.NoError(t, err) server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) }, expectedNotifications: 1, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[0].Method) tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -230,12 +230,12 @@ func TestMCPServer_Tools(t *testing.T) { } server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -243,7 +243,7 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 5, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { for _, notification := range notifications { - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notification.Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notification.Method) } tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) @@ -262,21 +262,21 @@ func TestMCPServer_Tools(t *testing.T) { require.NoError(t, err) server.AddTool( mcp.NewTool("test-tool-1"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) server.AddTool( mcp.NewTool("test-tool-2"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) }, expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[1].Method) tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools assert.Len(t, tools, 2) assert.Equal(t, "test-tool-1", tools[0].Name) @@ -300,9 +300,9 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { // One for SetTools - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[0].Method) // One for DeleteTools - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[1].Method) // Expect a successful response with an empty list of tools resp, ok := toolsList.(mcp.JSONRPCResponse) @@ -333,7 +333,7 @@ func TestMCPServer_Tools(t *testing.T) { expectedNotifications: 1, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, toolsList mcp.JSONRPCMessage) { // Only one notification expected for SetTools - assert.Equal(t, mcp.MethodNotificationToolsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationToolsListChanged), notifications[0].Method) // Confirm the tool list does not change tools := toolsList.(mcp.JSONRPCResponse).Result.(mcp.ListToolsResult).Tools @@ -706,7 +706,7 @@ func TestMCPServer_PromptHandling(t *testing.T) { server.AddPrompt( testPrompt, - func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Messages: []mcp.PromptMessage{ { @@ -843,9 +843,9 @@ func TestMCPServer_Prompts(t *testing.T) { expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { // One for AddPrompt - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[0].Method) // One for DeletePrompts - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[1].Method) // Expect a successful response with an empty list of prompts resp, ok := promptsList.(mcp.JSONRPCResponse) @@ -898,11 +898,11 @@ func TestMCPServer_Prompts(t *testing.T) { expectedNotifications: 3, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { // first notification expected for AddPrompt test-prompt-1 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[0].Method) // second notification expected for AddPrompt test-prompt-2 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[1].Method) // second notification expected for DeletePrompts test-prompt-1 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[2].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[2].Method) // Confirm the prompt list does not change prompts := promptsList.(mcp.JSONRPCResponse).Result.(mcp.ListPromptsResult).Prompts @@ -951,9 +951,9 @@ func TestMCPServer_Prompts(t *testing.T) { expectedNotifications: 2, validate: func(t *testing.T, notifications []mcp.JSONRPCNotification, promptsList mcp.JSONRPCMessage) { // first notification expected for AddPrompt test-prompt-1 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[0].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[0].Method) // second notification expected for AddPrompt test-prompt-2 - assert.Equal(t, mcp.MethodNotificationPromptsListChanged, notifications[1].Method) + assert.Equal(t, string(mcp.MethodNotificationPromptsListChanged), notifications[1].Method) // Confirm the prompt list does not change prompts := promptsList.(mcp.JSONRPCResponse).Result.(mcp.ListPromptsResult).Prompts @@ -1121,7 +1121,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { IdempotentHint: mcp.ToBoolPtr(false), OpenWorldHint: mcp.ToBoolPtr(false), }, - }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) @@ -1383,7 +1383,7 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { "test://{a}/test-resource{/b*}", "My Resource", ), - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { a := request.Params.Arguments["a"].([]string) b := request.Params.Arguments["b"].([]string) // Validate that the template arguments are passed correctly to the handler @@ -1470,7 +1470,7 @@ func createTestServer() *MCPServer { URI: "resource://testresource", Name: "My Resource", }, - func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "resource://testresource", @@ -1486,7 +1486,7 @@ func createTestServer() *MCPServer { Name: "test-tool", Description: "Test tool", }, - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -1522,6 +1522,16 @@ func (f fakeSession) Initialized() bool { return f.initialized } +func (f fakeSession) SetLogLevel(level mcp.LoggingLevel) {} + +func (f fakeSession) GetLogLevel() mcp.LoggingLevel { + return mcp.LoggingLevelError +} + +func (f fakeSession) GetLoggerName() *string { + return nil +} + var _ ClientSession = fakeSession{} func TestMCPServer_WithHooks(t *testing.T) { @@ -1622,7 +1632,7 @@ func TestMCPServer_WithHooks(t *testing.T) { // Add a test tool server.AddTool( mcp.NewTool("test-tool"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) @@ -1793,7 +1803,7 @@ func TestMCPServer_SessionHooks_NilHooks(t *testing.T) { } func TestMCPServer_WithRecover(t *testing.T) { - panicToolHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + panicToolHandler := func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { panic("test panic") } diff --git a/server/session.go b/server/session.go index 0c50a260e..f3ea8d36e 100644 --- a/server/session.go +++ b/server/session.go @@ -17,15 +17,12 @@ type ClientSession interface { NotificationChannel() chan<- mcp.JSONRPCNotification // SessionID is a unique identifier used to track user session. SessionID() string -} - -// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level -type SessionWithLogging interface { - ClientSession // SetLogLevel sets the minimum log level SetLogLevel(level mcp.LoggingLevel) // GetLogLevel retrieves the minimum log level GetLogLevel() mcp.LoggingLevel + // GetLoggerName returns the logger name + GetLoggerName() *string } // SessionWithTools is an extension of ClientSession that can store session-specific tool data diff --git a/server/session_test.go b/server/session_test.go index 8f2cfa763..409ca6feb 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -41,6 +41,17 @@ func (f sessionTestClient) Initialized() bool { return f.initialized } +func (f sessionTestClient) SetLogLevel(level mcp.LoggingLevel) { +} + +func (f sessionTestClient) GetLogLevel() mcp.LoggingLevel { + return mcp.LoggingLevelError +} + +func (f sessionTestClient) GetLoggerName() *string { + return nil +} + // sessionTestClientWithTools implements the SessionWithTools interface for testing type sessionTestClientWithTools struct { sessionID string @@ -66,6 +77,17 @@ func (f *sessionTestClientWithTools) Initialized() bool { return f.initialized } +func (f *sessionTestClientWithTools) SetLogLevel(level mcp.LoggingLevel) { +} + +func (f *sessionTestClientWithTools) GetLogLevel() mcp.LoggingLevel { + return mcp.LoggingLevelError +} + +func (f *sessionTestClientWithTools) GetLoggerName() *string { + return nil +} + func (f *sessionTestClientWithTools) GetSessionTools() map[string]ServerTool { f.mu.RLock() defer f.mu.RUnlock() @@ -104,7 +126,7 @@ type sessionTestClientWithLogging struct { sessionID string notificationChannel chan mcp.JSONRPCNotification initialized bool - loggingLevel atomic.Value + loggingLevel atomic.Value } func (f *sessionTestClientWithLogging) SessionID() string { @@ -134,11 +156,14 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (f *sessionTestClientWithLogging) GetLoggerName() *string { + return nil +} + // Verify that all implementations satisfy their respective interfaces var ( - _ ClientSession = (*sessionTestClient)(nil) - _ SessionWithTools = (*sessionTestClientWithTools)(nil) - _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) + _ ClientSession = (*sessionTestClient)(nil) + _ SessionWithTools = (*sessionTestClientWithTools)(nil) ) func TestSessionWithTools_Integration(t *testing.T) { @@ -147,7 +172,7 @@ func TestSessionWithTools_Integration(t *testing.T) { // Create session-specific tools sessionTool := ServerTool{ Tool: mcp.NewTool("session-tool"), - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("session-tool result"), nil }, } @@ -196,7 +221,7 @@ func TestSessionWithTools_Integration(t *testing.T) { require.NotNil(t, tool, "Session tool should not be nil") // Now test calling directly with the handler - result, err := tool.Handler(sessionCtx, testReq) + result, err := tool.Handler(sessionCtx, RequestContext{}, testReq) require.NoError(t, err, "No error calling session tool handler directly") require.NotNil(t, result, "Result should not be nil") require.Len(t, result.Content, 1, "Result should have one content item") @@ -316,7 +341,7 @@ func TestMCPServer_AddSessionTool(t *testing.T) { err = server.AddSessionTool( session.SessionID(), mcp.NewTool("session-tool-helper"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("helper result"), nil }, ) @@ -521,7 +546,7 @@ func TestMCPServer_CallSessionTool(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) // Add global tool - server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("global result"), nil }) @@ -541,7 +566,7 @@ func TestMCPServer_CallSessionTool(t *testing.T) { err = server.AddSessionTool( session.SessionID(), mcp.NewTool("test_tool"), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("session result"), nil }, ) @@ -1041,4 +1066,4 @@ func TestMCPServer_SetLevel(t *testing.T) { if session.GetLogLevel() != mcp.LoggingLevelCritical { t.Errorf("Expected critical level, got %v", session.GetLogLevel()) } -} \ No newline at end of file +} diff --git a/server/sse.go b/server/sse.go index d93fdb381..61dbc92f3 100644 --- a/server/sse.go +++ b/server/sse.go @@ -27,7 +27,9 @@ type sseSession struct { notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool loggingLevel atomic.Value - tools sync.Map // stores session-specific tools + // FIXME assign logger name in a proper way + loggerName *string + tools sync.Map // stores session-specific tools } // SSEContextFunc is a function that takes an existing context and the current @@ -72,6 +74,10 @@ func (s *sseSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (s *sseSession) GetLoggerName() *string { + return s.loggerName +} + func (s *sseSession) GetSessionTools() map[string]ServerTool { tools := make(map[string]ServerTool) s.tools.Range(func(key, value any) bool { @@ -94,9 +100,8 @@ func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { } var ( - _ ClientSession = (*sseSession)(nil) - _ SessionWithTools = (*sseSession)(nil) - _ SessionWithLogging = (*sseSession)(nil) + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) ) // SSEServer implements a Server-Sent Events (SSE) based MCP server. diff --git a/server/sse_test.go b/server/sse_test.go index 96912be49..92f4c1fd4 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -574,7 +574,7 @@ func TestSSEServer(t *testing.T) { WithResourceCapabilities(true, true), ) // Add a tool which uses the context function. - mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Note this is agnostic to the transport type i.e. doesn't know about request headers. testVal := testValFromContext(ctx) return mcp.NewToolResultText(testVal), nil @@ -1183,7 +1183,7 @@ func TestSSEServer(t *testing.T) { Title: "Test Tool", }, }, - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("test"), nil }, }, @@ -1335,7 +1335,7 @@ func TestSSEServer(t *testing.T) { processingCompleted := make(chan struct{}) processingStarted := make(chan struct{}) - mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { close(processingStarted) // signal for processing started select { diff --git a/server/stdio.go b/server/stdio.go index 0ebed6a9a..d1501df36 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -51,9 +51,11 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + // FIXME assign logger name in a proper way + loggerName *string } func (s *stdioSession) SessionID() string { @@ -74,11 +76,11 @@ func (s *stdioSession) Initialized() bool { return s.initialized.Load() } -func(s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { +func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { s.loggingLevel.Store(level) } -func(s *stdioSession) GetLogLevel() mcp.LoggingLevel { +func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { level := s.loggingLevel.Load() if level == nil { return mcp.LoggingLevelError @@ -86,9 +88,12 @@ func(s *stdioSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (f *stdioSession) GetLoggerName() *string { + return f.loggerName +} + var ( - _ ClientSession = (*stdioSession)(nil) - _ SessionWithLogging = (*stdioSession)(nil) + _ ClientSession = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ diff --git a/server/stdio_test.go b/server/stdio_test.go index 8433fd0ac..448f29913 100644 --- a/server/stdio_test.go +++ b/server/stdio_test.go @@ -140,7 +140,7 @@ func TestStdioServer(t *testing.T) { // Create server mcpServer := NewMCPServer("test", "1.0.0") // Add a tool which uses the context function. - mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Note this is agnostic to the transport type i.e. doesn't know about request headers. testVal := testValFromContext(ctx) return mcp.NewToolResultText(testVal), nil diff --git a/server/streamable_http.go b/server/streamable_http.go index b13577a8c..32e3cbab2 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -494,6 +495,9 @@ type streamableHttpSession struct { sessionID string notificationChannel chan mcp.JSONRPCNotification // server -> client notifications tools *sessionToolsStore + loggingLevel atomic.Value + // FIXME assign logger name in a proper way + loggerName *string } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { @@ -522,6 +526,22 @@ func (s *streamableHttpSession) Initialized() bool { return true } +func (s *streamableHttpSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *streamableHttpSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +func (s *streamableHttpSession) GetLoggerName() *string { + return s.loggerName +} + var _ ClientSession = (*streamableHttpSession)(nil) func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool { diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 9f48eade7..891ba0000 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -39,7 +39,7 @@ var initRequest = map[string]any{ func addSSETool(mcpServer *MCPServer) { mcpServer.AddTool(mcp.Tool{ Name: "sseTool", - }, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Send notification to client server := ServerFromContext(ctx) for i := 0; i < 10; i++ { @@ -601,7 +601,7 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { Title: "Test Tool", }, }, - Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("test"), nil }, }, diff --git a/server/typed_tools_handler_func.go b/server/typed_tools_handler_func.go new file mode 100644 index 000000000..8af35f5b5 --- /dev/null +++ b/server/typed_tools_handler_func.go @@ -0,0 +1,21 @@ +package server + +import ( + "context" + "fmt" + "github.com/mark3labs/mcp-go/mcp" +) + +// TypedToolHandlerFunc is a function that handles a tool call with typed arguments +type TypedToolHandlerFunc[T any] func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest, args T) (*mcp.CallToolResult, error) + +// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct +func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var args T + if err := request.BindArguments(&args); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil + } + return handler(ctx, requestContext, request, args) + } +} diff --git a/mcp/typed_tools_test.go b/server/typed_tools_handler_func_test.go similarity index 63% rename from mcp/typed_tools_test.go rename to server/typed_tools_handler_func_test.go index 109ade89c..8317538a3 100644 --- a/mcp/typed_tools_test.go +++ b/server/typed_tools_handler_func_test.go @@ -1,9 +1,10 @@ -package mcp +package server import ( "context" "encoding/json" "fmt" + "github.com/mark3labs/mcp-go/mcp" "testing" "github.com/stretchr/testify/assert" @@ -18,15 +19,15 @@ func TestTypedToolHandler(t *testing.T) { } // Create a typed handler function - typedHandler := func(ctx context.Context, request CallToolRequest, args HelloArgs) (*CallToolResult, error) { - return NewToolResultText(args.Name), nil + typedHandler := func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest, args HelloArgs) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText(args.Name), nil } // Create a wrapped handler wrappedHandler := NewTypedToolHandler(typedHandler) // Create a test request - req := CallToolRequest{} + req := mcp.CallToolRequest{} req.Params.Name = "test-tool" req.Params.Arguments = map[string]any{ "name": "John Doe", @@ -35,12 +36,12 @@ func TestTypedToolHandler(t *testing.T) { } // Call the wrapped handler - result, err := wrappedHandler(context.Background(), req) + result, err := wrappedHandler(context.Background(), RequestContext{}, req) // Verify results assert.NoError(t, err) assert.NotNil(t, result) - assert.Equal(t, "John Doe", result.Content[0].(TextContent).Text) + assert.Equal(t, "John Doe", result.Content[0].(mcp.TextContent).Text) // Test with invalid arguments req.Params.Arguments = map[string]any{ @@ -50,7 +51,7 @@ func TestTypedToolHandler(t *testing.T) { } // This should still work because of type conversion - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestContext{}, req) assert.NoError(t, err) assert.NotNil(t, result) @@ -62,14 +63,14 @@ func TestTypedToolHandler(t *testing.T) { } // This should still work but name will be empty - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestContext{}, req) assert.NoError(t, err) assert.NotNil(t, result) - assert.Equal(t, "", result.Content[0].(TextContent).Text) + assert.Equal(t, "", result.Content[0].(mcp.TextContent).Text) // Test with completely invalid arguments req.Params.Arguments = "not a map" - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestContext{}, req) assert.NoError(t, err) // Error is wrapped in the result assert.NotNil(t, result) assert.True(t, result.IsError) @@ -84,10 +85,10 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { } // Create a typed handler function with validation - typedHandler := func(ctx context.Context, request CallToolRequest, args CalculatorArgs) (*CallToolResult, error) { + typedHandler := func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest, args CalculatorArgs) (*mcp.CallToolResult, error) { // Validate operation if args.Operation == "" { - return NewToolResultError("operation is required"), nil + return mcp.NewToolResultError("operation is required"), nil } var result float64 @@ -100,21 +101,21 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { result = args.X * args.Y case "divide": if args.Y == 0 { - return NewToolResultError("division by zero"), nil + return mcp.NewToolResultError("division by zero"), nil } result = args.X / args.Y default: - return NewToolResultError("invalid operation"), nil + return mcp.NewToolResultError("invalid operation"), nil } - return NewToolResultText(fmt.Sprintf("%.0f", result)), nil + return mcp.NewToolResultText(fmt.Sprintf("%.0f", result)), nil } // Create a wrapped handler wrappedHandler := NewTypedToolHandler(typedHandler) // Create a test request - req := CallToolRequest{} + req := mcp.CallToolRequest{} req.Params.Name = "calculator" req.Params.Arguments = map[string]any{ "operation": "add", @@ -123,12 +124,12 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { } // Call the wrapped handler - result, err := wrappedHandler(context.Background(), req) + result, err := wrappedHandler(context.Background(), RequestContext{}, req) // Verify results assert.NoError(t, err) assert.NotNil(t, result) - assert.Equal(t, "16", result.Content[0].(TextContent).Text) + assert.Equal(t, "16", result.Content[0].(mcp.TextContent).Text) // Test division by zero req.Params.Arguments = map[string]any{ @@ -137,11 +138,11 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { "y": 0.0, } - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestContext{}, req) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) - assert.Contains(t, result.Content[0].(TextContent).Text, "division by zero") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "division by zero") } func TestTypedToolHandlerWithComplexObjects(t *testing.T) { @@ -160,64 +161,64 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { } type UserProfile struct { - Name string `json:"name"` - Email string `json:"email"` - Age int `json:"age"` - IsVerified bool `json:"is_verified"` - Address Address `json:"address"` + Name string `json:"name"` + Email string `json:"email"` + Age int `json:"age"` + IsVerified bool `json:"is_verified"` + Address Address `json:"address"` Preferences UserPreferences `json:"preferences"` - Tags []string `json:"tags"` + Tags []string `json:"tags"` } // Create a typed handler function - typedHandler := func(ctx context.Context, request CallToolRequest, profile UserProfile) (*CallToolResult, error) { + typedHandler := func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest, profile UserProfile) (*mcp.CallToolResult, error) { // Validate required fields if profile.Name == "" { - return NewToolResultError("name is required"), nil + return mcp.NewToolResultError("name is required"), nil } if profile.Email == "" { - return NewToolResultError("email is required"), nil + return mcp.NewToolResultError("email is required"), nil } // Build a response that includes nested object data response := fmt.Sprintf("User: %s (%s)", profile.Name, profile.Email) - + if profile.Age > 0 { response += fmt.Sprintf(", Age: %d", profile.Age) } - + if profile.IsVerified { response += ", Verified: Yes" } else { response += ", Verified: No" } - + // Include address information if available if profile.Address.City != "" && profile.Address.Country != "" { response += fmt.Sprintf(", Location: %s, %s", profile.Address.City, profile.Address.Country) } - + // Include preferences if available if profile.Preferences.Theme != "" { response += fmt.Sprintf(", Theme: %s", profile.Preferences.Theme) } - + if len(profile.Preferences.Newsletters) > 0 { response += fmt.Sprintf(", Subscribed to %d newsletters", len(profile.Preferences.Newsletters)) } - + if len(profile.Tags) > 0 { response += fmt.Sprintf(", Tags: %v", profile.Tags) } - - return NewToolResultText(response), nil + + return mcp.NewToolResultText(response), nil } // Create a wrapped handler wrappedHandler := NewTypedToolHandler(typedHandler) // Test with complete complex object - req := CallToolRequest{} + req := mcp.CallToolRequest{} req.Params.Name = "user_profile" req.Params.Arguments = map[string]any{ "name": "John Doe", @@ -239,16 +240,16 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { } // Call the wrapped handler - result, err := wrappedHandler(context.Background(), req) + result, err := wrappedHandler(context.Background(), RequestContext{}, req) // Verify results assert.NoError(t, err) assert.NotNil(t, result) - assert.Contains(t, result.Content[0].(TextContent).Text, "John Doe") - assert.Contains(t, result.Content[0].(TextContent).Text, "San Francisco, USA") - assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: dark") - assert.Contains(t, result.Content[0].(TextContent).Text, "Subscribed to 2 newsletters") - assert.Contains(t, result.Content[0].(TextContent).Text, "Tags: [premium early_adopter]") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "John Doe") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "San Francisco, USA") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Theme: dark") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Subscribed to 2 newsletters") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Tags: [premium early_adopter]") // Test with partial data (missing some nested fields) req.Params.Arguments = map[string]any{ @@ -265,13 +266,13 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { }, } - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestContext{}, req) assert.NoError(t, err) assert.NotNil(t, result) - assert.Contains(t, result.Content[0].(TextContent).Text, "Jane Smith") - assert.Contains(t, result.Content[0].(TextContent).Text, "London, UK") - assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: light") - assert.NotContains(t, result.Content[0].(TextContent).Text, "newsletters") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Jane Smith") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "London, UK") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Theme: light") + assert.NotContains(t, result.Content[0].(mcp.TextContent).Text, "newsletters") // Test with JSON string input (simulating raw JSON from client) jsonInput := `{ @@ -294,11 +295,11 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { }` req.Params.Arguments = json.RawMessage(jsonInput) - result, err = wrappedHandler(context.Background(), req) + result, err = wrappedHandler(context.Background(), RequestContext{}, req) assert.NoError(t, err) assert.NotNil(t, result) - assert.Contains(t, result.Content[0].(TextContent).Text, "Bob Johnson") - assert.Contains(t, result.Content[0].(TextContent).Text, "New York, USA") - assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: system") - assert.Contains(t, result.Content[0].(TextContent).Text, "Subscribed to 1 newsletters") -} \ No newline at end of file + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Bob Johnson") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "New York, USA") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Theme: system") + assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Subscribed to 1 newsletters") +} From 010bc7ad7488db3a2816e87500ad787fe834e47c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Sun, 25 May 2025 04:06:56 +0800 Subject: [PATCH 02/15] generate code --- server/internal/gen/request_handler.go.tmpl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/internal/gen/request_handler.go.tmpl b/server/internal/gen/request_handler.go.tmpl index 7e4a68a05..89e9d740e 100644 --- a/server/internal/gen/request_handler.go.tmpl +++ b/server/internal/gen/request_handler.go.tmpl @@ -55,7 +55,9 @@ func (s *MCPServer) HandleMessage( } s.handleNotification(ctx, notification) return nil // Return nil for notifications - } + } else { + ctx = context.WithValue(ctx, requestIDKey{}, mcp.NewRequestId(baseMessage.ID)) + } if baseMessage.Result != nil { // this is a response to a request sent by the server (e.g. from a ping From b5457f4bae56209abddedf19b77d2a2c097e86e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Sun, 25 May 2025 04:23:58 +0800 Subject: [PATCH 03/15] fix log key --- server/request_context.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/request_context.go b/server/request_context.go index bc983d043..011608d30 100644 --- a/server/request_context.go +++ b/server/request_context.go @@ -51,8 +51,8 @@ func (exchange *RequestContext) SendLoggingNotification(ctx context.Context, lev } params := map[string]any{ - "level": level, - "message": message, + "level": level, + "data": message, } if ClientSessionFromContext(ctx).GetLoggerName() != nil { params["logger"] = *ClientSessionFromContext(ctx).GetLoggerName() From 8ab7d004b96ca7ca82a30cab01ff64cc212fa461 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Sun, 25 May 2025 04:28:51 +0800 Subject: [PATCH 04/15] update progress type --- server/request_context.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/request_context.go b/server/request_context.go index 011608d30..6558dc11a 100644 --- a/server/request_context.go +++ b/server/request_context.go @@ -66,7 +66,7 @@ func (exchange *RequestContext) SendLoggingNotification(ctx context.Context, lev } // SendProgressNotification send progress notification only if the client has requested progress -func (exchange *RequestContext) SendProgressNotification(ctx context.Context, progress, total *float64, message *string) error { +func (exchange *RequestContext) SendProgressNotification(ctx context.Context, progress float64, total *float64, message *string) error { if exchange.progressToken == nil { return nil } From fd08508976f0f5df391f990fa2bf5c577ca4c7b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Sun, 25 May 2025 04:39:51 +0800 Subject: [PATCH 05/15] update readme --- README.md | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index a35a3ebe0..a459e6114 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ func main() { } } -func helloHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func helloHandler(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { name, err := request.RequireString("name") if err != nil { return mcp.NewToolResultError(err.Error()), nil @@ -92,6 +92,7 @@ MCP Go handles all the complex protocol details and server management, so you ca - [Resources](#resources) - [Tools](#tools) - [Prompts](#prompts) + - [RequestContext](#RequestContext) - [Examples](#examples) - [Extras](#extras) - [Transports](#transports) @@ -153,7 +154,7 @@ func main() { ) // Add the calculator handler - s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + s.AddTool(calculatorTool, func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Using helper functions for type-safe argument access op, err := request.RequireString("operation") if err != nil { @@ -251,7 +252,7 @@ resource := mcp.NewResource( ) // Add resource with its handler -s.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { +s.AddResource(resource, func(ctx context.Context, requestContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { content, err := os.ReadFile("README.md") if err != nil { return nil, err @@ -279,7 +280,7 @@ template := mcp.NewResourceTemplate( ) // Add template with its handler -s.AddResourceTemplate(template, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { +s.AddResourceTemplate(template, func(ctx context.Context, requestContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { // Extract ID from the URI using regex matching // The server automatically matches URIs to templates userID := extractIDFromURI(request.Params.URI) @@ -328,7 +329,7 @@ calculatorTool := mcp.NewTool("calculate", ), ) -s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +s.AddTool(calculatorTool, func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { args := request.GetArguments() op := args["operation"].(string) x := args["x"].(float64) @@ -372,7 +373,7 @@ httpTool := mcp.NewTool("http_request", ), ) -s.AddTool(httpTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +s.AddTool(httpTool, func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { args := request.GetArguments() method := args["method"].(string) url := args["url"].(string) @@ -440,7 +441,7 @@ s.AddPrompt(mcp.NewPrompt("greeting", mcp.WithArgument("name", mcp.ArgumentDescription("Name of the person to greet"), ), -), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { +), func(ctx context.Context, requestContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { name := request.Params.Arguments["name"] if name == "" { name = "friend" @@ -464,7 +465,7 @@ s.AddPrompt(mcp.NewPrompt("code_review", mcp.ArgumentDescription("Pull request number to review"), mcp.RequiredArgument(), ), -), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { +), func(ctx context.Context, requestContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { prNumber := request.Params.Arguments["pr_number"] if prNumber == "" { return nil, fmt.Errorf("pr_number is required") @@ -495,7 +496,7 @@ s.AddPrompt(mcp.NewPrompt("query_builder", mcp.ArgumentDescription("Name of the table to query"), mcp.RequiredArgument(), ), -), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { +), func(ctx context.Context, requestContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { tableName := request.Params.Arguments["table"] if tableName == "" { return nil, fmt.Errorf("table name is required") @@ -530,6 +531,9 @@ Prompts can include: +### RequestContext + + ## Examples For examples, see the [`examples/`](examples/) directory. @@ -723,7 +727,7 @@ s := server.NewMCPServer( The session context is automatically passed to tool and resource handlers: ```go -s.AddTool(mcp.NewTool("session_aware"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { +s.AddTool(mcp.NewTool("session_aware"), func(ctx context.Context, requestContext server.RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get the current session from context session := server.ClientSessionFromContext(ctx) if session == nil { From 56971310ad887008fcf6d4cf76eb14a8d2a43005 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Sun, 25 May 2025 04:49:17 +0800 Subject: [PATCH 06/15] update readme --- README.md | 32 ++++++++++++++++++++++++++++++++ server/request_context.go | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a459e6114..f7283e996 100644 --- a/README.md +++ b/README.md @@ -532,7 +532,39 @@ Prompts can include: ### RequestContext +
+Show RequestContext Examples + +The RequestContext object provides capabilities to interact with the client, such as sending logging notification and progress notification. +```go +// Example of using RequestContext to send logging notifications and progress notifications +mcpServer.AddTool(mcp.NewTool( + "test-RequestContent", + mcp.WithDescription("test RequestContent"), +), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // you could invoke `reqContext.IsLoggingNotificationSupported()` first the check if server supports logging notification + // ff server does not support logging notification, this method will do nothing. + _ = reqContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + "testLog": "test send log notification", + }) + + // server should send progress notification if request metadata includes a progressToken + total := float64(100) + progressMessage := "human readable progress information" + _ = reqContext.SendProgressNotification(ctx, float64(50), &total, &progressMessage) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "context from header: " + ctx.Value(testHeaderKey).(string) + ", " + ctx.Value(testHeaderFuncKey).(string), + }, + }, + }, nil +}) +``` +
## Examples diff --git a/server/request_context.go b/server/request_context.go index 6558dc11a..9f2257c85 100644 --- a/server/request_context.go +++ b/server/request_context.go @@ -8,7 +8,7 @@ import ( type requestIDKey struct{} -// RequestContext represents an exchange with MCP client. The exchange provides +// RequestContext represents an exchange with MCP client, and provides // methods to interact with the client and query its capabilities. type RequestContext struct { mcpServer *MCPServer From 78ea9244e1f82bce49cdbe60066b5730294998a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Sun, 25 May 2025 05:41:43 +0800 Subject: [PATCH 07/15] remove unused code --- mcp/types.go | 80 +++++++++------------------------------------------- mcp/utils.go | 59 -------------------------------------- 2 files changed, 13 insertions(+), 126 deletions(-) diff --git a/mcp/types.go b/mcp/types.go index c6ca3ff6f..09bd4e025 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -72,11 +72,22 @@ const ( // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#log-message-notifications MethodNotificationMessage MCPMethod = "notifications/message" - // MethodNotificationProgress notifies progress updates for long-running operations. + // MethodNotificationProgress notifies progress updates for long-running operations // https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/progress MethodNotificationProgress MCPMethod = "notifications/progress" - // MethodNotificationCancellation cancel a previously-issued request + // MethodNotificationCancellation can be sent by either side to indicate that it is + // cancelling a previously-issued request. + // + // The request SHOULD still be in-flight, but due to communication latency, it + // is always possible that this notification MAY arrive after the request has + // already finished. + // + // This notification indicates that the result will be unused, so any + // associated processing SHOULD cease. + // + // A client MUST NOT attempt to cancel its `initialize` request. + // // https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation MethodNotificationCancellation MCPMethod = "notifications/cancelled" ) @@ -366,34 +377,6 @@ const ( // EmptyResult represents a response that indicates success but carries no data. type EmptyResult Result -/* Cancellation */ - -// CancelledNotification can be sent by either side to indicate that it is -// cancelling a previously-issued request. -// -// The request SHOULD still be in-flight, but due to communication latency, it -// is always possible that this notification MAY arrive after the request has -// already finished. -// -// This notification indicates that the result will be unused, so any -// associated processing SHOULD cease. -// -// A client MUST NOT attempt to cancel its `initialize` request. -type CancelledNotification struct { - Notification - Params struct { - // The ID of the request to cancel. - // - // This MUST correspond to the ID of a request previously issued - // in the same direction. - RequestId RequestId `json:"requestId"` - - // An optional string describing the reason for the cancellation. This MAY - // be logged or presented to the user. - Reason string `json:"reason,omitempty"` - } `json:"params"` -} - /* Initialization */ // InitializeRequest is sent from the client to the server when it first @@ -491,27 +474,6 @@ type PingRequest struct { Request } -/* Progress notifications */ - -// ProgressNotification is an out-of-band notification used to inform the -// receiver of a progress update for a long-running request. -type ProgressNotification struct { - Notification - Params struct { - // The progress token which was given in the initial request, used to - // associate this notification with the request that is proceeding. - ProgressToken ProgressToken `json:"progressToken"` - // The progress thus far. This should increase every time progress is made, - // even if the total is unknown. - Progress float64 `json:"progress"` - // Total number of items to process (or total progress required), if known. - Total float64 `json:"total,omitempty"` - // Message related to progress. This should provide relevant human-readable - // progress information. - Message string `json:"message,omitempty"` - } `json:"params"` -} - /* Pagination */ type PaginatedRequest struct { @@ -713,22 +675,6 @@ type SetLevelRequest struct { } `json:"params"` } -// LoggingMessageNotification is a notification of a log message passed from -// server to client. If no logging/setLevel request has been sent from the client, -// the server MAY decide which messages to send automatically. -type LoggingMessageNotification struct { - Notification - Params struct { - // The severity of this log message. - Level LoggingLevel `json:"level"` - // An optional name of the logger issuing this message. - Logger string `json:"logger,omitempty"` - // The data to be logged, such as a string message or an object. Any JSON - // serializable type is allowed here. - Data any `json:"data"` - } `json:"params"` -} - // LoggingLevel represents the severity of a log message. // // These map to syslog message severities, as specified in RFC-5424: diff --git a/mcp/utils.go b/mcp/utils.go index d16850d11..bc7354ad0 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -22,8 +22,6 @@ var _ ClientRequest = &CallToolRequest{} var _ ClientRequest = &ListToolsRequest{} // ClientNotification types -var _ ClientNotification = &CancelledNotification{} -var _ ClientNotification = &ProgressNotification{} var _ ClientNotification = &InitializedNotification{} var _ ClientNotification = &RootsListChangedNotification{} @@ -38,9 +36,6 @@ var _ ServerRequest = &CreateMessageRequest{} var _ ServerRequest = &ListRootsRequest{} // ServerNotification types -var _ ServerNotification = &CancelledNotification{} -var _ ServerNotification = &ProgressNotification{} -var _ ServerNotification = &LoggingMessageNotification{} var _ ServerNotification = &ResourceUpdatedNotification{} var _ ServerNotification = &ResourceListChangedNotification{} var _ ServerNotification = &ToolListChangedNotification{} @@ -131,60 +126,6 @@ func NewJSONRPCError( } } -// NewProgressNotification -// Helper function for creating a progress notification -func NewProgressNotification( - token ProgressToken, - progress float64, - total *float64, - message *string, -) ProgressNotification { - notification := ProgressNotification{ - Notification: Notification{ - Method: string(MethodNotificationProgress), - }, - Params: struct { - ProgressToken ProgressToken `json:"progressToken"` - Progress float64 `json:"progress"` - Total float64 `json:"total,omitempty"` - Message string `json:"message,omitempty"` - }{ - ProgressToken: token, - Progress: progress, - }, - } - if total != nil { - notification.Params.Total = *total - } - if message != nil { - notification.Params.Message = *message - } - return notification -} - -// NewLoggingMessageNotification -// Helper function for creating a logging message notification -func NewLoggingMessageNotification( - level LoggingLevel, - logger string, - data any, -) LoggingMessageNotification { - return LoggingMessageNotification{ - Notification: Notification{ - Method: "notifications/message", - }, - Params: struct { - Level LoggingLevel `json:"level"` - Logger string `json:"logger,omitempty"` - Data any `json:"data"` - }{ - Level: level, - Logger: logger, - Data: data, - }, - } -} - // NewPromptMessage // Helper function to create a new PromptMessage func NewPromptMessage(role Role, content Content) PromptMessage { From da74e1b1b39d21b277e3074a31251117a8d4f234 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Mon, 26 May 2025 02:21:01 +0800 Subject: [PATCH 08/15] add unit test --- client/sse_test.go | 129 ++++++++++++++++++++++++++++++++++++++++++++- server/server.go | 2 +- 2 files changed, 128 insertions(+), 3 deletions(-) diff --git a/client/sse_test.go b/client/sse_test.go index 3a61de638..5c1bc77da 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "github.com/stretchr/testify/assert" "net/http" "testing" "time" @@ -27,6 +28,7 @@ func TestSSEMCPClient(t *testing.T) { server.WithResourceCapabilities(true, true), server.WithPromptCapabilities(true), server.WithToolCapabilities(true), + server.WithLogging(), ) // Add a test tool @@ -39,7 +41,7 @@ func TestSSEMCPClient(t *testing.T) { mcp.WithDestructiveHintAnnotation(false), mcp.WithIdempotentHintAnnotation(true), mcp.WithOpenWorldHintAnnotation(false), - ), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -52,7 +54,7 @@ func TestSSEMCPClient(t *testing.T) { mcpServer.AddTool(mcp.NewTool( "test-tool-for-http-header", mcp.WithDescription("Test tool for http header"), - ), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // , X-Test-Header-Func return &mcp.CallToolResult{ Content: []mcp.Content{ @@ -64,6 +66,47 @@ func TestSSEMCPClient(t *testing.T) { }, nil }) + mcpServer.AddTool(mcp.NewTool( + "test-tool-for-sending-notification", + mcp.WithDescription("Test tool for sending log notification, and the log level is warn"), + ), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + + totalProgreddValue := float64(100) + startFuncMessage := "start func" + err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgreddValue, &startFuncMessage) + if err != nil { + return nil, err + } + + err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + "filtered_log_message": "will be filtered by log level", + }) + if err != nil { + return nil, err + } + err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ + "log_message": "log message value", + }) + if err != nil { + return nil, err + } + + startFuncMessage = "end func" + err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgreddValue, &startFuncMessage) + if err != nil { + return nil, err + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "result", + }, + }, + }, nil + }) + // Initialize testServer := server.NewTestServer(mcpServer, server.WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context { @@ -330,4 +373,86 @@ func TestSSEMCPClient(t *testing.T) { t.Errorf("Got %q, want %q", result.Content[0].(mcp.TextContent).Text, "context from header: test-header-value, test-header-func-value") } }) + + t.Run("CallTool for testing log and progress notification", func(t *testing.T) { + client, err := NewSSEMCPClient(testServer.URL + "/sse") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + var messageNotification *mcp.JSONRPCNotification + progressNotifications := make([]*mcp.JSONRPCNotification, 0) + client.OnNotification(func(notification mcp.JSONRPCNotification) { + if notification.Method == string(mcp.MethodNotificationMessage) { + messageNotification = ¬ification + } else if notification.Method == string(mcp.MethodNotificationProgress) { + progressNotifications = append(progressNotifications, ¬ification) + } + }) + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + setLevelRequest := mcp.SetLevelRequest{} + setLevelRequest.Params.Level = mcp.LoggingLevelWarning + err = client.SetLevel(ctx, setLevelRequest) + if err != nil { + t.Errorf("SetLevel failed: %v", err) + } + + request := mcp.CallToolRequest{} + request.Params.Name = "test-tool-for-sending-notification" + request.Params.Meta = &mcp.Meta{ + ProgressToken: "progress_token", + } + + result, err := client.CallTool(ctx, request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + if len(result.Content) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Content)) + } + + time.Sleep(time.Millisecond * 200) + + assert.NotNil(t, messageNotification) + assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) + assert.Equal(t, messageNotification.Params.AdditionalFields["level"], "error") + assert.Equal(t, messageNotification.Params.AdditionalFields["data"], map[string]any{ + "log_message": "log message value", + }) + + assert.Len(t, progressNotifications, 2) + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[0].Method) + assert.Equal(t, "start func", progressNotifications[0].Params.AdditionalFields["message"]) + assert.EqualValues(t, 0, progressNotifications[0].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[0].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[0].Params.AdditionalFields["total"]) + + // Assert second progress notification (end func) + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[1].Method) + assert.Equal(t, "end func", progressNotifications[1].Params.AdditionalFields["message"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"]) + }) } diff --git a/server/server.go b/server/server.go index ac796502f..6f7f9bf08 100644 --- a/server/server.go +++ b/server/server.go @@ -994,7 +994,7 @@ func (s *MCPServer) handleToolCall( finalHandler = mw[i](finalHandler) } - requestContext := NewRequestContext(s, request.Request.Params.Meta) + requestContext := NewRequestContext(s, request.Params.Meta) result, err := finalHandler(ctx, requestContext, request) if err != nil { return nil, &requestError{ From 2704ca519f2af2c090aa2137b98cdd5999751297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Mon, 26 May 2025 02:38:24 +0800 Subject: [PATCH 09/15] add test --- client/sse_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/client/sse_test.go b/client/sse_test.go index 5c1bc77da..96754aa47 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -455,4 +455,53 @@ func TestSSEMCPClient(t *testing.T) { assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"]) assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"]) }) + + t.Run("Ensure the server does not send notifications", func(t *testing.T) { + client, err := NewSSEMCPClient(testServer.URL + "/sse") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + notifications := make([]*mcp.JSONRPCNotification, 0) + client.OnNotification(func(notification mcp.JSONRPCNotification) { + notifications = append(notifications, ¬ification) + }) + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + setLevelRequest := mcp.SetLevelRequest{} + setLevelRequest.Params.Level = mcp.LoggingLevelCritical + err = client.SetLevel(ctx, setLevelRequest) + if err != nil { + t.Errorf("SetLevel failed: %v", err) + } + + // request param without progressToken + request := mcp.CallToolRequest{} + request.Params.Name = "test-tool-for-sending-notification" + + _, _ = client.CallTool(ctx, request) + time.Sleep(time.Millisecond * 200) + + assert.Len(t, notifications, 0) + }) } From 86db7bba3e1761163f70b28ae214e8f9564332c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Mon, 26 May 2025 03:22:57 +0800 Subject: [PATCH 10/15] add test --- README.md | 8 +- client/inprocess_test.go | 6 +- client/sse_test.go | 142 +++++++++++++++++++++++++++++++- examples/custom_context/main.go | 2 +- examples/dynamic_path/main.go | 2 +- examples/everything/main.go | 20 ++--- mcp/prompts.go | 1 + mcptest/mcptest_test.go | 2 +- server/resource_test.go | 8 +- server/server.go | 2 +- server/server_race_test.go | 20 ++--- server/server_test.go | 30 +++---- server/session_test.go | 10 +-- server/sse_test.go | 6 +- server/stdio_test.go | 2 +- server/streamable_http_test.go | 4 +- 16 files changed, 200 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index f7283e996..7528295b4 100644 --- a/README.md +++ b/README.md @@ -542,17 +542,17 @@ The RequestContext object provides capabilities to interact with the client, suc mcpServer.AddTool(mcp.NewTool( "test-RequestContent", mcp.WithDescription("test RequestContent"), -), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // you could invoke `reqContext.IsLoggingNotificationSupported()` first the check if server supports logging notification +), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // you could invoke `requestContext.IsLoggingNotificationSupported()` first the check if server supports logging notification // ff server does not support logging notification, this method will do nothing. - _ = reqContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + _ = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ "testLog": "test send log notification", }) // server should send progress notification if request metadata includes a progressToken total := float64(100) progressMessage := "human readable progress information" - _ = reqContext.SendProgressNotification(ctx, float64(50), &total, &progressMessage) + _ = requestContext.SendProgressNotification(ctx, float64(50), &total, &progressMessage) return &mcp.CallToolResult{ Content: []mcp.Content{ diff --git a/client/inprocess_test.go b/client/inprocess_test.go index cdf426ccb..e0bcfac79 100644 --- a/client/inprocess_test.go +++ b/client/inprocess_test.go @@ -27,7 +27,7 @@ func TestInProcessMCPClient(t *testing.T) { mcp.WithDestructiveHintAnnotation(false), mcp.WithIdempotentHintAnnotation(true), mcp.WithOpenWorldHintAnnotation(false), - ), func(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -48,7 +48,7 @@ func TestInProcessMCPClient(t *testing.T) { URI: "resource://testresource", Name: "My Resource", }, - func(ctx context.Context, reqContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "resource://testresource", @@ -70,7 +70,7 @@ func TestInProcessMCPClient(t *testing.T) { }, }, }, - func(ctx context.Context, reqContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + func(ctx context.Context, requestContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Messages: []mcp.PromptMessage{ { diff --git a/client/sse_test.go b/client/sse_test.go index 96754aa47..2174c4568 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "fmt" "github.com/stretchr/testify/assert" "net/http" "testing" @@ -70,10 +71,9 @@ func TestSSEMCPClient(t *testing.T) { "test-tool-for-sending-notification", mcp.WithDescription("Test tool for sending log notification, and the log level is warn"), ), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - - totalProgreddValue := float64(100) + totalProgressValue := float64(100) startFuncMessage := "start func" - err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgreddValue, &startFuncMessage) + err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) if err != nil { return nil, err } @@ -92,7 +92,7 @@ func TestSSEMCPClient(t *testing.T) { } startFuncMessage = "end func" - err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgreddValue, &startFuncMessage) + err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) if err != nil { return nil, err } @@ -106,6 +106,48 @@ func TestSSEMCPClient(t *testing.T) { }, }, nil }) + mcpServer.AddPrompt(mcp.Prompt{ + Name: "prompt_get_for_server_notification", + Description: "Test prompt", + }, func(ctx context.Context, requestContext server.RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + totalProgressValue := float64(100) + startFuncMessage := "start get prompt" + err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + + err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + "filtered_log_message": "will be filtered by log level", + }) + if err != nil { + return nil, err + } + err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ + "log_message": "log message value", + }) + if err != nil { + return nil, err + } + + startFuncMessage = "end get prompt" + err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + + return &mcp.GetPromptResult{ + Messages: []mcp.PromptMessage{ + { + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "prompt value", + }, + }, + }, + }, nil + }) // Initialize testServer := server.NewTestServer(mcpServer, @@ -380,6 +422,7 @@ func TestSSEMCPClient(t *testing.T) { t.Fatalf("Failed to create client: %v", err) } + notificationNum := 0 var messageNotification *mcp.JSONRPCNotification progressNotifications := make([]*mcp.JSONRPCNotification, 0) client.OnNotification(func(notification mcp.JSONRPCNotification) { @@ -388,6 +431,7 @@ func TestSSEMCPClient(t *testing.T) { } else if notification.Method == string(mcp.MethodNotificationProgress) { progressNotifications = append(progressNotifications, ¬ification) } + notificationNum += 1 }) defer client.Close() @@ -434,6 +478,7 @@ func TestSSEMCPClient(t *testing.T) { time.Sleep(time.Millisecond * 200) + assert.Equal(t, notificationNum, 3) assert.NotNil(t, messageNotification) assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) assert.Equal(t, messageNotification.Params.AdditionalFields["level"], "error") @@ -504,4 +549,93 @@ func TestSSEMCPClient(t *testing.T) { assert.Len(t, notifications, 0) }) + + t.Run("GetPrompt for testing log and progress notification", func(t *testing.T) { + client, err := NewSSEMCPClient(testServer.URL + "/sse") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + var messageNotification *mcp.JSONRPCNotification + progressNotifications := make([]*mcp.JSONRPCNotification, 0) + notificationNum := 0 + client.OnNotification(func(notification mcp.JSONRPCNotification) { + println(notification.Method) + if notification.Method == string(mcp.MethodNotificationMessage) { + messageNotification = ¬ification + } else if notification.Method == string(mcp.MethodNotificationProgress) { + progressNotifications = append(progressNotifications, ¬ification) + } + notificationNum += 1 + }) + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + setLevelRequest := mcp.SetLevelRequest{} + setLevelRequest.Params.Level = mcp.LoggingLevelWarning + err = client.SetLevel(ctx, setLevelRequest) + if err != nil { + t.Errorf("SetLevel failed: %v", err) + } + + request := mcp.GetPromptRequest{} + request.Params.Name = "prompt_get_for_server_notification" + request.Params.Meta = &mcp.Meta{ + ProgressToken: "progress_token", + } + + result, err := client.GetPrompt(ctx, request) + if err != nil { + t.Fatalf("GetPrompt failed: %v", err) + } + assert.NotNil(t, result) + assert.Len(t, result.Messages, 1) + assert.Equal(t, result.Messages[0].Role, mcp.RoleAssistant) + assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Type, "text") + assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Text, "prompt value") + + println(fmt.Sprintf("%v", result)) + + time.Sleep(time.Millisecond * 200) + + assert.Equal(t, notificationNum, 3) + assert.NotNil(t, messageNotification) + assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) + assert.Equal(t, messageNotification.Params.AdditionalFields["level"], "error") + assert.Equal(t, messageNotification.Params.AdditionalFields["data"], map[string]any{ + "log_message": "log message value", + }) + + assert.Len(t, progressNotifications, 2) + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[0].Method) + assert.Equal(t, "start get prompt", progressNotifications[0].Params.AdditionalFields["message"]) + assert.EqualValues(t, 0, progressNotifications[0].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[0].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[0].Params.AdditionalFields["total"]) + + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[1].Method) + assert.Equal(t, "end get prompt", progressNotifications[1].Params.AdditionalFields["message"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"]) + }) } diff --git a/examples/custom_context/main.go b/examples/custom_context/main.go index 1232c6676..42e83a268 100644 --- a/examples/custom_context/main.go +++ b/examples/custom_context/main.go @@ -79,7 +79,7 @@ func makeRequest(ctx context.Context, message, token string) (*response, error) // using the token from the context. func handleMakeAuthenticatedRequestTool( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { message, ok := request.GetArguments()["message"].(string) diff --git a/examples/dynamic_path/main.go b/examples/dynamic_path/main.go index 52da36756..fdad69b0b 100644 --- a/examples/dynamic_path/main.go +++ b/examples/dynamic_path/main.go @@ -19,7 +19,7 @@ func main() { mcpServer := server.NewMCPServer("dynamic-path-example", "1.0.0") // Add a trivial tool for demonstration - mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, reqContext server.RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, requestContext server.RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.GetArguments()["message"])), nil }) diff --git a/examples/everything/main.go b/examples/everything/main.go index a4bcc3ffe..474f3e78c 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -191,7 +191,7 @@ func generateResources() []mcp.Resource { func handleReadResource( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ @@ -205,7 +205,7 @@ func handleReadResource( func handleResourceTemplate( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ @@ -219,7 +219,7 @@ func handleResourceTemplate( func handleGeneratedResource( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { uri := request.Params.URI @@ -257,7 +257,7 @@ func handleGeneratedResource( func handleSimplePrompt( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ @@ -276,7 +276,7 @@ func handleSimplePrompt( func handleComplexPrompt( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { arguments := request.Params.Arguments @@ -315,7 +315,7 @@ func handleComplexPrompt( func handleEchoTool( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -335,7 +335,7 @@ func handleEchoTool( func handleAddTool( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -357,7 +357,7 @@ func handleAddTool( func handleSendNotification( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { @@ -388,7 +388,7 @@ func handleSendNotification( func handleLongRunningOperationTool( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -454,7 +454,7 @@ func handleLongRunningOperationTool( func handleGetTinyImageTool( ctx context.Context, - reqContext server.RequestContext, + requestContext server.RequestContext, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ diff --git a/mcp/prompts.go b/mcp/prompts.go index db2fc3b82..1bfc54237 100644 --- a/mcp/prompts.go +++ b/mcp/prompts.go @@ -24,6 +24,7 @@ type GetPromptRequest struct { Name string `json:"name"` // Arguments to use for templating the prompt. Arguments map[string]string `json:"arguments,omitempty"` + Meta *Meta `json:"_meta,omitempty"` } `json:"params"` } diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go index 5e44f0f41..55b2276f6 100644 --- a/mcptest/mcptest_test.go +++ b/mcptest/mcptest_test.go @@ -50,7 +50,7 @@ func TestServer(t *testing.T) { } } -func helloWorldHandler(ctx context.Context, reqContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func helloWorldHandler(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Extract name from request arguments name, ok := request.GetArguments()["name"].(string) if !ok { diff --git a/server/resource_test.go b/server/resource_test.go index e32cb2675..fdc014c99 100644 --- a/server/resource_test.go +++ b/server/resource_test.go @@ -28,7 +28,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", @@ -47,7 +47,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 2"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource2", @@ -108,7 +108,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", @@ -164,7 +164,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", diff --git a/server/server.go b/server/server.go index 442f0dad0..515b92116 100644 --- a/server/server.go +++ b/server/server.go @@ -848,7 +848,7 @@ func (s *MCPServer) handleGetPrompt( } } - requestContext := NewRequestContext(s, request.Request.Params.Meta) + requestContext := NewRequestContext(s, request.Params.Meta) result, err := handler(ctx, requestContext, request) if err != nil { return nil, &requestError{ diff --git a/server/server_race_test.go b/server/server_race_test.go index 17354609c..228686c86 100644 --- a/server/server_race_test.go +++ b/server/server_race_test.go @@ -39,7 +39,7 @@ func TestRaceConditions(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: name, Description: "Test prompt", - }, func(ctx context.Context, reqContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) }) @@ -49,7 +49,7 @@ func TestRaceConditions(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: name, Description: "Temporary prompt", - }, func(ctx context.Context, reqContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) srv.DeletePrompts(name) @@ -60,7 +60,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: name, Description: "Test tool", - }, func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) }) @@ -71,7 +71,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: name, Description: "Temporary tool", - }, func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) srv.DeleteTools(name) @@ -79,8 +79,8 @@ func TestRaceConditions(t *testing.T) { runConcurrentOperation(&wg, testDuration, "add-middleware", func() { middleware := func(next ToolHandlerFunc) ToolHandlerFunc { - return func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return next(ctx, reqContext, req) + return func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return next(ctx, requestContext, req) } } WithToolHandlerMiddleware(middleware)(srv) @@ -102,7 +102,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: "persistent-tool", Description: "Test tool that always exists", - }, func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) @@ -121,7 +121,7 @@ func TestRaceConditions(t *testing.T) { URI: uri, Name: uri, Description: "Test resource", - }, func(ctx context.Context, reqContext RequestContext, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + }, func(ctx context.Context, requestContext RequestContext, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: uri, @@ -169,12 +169,12 @@ func TestConcurrentPromptAdd(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: "initial-prompt", Description: "Initial prompt", - }, func(ctx context.Context, reqContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { go func() { srv.AddPrompt(mcp.Prompt{ Name: fmt.Sprintf("new-prompt-%d", time.Now().UnixNano()), Description: "Added from handler", - }, func(ctx context.Context, reqContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) }() diff --git a/server/server_test.go b/server/server_test.go index 439d01a5c..78edb1975 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -160,12 +160,12 @@ func TestMCPServer_Tools(t *testing.T) { action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -189,12 +189,12 @@ func TestMCPServer_Tools(t *testing.T) { require.NoError(t, err) server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -230,12 +230,12 @@ func TestMCPServer_Tools(t *testing.T) { } server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -262,13 +262,13 @@ func TestMCPServer_Tools(t *testing.T) { require.NoError(t, err) server.AddTool( mcp.NewTool("test-tool-1"), - func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) server.AddTool( mcp.NewTool("test-tool-2"), - func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) @@ -706,7 +706,7 @@ func TestMCPServer_PromptHandling(t *testing.T) { server.AddPrompt( testPrompt, - func(ctx context.Context, reqContext RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Messages: []mcp.PromptMessage{ { @@ -1121,7 +1121,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { IdempotentHint: mcp.ToBoolPtr(false), OpenWorldHint: mcp.ToBoolPtr(false), }, - }, func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) @@ -1383,7 +1383,7 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { "test://{a}/test-resource{/b*}", "My Resource", ), - func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { a := request.Params.Arguments["a"].([]string) b := request.Params.Arguments["b"].([]string) // Validate that the template arguments are passed correctly to the handler @@ -1470,7 +1470,7 @@ func createTestServer() *MCPServer { URI: "resource://testresource", Name: "My Resource", }, - func(ctx context.Context, reqContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "resource://testresource", @@ -1486,7 +1486,7 @@ func createTestServer() *MCPServer { Name: "test-tool", Description: "Test tool", }, - func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -1632,7 +1632,7 @@ func TestMCPServer_WithHooks(t *testing.T) { // Add a test tool server.AddTool( mcp.NewTool("test-tool"), - func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) @@ -1803,7 +1803,7 @@ func TestMCPServer_SessionHooks_NilHooks(t *testing.T) { } func TestMCPServer_WithRecover(t *testing.T) { - panicToolHandler := func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + panicToolHandler := func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { panic("test panic") } diff --git a/server/session_test.go b/server/session_test.go index 700976f8f..fa9f142ee 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -222,7 +222,7 @@ func TestSessionWithTools_Integration(t *testing.T) { // Create session-specific tools sessionTool := ServerTool{ Tool: mcp.NewTool("session-tool"), - Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("session-tool result"), nil }, } @@ -391,7 +391,7 @@ func TestMCPServer_AddSessionTool(t *testing.T) { err = server.AddSessionTool( session.SessionID(), mcp.NewTool("session-tool-helper"), - func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("helper result"), nil }, ) @@ -596,7 +596,7 @@ func TestMCPServer_CallSessionTool(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) // Add global tool - server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("global result"), nil }) @@ -616,7 +616,7 @@ func TestMCPServer_CallSessionTool(t *testing.T) { err = server.AddSessionTool( session.SessionID(), mcp.NewTool("test_tool"), - func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("session result"), nil }, ) @@ -924,7 +924,7 @@ func TestMCPServer_SessionToolCapabilitiesBehavior(t *testing.T) { validateServer func(t *testing.T, s *MCPServer, session *sessionTestClientWithTools) }{ { - name: "no tool capabilities provided", + name: "no tool capabilities provided", serverOptions: []ServerOption{ // No WithToolCapabilities }, diff --git a/server/sse_test.go b/server/sse_test.go index 92f4c1fd4..3def52d13 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -574,7 +574,7 @@ func TestSSEServer(t *testing.T) { WithResourceCapabilities(true, true), ) // Add a tool which uses the context function. - mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Note this is agnostic to the transport type i.e. doesn't know about request headers. testVal := testValFromContext(ctx) return mcp.NewToolResultText(testVal), nil @@ -1183,7 +1183,7 @@ func TestSSEServer(t *testing.T) { Title: "Test Tool", }, }, - Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("test"), nil }, }, @@ -1335,7 +1335,7 @@ func TestSSEServer(t *testing.T) { processingCompleted := make(chan struct{}) processingStarted := make(chan struct{}) - mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { close(processingStarted) // signal for processing started select { diff --git a/server/stdio_test.go b/server/stdio_test.go index 448f29913..e47d0cbf3 100644 --- a/server/stdio_test.go +++ b/server/stdio_test.go @@ -140,7 +140,7 @@ func TestStdioServer(t *testing.T) { // Create server mcpServer := NewMCPServer("test", "1.0.0") // Add a tool which uses the context function. - mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Note this is agnostic to the transport type i.e. doesn't know about request headers. testVal := testValFromContext(ctx) return mcp.NewToolResultText(testVal), nil diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 891ba0000..abb227733 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -39,7 +39,7 @@ var initRequest = map[string]any{ func addSSETool(mcpServer *MCPServer) { mcpServer.AddTool(mcp.Tool{ Name: "sseTool", - }, func(ctx context.Context, reqContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Send notification to client server := ServerFromContext(ctx) for i := 0; i < 10; i++ { @@ -601,7 +601,7 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { Title: "Test Tool", }, }, - Handler: func(ctx context.Context, reqContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("test"), nil }, }, From 5fc3c83d06fde768382407f10ef66ebd76f0a80a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Mon, 26 May 2025 03:48:09 +0800 Subject: [PATCH 11/15] add test --- client/sse_test.go | 126 +++++++++++++++++++++++++++++++++++++++++++++ mcp/prompts.go | 4 +- mcp/types.go | 3 ++ server/server.go | 2 +- 4 files changed, 133 insertions(+), 2 deletions(-) diff --git a/client/sse_test.go b/client/sse_test.go index 2174c4568..9998629bf 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -149,6 +149,44 @@ func TestSSEMCPClient(t *testing.T) { }, nil }) + mcpServer.AddResource(mcp.Resource{ + URI: "resource://testresource", + Name: "My Resource", + }, func(ctx context.Context, requestContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + totalProgressValue := float64(100) + startFuncMessage := "start read resource" + err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + + err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + "filtered_log_message": "will be filtered by log level", + }) + if err != nil { + return nil, err + } + err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ + "log_message": "log message value", + }) + if err != nil { + return nil, err + } + + startFuncMessage = "end read resource" + err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) + if err != nil { + return nil, err + } + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: "resource://testresource", + MIMEType: "text/plain", + Text: "test content", + }, + }, nil + }) + // Initialize testServer := server.NewTestServer(mcpServer, server.WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context { @@ -638,4 +676,92 @@ func TestSSEMCPClient(t *testing.T) { assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"]) assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"]) }) + + t.Run("GetResource for testing log and progress notification", func(t *testing.T) { + client, err := NewSSEMCPClient(testServer.URL + "/sse") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + var messageNotification *mcp.JSONRPCNotification + progressNotifications := make([]*mcp.JSONRPCNotification, 0) + notificationNum := 0 + client.OnNotification(func(notification mcp.JSONRPCNotification) { + println(notification.Method) + if notification.Method == string(mcp.MethodNotificationMessage) { + messageNotification = ¬ification + } else if notification.Method == string(mcp.MethodNotificationProgress) { + progressNotifications = append(progressNotifications, ¬ification) + } + notificationNum += 1 + }) + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + setLevelRequest := mcp.SetLevelRequest{} + setLevelRequest.Params.Level = mcp.LoggingLevelWarning + err = client.SetLevel(ctx, setLevelRequest) + if err != nil { + t.Errorf("SetLevel failed: %v", err) + } + + request := mcp.ReadResourceRequest{} + request.Params.URI = "resource://testresource" + request.Params.Meta = &mcp.Meta{ + ProgressToken: "progress_token", + } + + result, err := client.ReadResource(ctx, request) + if err != nil { + t.Fatalf("ReadResource failed: %v", err) + } + + assert.NotNil(t, result) + assert.Len(t, result.Contents, 1) + assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).URI, "resource://testresource") + assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).MIMEType, "text/plain") + assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).Text, "test content") + + time.Sleep(time.Millisecond * 200) + + assert.Equal(t, notificationNum, 3) + assert.NotNil(t, messageNotification) + assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) + assert.Equal(t, messageNotification.Params.AdditionalFields["level"], "error") + assert.Equal(t, messageNotification.Params.AdditionalFields["data"], map[string]any{ + "log_message": "log message value", + }) + + assert.Len(t, progressNotifications, 2) + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[0].Method) + assert.Equal(t, "start read resource", progressNotifications[0].Params.AdditionalFields["message"]) + assert.EqualValues(t, 0, progressNotifications[0].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[0].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[0].Params.AdditionalFields["total"]) + + assert.Equal(t, string(mcp.MethodNotificationProgress), progressNotifications[1].Method) + assert.Equal(t, "end read resource", progressNotifications[1].Params.AdditionalFields["message"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"]) + assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"]) + assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"]) + }) } diff --git a/mcp/prompts.go b/mcp/prompts.go index 1bfc54237..5029e304f 100644 --- a/mcp/prompts.go +++ b/mcp/prompts.go @@ -24,7 +24,9 @@ type GetPromptRequest struct { Name string `json:"name"` // Arguments to use for templating the prompt. Arguments map[string]string `json:"arguments,omitempty"` - Meta *Meta `json:"_meta,omitempty"` + // Meta is metadata attached to a request's parameters. This can include fields + // formally defined by the protocol or other arbitrary data. + Meta *Meta `json:"_meta,omitempty"` } `json:"params"` } diff --git a/mcp/types.go b/mcp/types.go index 09bd4e025..62e694c9a 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -531,6 +531,9 @@ type ReadResourceRequest struct { URI string `json:"uri"` // Arguments to pass to the resource handler Arguments map[string]any `json:"arguments,omitempty"` + // Meta is metadata attached to a request's parameters. This can include fields + // formally defined by the protocol or other arbitrary data. + Meta *Meta `json:"_meta,omitempty"` } `json:"params"` } diff --git a/server/server.go b/server/server.go index 515b92116..a34d525f3 100644 --- a/server/server.go +++ b/server/server.go @@ -728,7 +728,7 @@ func (s *MCPServer) handleReadResource( request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, *requestError) { - requestContext := NewRequestContext(s, request.Request.Params.Meta) + requestContext := NewRequestContext(s, request.Params.Meta) s.resourcesMu.RLock() // First try direct resource handlers From c5b90a8663fe54db253bedcf77d4f922fb3aa939 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Mon, 26 May 2025 04:00:46 +0800 Subject: [PATCH 12/15] fix data race in test --- client/sse_test.go | 36 ++++++++++++++++++++++++++---------- server/sse.go | 3 ++- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/client/sse_test.go b/client/sse_test.go index 9998629bf..833eeade5 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -2,9 +2,9 @@ package client import ( "context" - "fmt" "github.com/stretchr/testify/assert" "net/http" + "sync" "testing" "time" @@ -460,10 +460,13 @@ func TestSSEMCPClient(t *testing.T) { t.Fatalf("Failed to create client: %v", err) } + mu := sync.Mutex{} notificationNum := 0 var messageNotification *mcp.JSONRPCNotification progressNotifications := make([]*mcp.JSONRPCNotification, 0) client.OnNotification(func(notification mcp.JSONRPCNotification) { + mu.Lock() + defer mu.Unlock() if notification.Method == string(mcp.MethodNotificationMessage) { messageNotification = ¬ification } else if notification.Method == string(mcp.MethodNotificationProgress) { @@ -514,8 +517,10 @@ func TestSSEMCPClient(t *testing.T) { t.Errorf("Expected 1 content item, got %d", len(result.Content)) } - time.Sleep(time.Millisecond * 200) + time.Sleep(time.Millisecond * 500) + mu.Lock() + defer mu.Unlock() assert.Equal(t, notificationNum, 3) assert.NotNil(t, messageNotification) assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) @@ -537,6 +542,7 @@ func TestSSEMCPClient(t *testing.T) { assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"]) assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"]) assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"]) + }) t.Run("Ensure the server does not send notifications", func(t *testing.T) { @@ -545,8 +551,11 @@ func TestSSEMCPClient(t *testing.T) { t.Fatalf("Failed to create client: %v", err) } + mu := sync.Mutex{} notifications := make([]*mcp.JSONRPCNotification, 0) client.OnNotification(func(notification mcp.JSONRPCNotification) { + mu.Lock() + defer mu.Unlock() notifications = append(notifications, ¬ification) }) defer client.Close() @@ -583,8 +592,10 @@ func TestSSEMCPClient(t *testing.T) { request.Params.Name = "test-tool-for-sending-notification" _, _ = client.CallTool(ctx, request) - time.Sleep(time.Millisecond * 200) + time.Sleep(time.Millisecond * 500) + mu.Lock() + defer mu.Unlock() assert.Len(t, notifications, 0) }) @@ -594,11 +605,13 @@ func TestSSEMCPClient(t *testing.T) { t.Fatalf("Failed to create client: %v", err) } + mu := sync.Mutex{} var messageNotification *mcp.JSONRPCNotification progressNotifications := make([]*mcp.JSONRPCNotification, 0) notificationNum := 0 client.OnNotification(func(notification mcp.JSONRPCNotification) { - println(notification.Method) + mu.Lock() + defer mu.Unlock() if notification.Method == string(mcp.MethodNotificationMessage) { messageNotification = ¬ification } else if notification.Method == string(mcp.MethodNotificationProgress) { @@ -650,11 +663,10 @@ func TestSSEMCPClient(t *testing.T) { assert.Equal(t, result.Messages[0].Role, mcp.RoleAssistant) assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Type, "text") assert.Equal(t, result.Messages[0].Content.(mcp.TextContent).Text, "prompt value") + time.Sleep(time.Millisecond * 500) - println(fmt.Sprintf("%v", result)) - - time.Sleep(time.Millisecond * 200) - + mu.Lock() + defer mu.Unlock() assert.Equal(t, notificationNum, 3) assert.NotNil(t, messageNotification) assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) @@ -683,11 +695,13 @@ func TestSSEMCPClient(t *testing.T) { t.Fatalf("Failed to create client: %v", err) } + mu := sync.Mutex{} var messageNotification *mcp.JSONRPCNotification progressNotifications := make([]*mcp.JSONRPCNotification, 0) notificationNum := 0 client.OnNotification(func(notification mcp.JSONRPCNotification) { - println(notification.Method) + mu.Lock() + defer mu.Unlock() if notification.Method == string(mcp.MethodNotificationMessage) { messageNotification = ¬ification } else if notification.Method == string(mcp.MethodNotificationProgress) { @@ -741,8 +755,10 @@ func TestSSEMCPClient(t *testing.T) { assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).MIMEType, "text/plain") assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).Text, "test content") - time.Sleep(time.Millisecond * 200) + time.Sleep(time.Millisecond * 500) + mu.Lock() + defer mu.Unlock() assert.Equal(t, notificationNum, 3) assert.NotNil(t, messageNotification) assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage)) diff --git a/server/sse.go b/server/sse.go index 0f16a05d8..1de631d1c 100644 --- a/server/sse.go +++ b/server/sse.go @@ -519,7 +519,8 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { var message string if eventData, err := json.Marshal(response); err != nil { // If there is an error marshalling the response, send a generic error response - log.Printf("failed to marshal response: %v", err) + marshal, _ := json.Marshal(response) + log.Printf("failed to marshal response: %v, response %s", err, string(marshal)) message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n" } else { message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) From 51fbd1d323042c3acfbbc645958dd11c5c329a43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Mon, 26 May 2025 13:25:58 +0800 Subject: [PATCH 13/15] rename --- README.md | 44 +++++++++---------- client/inprocess_test.go | 6 +-- client/sse_test.go | 34 +++++++------- examples/custom_context/main.go | 2 +- examples/dynamic_path/main.go | 2 +- examples/everything/main.go | 20 ++++----- examples/typed_tools/main.go | 2 +- mcptest/mcptest_test.go | 2 +- ...{request_context.go => request_session.go} | 20 ++++----- server/resource_test.go | 8 ++-- server/server.go | 26 +++++------ server/server_race_test.go | 20 ++++----- server/server_test.go | 30 ++++++------- server/session_test.go | 10 ++--- server/sse_test.go | 6 +-- server/stdio_test.go | 2 +- server/streamable_http_test.go | 4 +- server/typed_tools_handler_func.go | 8 ++-- server/typed_tools_handler_func_test.go | 24 +++++----- 19 files changed, 135 insertions(+), 135 deletions(-) rename server/{request_context.go => request_session.go} (82%) diff --git a/README.md b/README.md index 7528295b4..65cfa447b 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ func main() { } } -func helloHandler(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func helloHandler(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { name, err := request.RequireString("name") if err != nil { return mcp.NewToolResultError(err.Error()), nil @@ -92,7 +92,7 @@ MCP Go handles all the complex protocol details and server management, so you ca - [Resources](#resources) - [Tools](#tools) - [Prompts](#prompts) - - [RequestContext](#RequestContext) + - [RequestSession](#requestSession) - [Examples](#examples) - [Extras](#extras) - [Transports](#transports) @@ -154,7 +154,7 @@ func main() { ) // Add the calculator handler - s.AddTool(calculatorTool, func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + s.AddTool(calculatorTool, func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Using helper functions for type-safe argument access op, err := request.RequireString("operation") if err != nil { @@ -252,7 +252,7 @@ resource := mcp.NewResource( ) // Add resource with its handler -s.AddResource(resource, func(ctx context.Context, requestContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { +s.AddResource(resource, func(ctx context.Context, requestSession server.RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { content, err := os.ReadFile("README.md") if err != nil { return nil, err @@ -280,7 +280,7 @@ template := mcp.NewResourceTemplate( ) // Add template with its handler -s.AddResourceTemplate(template, func(ctx context.Context, requestContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { +s.AddResourceTemplate(template, func(ctx context.Context, requestSession server.RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { // Extract ID from the URI using regex matching // The server automatically matches URIs to templates userID := extractIDFromURI(request.Params.URI) @@ -329,7 +329,7 @@ calculatorTool := mcp.NewTool("calculate", ), ) -s.AddTool(calculatorTool, func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +s.AddTool(calculatorTool, func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { args := request.GetArguments() op := args["operation"].(string) x := args["x"].(float64) @@ -373,7 +373,7 @@ httpTool := mcp.NewTool("http_request", ), ) -s.AddTool(httpTool, func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +s.AddTool(httpTool, func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { args := request.GetArguments() method := args["method"].(string) url := args["url"].(string) @@ -441,7 +441,7 @@ s.AddPrompt(mcp.NewPrompt("greeting", mcp.WithArgument("name", mcp.ArgumentDescription("Name of the person to greet"), ), -), func(ctx context.Context, requestContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { +), func(ctx context.Context, requestSession server.RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { name := request.Params.Arguments["name"] if name == "" { name = "friend" @@ -465,7 +465,7 @@ s.AddPrompt(mcp.NewPrompt("code_review", mcp.ArgumentDescription("Pull request number to review"), mcp.RequiredArgument(), ), -), func(ctx context.Context, requestContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { +), func(ctx context.Context, requestSession server.RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { prNumber := request.Params.Arguments["pr_number"] if prNumber == "" { return nil, fmt.Errorf("pr_number is required") @@ -496,7 +496,7 @@ s.AddPrompt(mcp.NewPrompt("query_builder", mcp.ArgumentDescription("Name of the table to query"), mcp.RequiredArgument(), ), -), func(ctx context.Context, requestContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { +), func(ctx context.Context, requestSession server.RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { tableName := request.Params.Arguments["table"] if tableName == "" { return nil, fmt.Errorf("table name is required") @@ -531,28 +531,28 @@ Prompts can include: -### RequestContext +### requestSession
-Show RequestContext Examples +Show RequestSession Examples -The RequestContext object provides capabilities to interact with the client, such as sending logging notification and progress notification. +The RequestSession object provides capabilities to interact with the client, such as sending logging notification and progress notification. ```go -// Example of using RequestContext to send logging notifications and progress notifications +// Example of using RequestSession to send logging notifications and progress notifications mcpServer.AddTool(mcp.NewTool( - "test-RequestContent", - mcp.WithDescription("test RequestContent"), -), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // you could invoke `requestContext.IsLoggingNotificationSupported()` first the check if server supports logging notification - // ff server does not support logging notification, this method will do nothing. - _ = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + "test-RequestSession", + mcp.WithDescription("test RequestSession"), +), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // you could invoke `requestSession.IsLoggingNotificationSupported()` first the check if server supports logging notification + // if server does not support logging notification, this method will do nothing. + _ = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ "testLog": "test send log notification", }) // server should send progress notification if request metadata includes a progressToken total := float64(100) progressMessage := "human readable progress information" - _ = requestContext.SendProgressNotification(ctx, float64(50), &total, &progressMessage) + _ = requestSession.SendProgressNotification(ctx, float64(50), &total, &progressMessage) return &mcp.CallToolResult{ Content: []mcp.Content{ @@ -759,7 +759,7 @@ s := server.NewMCPServer( The session context is automatically passed to tool and resource handlers: ```go -s.AddTool(mcp.NewTool("session_aware"), func(ctx context.Context, requestContext server.RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { +s.AddTool(mcp.NewTool("session_aware"), func(ctx context.Context, requestSession server.RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Get the current session from context session := server.ClientSessionFromContext(ctx) if session == nil { diff --git a/client/inprocess_test.go b/client/inprocess_test.go index e0bcfac79..ec7520353 100644 --- a/client/inprocess_test.go +++ b/client/inprocess_test.go @@ -27,7 +27,7 @@ func TestInProcessMCPClient(t *testing.T) { mcp.WithDestructiveHintAnnotation(false), mcp.WithIdempotentHintAnnotation(true), mcp.WithOpenWorldHintAnnotation(false), - ), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -48,7 +48,7 @@ func TestInProcessMCPClient(t *testing.T) { URI: "resource://testresource", Name: "My Resource", }, - func(ctx context.Context, requestContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession server.RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "resource://testresource", @@ -70,7 +70,7 @@ func TestInProcessMCPClient(t *testing.T) { }, }, }, - func(ctx context.Context, requestContext server.RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + func(ctx context.Context, requestSession server.RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Messages: []mcp.PromptMessage{ { diff --git a/client/sse_test.go b/client/sse_test.go index 833eeade5..08bb8ad9c 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -42,7 +42,7 @@ func TestSSEMCPClient(t *testing.T) { mcp.WithDestructiveHintAnnotation(false), mcp.WithIdempotentHintAnnotation(true), mcp.WithOpenWorldHintAnnotation(false), - ), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -55,7 +55,7 @@ func TestSSEMCPClient(t *testing.T) { mcpServer.AddTool(mcp.NewTool( "test-tool-for-http-header", mcp.WithDescription("Test tool for http header"), - ), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // , X-Test-Header-Func return &mcp.CallToolResult{ Content: []mcp.Content{ @@ -70,21 +70,21 @@ func TestSSEMCPClient(t *testing.T) { mcpServer.AddTool(mcp.NewTool( "test-tool-for-sending-notification", mcp.WithDescription("Test tool for sending log notification, and the log level is warn"), - ), func(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ), func(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { totalProgressValue := float64(100) startFuncMessage := "start func" - err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) + err := requestSession.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) if err != nil { return nil, err } - err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ "filtered_log_message": "will be filtered by log level", }) if err != nil { return nil, err } - err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ "log_message": "log message value", }) if err != nil { @@ -92,7 +92,7 @@ func TestSSEMCPClient(t *testing.T) { } startFuncMessage = "end func" - err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) + err = requestSession.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) if err != nil { return nil, err } @@ -109,21 +109,21 @@ func TestSSEMCPClient(t *testing.T) { mcpServer.AddPrompt(mcp.Prompt{ Name: "prompt_get_for_server_notification", Description: "Test prompt", - }, func(ctx context.Context, requestContext server.RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestSession server.RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { totalProgressValue := float64(100) startFuncMessage := "start get prompt" - err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) + err := requestSession.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) if err != nil { return nil, err } - err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ "filtered_log_message": "will be filtered by log level", }) if err != nil { return nil, err } - err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ "log_message": "log message value", }) if err != nil { @@ -131,7 +131,7 @@ func TestSSEMCPClient(t *testing.T) { } startFuncMessage = "end get prompt" - err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) + err = requestSession.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) if err != nil { return nil, err } @@ -152,21 +152,21 @@ func TestSSEMCPClient(t *testing.T) { mcpServer.AddResource(mcp.Resource{ URI: "resource://testresource", Name: "My Resource", - }, func(ctx context.Context, requestContext server.RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + }, func(ctx context.Context, requestSession server.RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { totalProgressValue := float64(100) startFuncMessage := "start read resource" - err := requestContext.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) + err := requestSession.SendProgressNotification(ctx, float64(0), &totalProgressValue, &startFuncMessage) if err != nil { return nil, err } - err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelInfo, map[string]any{ "filtered_log_message": "will be filtered by log level", }) if err != nil { return nil, err } - err = requestContext.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ + err = requestSession.SendLoggingNotification(ctx, mcp.LoggingLevelError, map[string]any{ "log_message": "log message value", }) if err != nil { @@ -174,7 +174,7 @@ func TestSSEMCPClient(t *testing.T) { } startFuncMessage = "end read resource" - err = requestContext.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) + err = requestSession.SendProgressNotification(ctx, float64(100), &totalProgressValue, &startFuncMessage) if err != nil { return nil, err } diff --git a/examples/custom_context/main.go b/examples/custom_context/main.go index 42e83a268..88136a2c6 100644 --- a/examples/custom_context/main.go +++ b/examples/custom_context/main.go @@ -79,7 +79,7 @@ func makeRequest(ctx context.Context, message, token string) (*response, error) // using the token from the context. func handleMakeAuthenticatedRequestTool( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { message, ok := request.GetArguments()["message"].(string) diff --git a/examples/dynamic_path/main.go b/examples/dynamic_path/main.go index fdad69b0b..a38175245 100644 --- a/examples/dynamic_path/main.go +++ b/examples/dynamic_path/main.go @@ -19,7 +19,7 @@ func main() { mcpServer := server.NewMCPServer("dynamic-path-example", "1.0.0") // Add a trivial tool for demonstration - mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, requestContext server.RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, requestSession server.RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.GetArguments()["message"])), nil }) diff --git a/examples/everything/main.go b/examples/everything/main.go index 474f3e78c..1c3ccd566 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -191,7 +191,7 @@ func generateResources() []mcp.Resource { func handleReadResource( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ @@ -205,7 +205,7 @@ func handleReadResource( func handleResourceTemplate( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ @@ -219,7 +219,7 @@ func handleResourceTemplate( func handleGeneratedResource( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.ReadResourceRequest, ) ([]mcp.ResourceContents, error) { uri := request.Params.URI @@ -257,7 +257,7 @@ func handleGeneratedResource( func handleSimplePrompt( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ @@ -276,7 +276,7 @@ func handleSimplePrompt( func handleComplexPrompt( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, error) { arguments := request.Params.Arguments @@ -315,7 +315,7 @@ func handleComplexPrompt( func handleEchoTool( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -335,7 +335,7 @@ func handleEchoTool( func handleAddTool( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -357,7 +357,7 @@ func handleAddTool( func handleSendNotification( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { @@ -388,7 +388,7 @@ func handleSendNotification( func handleLongRunningOperationTool( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { arguments := request.GetArguments() @@ -454,7 +454,7 @@ func handleLongRunningOperationTool( func handleGetTinyImageTool( ctx context.Context, - requestContext server.RequestContext, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ diff --git a/examples/typed_tools/main.go b/examples/typed_tools/main.go index da7f21e38..90348e425 100644 --- a/examples/typed_tools/main.go +++ b/examples/typed_tools/main.go @@ -73,7 +73,7 @@ func main() { } // Our typed handler function that receives strongly-typed arguments -func typedGreetingHandler(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest, args GreetingArgs) (*mcp.CallToolResult, error) { +func typedGreetingHandler(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest, args GreetingArgs) (*mcp.CallToolResult, error) { if args.Name == "" { return mcp.NewToolResultError("name is required"), nil } diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go index 55b2276f6..8da2506dc 100644 --- a/mcptest/mcptest_test.go +++ b/mcptest/mcptest_test.go @@ -50,7 +50,7 @@ func TestServer(t *testing.T) { } } -func helloWorldHandler(ctx context.Context, requestContext server.RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func helloWorldHandler(ctx context.Context, requestSession server.RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Extract name from request arguments name, ok := request.GetArguments()["name"].(string) if !ok { diff --git a/server/request_context.go b/server/request_session.go similarity index 82% rename from server/request_context.go rename to server/request_session.go index 9f2257c85..bc61df155 100644 --- a/server/request_context.go +++ b/server/request_session.go @@ -8,35 +8,35 @@ import ( type requestIDKey struct{} -// RequestContext represents an exchange with MCP client, and provides +// RequestSession represents an exchange with MCP client, and provides // methods to interact with the client and query its capabilities. -type RequestContext struct { +type RequestSession struct { mcpServer *MCPServer progressToken *mcp.ProgressToken } -func NewRequestContext(mcpServer *MCPServer, requestParamMeta *mcp.Meta) RequestContext { - requestContext := RequestContext{ +func NewRequestSession(mcpServer *MCPServer, requestParamMeta *mcp.Meta) RequestSession { + requestSession := RequestSession{ mcpServer: mcpServer, } // server should send progress notification if request metadata includes a progressToken if requestParamMeta != nil && requestParamMeta.ProgressToken != nil { - requestContext.progressToken = &requestParamMeta.ProgressToken + requestSession.progressToken = &requestParamMeta.ProgressToken } - return requestContext + return requestSession } // IsLoggingNotificationSupported returns true if server supports logging notification -func (exchange *RequestContext) IsLoggingNotificationSupported() bool { +func (exchange *RequestSession) IsLoggingNotificationSupported() bool { return exchange.mcpServer != nil && exchange.mcpServer.capabilities.logging != nil && *exchange.mcpServer.capabilities.logging } // SendLoggingNotification send logging notification to client. // If server does not support logging notification, this method will do nothing. -func (exchange *RequestContext) SendLoggingNotification(ctx context.Context, level mcp.LoggingLevel, message map[string]any) error { +func (exchange *RequestSession) SendLoggingNotification(ctx context.Context, level mcp.LoggingLevel, message map[string]any) error { if !exchange.IsLoggingNotificationSupported() { return nil } @@ -66,7 +66,7 @@ func (exchange *RequestContext) SendLoggingNotification(ctx context.Context, lev } // SendProgressNotification send progress notification only if the client has requested progress -func (exchange *RequestContext) SendProgressNotification(ctx context.Context, progress float64, total *float64, message *string) error { +func (exchange *RequestSession) SendProgressNotification(ctx context.Context, progress float64, total *float64, message *string) error { if exchange.progressToken == nil { return nil } @@ -90,7 +90,7 @@ func (exchange *RequestContext) SendProgressNotification(ctx context.Context, pr } // SendCancellationNotification send cancellation notification to client -func (exchange *RequestContext) SendCancellationNotification(ctx context.Context, reason *string) error { +func (exchange *RequestSession) SendCancellationNotification(ctx context.Context, reason *string) error { requestIDRawValue := ctx.Value(requestIDKey{}) if requestIDRawValue == nil { return fmt.Errorf("invalid requestID") diff --git a/server/resource_test.go b/server/resource_test.go index fdc014c99..9b32c29fe 100644 --- a/server/resource_test.go +++ b/server/resource_test.go @@ -28,7 +28,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", @@ -47,7 +47,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 2"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource2", @@ -108,7 +108,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", @@ -164,7 +164,7 @@ func TestMCPServer_RemoveResource(t *testing.T) { mcp.WithResourceDescription("Test resource 1"), mcp.WithMIMEType("text/plain"), ), - func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "test://resource1", diff --git a/server/server.go b/server/server.go index a34d525f3..7475b33b3 100644 --- a/server/server.go +++ b/server/server.go @@ -28,16 +28,16 @@ type resourceTemplateEntry struct { type ServerOption func(*MCPServer) // ResourceHandlerFunc is a function that returns resource contents. -type ResourceHandlerFunc func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) +type ResourceHandlerFunc func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) // ResourceTemplateHandlerFunc is a function that returns a resource template. -type ResourceTemplateHandlerFunc func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) +type ResourceTemplateHandlerFunc func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) // PromptHandlerFunc handles prompt requests with given arguments. -type PromptHandlerFunc func(ctx context.Context, requestContext RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) +type PromptHandlerFunc func(ctx context.Context, requestSession RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) // ToolHandlerFunc handles tool calls with given arguments. -type ToolHandlerFunc func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) +type ToolHandlerFunc func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) // ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc @@ -217,7 +217,7 @@ func WithToolFilter( // WithRecovery adds a middleware that recovers from panics in tool handlers. func WithRecovery() ServerOption { return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc { - return func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { + return func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf( @@ -227,7 +227,7 @@ func WithRecovery() ServerOption { ) } }() - return next(ctx, requestContext, request) + return next(ctx, requestSession, request) } }) } @@ -728,14 +728,14 @@ func (s *MCPServer) handleReadResource( request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, *requestError) { - requestContext := NewRequestContext(s, request.Params.Meta) + requestSession := NewRequestSession(s, request.Params.Meta) s.resourcesMu.RLock() // First try direct resource handlers if entry, ok := s.resources[request.Params.URI]; ok { handler := entry.handler s.resourcesMu.RUnlock() - contents, err := handler(ctx, requestContext, request) + contents, err := handler(ctx, requestSession, request) if err != nil { return nil, &requestError{ id: id, @@ -766,7 +766,7 @@ func (s *MCPServer) handleReadResource( s.resourcesMu.RUnlock() if matched { - contents, err := matchedHandler(ctx, requestContext, request) + contents, err := matchedHandler(ctx, requestSession, request) if err != nil { return nil, &requestError{ id: id, @@ -848,8 +848,8 @@ func (s *MCPServer) handleGetPrompt( } } - requestContext := NewRequestContext(s, request.Params.Meta) - result, err := handler(ctx, requestContext, request) + requestSession := NewRequestSession(s, request.Params.Meta) + result, err := handler(ctx, requestSession, request) if err != nil { return nil, &requestError{ id: id, @@ -999,8 +999,8 @@ func (s *MCPServer) handleToolCall( finalHandler = mw[i](finalHandler) } - requestContext := NewRequestContext(s, request.Params.Meta) - result, err := finalHandler(ctx, requestContext, request) + requestSession := NewRequestSession(s, request.Params.Meta) + result, err := finalHandler(ctx, requestSession, request) if err != nil { return nil, &requestError{ id: id, diff --git a/server/server_race_test.go b/server/server_race_test.go index 228686c86..26ef7686d 100644 --- a/server/server_race_test.go +++ b/server/server_race_test.go @@ -39,7 +39,7 @@ func TestRaceConditions(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: name, Description: "Test prompt", - }, func(ctx context.Context, requestContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) }) @@ -49,7 +49,7 @@ func TestRaceConditions(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: name, Description: "Temporary prompt", - }, func(ctx context.Context, requestContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) srv.DeletePrompts(name) @@ -60,7 +60,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: name, Description: "Test tool", - }, func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) }) @@ -71,7 +71,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: name, Description: "Temporary tool", - }, func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) srv.DeleteTools(name) @@ -79,8 +79,8 @@ func TestRaceConditions(t *testing.T) { runConcurrentOperation(&wg, testDuration, "add-middleware", func() { middleware := func(next ToolHandlerFunc) ToolHandlerFunc { - return func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return next(ctx, requestContext, req) + return func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return next(ctx, requestSession, req) } } WithToolHandlerMiddleware(middleware)(srv) @@ -102,7 +102,7 @@ func TestRaceConditions(t *testing.T) { srv.AddTool(mcp.Tool{ Name: "persistent-tool", Description: "Test tool that always exists", - }, func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) @@ -121,7 +121,7 @@ func TestRaceConditions(t *testing.T) { URI: uri, Name: uri, Description: "Test resource", - }, func(ctx context.Context, requestContext RequestContext, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: uri, @@ -169,12 +169,12 @@ func TestConcurrentPromptAdd(t *testing.T) { srv.AddPrompt(mcp.Prompt{ Name: "initial-prompt", Description: "Initial prompt", - }, func(ctx context.Context, requestContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { go func() { srv.AddPrompt(mcp.Prompt{ Name: fmt.Sprintf("new-prompt-%d", time.Now().UnixNano()), Description: "Added from handler", - }, func(ctx context.Context, requestContext RequestContext, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{}, nil }) }() diff --git a/server/server_test.go b/server/server_test.go index 78edb1975..9cb511483 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -160,12 +160,12 @@ func TestMCPServer_Tools(t *testing.T) { action: func(t *testing.T, server *MCPServer, notificationChannel chan mcp.JSONRPCNotification) { server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -189,12 +189,12 @@ func TestMCPServer_Tools(t *testing.T) { require.NoError(t, err) server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -230,12 +230,12 @@ func TestMCPServer_Tools(t *testing.T) { } server.SetTools(ServerTool{ Tool: mcp.NewTool("test-tool-1"), - Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }, ServerTool{ Tool: mcp.NewTool("test-tool-2"), - Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, }) @@ -262,13 +262,13 @@ func TestMCPServer_Tools(t *testing.T) { require.NoError(t, err) server.AddTool( mcp.NewTool("test-tool-1"), - func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) server.AddTool( mcp.NewTool("test-tool-2"), - func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) @@ -706,7 +706,7 @@ func TestMCPServer_PromptHandling(t *testing.T) { server.AddPrompt( testPrompt, - func(ctx context.Context, requestContext RequestContext, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { return &mcp.GetPromptResult{ Messages: []mcp.PromptMessage{ { @@ -1121,7 +1121,7 @@ func TestMCPServer_HandleUndefinedHandlers(t *testing.T) { IdempotentHint: mcp.ToBoolPtr(false), OpenWorldHint: mcp.ToBoolPtr(false), }, - }, func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }) @@ -1383,7 +1383,7 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { "test://{a}/test-resource{/b*}", "My Resource", ), - func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { a := request.Params.Arguments["a"].([]string) b := request.Params.Arguments["b"].([]string) // Validate that the template arguments are passed correctly to the handler @@ -1470,7 +1470,7 @@ func createTestServer() *MCPServer { URI: "resource://testresource", Name: "My Resource", }, - func(ctx context.Context, requestContext RequestContext, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return []mcp.ResourceContents{ mcp.TextResourceContents{ URI: "resource://testresource", @@ -1486,7 +1486,7 @@ func createTestServer() *MCPServer { Name: "test-tool", Description: "Test tool", }, - func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ @@ -1632,7 +1632,7 @@ func TestMCPServer_WithHooks(t *testing.T) { // Add a test tool server.AddTool( mcp.NewTool("test-tool"), - func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return &mcp.CallToolResult{}, nil }, ) @@ -1803,7 +1803,7 @@ func TestMCPServer_SessionHooks_NilHooks(t *testing.T) { } func TestMCPServer_WithRecover(t *testing.T) { - panicToolHandler := func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + panicToolHandler := func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { panic("test panic") } diff --git a/server/session_test.go b/server/session_test.go index fa9f142ee..1eb5e7ca9 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -222,7 +222,7 @@ func TestSessionWithTools_Integration(t *testing.T) { // Create session-specific tools sessionTool := ServerTool{ Tool: mcp.NewTool("session-tool"), - Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("session-tool result"), nil }, } @@ -271,7 +271,7 @@ func TestSessionWithTools_Integration(t *testing.T) { require.NotNil(t, tool, "Session tool should not be nil") // Now test calling directly with the handler - result, err := tool.Handler(sessionCtx, RequestContext{}, testReq) + result, err := tool.Handler(sessionCtx, RequestSession{}, testReq) require.NoError(t, err, "No error calling session tool handler directly") require.NotNil(t, result, "Result should not be nil") require.Len(t, result.Content, 1, "Result should have one content item") @@ -391,7 +391,7 @@ func TestMCPServer_AddSessionTool(t *testing.T) { err = server.AddSessionTool( session.SessionID(), mcp.NewTool("session-tool-helper"), - func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("helper result"), nil }, ) @@ -596,7 +596,7 @@ func TestMCPServer_CallSessionTool(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) // Add global tool - server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("global result"), nil }) @@ -616,7 +616,7 @@ func TestMCPServer_CallSessionTool(t *testing.T) { err = server.AddSessionTool( session.SessionID(), mcp.NewTool("test_tool"), - func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("session result"), nil }, ) diff --git a/server/sse_test.go b/server/sse_test.go index 3def52d13..c9789f121 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -574,7 +574,7 @@ func TestSSEServer(t *testing.T) { WithResourceCapabilities(true, true), ) // Add a tool which uses the context function. - mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Note this is agnostic to the transport type i.e. doesn't know about request headers. testVal := testValFromContext(ctx) return mcp.NewToolResultText(testVal), nil @@ -1183,7 +1183,7 @@ func TestSSEServer(t *testing.T) { Title: "Test Tool", }, }, - Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("test"), nil }, }, @@ -1335,7 +1335,7 @@ func TestSSEServer(t *testing.T) { processingCompleted := make(chan struct{}) processingStarted := make(chan struct{}) - mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { close(processingStarted) // signal for processing started select { diff --git a/server/stdio_test.go b/server/stdio_test.go index e47d0cbf3..477a799b9 100644 --- a/server/stdio_test.go +++ b/server/stdio_test.go @@ -140,7 +140,7 @@ func TestStdioServer(t *testing.T) { // Create server mcpServer := NewMCPServer("test", "1.0.0") // Add a tool which uses the context function. - mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + mcpServer.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Note this is agnostic to the transport type i.e. doesn't know about request headers. testVal := testValFromContext(ctx) return mcp.NewToolResultText(testVal), nil diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index abb227733..c0c1f31c0 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -39,7 +39,7 @@ var initRequest = map[string]any{ func addSSETool(mcpServer *MCPServer) { mcpServer.AddTool(mcp.Tool{ Name: "sseTool", - }, func(ctx context.Context, requestContext RequestContext, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + }, func(ctx context.Context, requestSession RequestSession, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Send notification to client server := ServerFromContext(ctx) for i := 0; i < 10; i++ { @@ -601,7 +601,7 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { Title: "Test Tool", }, }, - Handler: func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + Handler: func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return mcp.NewToolResultText("test"), nil }, }, diff --git a/server/typed_tools_handler_func.go b/server/typed_tools_handler_func.go index 8af35f5b5..c58c95775 100644 --- a/server/typed_tools_handler_func.go +++ b/server/typed_tools_handler_func.go @@ -7,15 +7,15 @@ import ( ) // TypedToolHandlerFunc is a function that handles a tool call with typed arguments -type TypedToolHandlerFunc[T any] func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest, args T) (*mcp.CallToolResult, error) +type TypedToolHandlerFunc[T any] func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest, args T) (*mcp.CallToolResult, error) // NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct -func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { +func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { var args T if err := request.BindArguments(&args); err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil } - return handler(ctx, requestContext, request, args) + return handler(ctx, requestSession, request, args) } } diff --git a/server/typed_tools_handler_func_test.go b/server/typed_tools_handler_func_test.go index 8317538a3..32873140e 100644 --- a/server/typed_tools_handler_func_test.go +++ b/server/typed_tools_handler_func_test.go @@ -19,7 +19,7 @@ func TestTypedToolHandler(t *testing.T) { } // Create a typed handler function - typedHandler := func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest, args HelloArgs) (*mcp.CallToolResult, error) { + typedHandler := func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest, args HelloArgs) (*mcp.CallToolResult, error) { return mcp.NewToolResultText(args.Name), nil } @@ -36,7 +36,7 @@ func TestTypedToolHandler(t *testing.T) { } // Call the wrapped handler - result, err := wrappedHandler(context.Background(), RequestContext{}, req) + result, err := wrappedHandler(context.Background(), RequestSession{}, req) // Verify results assert.NoError(t, err) @@ -51,7 +51,7 @@ func TestTypedToolHandler(t *testing.T) { } // This should still work because of type conversion - result, err = wrappedHandler(context.Background(), RequestContext{}, req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) @@ -63,14 +63,14 @@ func TestTypedToolHandler(t *testing.T) { } // This should still work but name will be empty - result, err = wrappedHandler(context.Background(), RequestContext{}, req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) assert.Equal(t, "", result.Content[0].(mcp.TextContent).Text) // Test with completely invalid arguments req.Params.Arguments = "not a map" - result, err = wrappedHandler(context.Background(), RequestContext{}, req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) // Error is wrapped in the result assert.NotNil(t, result) assert.True(t, result.IsError) @@ -85,7 +85,7 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { } // Create a typed handler function with validation - typedHandler := func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest, args CalculatorArgs) (*mcp.CallToolResult, error) { + typedHandler := func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest, args CalculatorArgs) (*mcp.CallToolResult, error) { // Validate operation if args.Operation == "" { return mcp.NewToolResultError("operation is required"), nil @@ -124,7 +124,7 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { } // Call the wrapped handler - result, err := wrappedHandler(context.Background(), RequestContext{}, req) + result, err := wrappedHandler(context.Background(), RequestSession{}, req) // Verify results assert.NoError(t, err) @@ -138,7 +138,7 @@ func TestTypedToolHandlerWithValidation(t *testing.T) { "y": 0.0, } - result, err = wrappedHandler(context.Background(), RequestContext{}, req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) assert.True(t, result.IsError) @@ -171,7 +171,7 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { } // Create a typed handler function - typedHandler := func(ctx context.Context, requestContext RequestContext, request mcp.CallToolRequest, profile UserProfile) (*mcp.CallToolResult, error) { + typedHandler := func(ctx context.Context, requestSession RequestSession, request mcp.CallToolRequest, profile UserProfile) (*mcp.CallToolResult, error) { // Validate required fields if profile.Name == "" { return mcp.NewToolResultError("name is required"), nil @@ -240,7 +240,7 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { } // Call the wrapped handler - result, err := wrappedHandler(context.Background(), RequestContext{}, req) + result, err := wrappedHandler(context.Background(), RequestSession{}, req) // Verify results assert.NoError(t, err) @@ -266,7 +266,7 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { }, } - result, err = wrappedHandler(context.Background(), RequestContext{}, req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Jane Smith") @@ -295,7 +295,7 @@ func TestTypedToolHandlerWithComplexObjects(t *testing.T) { }` req.Params.Arguments = json.RawMessage(jsonInput) - result, err = wrappedHandler(context.Background(), RequestContext{}, req) + result, err = wrappedHandler(context.Background(), RequestSession{}, req) assert.NoError(t, err) assert.NotNil(t, result) assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Bob Johnson") From 45d7713ba2223737c24e3e34b3fd8a386d8271c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Tue, 27 May 2025 20:29:14 +0800 Subject: [PATCH 14/15] update RequestSession --- server/request_session.go | 24 ++++++++++++------------ server/server.go | 6 +++--- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/server/request_session.go b/server/request_session.go index bc61df155..80c3c0217 100644 --- a/server/request_session.go +++ b/server/request_session.go @@ -11,15 +11,11 @@ type requestIDKey struct{} // RequestSession represents an exchange with MCP client, and provides // methods to interact with the client and query its capabilities. type RequestSession struct { - mcpServer *MCPServer - progressToken *mcp.ProgressToken } -func NewRequestSession(mcpServer *MCPServer, requestParamMeta *mcp.Meta) RequestSession { - requestSession := RequestSession{ - mcpServer: mcpServer, - } +func NewRequestSession(requestParamMeta *mcp.Meta) RequestSession { + requestSession := RequestSession{} // server should send progress notification if request metadata includes a progressToken if requestParamMeta != nil && requestParamMeta.ProgressToken != nil { @@ -30,14 +26,15 @@ func NewRequestSession(mcpServer *MCPServer, requestParamMeta *mcp.Meta) Request } // IsLoggingNotificationSupported returns true if server supports logging notification -func (exchange *RequestSession) IsLoggingNotificationSupported() bool { - return exchange.mcpServer != nil && exchange.mcpServer.capabilities.logging != nil && *exchange.mcpServer.capabilities.logging +func (exchange *RequestSession) IsLoggingNotificationSupported(ctx context.Context) bool { + mcpServer := ServerFromContext(ctx) + return mcpServer != nil && mcpServer.capabilities.logging != nil && *mcpServer.capabilities.logging } // SendLoggingNotification send logging notification to client. // If server does not support logging notification, this method will do nothing. func (exchange *RequestSession) SendLoggingNotification(ctx context.Context, level mcp.LoggingLevel, message map[string]any) error { - if !exchange.IsLoggingNotificationSupported() { + if !exchange.IsLoggingNotificationSupported(ctx) { return nil } @@ -58,7 +55,8 @@ func (exchange *RequestSession) SendLoggingNotification(ctx context.Context, lev params["logger"] = *ClientSessionFromContext(ctx).GetLoggerName() } - return exchange.mcpServer.SendNotificationToClient( + mcpServer := ServerFromContext(ctx) + return mcpServer.SendNotificationToClient( ctx, string(mcp.MethodNotificationMessage), params, @@ -82,7 +80,8 @@ func (exchange *RequestSession) SendProgressNotification(ctx context.Context, pr params["message"] = *message } - return exchange.mcpServer.SendNotificationToClient( + mcpServer := ServerFromContext(ctx) + return mcpServer.SendNotificationToClient( ctx, string(mcp.MethodNotificationProgress), params, @@ -108,7 +107,8 @@ func (exchange *RequestSession) SendCancellationNotification(ctx context.Context params["reason"] = reason } - return exchange.mcpServer.SendNotificationToClient( + mcpServer := ServerFromContext(ctx) + return mcpServer.SendNotificationToClient( ctx, string(mcp.MethodNotificationCancellation), params, diff --git a/server/server.go b/server/server.go index 7475b33b3..164ec17e2 100644 --- a/server/server.go +++ b/server/server.go @@ -728,7 +728,7 @@ func (s *MCPServer) handleReadResource( request mcp.ReadResourceRequest, ) (*mcp.ReadResourceResult, *requestError) { - requestSession := NewRequestSession(s, request.Params.Meta) + requestSession := NewRequestSession(request.Params.Meta) s.resourcesMu.RLock() // First try direct resource handlers @@ -848,7 +848,7 @@ func (s *MCPServer) handleGetPrompt( } } - requestSession := NewRequestSession(s, request.Params.Meta) + requestSession := NewRequestSession(request.Params.Meta) result, err := handler(ctx, requestSession, request) if err != nil { return nil, &requestError{ @@ -999,7 +999,7 @@ func (s *MCPServer) handleToolCall( finalHandler = mw[i](finalHandler) } - requestSession := NewRequestSession(s, request.Params.Meta) + requestSession := NewRequestSession(request.Params.Meta) result, err := finalHandler(ctx, requestSession, request) if err != nil { return nil, &requestError{ From bc8c32605e4be1d117dc8eb8a2ca9c22807e1d9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=9C=E8=89=AE=E9=AD=81?= Date: Fri, 30 May 2025 13:52:06 +0800 Subject: [PATCH 15/15] update --- client/http_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/http_test.go b/client/http_test.go index 2fcc71344..b64c8944a 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -31,7 +31,7 @@ func TestHTTPClient(t *testing.T) { mcp.NewTool("notify"), func( ctx context.Context, - requestionSession server.RequestSession, + requestSession server.RequestSession, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { server := server.ServerFromContext(ctx)