From 78d1996aff62f97d5fb37f35c24e3dcb6f5f71c8 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 15 May 2025 22:06:49 +0300 Subject: [PATCH 01/13] Implement OAuth in the client --- client/oauth.go | 61 +++ client/oauth_test.go | 123 +++++ client/transport/oauth.go | 472 ++++++++++++++++++ client/transport/oauth_test.go | 139 ++++++ client/transport/oauth_utils.go | 56 +++ client/transport/oauth_utils_test.go | 88 ++++ client/transport/streamable_http.go | 73 +++ .../transport/streamable_http_oauth_test.go | 143 ++++++ examples/oauth_client/README.md | 59 +++ examples/oauth_client/main.go | 220 ++++++++ 10 files changed, 1434 insertions(+) create mode 100644 client/oauth.go create mode 100644 client/oauth_test.go create mode 100644 client/transport/oauth.go create mode 100644 client/transport/oauth_test.go create mode 100644 client/transport/oauth_utils.go create mode 100644 client/transport/oauth_utils_test.go create mode 100644 client/transport/streamable_http_oauth_test.go create mode 100644 examples/oauth_client/README.md create mode 100644 examples/oauth_client/main.go diff --git a/client/oauth.go b/client/oauth.go new file mode 100644 index 000000000..9dae8afac --- /dev/null +++ b/client/oauth.go @@ -0,0 +1,61 @@ +package client + +import ( + "fmt" + + "github.com/mark3labs/mcp-go/client/transport" +) + +// OAuthConfig is a convenience type that wraps transport.OAuthConfig +type OAuthConfig = transport.OAuthConfig + +// Token is a convenience type that wraps transport.Token +type Token = transport.Token + +// TokenStore is a convenience type that wraps transport.TokenStore +type TokenStore = transport.TokenStore + +// MemoryTokenStore is a convenience type that wraps transport.MemoryTokenStore +type MemoryTokenStore = transport.MemoryTokenStore + +// NewMemoryTokenStore is a convenience function that wraps transport.NewMemoryTokenStore +var NewMemoryTokenStore = transport.NewMemoryTokenStore + +// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support. +// Returns an error if the URL is invalid. +func NewOAuthStreamableHttpClient(baseURL string, oauthConfig OAuthConfig, options ...transport.StreamableHTTPCOption) (*Client, error) { + // Add OAuth option to the list of options + options = append(options, transport.WithOAuth(oauthConfig)) + + trans, err := transport.NewStreamableHTTP(baseURL, options...) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP transport: %w", err) + } + return NewClient(trans), nil +} + +// GenerateCodeVerifier generates a code verifier for PKCE +var GenerateCodeVerifier = transport.GenerateCodeVerifier + +// GenerateCodeChallenge generates a code challenge from a code verifier +var GenerateCodeChallenge = transport.GenerateCodeChallenge + +// GenerateState generates a state parameter for OAuth +var GenerateState = transport.GenerateState + +// OAuthAuthorizationRequiredError is returned when OAuth authorization is required +type OAuthAuthorizationRequiredError = transport.OAuthAuthorizationRequiredError + +// IsOAuthAuthorizationRequiredError checks if an error is an OAuthAuthorizationRequiredError +func IsOAuthAuthorizationRequiredError(err error) bool { + _, ok := err.(*OAuthAuthorizationRequiredError) + return ok +} + +// GetOAuthHandler extracts the OAuthHandler from an OAuthAuthorizationRequiredError +func GetOAuthHandler(err error) *transport.OAuthHandler { + if oauthErr, ok := err.(*OAuthAuthorizationRequiredError); ok { + return oauthErr.Handler + } + return nil +} \ No newline at end of file diff --git a/client/oauth_test.go b/client/oauth_test.go new file mode 100644 index 000000000..20c39ee02 --- /dev/null +++ b/client/oauth_test.go @@ -0,0 +1,123 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/mark3labs/mcp-go/client/transport" +) + +func TestNewOAuthStreamableHttpClient(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Return a successful response + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "result": map[string]any{ + "protocolVersion": "2024-11-05", + "serverInfo": map[string]any{ + "name": "test-server", + "version": "1.0.0", + }, + "capabilities": map[string]any{}, + }, + }) + })) + defer server.Close() + + // Create a token store with a valid token + tokenStore := NewMemoryTokenStore() + validToken := &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour + } + tokenStore.SaveToken(validToken) + + // Create OAuth config + oauthConfig := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, + } + + // Create client with OAuth + client, err := NewOAuthStreamableHttpClient(server.URL, oauthConfig) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Start the client + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + defer client.Close() + + // Verify that the client was created successfully + trans := client.GetTransport() + streamableHTTP, ok := trans.(*transport.StreamableHTTP) + if !ok { + t.Fatalf("Expected transport to be *transport.StreamableHTTP, got %T", trans) + } + + // Verify OAuth is enabled + if !streamableHTTP.IsOAuthEnabled() { + t.Errorf("Expected IsOAuthEnabled() to return true") + } + + // Verify the OAuth handler is set + if streamableHTTP.GetOAuthHandler() == nil { + t.Errorf("Expected GetOAuthHandler() to return a handler") + } +} + +func TestIsOAuthAuthorizationRequiredError(t *testing.T) { + // Create a test error + err := &transport.OAuthAuthorizationRequiredError{ + Handler: transport.NewOAuthHandler(transport.OAuthConfig{}), + } + + // Verify IsOAuthAuthorizationRequiredError returns true + if !IsOAuthAuthorizationRequiredError(err) { + t.Errorf("Expected IsOAuthAuthorizationRequiredError to return true") + } + + // Verify GetOAuthHandler returns the handler + handler := GetOAuthHandler(err) + if handler == nil { + t.Errorf("Expected GetOAuthHandler to return a handler") + } + + // Test with a different error + err2 := fmt.Errorf("some other error") + + // Verify IsOAuthAuthorizationRequiredError returns false + if IsOAuthAuthorizationRequiredError(err2) { + t.Errorf("Expected IsOAuthAuthorizationRequiredError to return false") + } + + // Verify GetOAuthHandler returns nil + handler = GetOAuthHandler(err2) + if handler != nil { + t.Errorf("Expected GetOAuthHandler to return nil") + } +} \ No newline at end of file diff --git a/client/transport/oauth.go b/client/transport/oauth.go new file mode 100644 index 000000000..6f6e671eb --- /dev/null +++ b/client/transport/oauth.go @@ -0,0 +1,472 @@ +package transport + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// OAuthConfig holds the OAuth configuration for the client +type OAuthConfig struct { + // ClientID is the OAuth client ID + ClientID string + // ClientSecret is the OAuth client secret (for confidential clients) + ClientSecret string + // RedirectURI is the redirect URI for the OAuth flow + RedirectURI string + // Scopes is the list of OAuth scopes to request + Scopes []string + // TokenStore is the storage for OAuth tokens + TokenStore TokenStore + // AuthServerMetadataURL is the URL to the OAuth server metadata + // If empty, the client will attempt to discover it from the base URL + AuthServerMetadataURL string + // PKCEEnabled enables PKCE for the OAuth flow (recommended for public clients) + PKCEEnabled bool +} + +// TokenStore is an interface for storing and retrieving OAuth tokens +type TokenStore interface { + // GetToken returns the current token + GetToken() (*Token, error) + // SaveToken saves a token + SaveToken(token *Token) error +} + +// Token represents an OAuth token +type Token struct { + // AccessToken is the OAuth access token + AccessToken string `json:"access_token"` + // TokenType is the type of token (usually "Bearer") + TokenType string `json:"token_type"` + // RefreshToken is the OAuth refresh token + RefreshToken string `json:"refresh_token,omitempty"` + // ExpiresIn is the number of seconds until the token expires + ExpiresIn int64 `json:"expires_in,omitempty"` + // Scope is the scope of the token + Scope string `json:"scope,omitempty"` + // ExpiresAt is the time when the token expires + ExpiresAt time.Time `json:"expires_at,omitempty"` +} + +// IsExpired returns true if the token is expired +func (t *Token) IsExpired() bool { + if t.ExpiresAt.IsZero() { + return false + } + return time.Now().After(t.ExpiresAt) +} + +// MemoryTokenStore is a simple in-memory token store +type MemoryTokenStore struct { + token *Token + mu sync.RWMutex +} + +// NewMemoryTokenStore creates a new in-memory token store +func NewMemoryTokenStore() *MemoryTokenStore { + return &MemoryTokenStore{} +} + +// GetToken returns the current token +func (s *MemoryTokenStore) GetToken() (*Token, error) { + s.mu.RLock() + defer s.mu.RUnlock() + if s.token == nil { + return nil, errors.New("no token available") + } + return s.token, nil +} + +// SaveToken saves a token +func (s *MemoryTokenStore) SaveToken(token *Token) error { + s.mu.Lock() + defer s.mu.Unlock() + s.token = token + return nil +} + +// AuthServerMetadata represents the OAuth 2.0 Authorization Server Metadata +type AuthServerMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + JwksURI string `json:"jwks_uri,omitempty"` + ScopesSupported []string `json:"scopes_supported,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported"` + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` +} + +// OAuthHandler handles OAuth authentication for HTTP requests +type OAuthHandler struct { + config OAuthConfig + httpClient *http.Client + serverMetadata *AuthServerMetadata + metadataFetchErr error + metadataOnce sync.Once +} + +// NewOAuthHandler creates a new OAuth handler +func NewOAuthHandler(config OAuthConfig) *OAuthHandler { + if config.TokenStore == nil { + config.TokenStore = NewMemoryTokenStore() + } + + return &OAuthHandler{ + config: config, + httpClient: &http.Client{Timeout: 30 * time.Second}, + } +} + +// GetAuthorizationHeader returns the Authorization header value for a request +func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, error) { + token, err := h.getValidToken(ctx) + if err != nil { + return "", err + } + return fmt.Sprintf("%s %s", token.TokenType, token.AccessToken), nil +} + +// getValidToken returns a valid token, refreshing if necessary +func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) { + token, err := h.config.TokenStore.GetToken() + if err == nil && !token.IsExpired() { + return token, nil + } + + // If we have a refresh token, try to use it + if err == nil && token.RefreshToken != "" { + newToken, err := h.refreshToken(ctx, token.RefreshToken) + if err == nil { + return newToken, nil + } + // If refresh fails, continue to authorization flow + } + + // We need to get a new token through the authorization flow + return nil, errors.New("no valid token available, authorization required") +} + +// refreshToken refreshes an OAuth token +func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*Token, error) { + metadata, err := h.getServerMetadata(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get server metadata: %w", err) + } + + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", refreshToken) + data.Set("client_id", h.config.ClientID) + if h.config.ClientSecret != "" { + data.Set("client_secret", h.config.ClientSecret) + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + metadata.TokenEndpoint, + strings.NewReader(data.Encode()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create refresh token request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := h.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send refresh token request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("refresh token request failed with status %d: %s", resp.StatusCode, body) + } + + var tokenResp Token + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return nil, fmt.Errorf("failed to decode token response: %w", err) + } + + // Set expiration time + if tokenResp.ExpiresIn > 0 { + tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + // Save the token + if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil { + return nil, fmt.Errorf("failed to save token: %w", err) + } + + return &tokenResp, nil +} + +// RefreshToken is a public wrapper for refreshToken +func (h *OAuthHandler) RefreshToken(ctx context.Context, refreshToken string) (*Token, error) { + return h.refreshToken(ctx, refreshToken) +} + +// GetClientID returns the client ID +func (h *OAuthHandler) GetClientID() string { + return h.config.ClientID +} + +// GetClientSecret returns the client secret +func (h *OAuthHandler) GetClientSecret() string { + return h.config.ClientSecret +} + +// getServerMetadata fetches the OAuth server metadata +func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetadata, error) { + h.metadataOnce.Do(func() { + var metadataURL string + if h.config.AuthServerMetadataURL != "" { + metadataURL = h.config.AuthServerMetadataURL + } else { + // Construct the well-known URL from the base URL + // According to the spec, we need to discard any path component + baseURL, err := url.Parse(h.config.AuthServerMetadataURL) + if err != nil { + h.metadataFetchErr = fmt.Errorf("invalid base URL: %w", err) + return + } + baseURL.Path = "" + metadataURL = baseURL.String() + "/.well-known/oauth-authorization-server" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to create metadata request: %w", err) + return + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("MCP-Protocol-Version", "2025-03-26") + + resp, err := h.httpClient.Do(req) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to send metadata request: %w", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + // If metadata discovery fails, use default endpoints + h.serverMetadata = h.getDefaultEndpoints(metadataURL) + return + } + + var metadata AuthServerMetadata + if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { + h.metadataFetchErr = fmt.Errorf("failed to decode metadata response: %w", err) + return + } + + h.serverMetadata = &metadata + }) + + if h.metadataFetchErr != nil { + return nil, h.metadataFetchErr + } + + return h.serverMetadata, nil +} + +// GetServerMetadata is a public wrapper for getServerMetadata +func (h *OAuthHandler) GetServerMetadata(ctx context.Context) (*AuthServerMetadata, error) { + return h.getServerMetadata(ctx) +} + +// getDefaultEndpoints returns default OAuth endpoints based on the base URL +func (h *OAuthHandler) getDefaultEndpoints(baseURL string) *AuthServerMetadata { + // Parse the base URL to extract the authority + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil + } + + // Discard any path component to get the authorization base URL + parsedURL.Path = "" + authBaseURL := parsedURL.String() + + return &AuthServerMetadata{ + Issuer: authBaseURL, + AuthorizationEndpoint: authBaseURL + "/authorize", + TokenEndpoint: authBaseURL + "/token", + RegistrationEndpoint: authBaseURL + "/register", + } +} + +// RegisterClient performs dynamic client registration +func (h *OAuthHandler) RegisterClient(ctx context.Context, clientName string) error { + metadata, err := h.getServerMetadata(ctx) + if err != nil { + return fmt.Errorf("failed to get server metadata: %w", err) + } + + if metadata.RegistrationEndpoint == "" { + return errors.New("server does not support dynamic client registration") + } + + // Prepare registration request + regRequest := map[string]any{ + "client_name": clientName, + "redirect_uris": []string{h.config.RedirectURI}, + "token_endpoint_auth_method": "none", // For public clients + "grant_types": []string{"authorization_code", "refresh_token"}, + "response_types": []string{"code"}, + "scope": strings.Join(h.config.Scopes, " "), + } + + // Add client_secret if this is a confidential client + if h.config.ClientSecret != "" { + regRequest["token_endpoint_auth_method"] = "client_secret_basic" + } + + reqBody, err := json.Marshal(regRequest) + if err != nil { + return fmt.Errorf("failed to marshal registration request: %w", err) + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + metadata.RegistrationEndpoint, + bytes.NewReader(reqBody), + ) + if err != nil { + return fmt.Errorf("failed to create registration request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := h.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send registration request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("registration request failed with status %d: %s", resp.StatusCode, body) + } + + var regResponse struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` + } + + if err := json.NewDecoder(resp.Body).Decode(®Response); err != nil { + return fmt.Errorf("failed to decode registration response: %w", err) + } + + // Update the client configuration + h.config.ClientID = regResponse.ClientID + if regResponse.ClientSecret != "" { + h.config.ClientSecret = regResponse.ClientSecret + } + + return nil +} + +// ProcessAuthorizationResponse processes the authorization response and exchanges the code for a token +func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, state, codeVerifier string) error { + metadata, err := h.getServerMetadata(ctx) + if err != nil { + return fmt.Errorf("failed to get server metadata: %w", err) + } + + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("client_id", h.config.ClientID) + data.Set("redirect_uri", h.config.RedirectURI) + + if h.config.ClientSecret != "" { + data.Set("client_secret", h.config.ClientSecret) + } + + if h.config.PKCEEnabled && codeVerifier != "" { + data.Set("code_verifier", codeVerifier) + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + metadata.TokenEndpoint, + strings.NewReader(data.Encode()), + ) + if err != nil { + return fmt.Errorf("failed to create token request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := h.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send token request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("token request failed with status %d: %s", resp.StatusCode, body) + } + + var tokenResp Token + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return fmt.Errorf("failed to decode token response: %w", err) + } + + // Set expiration time + if tokenResp.ExpiresIn > 0 { + tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + // Save the token + if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil { + return fmt.Errorf("failed to save token: %w", err) + } + + return nil +} + +// GetAuthorizationURL returns the URL for the authorization endpoint +func (h *OAuthHandler) GetAuthorizationURL(ctx context.Context, state, codeChallenge string) (string, error) { + metadata, err := h.getServerMetadata(ctx) + if err != nil { + return "", fmt.Errorf("failed to get server metadata: %w", err) + } + + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", h.config.ClientID) + params.Set("redirect_uri", h.config.RedirectURI) + params.Set("state", state) + + if len(h.config.Scopes) > 0 { + params.Set("scope", strings.Join(h.config.Scopes, " ")) + } + + if h.config.PKCEEnabled && codeChallenge != "" { + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + } + + return metadata.AuthorizationEndpoint + "?" + params.Encode(), nil +} \ No newline at end of file diff --git a/client/transport/oauth_test.go b/client/transport/oauth_test.go new file mode 100644 index 000000000..2b59541c9 --- /dev/null +++ b/client/transport/oauth_test.go @@ -0,0 +1,139 @@ +package transport + +import ( + "testing" + "time" +) + +func TestToken_IsExpired(t *testing.T) { + // Test cases + testCases := []struct { + name string + token Token + expected bool + }{ + { + name: "Valid token", + token: Token{ + AccessToken: "valid-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, + expected: false, + }, + { + name: "Expired token", + token: Token{ + AccessToken: "expired-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + }, + expected: true, + }, + { + name: "Token with no expiration", + token: Token{ + AccessToken: "no-expiration-token", + }, + expected: false, + }, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.token.IsExpired() + if result != tc.expected { + t.Errorf("Expected IsExpired() to return %v, got %v", tc.expected, result) + } + }) + } +} + +func TestMemoryTokenStore(t *testing.T) { + // Create a token store + store := NewMemoryTokenStore() + + // Test getting token from empty store + _, err := store.GetToken() + if err == nil { + t.Errorf("Expected error when getting token from empty store") + } + + // Create a test token + token := &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + // Save the token + err = store.SaveToken(token) + if err != nil { + t.Fatalf("Failed to save token: %v", err) + } + + // Get the token + retrievedToken, err := store.GetToken() + if err != nil { + t.Fatalf("Failed to get token: %v", err) + } + + // Verify the token + if retrievedToken.AccessToken != token.AccessToken { + t.Errorf("Expected access token to be %s, got %s", token.AccessToken, retrievedToken.AccessToken) + } + if retrievedToken.TokenType != token.TokenType { + t.Errorf("Expected token type to be %s, got %s", token.TokenType, retrievedToken.TokenType) + } + if retrievedToken.RefreshToken != token.RefreshToken { + t.Errorf("Expected refresh token to be %s, got %s", token.RefreshToken, retrievedToken.RefreshToken) + } +} + +func TestValidateRedirectURI(t *testing.T) { + // Test cases + testCases := []struct { + name string + redirectURI string + expectError bool + }{ + { + name: "Valid HTTPS URI", + redirectURI: "https://example.com/callback", + expectError: false, + }, + { + name: "Valid localhost URI", + redirectURI: "http://localhost:8085/callback", + expectError: false, + }, + { + name: "Invalid HTTP URI (non-localhost)", + redirectURI: "http://example.com/callback", + expectError: true, + }, + { + name: "Empty URI", + redirectURI: "", + expectError: true, + }, + { + name: "Invalid scheme", + redirectURI: "ftp://example.com/callback", + expectError: true, + }, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateRedirectURI(tc.redirectURI) + if tc.expectError && err == nil { + t.Errorf("Expected error for redirect URI %s, got nil", tc.redirectURI) + } else if !tc.expectError && err != nil { + t.Errorf("Expected no error for redirect URI %s, got %v", tc.redirectURI, err) + } + }) + } +} \ No newline at end of file diff --git a/client/transport/oauth_utils.go b/client/transport/oauth_utils.go new file mode 100644 index 000000000..9f2361209 --- /dev/null +++ b/client/transport/oauth_utils.go @@ -0,0 +1,56 @@ +package transport + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" +) + +// GenerateRandomString generates a random string of the specified length +func GenerateRandomString(length int) (string, error) { + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(bytes)[:length], nil +} + +// GenerateCodeVerifier generates a code verifier for PKCE +func GenerateCodeVerifier() (string, error) { + // According to RFC 7636, the code verifier should be between 43 and 128 characters + return GenerateRandomString(64) +} + +// GenerateCodeChallenge generates a code challenge from a code verifier +func GenerateCodeChallenge(codeVerifier string) string { + // SHA256 hash the code verifier + hash := sha256.Sum256([]byte(codeVerifier)) + // Base64url encode the hash + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +// GenerateState generates a state parameter for OAuth +func GenerateState() (string, error) { + return GenerateRandomString(32) +} + +// ValidateRedirectURI validates that a redirect URI is secure +func ValidateRedirectURI(redirectURI string) error { + // According to the spec, redirect URIs must be either localhost URLs or HTTPS URLs + if redirectURI == "" { + return fmt.Errorf("redirect URI cannot be empty") + } + + // Check if it's a localhost URL + if len(redirectURI) >= 9 && redirectURI[:9] == "http://lo" { + return nil + } + + // Check if it's an HTTPS URL + if len(redirectURI) >= 8 && redirectURI[:8] == "https://" { + return nil + } + + return fmt.Errorf("redirect URI must be either a localhost URL or an HTTPS URL") +} \ No newline at end of file diff --git a/client/transport/oauth_utils_test.go b/client/transport/oauth_utils_test.go new file mode 100644 index 000000000..da6eacaee --- /dev/null +++ b/client/transport/oauth_utils_test.go @@ -0,0 +1,88 @@ +package transport + +import ( + "fmt" + "testing" +) + +func TestGenerateRandomString(t *testing.T) { + // Test generating strings of different lengths + lengths := []int{10, 32, 64, 128} + for _, length := range lengths { + t.Run(fmt.Sprintf("Length_%d", length), func(t *testing.T) { + str, err := GenerateRandomString(length) + if err != nil { + t.Fatalf("Failed to generate random string: %v", err) + } + if len(str) != length { + t.Errorf("Expected string of length %d, got %d", length, len(str)) + } + + // Generate another string to ensure they're different + str2, err := GenerateRandomString(length) + if err != nil { + t.Fatalf("Failed to generate second random string: %v", err) + } + if str == str2 { + t.Errorf("Generated identical random strings: %s", str) + } + }) + } +} + +func TestGenerateCodeVerifierAndChallenge(t *testing.T) { + // Generate a code verifier + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("Failed to generate code verifier: %v", err) + } + + // Verify the length (should be 64 characters) + if len(verifier) != 64 { + t.Errorf("Expected code verifier of length 64, got %d", len(verifier)) + } + + // Generate a code challenge + challenge := GenerateCodeChallenge(verifier) + + // Verify the challenge is not empty + if challenge == "" { + t.Errorf("Generated empty code challenge") + } + + // Generate another verifier and challenge to ensure they're different + verifier2, _ := GenerateCodeVerifier() + challenge2 := GenerateCodeChallenge(verifier2) + + if verifier == verifier2 { + t.Errorf("Generated identical code verifiers: %s", verifier) + } + if challenge == challenge2 { + t.Errorf("Generated identical code challenges: %s", challenge) + } + + // Verify the same verifier always produces the same challenge + challenge3 := GenerateCodeChallenge(verifier) + if challenge != challenge3 { + t.Errorf("Same verifier produced different challenges: %s and %s", challenge, challenge3) + } +} + +func TestGenerateState(t *testing.T) { + // Generate a state parameter + state, err := GenerateState() + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Verify the length (should be 32 characters) + if len(state) != 32 { + t.Errorf("Expected state of length 32, got %d", len(state)) + } + + // Generate another state to ensure they're different + state2, _ := GenerateState() + if state == state2 { + t.Errorf("Generated identical states: %s", state) + } +} \ No newline at end of file diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 98719bd04..91b07f00b 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -33,6 +33,13 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { } } +// WithOAuth enables OAuth authentication for the client. +func WithOAuth(config OAuthConfig) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.oauthHandler = NewOAuthHandler(config) + } +} + // StreamableHTTP implements Streamable HTTP transport. // // It transmits JSON-RPC messages over individual HTTP requests. One message per request. @@ -59,6 +66,9 @@ type StreamableHTTP struct { notifyMu sync.RWMutex closed chan struct{} + + // OAuth support + oauthHandler *OAuthHandler } // NewStreamableHTTP creates a new Streamable HTTP transport with the given base URL. @@ -131,6 +141,15 @@ const ( headerKeySessionID = "Mcp-Session-Id" ) +// OAuthAuthorizationRequiredError is returned when OAuth authorization is required +type OAuthAuthorizationRequiredError struct { + Handler *OAuthHandler +} + +func (e *OAuthAuthorizationRequiredError) Error() string { + return "OAuth authorization required" +} + // SendRequest sends a JSON-RPC request to the server and waits for a response. // Returns the raw JSON response message or an error if the request fails. func (c *StreamableHTTP) SendRequest( @@ -173,6 +192,21 @@ func (c *StreamableHTTP) SendRequest( for k, v := range c.headers { req.Header.Set(k, v) } + + // Add OAuth authorization if configured + if c.oauthHandler != nil { + authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) + if err != nil { + // If we get an authorization error, return a specific error that can be handled by the client + if err.Error() == "no valid token available, authorization required" { + return nil, &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + return nil, fmt.Errorf("failed to get authorization header: %w", err) + } + req.Header.Set("Authorization", authHeader) + } // Send request resp, err := c.httpClient.Do(req) @@ -188,6 +222,13 @@ func (c *StreamableHTTP) SendRequest( c.sessionID.CompareAndSwap(sessionID, "") return nil, fmt.Errorf("session terminated (404). need to re-initialize") } + + // Handle OAuth unauthorized error + if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { + return nil, &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } // handle error response var errResponse JSONRPCResponse @@ -363,6 +404,21 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. for k, v := range c.headers { req.Header.Set(k, v) } + + // Add OAuth authorization if configured + if c.oauthHandler != nil { + authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) + if err != nil { + // If we get an authorization error, return a specific error that can be handled by the client + if err.Error() == "no valid token available, authorization required" { + return &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + return fmt.Errorf("failed to get authorization header: %w", err) + } + req.Header.Set("Authorization", authHeader) + } // Send request resp, err := c.httpClient.Do(req) @@ -372,6 +428,13 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + // Handle OAuth unauthorized error + if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { + return &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + body, _ := io.ReadAll(resp.Body) return fmt.Errorf( "notification failed with status %d: %s", @@ -392,3 +455,13 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica func (c *StreamableHTTP) GetSessionId() string { return c.sessionID.Load().(string) } + +// GetOAuthHandler returns the OAuth handler if configured +func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler { + return c.oauthHandler +} + +// IsOAuthEnabled returns true if OAuth is enabled +func (c *StreamableHTTP) IsOAuthEnabled() bool { + return c.oauthHandler != nil +} diff --git a/client/transport/streamable_http_oauth_test.go b/client/transport/streamable_http_oauth_test.go new file mode 100644 index 000000000..7a94bc66c --- /dev/null +++ b/client/transport/streamable_http_oauth_test.go @@ -0,0 +1,143 @@ +package transport + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestStreamableHTTP_WithOAuth(t *testing.T) { + // Create a test server that requires OAuth + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Return a successful response + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "result": "success", + }) + })) + defer server.Close() + + // Create a token store with a valid token + tokenStore := NewMemoryTokenStore() + validToken := &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour + } + tokenStore.SaveToken(validToken) + + // Create OAuth config + oauthConfig := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, + } + + // Create StreamableHTTP with OAuth + transport, err := NewStreamableHTTP(server.URL, WithOAuth(oauthConfig)) + if err != nil { + t.Fatalf("Failed to create StreamableHTTP: %v", err) + } + + // Verify that OAuth is enabled + if !transport.IsOAuthEnabled() { + t.Errorf("Expected IsOAuthEnabled() to return true") + } + + // Verify the OAuth handler is set + if transport.GetOAuthHandler() == nil { + t.Errorf("Expected GetOAuthHandler() to return a handler") + } +} + +func TestStreamableHTTP_WithOAuth_Unauthorized(t *testing.T) { + // Create a test server that requires OAuth + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Always return unauthorized + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + // Create an empty token store + tokenStore := NewMemoryTokenStore() + + // Create OAuth config + oauthConfig := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, + } + + // Create StreamableHTTP with OAuth + transport, err := NewStreamableHTTP(server.URL, WithOAuth(oauthConfig)) + if err != nil { + t.Fatalf("Failed to create StreamableHTTP: %v", err) + } + + // Send a request + _, err = transport.SendRequest(context.Background(), JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "test", + }) + + // Verify the error is an OAuthAuthorizationRequiredError + if err == nil { + t.Fatalf("Expected error, got nil") + } + + oauthErr, ok := err.(*OAuthAuthorizationRequiredError) + if !ok { + t.Fatalf("Expected OAuthAuthorizationRequiredError, got %T: %v", err, err) + } + + // Verify the error has the handler + if oauthErr.Handler == nil { + t.Errorf("Expected OAuthAuthorizationRequiredError to have a handler") + } +} + +func TestStreamableHTTP_IsOAuthEnabled(t *testing.T) { + // Create StreamableHTTP without OAuth + transport1, err := NewStreamableHTTP("http://example.com") + if err != nil { + t.Fatalf("Failed to create StreamableHTTP: %v", err) + } + + // Verify OAuth is not enabled + if transport1.IsOAuthEnabled() { + t.Errorf("Expected IsOAuthEnabled() to return false") + } + + // Create StreamableHTTP with OAuth + transport2, err := NewStreamableHTTP("http://example.com", WithOAuth(OAuthConfig{ + ClientID: "test-client", + })) + if err != nil { + t.Fatalf("Failed to create StreamableHTTP: %v", err) + } + + // Verify OAuth is enabled + if !transport2.IsOAuthEnabled() { + t.Errorf("Expected IsOAuthEnabled() to return true") + } +} \ No newline at end of file diff --git a/examples/oauth_client/README.md b/examples/oauth_client/README.md new file mode 100644 index 000000000..a60bb7c5f --- /dev/null +++ b/examples/oauth_client/README.md @@ -0,0 +1,59 @@ +# OAuth Client Example + +This example demonstrates how to use the OAuth capabilities of the MCP Go client to authenticate with an MCP server that requires OAuth authentication. + +## Features + +- OAuth 2.1 authentication with PKCE support +- Dynamic client registration +- Authorization code flow +- Token refresh +- Local callback server for handling OAuth redirects + +## Usage + +```bash +# Set environment variables (optional) +export MCP_CLIENT_ID=your_client_id +export MCP_CLIENT_SECRET=your_client_secret + +# Run the example +go run main.go +``` + +## How it Works + +1. The client attempts to initialize a connection to the MCP server +2. If the server requires OAuth authentication, it will return a 401 Unauthorized response +3. The client detects this and starts the OAuth flow: + - Generates PKCE code verifier and challenge + - Generates a state parameter for security + - Opens a browser to the authorization URL + - Starts a local server to handle the callback +4. The user authorizes the application in their browser +5. The authorization server redirects back to the local callback server +6. The client exchanges the authorization code for an access token +7. The client retries the initialization with the access token +8. The client can now make authenticated requests to the MCP server + +## Configuration + +Edit the following constants in `main.go` to match your environment: + +```go +const ( + // Replace with your MCP server URL + serverURL = "https://api.example.com/v1/mcp" + // Use a localhost redirect URI for this example + redirectURI = "http://localhost:8085/oauth/callback" +) +``` + +## OAuth Scopes + +The example requests the following scopes: + +- `mcp.read` - Read access to MCP resources +- `mcp.write` - Write access to MCP resources + +You can modify the scopes in the `oauthConfig` to match the requirements of your MCP server. \ No newline at end of file diff --git a/examples/oauth_client/main.go b/examples/oauth_client/main.go new file mode 100644 index 000000000..c92da09dc --- /dev/null +++ b/examples/oauth_client/main.go @@ -0,0 +1,220 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/exec" + "runtime" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" +) + +const ( + // Replace with your MCP server URL + serverURL = "https://api.example.com/v1/mcp" + // Use a localhost redirect URI for this example + redirectURI = "http://localhost:8085/oauth/callback" +) + +func main() { + // Create a token store to persist tokens + tokenStore := client.NewMemoryTokenStore() + + // Create OAuth configuration + oauthConfig := client.OAuthConfig{ + // Client ID can be empty if using dynamic registration + ClientID: os.Getenv("MCP_CLIENT_ID"), + ClientSecret: os.Getenv("MCP_CLIENT_SECRET"), + RedirectURI: redirectURI, + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, // Enable PKCE for public clients + } + + // Create the client with OAuth support + c, err := client.NewOAuthStreamableHttpClient(serverURL, oauthConfig) + if err != nil { + log.Fatalf("Failed to create client: %v", err) + } + + // Start the client + if err := c.Start(context.Background()); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + defer c.Close() + + // Try to initialize the client + result, err := c.Initialize(context.Background(), mcp.InitializeRequest{ + Params: struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + ClientInfo mcp.Implementation `json:"clientInfo"` + }{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "mcp-go-oauth-example", + Version: "0.1.0", + }, + }, + }) + + // Check if we need OAuth authorization + if client.IsOAuthAuthorizationRequiredError(err) { + fmt.Println("OAuth authorization required. Starting authorization flow...") + + // Get the OAuth handler from the error + oauthHandler := client.GetOAuthHandler(err) + + // Start a local server to handle the OAuth callback + callbackChan := make(chan map[string]string) + server := startCallbackServer(callbackChan) + defer server.Close() + + // Generate PKCE code verifier and challenge + codeVerifier, err := client.GenerateCodeVerifier() + if err != nil { + log.Fatalf("Failed to generate code verifier: %v", err) + } + codeChallenge := client.GenerateCodeChallenge(codeVerifier) + + // Generate state parameter + state, err := client.GenerateState() + if err != nil { + log.Fatalf("Failed to generate state: %v", err) + } + + // Get the authorization URL + authURL, err := oauthHandler.GetAuthorizationURL(context.Background(), state, codeChallenge) + if err != nil { + log.Fatalf("Failed to get authorization URL: %v", err) + } + + // Open the browser to the authorization URL + fmt.Printf("Opening browser to: %s\n", authURL) + openBrowser(authURL) + + // Wait for the callback + fmt.Println("Waiting for authorization callback...") + params := <-callbackChan + + // Verify state parameter + if params["state"] != state { + log.Fatalf("State mismatch: expected %s, got %s", state, params["state"]) + } + + // Exchange the authorization code for a token + code := params["code"] + if code == "" { + log.Fatalf("No authorization code received") + } + + fmt.Println("Exchanging authorization code for token...") + err = oauthHandler.ProcessAuthorizationResponse(context.Background(), code, state, codeVerifier) + if err != nil { + log.Fatalf("Failed to process authorization response: %v", err) + } + + fmt.Println("Authorization successful!") + + // Try to initialize again with the token + result, err = c.Initialize(context.Background(), mcp.InitializeRequest{ + Params: struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + ClientInfo mcp.Implementation `json:"clientInfo"` + }{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "mcp-go-oauth-example", + Version: "0.1.0", + }, + }, + }) + if err != nil { + log.Fatalf("Failed to initialize client after authorization: %v", err) + } + } else if err != nil { + log.Fatalf("Failed to initialize client: %v", err) + } + + fmt.Printf("Client initialized successfully! Server: %s %s\n", + result.ServerInfo.Name, + result.ServerInfo.Version) + + // Now you can use the client as usual + // For example, list resources + resources, err := c.ListResources(context.Background(), mcp.ListResourcesRequest{}) + if err != nil { + log.Fatalf("Failed to list resources: %v", err) + } + + fmt.Println("Available resources:") + for _, resource := range resources.Resources { + fmt.Printf("- %s\n", resource.URI) + } +} + +// startCallbackServer starts a local HTTP server to handle the OAuth callback +func startCallbackServer(callbackChan chan<- map[string]string) *http.Server { + server := &http.Server{ + Addr: ":8085", + } + + http.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { + // Extract query parameters + params := make(map[string]string) + for key, values := range r.URL.Query() { + if len(values) > 0 { + params[key] = values[0] + } + } + + // Send parameters to the channel + callbackChan <- params + + // Respond to the user + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(` + + +

