diff --git a/README.md b/README.md index ff95e1c29..25b0d5ae8 100644 --- a/README.md +++ b/README.md @@ -149,9 +149,21 @@ func main() { // Add the calculator handler s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - op := request.Params.Arguments["operation"].(string) - x := request.Params.Arguments["x"].(float64) - y := request.Params.Arguments["y"].(float64) + // Using helper functions for type-safe argument access + op, err := request.RequireString("operation") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + x, err := request.RequireFloat("x") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + y, err := request.RequireFloat("y") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } var result float64 switch op { @@ -312,9 +324,10 @@ calculatorTool := mcp.NewTool("calculate", ) s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - op := request.Params.Arguments["operation"].(string) - x := request.Params.Arguments["x"].(float64) - y := request.Params.Arguments["y"].(float64) + args := request.GetArguments() + op := args["operation"].(string) + x := args["x"].(float64) + y := args["y"].(float64) var result float64 switch op { @@ -355,10 +368,11 @@ httpTool := mcp.NewTool("http_request", ) s.AddTool(httpTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - method := request.Params.Arguments["method"].(string) - url := request.Params.Arguments["url"].(string) + args := request.GetArguments() + method := args["method"].(string) + url := args["url"].(string) body := "" - if b, ok := request.Params.Arguments["body"].(string); ok { + if b, ok := args["body"].(string); ok { body = b } diff --git a/client/inprocess_test.go b/client/inprocess_test.go index beaa0c06c..7b150e81e 100644 --- a/client/inprocess_test.go +++ b/client/inprocess_test.go @@ -32,7 +32,7 @@ func TestInProcessMCPClient(t *testing.T) { Content: []mcp.Content{ mcp.TextContent{ Type: "text", - Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string), + Text: "Input parameter: " + request.GetArguments()["parameter-1"].(string), }, mcp.AudioContent{ Type: "audio", diff --git a/client/sse_test.go b/client/sse_test.go index f02ed41a1..82e0e21bd 100644 --- a/client/sse_test.go +++ b/client/sse_test.go @@ -36,7 +36,7 @@ func TestSSEMCPClient(t *testing.T) { Content: []mcp.Content{ mcp.TextContent{ Type: "text", - Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string), + Text: "Input parameter: " + request.GetArguments()["parameter-1"].(string), }, }, }, nil diff --git a/examples/custom_context/main.go b/examples/custom_context/main.go index ef2c6dd0f..3e7cf7b4c 100644 --- a/examples/custom_context/main.go +++ b/examples/custom_context/main.go @@ -81,7 +81,7 @@ func handleMakeAuthenticatedRequestTool( ctx context.Context, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { - message, ok := request.Params.Arguments["message"].(string) + message, ok := request.GetArguments()["message"].(string) if !ok { return nil, fmt.Errorf("missing message") } diff --git a/examples/dynamic_path/main.go b/examples/dynamic_path/main.go index 57a7de88a..80d96789a 100644 --- a/examples/dynamic_path/main.go +++ b/examples/dynamic_path/main.go @@ -20,7 +20,7 @@ func main() { // Add a trivial tool for demonstration mcpServer.AddTool(mcp.NewTool("echo"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.Params.Arguments["message"])), nil + return mcp.NewToolResultText(fmt.Sprintf("Echo: %v", req.GetArguments()["message"])), nil }) // Use a dynamic base path based on a path parameter (Go 1.22+) diff --git a/examples/everything/main.go b/examples/everything/main.go index 751d3d2a9..9857e8d2d 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -312,7 +312,7 @@ func handleEchoTool( ctx context.Context, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { - arguments := request.Params.Arguments + arguments := request.GetArguments() message, ok := arguments["message"].(string) if !ok { return nil, fmt.Errorf("invalid message argument") @@ -331,7 +331,7 @@ func handleAddTool( ctx context.Context, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { - arguments := request.Params.Arguments + arguments := request.GetArguments() a, ok1 := arguments["a"].(float64) b, ok2 := arguments["b"].(float64) if !ok1 || !ok2 { @@ -382,7 +382,7 @@ func handleLongRunningOperationTool( ctx context.Context, request mcp.CallToolRequest, ) (*mcp.CallToolResult, error) { - arguments := request.Params.Arguments + arguments := request.GetArguments() progressToken := request.Params.Meta.ProgressToken duration, _ := arguments["duration"].(float64) steps, _ := arguments["steps"].(float64) diff --git a/examples/typed_tools/main.go b/examples/typed_tools/main.go new file mode 100644 index 000000000..5c49fed85 --- /dev/null +++ b/examples/typed_tools/main.go @@ -0,0 +1,105 @@ +package main + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// Define a struct for our typed arguments +type GreetingArgs struct { + Name string `json:"name"` + Age int `json:"age"` + IsVIP bool `json:"is_vip"` + Languages []string `json:"languages"` + Metadata struct { + Location string `json:"location"` + Timezone string `json:"timezone"` + } `json:"metadata"` +} + +func main() { + // Create a new MCP server + s := server.NewMCPServer( + "Typed Tools Demo 🚀", + "1.0.0", + server.WithToolCapabilities(false), + ) + + // Add tool with complex schema + tool := mcp.NewTool("greeting", + mcp.WithDescription("Generate a personalized greeting"), + mcp.WithString("name", + mcp.Required(), + mcp.Description("Name of the person to greet"), + ), + mcp.WithNumber("age", + mcp.Description("Age of the person"), + mcp.Min(0), + mcp.Max(150), + ), + mcp.WithBoolean("is_vip", + mcp.Description("Whether the person is a VIP"), + mcp.DefaultBool(false), + ), + mcp.WithArray("languages", + mcp.Description("Languages the person speaks"), + mcp.Items(map[string]any{"type": "string"}), + ), + mcp.WithObject("metadata", + mcp.Description("Additional information about the person"), + mcp.Properties(map[string]any{ + "location": map[string]any{ + "type": "string", + "description": "Current location", + }, + "timezone": map[string]any{ + "type": "string", + "description": "Timezone", + }, + }), + ), + ) + + // Add tool handler using the typed handler + s.AddTool(tool, mcp.NewTypedToolHandler(typedGreetingHandler)) + + // Start the stdio server + if err := server.ServeStdio(s); err != nil { + fmt.Printf("Server error: %v\n", err) + } +} + +// Our typed handler function that receives strongly-typed arguments +func typedGreetingHandler(ctx context.Context, request mcp.CallToolRequest, args GreetingArgs) (*mcp.CallToolResult, error) { + if args.Name == "" { + return mcp.NewToolResultError("name is required"), nil + } + + // Build a personalized greeting based on the complex arguments + greeting := fmt.Sprintf("Hello, %s!", args.Name) + + if args.Age > 0 { + greeting += fmt.Sprintf(" You are %d years old.", args.Age) + } + + if args.IsVIP { + greeting += " Welcome back, valued VIP customer!" + } + + if len(args.Languages) > 0 { + greeting += fmt.Sprintf(" You speak %d languages: %v.", len(args.Languages), args.Languages) + } + + if args.Metadata.Location != "" { + greeting += fmt.Sprintf(" I see you're from %s.", args.Metadata.Location) + + if args.Metadata.Timezone != "" { + greeting += fmt.Sprintf(" Your timezone is %s.", args.Metadata.Timezone) + } + } + + return mcp.NewToolResultText(greeting), nil +} \ No newline at end of file diff --git a/mcp/tools.go b/mcp/tools.go index 79d66e3f5..f69456ae4 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -4,6 +4,8 @@ import ( "encoding/json" "errors" "fmt" + "reflect" + "strconv" ) var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both") @@ -44,12 +46,419 @@ type CallToolResult struct { type CallToolRequest struct { Request Params struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments,omitempty"` - Meta *Meta `json:"_meta,omitempty"` + Name string `json:"name"` + Arguments any `json:"arguments,omitempty"` + Meta *Meta `json:"_meta,omitempty"` } `json:"params"` } +// GetArguments returns the Arguments as map[string]any for backward compatibility +// If Arguments is not a map, it returns an empty map +func (r CallToolRequest) GetArguments() map[string]any { + if args, ok := r.Params.Arguments.(map[string]any); ok { + return args + } + return nil +} + +// GetRawArguments returns the Arguments as-is without type conversion +// This allows users to access the raw arguments in any format +func (r CallToolRequest) GetRawArguments() any { + return r.Params.Arguments +} + +// BindArguments unmarshals the Arguments into the provided struct +// This is useful for working with strongly-typed arguments +func (r CallToolRequest) BindArguments(target any) error { + if target == nil || reflect.ValueOf(target).Kind() != reflect.Ptr { + return fmt.Errorf("target must be a non-nil pointer") + } + + // Fast-path: already raw JSON + if raw, ok := r.Params.Arguments.(json.RawMessage); ok { + return json.Unmarshal(raw, target) + } + + data, err := json.Marshal(r.Params.Arguments) + if err != nil { + return fmt.Errorf("failed to marshal arguments: %w", err) + } + + return json.Unmarshal(data, target) +} + +// GetString returns a string argument by key, or the default value if not found +func (r CallToolRequest) GetString(key string, defaultValue string) string { + args := r.GetArguments() + if val, ok := args[key]; ok { + if str, ok := val.(string); ok { + return str + } + } + return defaultValue +} + +// RequireString returns a string argument by key, or an error if not found or not a string +func (r CallToolRequest) RequireString(key string) (string, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + if str, ok := val.(string); ok { + return str, nil + } + return "", fmt.Errorf("argument %q is not a string", key) + } + return "", fmt.Errorf("required argument %q not found", key) +} + +// GetInt returns an int argument by key, or the default value if not found +func (r CallToolRequest) GetInt(key string, defaultValue int) int { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case int: + return v + case float64: + return int(v) + case string: + if i, err := strconv.Atoi(v); err == nil { + return i + } + } + } + return defaultValue +} + +// RequireInt returns an int argument by key, or an error if not found or not convertible to int +func (r CallToolRequest) RequireInt(key string) (int, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case int: + return v, nil + case float64: + return int(v), nil + case string: + if i, err := strconv.Atoi(v); err == nil { + return i, nil + } + return 0, fmt.Errorf("argument %q cannot be converted to int", key) + default: + return 0, fmt.Errorf("argument %q is not an int", key) + } + } + return 0, fmt.Errorf("required argument %q not found", key) +} + +// GetFloat returns a float64 argument by key, or the default value if not found +func (r CallToolRequest) GetFloat(key string, defaultValue float64) float64 { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case float64: + return v + case int: + return float64(v) + case string: + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f + } + } + } + return defaultValue +} + +// RequireFloat returns a float64 argument by key, or an error if not found or not convertible to float64 +func (r CallToolRequest) RequireFloat(key string) (float64, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case float64: + return v, nil + case int: + return float64(v), nil + case string: + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f, nil + } + return 0, fmt.Errorf("argument %q cannot be converted to float64", key) + default: + return 0, fmt.Errorf("argument %q is not a float64", key) + } + } + return 0, fmt.Errorf("required argument %q not found", key) +} + +// GetBool returns a bool argument by key, or the default value if not found +func (r CallToolRequest) GetBool(key string, defaultValue bool) bool { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case bool: + return v + case string: + if b, err := strconv.ParseBool(v); err == nil { + return b + } + case int: + return v != 0 + case float64: + return v != 0 + } + } + return defaultValue +} + +// RequireBool returns a bool argument by key, or an error if not found or not convertible to bool +func (r CallToolRequest) RequireBool(key string) (bool, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case bool: + return v, nil + case string: + if b, err := strconv.ParseBool(v); err == nil { + return b, nil + } + return false, fmt.Errorf("argument %q cannot be converted to bool", key) + case int: + return v != 0, nil + case float64: + return v != 0, nil + default: + return false, fmt.Errorf("argument %q is not a bool", key) + } + } + return false, fmt.Errorf("required argument %q not found", key) +} + +// GetStringSlice returns a string slice argument by key, or the default value if not found +func (r CallToolRequest) GetStringSlice(key string, defaultValue []string) []string { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []string: + return v + case []any: + result := make([]string, 0, len(v)) + for _, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } + } + return result + } + } + return defaultValue +} + +// RequireStringSlice returns a string slice argument by key, or an error if not found or not convertible to string slice +func (r CallToolRequest) RequireStringSlice(key string) ([]string, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []string: + return v, nil + case []any: + result := make([]string, 0, len(v)) + for i, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } else { + return nil, fmt.Errorf("item %d in argument %q is not a string", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a string slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetIntSlice returns an int slice argument by key, or the default value if not found +func (r CallToolRequest) GetIntSlice(key string, defaultValue []int) []int { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []int: + return v + case []any: + result := make([]int, 0, len(v)) + for _, item := range v { + switch num := item.(type) { + case int: + result = append(result, num) + case float64: + result = append(result, int(num)) + case string: + if i, err := strconv.Atoi(num); err == nil { + result = append(result, i) + } + } + } + return result + } + } + return defaultValue +} + +// RequireIntSlice returns an int slice argument by key, or an error if not found or not convertible to int slice +func (r CallToolRequest) RequireIntSlice(key string) ([]int, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []int: + return v, nil + case []any: + result := make([]int, 0, len(v)) + for i, item := range v { + switch num := item.(type) { + case int: + result = append(result, num) + case float64: + result = append(result, int(num)) + case string: + if i, err := strconv.Atoi(num); err == nil { + result = append(result, i) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to int", i, key) + } + default: + return nil, fmt.Errorf("item %d in argument %q is not an int", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not an int slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetFloatSlice returns a float64 slice argument by key, or the default value if not found +func (r CallToolRequest) GetFloatSlice(key string, defaultValue []float64) []float64 { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []float64: + return v + case []any: + result := make([]float64, 0, len(v)) + for _, item := range v { + switch num := item.(type) { + case float64: + result = append(result, num) + case int: + result = append(result, float64(num)) + case string: + if f, err := strconv.ParseFloat(num, 64); err == nil { + result = append(result, f) + } + } + } + return result + } + } + return defaultValue +} + +// RequireFloatSlice returns a float64 slice argument by key, or an error if not found or not convertible to float64 slice +func (r CallToolRequest) RequireFloatSlice(key string) ([]float64, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []float64: + return v, nil + case []any: + result := make([]float64, 0, len(v)) + for i, item := range v { + switch num := item.(type) { + case float64: + result = append(result, num) + case int: + result = append(result, float64(num)) + case string: + if f, err := strconv.ParseFloat(num, 64); err == nil { + result = append(result, f) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to float64", i, key) + } + default: + return nil, fmt.Errorf("item %d in argument %q is not a float64", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a float64 slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetBoolSlice returns a bool slice argument by key, or the default value if not found +func (r CallToolRequest) GetBoolSlice(key string, defaultValue []bool) []bool { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []bool: + return v + case []any: + result := make([]bool, 0, len(v)) + for _, item := range v { + switch b := item.(type) { + case bool: + result = append(result, b) + case string: + if parsed, err := strconv.ParseBool(b); err == nil { + result = append(result, parsed) + } + case int: + result = append(result, b != 0) + case float64: + result = append(result, b != 0) + } + } + return result + } + } + return defaultValue +} + +// RequireBoolSlice returns a bool slice argument by key, or an error if not found or not convertible to bool slice +func (r CallToolRequest) RequireBoolSlice(key string) ([]bool, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []bool: + return v, nil + case []any: + result := make([]bool, 0, len(v)) + for i, item := range v { + switch b := item.(type) { + case bool: + result = append(result, b) + case string: + if parsed, err := strconv.ParseBool(b); err == nil { + result = append(result, parsed) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to bool", i, key) + } + case int: + result = append(result, b != 0) + case float64: + result = append(result, b != 0) + default: + return nil, fmt.Errorf("item %d in argument %q is not a bool", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a bool slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + // ToolListChangedNotification is an optional notification from the server to // the client, informing it that the list of tools it offers has changed. This may // be issued by servers without any previous subscription from the client. diff --git a/mcp/tools_test.go b/mcp/tools_test.go index e2be72fbc..7f2640b94 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -308,3 +308,223 @@ func TestParseToolCallToolRequest(t *testing.T) { t.Logf("param15 type: %T,value:%v", param15, param15) } + +func TestCallToolRequestBindArguments(t *testing.T) { + // Define a struct to bind to + type TestArgs struct { + Name string `json:"name"` + Age int `json:"age"` + Email string `json:"email"` + } + + // Create a request with map arguments + req := CallToolRequest{} + req.Params.Name = "test-tool" + req.Params.Arguments = map[string]any{ + "name": "John Doe", + "age": 30, + "email": "john@example.com", + } + + // Bind arguments to struct + var args TestArgs + err := req.BindArguments(&args) + assert.NoError(t, err) + assert.Equal(t, "John Doe", args.Name) + assert.Equal(t, 30, args.Age) + assert.Equal(t, "john@example.com", args.Email) +} + +func TestCallToolRequestHelperFunctions(t *testing.T) { + // Create a request with map arguments + req := CallToolRequest{} + req.Params.Name = "test-tool" + req.Params.Arguments = map[string]any{ + "string_val": "hello", + "int_val": 42, + "float_val": 3.14, + "bool_val": true, + "string_slice_val": []any{"one", "two", "three"}, + "int_slice_val": []any{1, 2, 3}, + "float_slice_val": []any{1.1, 2.2, 3.3}, + "bool_slice_val": []any{true, false, true}, + } + + // Test GetString + assert.Equal(t, "hello", req.GetString("string_val", "default")) + assert.Equal(t, "default", req.GetString("missing_val", "default")) + + // Test RequireString + str, err := req.RequireString("string_val") + assert.NoError(t, err) + assert.Equal(t, "hello", str) + _, err = req.RequireString("missing_val") + assert.Error(t, err) + + // Test GetInt + assert.Equal(t, 42, req.GetInt("int_val", 0)) + assert.Equal(t, 0, req.GetInt("missing_val", 0)) + + // Test RequireInt + i, err := req.RequireInt("int_val") + assert.NoError(t, err) + assert.Equal(t, 42, i) + _, err = req.RequireInt("missing_val") + assert.Error(t, err) + + // Test GetFloat + assert.Equal(t, 3.14, req.GetFloat("float_val", 0.0)) + assert.Equal(t, 0.0, req.GetFloat("missing_val", 0.0)) + + // Test RequireFloat + f, err := req.RequireFloat("float_val") + assert.NoError(t, err) + assert.Equal(t, 3.14, f) + _, err = req.RequireFloat("missing_val") + assert.Error(t, err) + + // Test GetBool + assert.Equal(t, true, req.GetBool("bool_val", false)) + assert.Equal(t, false, req.GetBool("missing_val", false)) + + // Test RequireBool + b, err := req.RequireBool("bool_val") + assert.NoError(t, err) + assert.Equal(t, true, b) + _, err = req.RequireBool("missing_val") + assert.Error(t, err) + + // Test GetStringSlice + assert.Equal(t, []string{"one", "two", "three"}, req.GetStringSlice("string_slice_val", nil)) + assert.Equal(t, []string{"default"}, req.GetStringSlice("missing_val", []string{"default"})) + + // Test RequireStringSlice + ss, err := req.RequireStringSlice("string_slice_val") + assert.NoError(t, err) + assert.Equal(t, []string{"one", "two", "three"}, ss) + _, err = req.RequireStringSlice("missing_val") + assert.Error(t, err) + + // Test GetIntSlice + assert.Equal(t, []int{1, 2, 3}, req.GetIntSlice("int_slice_val", nil)) + assert.Equal(t, []int{42}, req.GetIntSlice("missing_val", []int{42})) + + // Test RequireIntSlice + is, err := req.RequireIntSlice("int_slice_val") + assert.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, is) + _, err = req.RequireIntSlice("missing_val") + assert.Error(t, err) + + // Test GetFloatSlice + assert.Equal(t, []float64{1.1, 2.2, 3.3}, req.GetFloatSlice("float_slice_val", nil)) + assert.Equal(t, []float64{4.4}, req.GetFloatSlice("missing_val", []float64{4.4})) + + // Test RequireFloatSlice + fs, err := req.RequireFloatSlice("float_slice_val") + assert.NoError(t, err) + assert.Equal(t, []float64{1.1, 2.2, 3.3}, fs) + _, err = req.RequireFloatSlice("missing_val") + assert.Error(t, err) + + // Test GetBoolSlice + assert.Equal(t, []bool{true, false, true}, req.GetBoolSlice("bool_slice_val", nil)) + assert.Equal(t, []bool{false}, req.GetBoolSlice("missing_val", []bool{false})) + + // Test RequireBoolSlice + bs, err := req.RequireBoolSlice("bool_slice_val") + assert.NoError(t, err) + assert.Equal(t, []bool{true, false, true}, bs) + _, err = req.RequireBoolSlice("missing_val") + assert.Error(t, err) +} + +func TestFlexibleArgumentsWithMap(t *testing.T) { + // Create a request with map arguments + req := CallToolRequest{} + req.Params.Name = "test-tool" + req.Params.Arguments = map[string]any{ + "key1": "value1", + "key2": 123, + } + + // Test GetArguments + args := req.GetArguments() + assert.Equal(t, "value1", args["key1"]) + assert.Equal(t, 123, args["key2"]) + + // Test GetRawArguments + rawArgs := req.GetRawArguments() + mapArgs, ok := rawArgs.(map[string]any) + assert.True(t, ok) + assert.Equal(t, "value1", mapArgs["key1"]) + assert.Equal(t, 123, mapArgs["key2"]) +} + +func TestFlexibleArgumentsWithString(t *testing.T) { + // Create a request with non-map arguments + req := CallToolRequest{} + req.Params.Name = "test-tool" + req.Params.Arguments = "string-argument" + + // Test GetArguments (should return empty map) + args := req.GetArguments() + assert.Empty(t, args) + + // Test GetRawArguments + rawArgs := req.GetRawArguments() + strArg, ok := rawArgs.(string) + assert.True(t, ok) + assert.Equal(t, "string-argument", strArg) +} + +func TestFlexibleArgumentsWithStruct(t *testing.T) { + // Create a custom struct + type CustomArgs struct { + Field1 string `json:"field1"` + Field2 int `json:"field2"` + } + + // Create a request with struct arguments + req := CallToolRequest{} + req.Params.Name = "test-tool" + req.Params.Arguments = CustomArgs{ + Field1: "test", + Field2: 42, + } + + // Test GetArguments (should return empty map) + args := req.GetArguments() + assert.Empty(t, args) + + // Test GetRawArguments + rawArgs := req.GetRawArguments() + structArg, ok := rawArgs.(CustomArgs) + assert.True(t, ok) + assert.Equal(t, "test", structArg.Field1) + assert.Equal(t, 42, structArg.Field2) +} + +func TestFlexibleArgumentsJSONMarshalUnmarshal(t *testing.T) { + // Create a request with map arguments + req := CallToolRequest{} + req.Params.Name = "test-tool" + req.Params.Arguments = map[string]any{ + "key1": "value1", + "key2": 123, + } + + // Marshal to JSON + data, err := json.Marshal(req) + assert.NoError(t, err) + + // Unmarshal from JSON + var unmarshaledReq CallToolRequest + err = json.Unmarshal(data, &unmarshaledReq) + assert.NoError(t, err) + + // Check if arguments are correctly unmarshaled + args := unmarshaledReq.GetArguments() + assert.Equal(t, "value1", args["key1"]) + assert.Equal(t, float64(123), args["key2"]) // JSON numbers are unmarshaled as float64 +} diff --git a/mcp/typed_tools.go b/mcp/typed_tools.go new file mode 100644 index 000000000..68d8cdd1f --- /dev/null +++ b/mcp/typed_tools.go @@ -0,0 +1,20 @@ +package mcp + +import ( + "context" + "fmt" +) + +// TypedToolHandlerFunc is a function that handles a tool call with typed arguments +type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error) + +// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct +func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + var args T + if err := request.BindArguments(&args); err != nil { + return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil + } + return handler(ctx, request, args) + } +} diff --git a/mcp/typed_tools_test.go b/mcp/typed_tools_test.go new file mode 100644 index 000000000..109ade89c --- /dev/null +++ b/mcp/typed_tools_test.go @@ -0,0 +1,304 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTypedToolHandler(t *testing.T) { + // Define a test struct for arguments + type HelloArgs struct { + Name string `json:"name"` + Age int `json:"age"` + IsAdmin bool `json:"is_admin"` + } + + // Create a typed handler function + typedHandler := func(ctx context.Context, request CallToolRequest, args HelloArgs) (*CallToolResult, error) { + return NewToolResultText(args.Name), nil + } + + // Create a wrapped handler + wrappedHandler := NewTypedToolHandler(typedHandler) + + // Create a test request + req := CallToolRequest{} + req.Params.Name = "test-tool" + req.Params.Arguments = map[string]any{ + "name": "John Doe", + "age": 30, + "is_admin": true, + } + + // Call the wrapped handler + result, err := wrappedHandler(context.Background(), req) + + // Verify results + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, "John Doe", result.Content[0].(TextContent).Text) + + // Test with invalid arguments + req.Params.Arguments = map[string]any{ + "name": 123, // Wrong type + "age": "thirty", + "is_admin": "yes", + } + + // This should still work because of type conversion + result, err = wrappedHandler(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, result) + + // Test with missing required field + req.Params.Arguments = map[string]any{ + "age": 30, + "is_admin": true, + // Name is missing + } + + // This should still work but name will be empty + result, err = wrappedHandler(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, "", result.Content[0].(TextContent).Text) + + // Test with completely invalid arguments + req.Params.Arguments = "not a map" + result, err = wrappedHandler(context.Background(), req) + assert.NoError(t, err) // Error is wrapped in the result + assert.NotNil(t, result) + assert.True(t, result.IsError) +} + +func TestTypedToolHandlerWithValidation(t *testing.T) { + // Define a test struct for arguments with validation + type CalculatorArgs struct { + Operation string `json:"operation"` + X float64 `json:"x"` + Y float64 `json:"y"` + } + + // Create a typed handler function with validation + typedHandler := func(ctx context.Context, request CallToolRequest, args CalculatorArgs) (*CallToolResult, error) { + // Validate operation + if args.Operation == "" { + return NewToolResultError("operation is required"), nil + } + + var result float64 + switch args.Operation { + case "add": + result = args.X + args.Y + case "subtract": + result = args.X - args.Y + case "multiply": + result = args.X * args.Y + case "divide": + if args.Y == 0 { + return NewToolResultError("division by zero"), nil + } + result = args.X / args.Y + default: + return NewToolResultError("invalid operation"), nil + } + + return NewToolResultText(fmt.Sprintf("%.0f", result)), nil + } + + // Create a wrapped handler + wrappedHandler := NewTypedToolHandler(typedHandler) + + // Create a test request + req := CallToolRequest{} + req.Params.Name = "calculator" + req.Params.Arguments = map[string]any{ + "operation": "add", + "x": 10.5, + "y": 5.5, + } + + // Call the wrapped handler + result, err := wrappedHandler(context.Background(), req) + + // Verify results + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, "16", result.Content[0].(TextContent).Text) + + // Test division by zero + req.Params.Arguments = map[string]any{ + "operation": "divide", + "x": 10.0, + "y": 0.0, + } + + result, err = wrappedHandler(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.True(t, result.IsError) + assert.Contains(t, result.Content[0].(TextContent).Text, "division by zero") +} + +func TestTypedToolHandlerWithComplexObjects(t *testing.T) { + // Define a complex test struct with nested objects + type Address struct { + Street string `json:"street"` + City string `json:"city"` + Country string `json:"country"` + ZipCode string `json:"zip_code"` + } + + type UserPreferences struct { + Theme string `json:"theme"` + Timezone string `json:"timezone"` + Newsletters []string `json:"newsletters"` + } + + type UserProfile struct { + Name string `json:"name"` + Email string `json:"email"` + Age int `json:"age"` + IsVerified bool `json:"is_verified"` + Address Address `json:"address"` + Preferences UserPreferences `json:"preferences"` + Tags []string `json:"tags"` + } + + // Create a typed handler function + typedHandler := func(ctx context.Context, request CallToolRequest, profile UserProfile) (*CallToolResult, error) { + // Validate required fields + if profile.Name == "" { + return NewToolResultError("name is required"), nil + } + if profile.Email == "" { + return NewToolResultError("email is required"), nil + } + + // Build a response that includes nested object data + response := fmt.Sprintf("User: %s (%s)", profile.Name, profile.Email) + + if profile.Age > 0 { + response += fmt.Sprintf(", Age: %d", profile.Age) + } + + if profile.IsVerified { + response += ", Verified: Yes" + } else { + response += ", Verified: No" + } + + // Include address information if available + if profile.Address.City != "" && profile.Address.Country != "" { + response += fmt.Sprintf(", Location: %s, %s", profile.Address.City, profile.Address.Country) + } + + // Include preferences if available + if profile.Preferences.Theme != "" { + response += fmt.Sprintf(", Theme: %s", profile.Preferences.Theme) + } + + if len(profile.Preferences.Newsletters) > 0 { + response += fmt.Sprintf(", Subscribed to %d newsletters", len(profile.Preferences.Newsletters)) + } + + if len(profile.Tags) > 0 { + response += fmt.Sprintf(", Tags: %v", profile.Tags) + } + + return NewToolResultText(response), nil + } + + // Create a wrapped handler + wrappedHandler := NewTypedToolHandler(typedHandler) + + // Test with complete complex object + req := CallToolRequest{} + req.Params.Name = "user_profile" + req.Params.Arguments = map[string]any{ + "name": "John Doe", + "email": "john@example.com", + "age": 35, + "is_verified": true, + "address": map[string]any{ + "street": "123 Main St", + "city": "San Francisco", + "country": "USA", + "zip_code": "94105", + }, + "preferences": map[string]any{ + "theme": "dark", + "timezone": "America/Los_Angeles", + "newsletters": []string{"weekly", "product_updates"}, + }, + "tags": []string{"premium", "early_adopter"}, + } + + // Call the wrapped handler + result, err := wrappedHandler(context.Background(), req) + + // Verify results + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Contains(t, result.Content[0].(TextContent).Text, "John Doe") + assert.Contains(t, result.Content[0].(TextContent).Text, "San Francisco, USA") + assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: dark") + assert.Contains(t, result.Content[0].(TextContent).Text, "Subscribed to 2 newsletters") + assert.Contains(t, result.Content[0].(TextContent).Text, "Tags: [premium early_adopter]") + + // Test with partial data (missing some nested fields) + req.Params.Arguments = map[string]any{ + "name": "Jane Smith", + "email": "jane@example.com", + "age": 28, + "is_verified": false, + "address": map[string]any{ + "city": "London", + "country": "UK", + }, + "preferences": map[string]any{ + "theme": "light", + }, + } + + result, err = wrappedHandler(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Contains(t, result.Content[0].(TextContent).Text, "Jane Smith") + assert.Contains(t, result.Content[0].(TextContent).Text, "London, UK") + assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: light") + assert.NotContains(t, result.Content[0].(TextContent).Text, "newsletters") + + // Test with JSON string input (simulating raw JSON from client) + jsonInput := `{ + "name": "Bob Johnson", + "email": "bob@example.com", + "age": 42, + "is_verified": true, + "address": { + "street": "456 Park Ave", + "city": "New York", + "country": "USA", + "zip_code": "10022" + }, + "preferences": { + "theme": "system", + "timezone": "America/New_York", + "newsletters": ["monthly"] + }, + "tags": ["business"] + }` + + req.Params.Arguments = json.RawMessage(jsonInput) + result, err = wrappedHandler(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Contains(t, result.Content[0].(TextContent).Text, "Bob Johnson") + assert.Contains(t, result.Content[0].(TextContent).Text, "New York, USA") + assert.Contains(t, result.Content[0].(TextContent).Text, "Theme: system") + assert.Contains(t, result.Content[0].(TextContent).Text, "Subscribed to 1 newsletters") +} \ No newline at end of file diff --git a/mcp/utils.go b/mcp/utils.go index bf6acbdff..d6d42b7e4 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -675,10 +675,11 @@ func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, } func ParseArgument(request CallToolRequest, key string, defaultVal any) any { - if _, ok := request.Params.Arguments[key]; !ok { + args := request.GetArguments() + if _, ok := args[key]; !ok { return defaultVal } else { - return request.Params.Arguments[key] + return args[key] } } diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go index 3fa8af5ba..b71b93eaa 100644 --- a/mcptest/mcptest_test.go +++ b/mcptest/mcptest_test.go @@ -52,7 +52,7 @@ func TestServer(t *testing.T) { func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Extract name from request arguments - name, ok := request.Params.Arguments["name"].(string) + name, ok := request.GetArguments()["name"].(string) if !ok { name = "World" }