Authorization Successful

+

You can now close this window and return to the application.

+ + + + `)) + }) + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Printf("HTTP server error: %v", err) + } + }() + + return server +} + +// openBrowser opens the default browser to the specified URL +func openBrowser(url string) { + var err error + + switch runtime.GOOS { + case "linux": + err = exec.Command("xdg-open", url).Start() + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + default: + err = fmt.Errorf("unsupported platform") + } + + if err != nil { + log.Printf("Failed to open browser: %v", err) + fmt.Printf("Please open the following URL in your browser: %s\n", url) + } +} \ No newline at end of file From 307415ca9c9c7177f0728e082ad47094c690b3b2 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 15 May 2025 22:13:20 +0300 Subject: [PATCH 02/13] Fix linting issues --- client/oauth_test.go | 10 +++++++--- client/transport/streamable_http_oauth_test.go | 10 +++++++--- examples/oauth_client/main.go | 5 ++++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/client/oauth_test.go b/client/oauth_test.go index 20c39ee02..e327d289a 100644 --- a/client/oauth_test.go +++ b/client/oauth_test.go @@ -25,7 +25,7 @@ func TestNewOAuthStreamableHttpClient(t *testing.T) { // Return a successful response w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ + if err := json.NewEncoder(w).Encode(map[string]any{ "jsonrpc": "2.0", "id": 1, "result": map[string]any{ @@ -36,7 +36,9 @@ func TestNewOAuthStreamableHttpClient(t *testing.T) { }, "capabilities": map[string]any{}, }, - }) + }); err != nil { + t.Errorf("Failed to encode JSON response: %v", err) + } })) defer server.Close() @@ -49,7 +51,9 @@ func TestNewOAuthStreamableHttpClient(t *testing.T) { ExpiresIn: 3600, ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour } - tokenStore.SaveToken(validToken) + if err := tokenStore.SaveToken(validToken); err != nil { + t.Fatalf("Failed to save token: %v", err) + } // Create OAuth config oauthConfig := OAuthConfig{ diff --git a/client/transport/streamable_http_oauth_test.go b/client/transport/streamable_http_oauth_test.go index 7a94bc66c..8b745c525 100644 --- a/client/transport/streamable_http_oauth_test.go +++ b/client/transport/streamable_http_oauth_test.go @@ -22,11 +22,13 @@ func TestStreamableHTTP_WithOAuth(t *testing.T) { // Return a successful response w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ + if err := json.NewEncoder(w).Encode(map[string]any{ "jsonrpc": "2.0", "id": 1, "result": "success", - }) + }); err != nil { + t.Errorf("Failed to encode JSON response: %v", err) + } })) defer server.Close() @@ -39,7 +41,9 @@ func TestStreamableHTTP_WithOAuth(t *testing.T) { ExpiresIn: 3600, ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour } - tokenStore.SaveToken(validToken) + if err := tokenStore.SaveToken(validToken); err != nil { + t.Fatalf("Failed to save token: %v", err) + } // Create OAuth config oauthConfig := OAuthConfig{ diff --git a/examples/oauth_client/main.go b/examples/oauth_client/main.go index c92da09dc..3fc653324 100644 --- a/examples/oauth_client/main.go +++ b/examples/oauth_client/main.go @@ -178,7 +178,7 @@ func startCallbackServer(callbackChan chan<- map[string]string) *http.Server { // Respond to the user w.Header().Set("Content-Type", "text/html") - w.Write([]byte(` + _, err := w.Write([]byte(`

Authorization Successful

@@ -187,6 +187,9 @@ func startCallbackServer(callbackChan chan<- map[string]string) *http.Server { `)) + if err != nil { + log.Printf("Error writing response: %v", err) + } }) go func() { From 95de2fb1b47359073ebbfcfbf9978c424565ba6a Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 15 May 2025 23:10:18 +0300 Subject: [PATCH 03/13] More fixes --- client/transport/oauth.go | 16 ++-- client/transport/oauth_test.go | 80 +++++++++++++++++++ client/transport/oauth_utils.go | 20 ++++- client/transport/streamable_http.go | 12 ++- .../transport/streamable_http_oauth_test.go | 74 ++++++++++++++++- 5 files changed, 182 insertions(+), 20 deletions(-) diff --git a/client/transport/oauth.go b/client/transport/oauth.go index 6f6e671eb..2964a6ddb 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -140,7 +140,7 @@ func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, erro // getValidToken returns a valid token, refreshing if necessary func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) { token, err := h.config.TokenStore.GetToken() - if err == nil && !token.IsExpired() { + if err == nil && !token.IsExpired() && token.AccessToken != "" { return token, nil } @@ -154,7 +154,7 @@ func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) { } // We need to get a new token through the authorization flow - return nil, errors.New("no valid token available, authorization required") + return nil, ErrOAuthAuthorizationRequired } // refreshToken refreshes an OAuth token @@ -236,15 +236,9 @@ func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetada if h.config.AuthServerMetadataURL != "" { metadataURL = h.config.AuthServerMetadataURL } else { - // Construct the well-known URL from the base URL - // According to the spec, we need to discard any path component - baseURL, err := url.Parse(h.config.AuthServerMetadataURL) - if err != nil { - h.metadataFetchErr = fmt.Errorf("invalid base URL: %w", err) - return - } - baseURL.Path = "" - metadataURL = baseURL.String() + "/.well-known/oauth-authorization-server" + // If AuthServerMetadataURL is not provided, we can't discover the metadata + h.metadataFetchErr = fmt.Errorf("AuthServerMetadataURL is required but was not provided") + return } req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil) diff --git a/client/transport/oauth_test.go b/client/transport/oauth_test.go index 2b59541c9..7c0fe74b1 100644 --- a/client/transport/oauth_test.go +++ b/client/transport/oauth_test.go @@ -1,6 +1,9 @@ package transport import ( + "context" + "errors" + "strings" "testing" "time" ) @@ -108,11 +111,21 @@ func TestValidateRedirectURI(t *testing.T) { redirectURI: "http://localhost:8085/callback", expectError: false, }, + { + name: "Valid localhost URI with 127.0.0.1", + redirectURI: "http://127.0.0.1:8085/callback", + expectError: false, + }, { name: "Invalid HTTP URI (non-localhost)", redirectURI: "http://example.com/callback", expectError: true, }, + { + name: "Invalid HTTP URI with 'local' in domain", + redirectURI: "http://localdomain.com/callback", + expectError: true, + }, { name: "Empty URI", redirectURI: "", @@ -123,6 +136,11 @@ func TestValidateRedirectURI(t *testing.T) { redirectURI: "ftp://example.com/callback", expectError: true, }, + { + name: "IPv6 localhost", + redirectURI: "http://[::1]:8080/callback", + expectError: false, // IPv6 localhost is valid + }, } // Run test cases @@ -136,4 +154,66 @@ func TestValidateRedirectURI(t *testing.T) { } }) } +} + +func TestOAuthHandler_GetAuthorizationHeader_EmptyAccessToken(t *testing.T) { + // Create a token store with a token that has an empty access token + tokenStore := NewMemoryTokenStore() + invalidToken := &Token{ + AccessToken: "", // Empty access token + TokenType: "Bearer", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour + } + if err := tokenStore.SaveToken(invalidToken); err != nil { + t.Fatalf("Failed to save token: %v", err) + } + + // Create an OAuth handler + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, + } + + handler := NewOAuthHandler(config) + + // Test getting authorization header with empty access token + _, err := handler.GetAuthorizationHeader(context.Background()) + if err == nil { + t.Fatalf("Expected error when getting authorization header with empty access token") + } + + // Verify the error message + if !errors.Is(err, ErrOAuthAuthorizationRequired) { + t.Errorf("Expected error to be ErrOAuthAuthorizationRequired, got %v", err) + } +} + +func TestOAuthHandler_GetServerMetadata_EmptyURL(t *testing.T) { + // Create an OAuth handler with an empty AuthServerMetadataURL + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read"}, + TokenStore: NewMemoryTokenStore(), + AuthServerMetadataURL: "", // Empty URL + PKCEEnabled: true, + } + + handler := NewOAuthHandler(config) + + // Test getting server metadata with empty URL + _, err := handler.GetServerMetadata(context.Background()) + if err == nil { + t.Fatalf("Expected error when getting server metadata with empty URL") + } + + // Verify the error message + if !strings.Contains(err.Error(), "AuthServerMetadataURL is required") { + t.Errorf("Expected error message to contain 'AuthServerMetadataURL is required', got %s", err.Error()) + } } \ No newline at end of file diff --git a/client/transport/oauth_utils.go b/client/transport/oauth_utils.go index 9f2361209..e2104307a 100644 --- a/client/transport/oauth_utils.go +++ b/client/transport/oauth_utils.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/base64" "fmt" + "net/url" ) // GenerateRandomString generates a random string of the specified length @@ -42,15 +43,26 @@ func ValidateRedirectURI(redirectURI string) error { return fmt.Errorf("redirect URI cannot be empty") } + // Parse the URL + parsedURL, err := url.Parse(redirectURI) + if err != nil { + return fmt.Errorf("invalid redirect URI: %w", err) + } + // Check if it's a localhost URL - if len(redirectURI) >= 9 && redirectURI[:9] == "http://lo" { - return nil + if parsedURL.Scheme == "http" { + hostname := parsedURL.Hostname() + // Check for various forms of localhost + if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == "[::1]" { + return nil + } + return fmt.Errorf("HTTP redirect URI must use localhost or 127.0.0.1") } // Check if it's an HTTPS URL - if len(redirectURI) >= 8 && redirectURI[:8] == "https://" { + if parsedURL.Scheme == "https" { return nil } - return fmt.Errorf("redirect URI must be either a localhost URL or an HTTPS URL") + return fmt.Errorf("redirect URI must use either HTTP with localhost or HTTPS") } \ No newline at end of file diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 91b07f00b..ae092ab92 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "mime" @@ -141,13 +142,20 @@ const ( headerKeySessionID = "Mcp-Session-Id" ) +// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required +var ErrOAuthAuthorizationRequired = errors.New("no valid token available, authorization required") + // OAuthAuthorizationRequiredError is returned when OAuth authorization is required type OAuthAuthorizationRequiredError struct { Handler *OAuthHandler } func (e *OAuthAuthorizationRequiredError) Error() string { - return "OAuth authorization required" + return ErrOAuthAuthorizationRequired.Error() +} + +func (e *OAuthAuthorizationRequiredError) Unwrap() error { + return ErrOAuthAuthorizationRequired } // SendRequest sends a JSON-RPC request to the server and waits for a response. @@ -410,7 +418,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) if err != nil { // If we get an authorization error, return a specific error that can be handled by the client - if err.Error() == "no valid token available, authorization required" { + if errors.Is(err, ErrOAuthAuthorizationRequired) { return &OAuthAuthorizationRequiredError{ Handler: c.oauthHandler, } diff --git a/client/transport/streamable_http_oauth_test.go b/client/transport/streamable_http_oauth_test.go index 8b745c525..0255800ad 100644 --- a/client/transport/streamable_http_oauth_test.go +++ b/client/transport/streamable_http_oauth_test.go @@ -10,18 +10,33 @@ import ( ) func TestStreamableHTTP_WithOAuth(t *testing.T) { + // Track request count to simulate 401 on first request, then success + requestCount := 0 + authHeaderReceived := "" + // Create a test server that requires OAuth server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture the Authorization header + authHeaderReceived = r.Header.Get("Authorization") + // Check for Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader != "Bearer test-token" { + if requestCount == 0 { + // First request - simulate 401 to test error handling + requestCount++ + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Subsequent requests - verify the Authorization header + if authHeaderReceived != "Bearer test-token" { + t.Errorf("Expected Authorization header 'Bearer test-token', got '%s'", authHeaderReceived) w.WriteHeader(http.StatusUnauthorized) return } // Return a successful response - w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(map[string]any{ "jsonrpc": "2.0", "id": 1, @@ -69,6 +84,59 @@ func TestStreamableHTTP_WithOAuth(t *testing.T) { if transport.GetOAuthHandler() == nil { t.Errorf("Expected GetOAuthHandler() to return a handler") } + + // First request should fail with OAuthAuthorizationRequiredError + _, err = transport.SendRequest(context.Background(), JSONRPCRequest{ + JSONRPC: "2.0", + ID: 1, + Method: "test", + }) + + // Verify the error is an OAuthAuthorizationRequiredError + if err == nil { + t.Fatalf("Expected error on first request, got nil") + } + + oauthErr, ok := err.(*OAuthAuthorizationRequiredError) + if !ok { + t.Fatalf("Expected OAuthAuthorizationRequiredError, got %T: %v", err, err) + } + + // Verify the error has the handler + if oauthErr.Handler == nil { + t.Errorf("Expected OAuthAuthorizationRequiredError to have a handler") + } + + // Verify the server received the first request + if requestCount != 1 { + t.Errorf("Expected server to receive 1 request, got %d", requestCount) + } + + // Second request should succeed + response, err := transport.SendRequest(context.Background(), JSONRPCRequest{ + JSONRPC: "2.0", + ID: 2, + Method: "test", + }) + + if err != nil { + t.Fatalf("Failed to send second request: %v", err) + } + + // Verify the response + var resultStr string + if err := json.Unmarshal(response.Result, &resultStr); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + if resultStr != "success" { + t.Errorf("Expected result to be 'success', got %v", resultStr) + } + + // Verify the server received the Authorization header + if authHeaderReceived != "Bearer test-token" { + t.Errorf("Expected server to receive Authorization header 'Bearer test-token', got '%s'", authHeaderReceived) + } } func TestStreamableHTTP_WithOAuth_Unauthorized(t *testing.T) { From 92e28b8eef0afcdaaeafcdc34938fd7d88cf7cc0 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 22 May 2025 17:00:35 +0300 Subject: [PATCH 04/13] Update client/oauth.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- client/oauth.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/oauth.go b/client/oauth.go index 9dae8afac..7b28db5cc 100644 --- a/client/oauth.go +++ b/client/oauth.go @@ -48,8 +48,8 @@ type OAuthAuthorizationRequiredError = transport.OAuthAuthorizationRequiredError // IsOAuthAuthorizationRequiredError checks if an error is an OAuthAuthorizationRequiredError func IsOAuthAuthorizationRequiredError(err error) bool { - _, ok := err.(*OAuthAuthorizationRequiredError) - return ok + var target *OAuthAuthorizationRequiredError + return errors.As(err, &target) } // GetOAuthHandler extracts the OAuthHandler from an OAuthAuthorizationRequiredError From 53088d0159c14c6e0fa3442201576a9ce9fbd93a Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 22 May 2025 17:01:52 +0300 Subject: [PATCH 05/13] Update client/transport/oauth.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- client/transport/oauth.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/client/transport/oauth.go b/client/transport/oauth.go index 2964a6ddb..c013dfb25 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -206,7 +206,14 @@ func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (* tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) } + // If no new refresh token is provided, keep the old one + if tokenResp.RefreshToken == "" && token != nil { + tokenResp.RefreshToken = token.RefreshToken + } + // Save the token + if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil { + … if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil { return nil, fmt.Errorf("failed to save token: %w", err) } From df07af5f09db491c54cca83faf6f9eaa157b12dc Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 22 May 2025 17:07:11 +0300 Subject: [PATCH 06/13] fix --- client/oauth.go | 1 + client/transport/oauth.go | 7 +++---- client/transport/streamable_http_oauth_test.go | 8 +++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/client/oauth.go b/client/oauth.go index 7b28db5cc..5c59ffbdd 100644 --- a/client/oauth.go +++ b/client/oauth.go @@ -1,6 +1,7 @@ package client import ( + "errors" "fmt" "github.com/mark3labs/mcp-go/client/transport" diff --git a/client/transport/oauth.go b/client/transport/oauth.go index c013dfb25..b5adc4320 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -207,13 +207,12 @@ func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (* } // If no new refresh token is provided, keep the old one - if tokenResp.RefreshToken == "" && token != nil { - tokenResp.RefreshToken = token.RefreshToken + oldToken, _ := h.config.TokenStore.GetToken() + if tokenResp.RefreshToken == "" && oldToken != nil { + tokenResp.RefreshToken = oldToken.RefreshToken } // Save the token - if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil { - … if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil { return nil, fmt.Errorf("failed to save token: %w", err) } diff --git a/client/transport/streamable_http_oauth_test.go b/client/transport/streamable_http_oauth_test.go index 0255800ad..41f9a67f2 100644 --- a/client/transport/streamable_http_oauth_test.go +++ b/client/transport/streamable_http_oauth_test.go @@ -7,6 +7,8 @@ import ( "net/http/httptest" "testing" "time" + + "github.com/mark3labs/mcp-go/mcp" ) func TestStreamableHTTP_WithOAuth(t *testing.T) { @@ -88,7 +90,7 @@ func TestStreamableHTTP_WithOAuth(t *testing.T) { // First request should fail with OAuthAuthorizationRequiredError _, err = transport.SendRequest(context.Background(), JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(1), Method: "test", }) @@ -115,7 +117,7 @@ func TestStreamableHTTP_WithOAuth(t *testing.T) { // Second request should succeed response, err := transport.SendRequest(context.Background(), JSONRPCRequest{ JSONRPC: "2.0", - ID: 2, + ID: mcp.NewRequestId(2), Method: "test", }) @@ -168,7 +170,7 @@ func TestStreamableHTTP_WithOAuth_Unauthorized(t *testing.T) { // Send a request _, err = transport.SendRequest(context.Background(), JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(1), Method: "test", }) From 7bb537b79fb0a2199a26dcffdbccc2b2264a2532 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 22 May 2025 17:16:14 +0300 Subject: [PATCH 07/13] oauth discovery --- client/transport/oauth.go | 126 +++++++++++++++++++++++++++++---- client/transport/oauth_test.go | 8 ++- 2 files changed, 117 insertions(+), 17 deletions(-) diff --git a/client/transport/oauth.go b/client/transport/oauth.go index b5adc4320..079047055 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -114,6 +114,7 @@ type OAuthHandler struct { serverMetadata *AuthServerMetadata metadataFetchErr error metadataOnce sync.Once + baseURL string } // NewOAuthHandler creates a new OAuth handler @@ -235,21 +236,40 @@ func (h *OAuthHandler) GetClientSecret() string { return h.config.ClientSecret } +// SetBaseURL sets the base URL for the API server +func (h *OAuthHandler) SetBaseURL(baseURL string) { + h.baseURL = baseURL +} + +// OAuthProtectedResource represents the response from /.well-known/oauth-protected-resource +type OAuthProtectedResource struct { + AuthorizationServers []string `json:"authorization_servers"` + Resource string `json:"resource"` + ResourceName string `json:"resource_name,omitempty"` +} + // getServerMetadata fetches the OAuth server metadata func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetadata, error) { h.metadataOnce.Do(func() { - var metadataURL string + // If AuthServerMetadataURL is explicitly provided, use it directly if h.config.AuthServerMetadataURL != "" { - metadataURL = h.config.AuthServerMetadataURL - } else { - // If AuthServerMetadataURL is not provided, we can't discover the metadata - h.metadataFetchErr = fmt.Errorf("AuthServerMetadataURL is required but was not provided") + h.fetchMetadataFromURL(ctx, h.config.AuthServerMetadataURL) return } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil) + // Try to discover the authorization server via OAuth Protected Resource + // as per RFC 9728 (https://datatracker.ietf.org/doc/html/rfc9728) + baseURL, err := h.extractBaseURL() if err != nil { - h.metadataFetchErr = fmt.Errorf("failed to create metadata request: %w", err) + h.metadataFetchErr = fmt.Errorf("failed to extract base URL: %w", err) + return + } + + // Try to fetch the OAuth Protected Resource metadata + protectedResourceURL := baseURL + "/.well-known/oauth-protected-resource" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, protectedResourceURL, nil) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to create protected resource request: %w", err) return } @@ -258,24 +278,47 @@ func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetada resp, err := h.httpClient.Do(req) if err != nil { - h.metadataFetchErr = fmt.Errorf("failed to send metadata request: %w", err) + h.metadataFetchErr = fmt.Errorf("failed to send protected resource request: %w", err) return } defer resp.Body.Close() + // If we can't get the protected resource metadata, fall back to default endpoints if resp.StatusCode != http.StatusOK { - // If metadata discovery fails, use default endpoints - h.serverMetadata = h.getDefaultEndpoints(metadataURL) + h.serverMetadata = h.getDefaultEndpoints(baseURL) + return + } + + // Parse the protected resource metadata + var protectedResource OAuthProtectedResource + if err := json.NewDecoder(resp.Body).Decode(&protectedResource); err != nil { + h.metadataFetchErr = fmt.Errorf("failed to decode protected resource response: %w", err) return } - var metadata AuthServerMetadata - if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { - h.metadataFetchErr = fmt.Errorf("failed to decode metadata response: %w", err) + // If no authorization servers are specified, fall back to default endpoints + if len(protectedResource.AuthorizationServers) == 0 { + h.serverMetadata = h.getDefaultEndpoints(baseURL) return } - h.serverMetadata = &metadata + // Use the first authorization server + authServerURL := protectedResource.AuthorizationServers[0] + + // Try OpenID Connect discovery first + h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/openid-configuration") + if h.serverMetadata != nil { + return + } + + // If OpenID Connect discovery fails, try OAuth Authorization Server Metadata + h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/oauth-authorization-server") + if h.serverMetadata != nil { + return + } + + // If both discovery methods fail, use default endpoints based on the authorization server URL + h.serverMetadata = h.getDefaultEndpoints(authServerURL) }) if h.metadataFetchErr != nil { @@ -285,6 +328,61 @@ func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetada return h.serverMetadata, nil } +// fetchMetadataFromURL fetches and parses OAuth server metadata from a URL +func (h *OAuthHandler) fetchMetadataFromURL(ctx context.Context, metadataURL string) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to create metadata request: %w", err) + return + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("MCP-Protocol-Version", "2025-03-26") + + resp, err := h.httpClient.Do(req) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to send metadata request: %w", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + // If metadata discovery fails, don't set any metadata + return + } + + var metadata AuthServerMetadata + if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { + h.metadataFetchErr = fmt.Errorf("failed to decode metadata response: %w", err) + return + } + + h.serverMetadata = &metadata +} + +// extractBaseURL extracts the base URL from the first request +func (h *OAuthHandler) extractBaseURL() (string, error) { + // If we have a base URL from a previous request, use it + if h.baseURL != "" { + return h.baseURL, nil + } + + // Otherwise, we need to infer it from the redirect URI + if h.config.RedirectURI == "" { + return "", fmt.Errorf("no base URL available and no redirect URI provided") + } + + // Parse the redirect URI to extract the authority + parsedURL, err := url.Parse(h.config.RedirectURI) + if err != nil { + return "", fmt.Errorf("failed to parse redirect URI: %w", err) + } + + // Use the scheme and host from the redirect URI + baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host) + return baseURL, nil +} + // GetServerMetadata is a public wrapper for getServerMetadata func (h *OAuthHandler) GetServerMetadata(ctx context.Context) (*AuthServerMetadata, error) { return h.getServerMetadata(ctx) diff --git a/client/transport/oauth_test.go b/client/transport/oauth_test.go index 7c0fe74b1..c30c28fde 100644 --- a/client/transport/oauth_test.go +++ b/client/transport/oauth_test.go @@ -212,8 +212,10 @@ func TestOAuthHandler_GetServerMetadata_EmptyURL(t *testing.T) { t.Fatalf("Expected error when getting server metadata with empty URL") } - // Verify the error message - if !strings.Contains(err.Error(), "AuthServerMetadataURL is required") { - t.Errorf("Expected error message to contain 'AuthServerMetadataURL is required', got %s", err.Error()) + // Verify the error message contains something about a connection error + // since we're now trying to connect to the well-known endpoint + if !strings.Contains(err.Error(), "connection refused") && + !strings.Contains(err.Error(), "failed to send protected resource request") { + t.Errorf("Expected error message to contain connection error, got %s", err.Error()) } } \ No newline at end of file From 8a999d6bf46c2a371a198d60632765dfae23e45f Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 22 May 2025 17:33:59 +0300 Subject: [PATCH 08/13] fix --- client/transport/oauth.go | 26 +++++++++++++++++++++ client/transport/oauth_test.go | 42 ++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/client/transport/oauth.go b/client/transport/oauth.go index 079047055..7f41ae2a1 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -115,6 +115,7 @@ type OAuthHandler struct { metadataFetchErr error metadataOnce sync.Once baseURL string + expectedState string // Expected state value for CSRF protection } // NewOAuthHandler creates a new OAuth handler @@ -241,6 +242,11 @@ func (h *OAuthHandler) SetBaseURL(baseURL string) { h.baseURL = baseURL } +// GetExpectedState returns the expected state value (for testing purposes) +func (h *OAuthHandler) GetExpectedState() string { + return h.expectedState +} + // OAuthProtectedResource represents the response from /.well-known/oauth-protected-resource type OAuthProtectedResource struct { AuthorizationServers []string `json:"authorization_servers"` @@ -481,8 +487,25 @@ func (h *OAuthHandler) RegisterClient(ctx context.Context, clientName string) er return nil } +// ErrInvalidState is returned when the state parameter doesn't match the expected value +var ErrInvalidState = errors.New("invalid state parameter, possible CSRF attack") + // ProcessAuthorizationResponse processes the authorization response and exchanges the code for a token func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, state, codeVerifier string) error { + // Validate the state parameter to prevent CSRF attacks + if h.expectedState == "" { + return errors.New("no expected state found, authorization flow may not have been initiated properly") + } + + if state != h.expectedState { + return ErrInvalidState + } + + // Clear the expected state after validation + defer func() { + h.expectedState = "" + }() + metadata, err := h.getServerMetadata(ctx) if err != nil { return fmt.Errorf("failed to get server metadata: %w", err) @@ -551,6 +574,9 @@ func (h *OAuthHandler) GetAuthorizationURL(ctx context.Context, state, codeChall return "", fmt.Errorf("failed to get server metadata: %w", err) } + // Store the state for later validation + h.expectedState = state + params := url.Values{} params.Set("response_type", "code") params.Set("client_id", h.config.ClientID) diff --git a/client/transport/oauth_test.go b/client/transport/oauth_test.go index c30c28fde..cf4aa3eea 100644 --- a/client/transport/oauth_test.go +++ b/client/transport/oauth_test.go @@ -218,4 +218,46 @@ func TestOAuthHandler_GetServerMetadata_EmptyURL(t *testing.T) { !strings.Contains(err.Error(), "failed to send protected resource request") { t.Errorf("Expected error message to contain connection error, got %s", err.Error()) } +} + +func TestOAuthHandler_ProcessAuthorizationResponse_StateValidation(t *testing.T) { + // Create an OAuth handler + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: NewMemoryTokenStore(), + AuthServerMetadataURL: "http://example.com/.well-known/oauth-authorization-server", + PKCEEnabled: true, + } + + handler := NewOAuthHandler(config) + + // Mock the server metadata to avoid nil pointer dereference + handler.serverMetadata = &AuthServerMetadata{ + Issuer: "http://example.com", + AuthorizationEndpoint: "http://example.com/authorize", + TokenEndpoint: "http://example.com/token", + } + + // Set the expected state + expectedState := "test-state-123" + handler.expectedState = expectedState + + // Test with non-matching state - this should fail immediately with ErrInvalidState + // before trying to connect to any server + err := handler.ProcessAuthorizationResponse(context.Background(), "test-code", "wrong-state", "test-code-verifier") + if !errors.Is(err, ErrInvalidState) { + t.Errorf("Expected ErrInvalidState, got %v", err) + } + + // Test with empty expected state + handler.expectedState = "" + err = handler.ProcessAuthorizationResponse(context.Background(), "test-code", expectedState, "test-code-verifier") + if err == nil { + t.Errorf("Expected error with empty expected state, got nil") + } + if errors.Is(err, ErrInvalidState) { + t.Errorf("Got ErrInvalidState when expected a different error for empty expected state") + } } \ No newline at end of file From e1818a453c4b7dac9da3d34fbcba281d417773c2 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 22 May 2025 17:47:02 +0300 Subject: [PATCH 09/13] handle invalid urls --- client/transport/oauth.go | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/client/transport/oauth.go b/client/transport/oauth.go index 7f41ae2a1..925512dc6 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -291,7 +291,12 @@ func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetada // If we can't get the protected resource metadata, fall back to default endpoints if resp.StatusCode != http.StatusOK { - h.serverMetadata = h.getDefaultEndpoints(baseURL) + metadata, err := h.getDefaultEndpoints(baseURL) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err) + return + } + h.serverMetadata = metadata return } @@ -304,7 +309,12 @@ func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetada // If no authorization servers are specified, fall back to default endpoints if len(protectedResource.AuthorizationServers) == 0 { - h.serverMetadata = h.getDefaultEndpoints(baseURL) + metadata, err := h.getDefaultEndpoints(baseURL) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err) + return + } + h.serverMetadata = metadata return } @@ -324,7 +334,12 @@ func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetada } // If both discovery methods fail, use default endpoints based on the authorization server URL - h.serverMetadata = h.getDefaultEndpoints(authServerURL) + metadata, err := h.getDefaultEndpoints(authServerURL) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err) + return + } + h.serverMetadata = metadata }) if h.metadataFetchErr != nil { @@ -395,23 +410,28 @@ func (h *OAuthHandler) GetServerMetadata(ctx context.Context) (*AuthServerMetada } // getDefaultEndpoints returns default OAuth endpoints based on the base URL -func (h *OAuthHandler) getDefaultEndpoints(baseURL string) *AuthServerMetadata { +func (h *OAuthHandler) getDefaultEndpoints(baseURL string) (*AuthServerMetadata, error) { // Parse the base URL to extract the authority parsedURL, err := url.Parse(baseURL) if err != nil { - return nil + return nil, fmt.Errorf("failed to parse base URL: %w", err) } // Discard any path component to get the authorization base URL parsedURL.Path = "" authBaseURL := parsedURL.String() + + // Validate that the URL has a scheme and host + if parsedURL.Scheme == "" || parsedURL.Host == "" { + return nil, fmt.Errorf("invalid base URL: missing scheme or host in %q", baseURL) + } return &AuthServerMetadata{ Issuer: authBaseURL, AuthorizationEndpoint: authBaseURL + "/authorize", TokenEndpoint: authBaseURL + "/token", RegistrationEndpoint: authBaseURL + "/register", - } + }, nil } // RegisterClient performs dynamic client registration From 7d132c9a04a8084be71f9c973c766e24f81f75b0 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 22 May 2025 18:06:13 +0300 Subject: [PATCH 10/13] more error handling --- client/transport/oauth.go | 33 +++++++++++++++++++++++++--- client/transport/oauth_test.go | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/client/transport/oauth.go b/client/transport/oauth.go index 925512dc6..637258eae 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -195,7 +195,7 @@ func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (* if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("refresh token request failed with status %d: %s", resp.StatusCode, body) + return nil, extractOAuthError(body, resp.StatusCode, "refresh token request failed") } var tokenResp Token @@ -232,6 +232,18 @@ func (h *OAuthHandler) GetClientID() string { return h.config.ClientID } +// extractOAuthError attempts to parse an OAuth error response from the response body +func extractOAuthError(body []byte, statusCode int, context string) error { + // Try to parse the error as an OAuth error response + var oauthErr OAuthError + if err := json.Unmarshal(body, &oauthErr); err == nil && oauthErr.ErrorCode != "" { + return fmt.Errorf("%s: %w", context, oauthErr) + } + + // If not a valid OAuth error, return the raw response + return fmt.Errorf("%s with status %d: %s", context, statusCode, body) +} + // GetClientSecret returns the client secret func (h *OAuthHandler) GetClientSecret() string { return h.config.ClientSecret @@ -247,6 +259,21 @@ func (h *OAuthHandler) GetExpectedState() string { return h.expectedState } +// OAuthError represents a standard OAuth 2.0 error response +type OAuthError struct { + ErrorCode string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` +} + +// Error implements the error interface +func (e OAuthError) Error() string { + if e.ErrorDescription != "" { + return fmt.Sprintf("OAuth error: %s - %s", e.ErrorCode, e.ErrorDescription) + } + return fmt.Sprintf("OAuth error: %s", e.ErrorCode) +} + // OAuthProtectedResource represents the response from /.well-known/oauth-protected-resource type OAuthProtectedResource struct { AuthorizationServers []string `json:"authorization_servers"` @@ -486,7 +513,7 @@ func (h *OAuthHandler) RegisterClient(ctx context.Context, clientName string) er if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("registration request failed with status %d: %s", resp.StatusCode, body) + return extractOAuthError(body, resp.StatusCode, "registration request failed") } var regResponse struct { @@ -566,7 +593,7 @@ func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, s if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("token request failed with status %d: %s", resp.StatusCode, body) + return extractOAuthError(body, resp.StatusCode, "token request failed") } var tokenResp Token diff --git a/client/transport/oauth_test.go b/client/transport/oauth_test.go index cf4aa3eea..c2660675b 100644 --- a/client/transport/oauth_test.go +++ b/client/transport/oauth_test.go @@ -220,6 +220,45 @@ func TestOAuthHandler_GetServerMetadata_EmptyURL(t *testing.T) { } } +func TestOAuthError(t *testing.T) { + testCases := []struct { + name string + errorCode string + description string + uri string + expected string + }{ + { + name: "Error with description", + errorCode: "invalid_request", + description: "The request is missing a required parameter", + uri: "https://example.com/errors/invalid_request", + expected: "OAuth error: invalid_request - The request is missing a required parameter", + }, + { + name: "Error without description", + errorCode: "unauthorized_client", + description: "", + uri: "", + expected: "OAuth error: unauthorized_client", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + oauthErr := OAuthError{ + ErrorCode: tc.errorCode, + ErrorDescription: tc.description, + ErrorURI: tc.uri, + } + + if oauthErr.Error() != tc.expected { + t.Errorf("Expected error message %q, got %q", tc.expected, oauthErr.Error()) + } + }) + } +} + func TestOAuthHandler_ProcessAuthorizationResponse_StateValidation(t *testing.T) { // Create an OAuth handler config := OAuthConfig{ From c4bbefcb5650e712ed4ee1dc534304000c4bf7fa Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sat, 24 May 2025 14:32:28 +0300 Subject: [PATCH 11/13] use errors.As --- client/oauth.go | 3 ++- client/transport/streamable_http_oauth_test.go | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/client/oauth.go b/client/oauth.go index 5c59ffbdd..f37a22bda 100644 --- a/client/oauth.go +++ b/client/oauth.go @@ -55,7 +55,8 @@ func IsOAuthAuthorizationRequiredError(err error) bool { // GetOAuthHandler extracts the OAuthHandler from an OAuthAuthorizationRequiredError func GetOAuthHandler(err error) *transport.OAuthHandler { - if oauthErr, ok := err.(*OAuthAuthorizationRequiredError); ok { + var oauthErr *OAuthAuthorizationRequiredError + if errors.As(err, &oauthErr) { return oauthErr.Handler } return nil diff --git a/client/transport/streamable_http_oauth_test.go b/client/transport/streamable_http_oauth_test.go index 41f9a67f2..3b38c2cbb 100644 --- a/client/transport/streamable_http_oauth_test.go +++ b/client/transport/streamable_http_oauth_test.go @@ -3,6 +3,7 @@ package transport import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" @@ -99,8 +100,8 @@ func TestStreamableHTTP_WithOAuth(t *testing.T) { t.Fatalf("Expected error on first request, got nil") } - oauthErr, ok := err.(*OAuthAuthorizationRequiredError) - if !ok { + var oauthErr *OAuthAuthorizationRequiredError + if !errors.As(err, &oauthErr) { t.Fatalf("Expected OAuthAuthorizationRequiredError, got %T: %v", err, err) } @@ -179,8 +180,8 @@ func TestStreamableHTTP_WithOAuth_Unauthorized(t *testing.T) { t.Fatalf("Expected error, got nil") } - oauthErr, ok := err.(*OAuthAuthorizationRequiredError) - if !ok { + var oauthErr *OAuthAuthorizationRequiredError + if !errors.As(err, &oauthErr) { t.Fatalf("Expected OAuthAuthorizationRequiredError, got %T: %v", err, err) } From b279531bed6bebec0577862f27748e0758e8758b Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sat, 24 May 2025 14:42:00 +0300 Subject: [PATCH 12/13] get baseURL from server --- client/transport/streamable_http.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index b7462cbaf..5be12c549 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -99,6 +99,11 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea opt(smc) } + // If OAuth is configured, set the base URL for metadata discovery + if smc.oauthHandler != nil { + smc.oauthHandler.SetBaseURL(baseURL) + } + return smc, nil } From 775e9e57a8a77cbb4341d2ffa7380a41bbf2a960 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Sun, 25 May 2025 13:42:09 +0300 Subject: [PATCH 13/13] Fix misleading naming --- client/transport/streamable_http.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 5be12c549..8727bb190 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -63,7 +63,7 @@ func WithOAuth(config OAuthConfig) StreamableHTTPCOption { // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) // - server -> client request type StreamableHTTP struct { - baseURL *url.URL + serverURL *url.URL httpClient *http.Client headers map[string]string headerFunc HTTPHeaderFunc @@ -79,16 +79,16 @@ type StreamableHTTP struct { oauthHandler *OAuthHandler } -// NewStreamableHTTP creates a new Streamable HTTP transport with the given base URL. +// NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL. // Returns an error if the URL is invalid. -func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) { - parsedURL, err := url.Parse(baseURL) +func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) { + parsedURL, err := url.Parse(serverURL) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } smc := &StreamableHTTP{ - baseURL: parsedURL, + serverURL: parsedURL, httpClient: &http.Client{}, headers: make(map[string]string), closed: make(chan struct{}), @@ -101,6 +101,8 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea // If OAuth is configured, set the base URL for metadata discovery if smc.oauthHandler != nil { + // Extract base URL from server URL for metadata discovery + baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host) smc.oauthHandler.SetBaseURL(baseURL) } @@ -131,7 +133,7 @@ func (c *StreamableHTTP) Close() error { go func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.baseURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil) if err != nil { fmt.Printf("failed to create close request\n: %v", err) return @@ -196,7 +198,7 @@ func (c *StreamableHTTP) SendRequest( } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -415,7 +417,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) if err != nil { return fmt.Errorf("failed to create request: %w", err) }