diff --git a/client/oauth.go b/client/oauth.go new file mode 100644 index 000000000..f37a22bda --- /dev/null +++ b/client/oauth.go @@ -0,0 +1,63 @@ +package client + +import ( + "errors" + "fmt" + + "github.com/mark3labs/mcp-go/client/transport" +) + +// OAuthConfig is a convenience type that wraps transport.OAuthConfig +type OAuthConfig = transport.OAuthConfig + +// Token is a convenience type that wraps transport.Token +type Token = transport.Token + +// TokenStore is a convenience type that wraps transport.TokenStore +type TokenStore = transport.TokenStore + +// MemoryTokenStore is a convenience type that wraps transport.MemoryTokenStore +type MemoryTokenStore = transport.MemoryTokenStore + +// NewMemoryTokenStore is a convenience function that wraps transport.NewMemoryTokenStore +var NewMemoryTokenStore = transport.NewMemoryTokenStore + +// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support. +// Returns an error if the URL is invalid. +func NewOAuthStreamableHttpClient(baseURL string, oauthConfig OAuthConfig, options ...transport.StreamableHTTPCOption) (*Client, error) { + // Add OAuth option to the list of options + options = append(options, transport.WithOAuth(oauthConfig)) + + trans, err := transport.NewStreamableHTTP(baseURL, options...) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP transport: %w", err) + } + return NewClient(trans), nil +} + +// GenerateCodeVerifier generates a code verifier for PKCE +var GenerateCodeVerifier = transport.GenerateCodeVerifier + +// GenerateCodeChallenge generates a code challenge from a code verifier +var GenerateCodeChallenge = transport.GenerateCodeChallenge + +// GenerateState generates a state parameter for OAuth +var GenerateState = transport.GenerateState + +// OAuthAuthorizationRequiredError is returned when OAuth authorization is required +type OAuthAuthorizationRequiredError = transport.OAuthAuthorizationRequiredError + +// IsOAuthAuthorizationRequiredError checks if an error is an OAuthAuthorizationRequiredError +func IsOAuthAuthorizationRequiredError(err error) bool { + var target *OAuthAuthorizationRequiredError + return errors.As(err, &target) +} + +// GetOAuthHandler extracts the OAuthHandler from an OAuthAuthorizationRequiredError +func GetOAuthHandler(err error) *transport.OAuthHandler { + var oauthErr *OAuthAuthorizationRequiredError + if errors.As(err, &oauthErr) { + return oauthErr.Handler + } + return nil +} \ No newline at end of file diff --git a/client/oauth_test.go b/client/oauth_test.go new file mode 100644 index 000000000..e327d289a --- /dev/null +++ b/client/oauth_test.go @@ -0,0 +1,127 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/mark3labs/mcp-go/client/transport" +) + +func TestNewOAuthStreamableHttpClient(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Return a successful response + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "result": map[string]any{ + "protocolVersion": "2024-11-05", + "serverInfo": map[string]any{ + "name": "test-server", + "version": "1.0.0", + }, + "capabilities": map[string]any{}, + }, + }); err != nil { + t.Errorf("Failed to encode JSON response: %v", err) + } + })) + defer server.Close() + + // Create a token store with a valid token + tokenStore := NewMemoryTokenStore() + validToken := &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour + } + if err := tokenStore.SaveToken(validToken); err != nil { + t.Fatalf("Failed to save token: %v", err) + } + + // Create OAuth config + oauthConfig := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, + } + + // Create client with OAuth + client, err := NewOAuthStreamableHttpClient(server.URL, oauthConfig) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Start the client + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + defer client.Close() + + // Verify that the client was created successfully + trans := client.GetTransport() + streamableHTTP, ok := trans.(*transport.StreamableHTTP) + if !ok { + t.Fatalf("Expected transport to be *transport.StreamableHTTP, got %T", trans) + } + + // Verify OAuth is enabled + if !streamableHTTP.IsOAuthEnabled() { + t.Errorf("Expected IsOAuthEnabled() to return true") + } + + // Verify the OAuth handler is set + if streamableHTTP.GetOAuthHandler() == nil { + t.Errorf("Expected GetOAuthHandler() to return a handler") + } +} + +func TestIsOAuthAuthorizationRequiredError(t *testing.T) { + // Create a test error + err := &transport.OAuthAuthorizationRequiredError{ + Handler: transport.NewOAuthHandler(transport.OAuthConfig{}), + } + + // Verify IsOAuthAuthorizationRequiredError returns true + if !IsOAuthAuthorizationRequiredError(err) { + t.Errorf("Expected IsOAuthAuthorizationRequiredError to return true") + } + + // Verify GetOAuthHandler returns the handler + handler := GetOAuthHandler(err) + if handler == nil { + t.Errorf("Expected GetOAuthHandler to return a handler") + } + + // Test with a different error + err2 := fmt.Errorf("some other error") + + // Verify IsOAuthAuthorizationRequiredError returns false + if IsOAuthAuthorizationRequiredError(err2) { + t.Errorf("Expected IsOAuthAuthorizationRequiredError to return false") + } + + // Verify GetOAuthHandler returns nil + handler = GetOAuthHandler(err2) + if handler != nil { + t.Errorf("Expected GetOAuthHandler to return nil") + } +} \ No newline at end of file diff --git a/client/transport/oauth.go b/client/transport/oauth.go new file mode 100644 index 000000000..637258eae --- /dev/null +++ b/client/transport/oauth.go @@ -0,0 +1,643 @@ +package transport + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// OAuthConfig holds the OAuth configuration for the client +type OAuthConfig struct { + // ClientID is the OAuth client ID + ClientID string + // ClientSecret is the OAuth client secret (for confidential clients) + ClientSecret string + // RedirectURI is the redirect URI for the OAuth flow + RedirectURI string + // Scopes is the list of OAuth scopes to request + Scopes []string + // TokenStore is the storage for OAuth tokens + TokenStore TokenStore + // AuthServerMetadataURL is the URL to the OAuth server metadata + // If empty, the client will attempt to discover it from the base URL + AuthServerMetadataURL string + // PKCEEnabled enables PKCE for the OAuth flow (recommended for public clients) + PKCEEnabled bool +} + +// TokenStore is an interface for storing and retrieving OAuth tokens +type TokenStore interface { + // GetToken returns the current token + GetToken() (*Token, error) + // SaveToken saves a token + SaveToken(token *Token) error +} + +// Token represents an OAuth token +type Token struct { + // AccessToken is the OAuth access token + AccessToken string `json:"access_token"` + // TokenType is the type of token (usually "Bearer") + TokenType string `json:"token_type"` + // RefreshToken is the OAuth refresh token + RefreshToken string `json:"refresh_token,omitempty"` + // ExpiresIn is the number of seconds until the token expires + ExpiresIn int64 `json:"expires_in,omitempty"` + // Scope is the scope of the token + Scope string `json:"scope,omitempty"` + // ExpiresAt is the time when the token expires + ExpiresAt time.Time `json:"expires_at,omitempty"` +} + +// IsExpired returns true if the token is expired +func (t *Token) IsExpired() bool { + if t.ExpiresAt.IsZero() { + return false + } + return time.Now().After(t.ExpiresAt) +} + +// MemoryTokenStore is a simple in-memory token store +type MemoryTokenStore struct { + token *Token + mu sync.RWMutex +} + +// NewMemoryTokenStore creates a new in-memory token store +func NewMemoryTokenStore() *MemoryTokenStore { + return &MemoryTokenStore{} +} + +// GetToken returns the current token +func (s *MemoryTokenStore) GetToken() (*Token, error) { + s.mu.RLock() + defer s.mu.RUnlock() + if s.token == nil { + return nil, errors.New("no token available") + } + return s.token, nil +} + +// SaveToken saves a token +func (s *MemoryTokenStore) SaveToken(token *Token) error { + s.mu.Lock() + defer s.mu.Unlock() + s.token = token + return nil +} + +// AuthServerMetadata represents the OAuth 2.0 Authorization Server Metadata +type AuthServerMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + JwksURI string `json:"jwks_uri,omitempty"` + ScopesSupported []string `json:"scopes_supported,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported"` + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` +} + +// OAuthHandler handles OAuth authentication for HTTP requests +type OAuthHandler struct { + config OAuthConfig + httpClient *http.Client + serverMetadata *AuthServerMetadata + metadataFetchErr error + metadataOnce sync.Once + baseURL string + expectedState string // Expected state value for CSRF protection +} + +// NewOAuthHandler creates a new OAuth handler +func NewOAuthHandler(config OAuthConfig) *OAuthHandler { + if config.TokenStore == nil { + config.TokenStore = NewMemoryTokenStore() + } + + return &OAuthHandler{ + config: config, + httpClient: &http.Client{Timeout: 30 * time.Second}, + } +} + +// GetAuthorizationHeader returns the Authorization header value for a request +func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, error) { + token, err := h.getValidToken(ctx) + if err != nil { + return "", err + } + return fmt.Sprintf("%s %s", token.TokenType, token.AccessToken), nil +} + +// getValidToken returns a valid token, refreshing if necessary +func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) { + token, err := h.config.TokenStore.GetToken() + if err == nil && !token.IsExpired() && token.AccessToken != "" { + return token, nil + } + + // If we have a refresh token, try to use it + if err == nil && token.RefreshToken != "" { + newToken, err := h.refreshToken(ctx, token.RefreshToken) + if err == nil { + return newToken, nil + } + // If refresh fails, continue to authorization flow + } + + // We need to get a new token through the authorization flow + return nil, ErrOAuthAuthorizationRequired +} + +// refreshToken refreshes an OAuth token +func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*Token, error) { + metadata, err := h.getServerMetadata(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get server metadata: %w", err) + } + + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", refreshToken) + data.Set("client_id", h.config.ClientID) + if h.config.ClientSecret != "" { + data.Set("client_secret", h.config.ClientSecret) + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + metadata.TokenEndpoint, + strings.NewReader(data.Encode()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create refresh token request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := h.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send refresh token request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, extractOAuthError(body, resp.StatusCode, "refresh token request failed") + } + + var tokenResp Token + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return nil, fmt.Errorf("failed to decode token response: %w", err) + } + + // Set expiration time + if tokenResp.ExpiresIn > 0 { + tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + // If no new refresh token is provided, keep the old one + oldToken, _ := h.config.TokenStore.GetToken() + if tokenResp.RefreshToken == "" && oldToken != nil { + tokenResp.RefreshToken = oldToken.RefreshToken + } + + // Save the token + if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil { + return nil, fmt.Errorf("failed to save token: %w", err) + } + + return &tokenResp, nil +} + +// RefreshToken is a public wrapper for refreshToken +func (h *OAuthHandler) RefreshToken(ctx context.Context, refreshToken string) (*Token, error) { + return h.refreshToken(ctx, refreshToken) +} + +// GetClientID returns the client ID +func (h *OAuthHandler) GetClientID() string { + return h.config.ClientID +} + +// extractOAuthError attempts to parse an OAuth error response from the response body +func extractOAuthError(body []byte, statusCode int, context string) error { + // Try to parse the error as an OAuth error response + var oauthErr OAuthError + if err := json.Unmarshal(body, &oauthErr); err == nil && oauthErr.ErrorCode != "" { + return fmt.Errorf("%s: %w", context, oauthErr) + } + + // If not a valid OAuth error, return the raw response + return fmt.Errorf("%s with status %d: %s", context, statusCode, body) +} + +// GetClientSecret returns the client secret +func (h *OAuthHandler) GetClientSecret() string { + return h.config.ClientSecret +} + +// SetBaseURL sets the base URL for the API server +func (h *OAuthHandler) SetBaseURL(baseURL string) { + h.baseURL = baseURL +} + +// GetExpectedState returns the expected state value (for testing purposes) +func (h *OAuthHandler) GetExpectedState() string { + return h.expectedState +} + +// OAuthError represents a standard OAuth 2.0 error response +type OAuthError struct { + ErrorCode string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` +} + +// Error implements the error interface +func (e OAuthError) Error() string { + if e.ErrorDescription != "" { + return fmt.Sprintf("OAuth error: %s - %s", e.ErrorCode, e.ErrorDescription) + } + return fmt.Sprintf("OAuth error: %s", e.ErrorCode) +} + +// OAuthProtectedResource represents the response from /.well-known/oauth-protected-resource +type OAuthProtectedResource struct { + AuthorizationServers []string `json:"authorization_servers"` + Resource string `json:"resource"` + ResourceName string `json:"resource_name,omitempty"` +} + +// getServerMetadata fetches the OAuth server metadata +func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetadata, error) { + h.metadataOnce.Do(func() { + // If AuthServerMetadataURL is explicitly provided, use it directly + if h.config.AuthServerMetadataURL != "" { + h.fetchMetadataFromURL(ctx, h.config.AuthServerMetadataURL) + return + } + + // Try to discover the authorization server via OAuth Protected Resource + // as per RFC 9728 (https://datatracker.ietf.org/doc/html/rfc9728) + baseURL, err := h.extractBaseURL() + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to extract base URL: %w", err) + return + } + + // Try to fetch the OAuth Protected Resource metadata + protectedResourceURL := baseURL + "/.well-known/oauth-protected-resource" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, protectedResourceURL, nil) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to create protected resource request: %w", err) + return + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("MCP-Protocol-Version", "2025-03-26") + + resp, err := h.httpClient.Do(req) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to send protected resource request: %w", err) + return + } + defer resp.Body.Close() + + // If we can't get the protected resource metadata, fall back to default endpoints + if resp.StatusCode != http.StatusOK { + metadata, err := h.getDefaultEndpoints(baseURL) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err) + return + } + h.serverMetadata = metadata + return + } + + // Parse the protected resource metadata + var protectedResource OAuthProtectedResource + if err := json.NewDecoder(resp.Body).Decode(&protectedResource); err != nil { + h.metadataFetchErr = fmt.Errorf("failed to decode protected resource response: %w", err) + return + } + + // If no authorization servers are specified, fall back to default endpoints + if len(protectedResource.AuthorizationServers) == 0 { + metadata, err := h.getDefaultEndpoints(baseURL) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err) + return + } + h.serverMetadata = metadata + return + } + + // Use the first authorization server + authServerURL := protectedResource.AuthorizationServers[0] + + // Try OpenID Connect discovery first + h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/openid-configuration") + if h.serverMetadata != nil { + return + } + + // If OpenID Connect discovery fails, try OAuth Authorization Server Metadata + h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/oauth-authorization-server") + if h.serverMetadata != nil { + return + } + + // If both discovery methods fail, use default endpoints based on the authorization server URL + metadata, err := h.getDefaultEndpoints(authServerURL) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err) + return + } + h.serverMetadata = metadata + }) + + if h.metadataFetchErr != nil { + return nil, h.metadataFetchErr + } + + return h.serverMetadata, nil +} + +// fetchMetadataFromURL fetches and parses OAuth server metadata from a URL +func (h *OAuthHandler) fetchMetadataFromURL(ctx context.Context, metadataURL string) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to create metadata request: %w", err) + return + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("MCP-Protocol-Version", "2025-03-26") + + resp, err := h.httpClient.Do(req) + if err != nil { + h.metadataFetchErr = fmt.Errorf("failed to send metadata request: %w", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + // If metadata discovery fails, don't set any metadata + return + } + + var metadata AuthServerMetadata + if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { + h.metadataFetchErr = fmt.Errorf("failed to decode metadata response: %w", err) + return + } + + h.serverMetadata = &metadata +} + +// extractBaseURL extracts the base URL from the first request +func (h *OAuthHandler) extractBaseURL() (string, error) { + // If we have a base URL from a previous request, use it + if h.baseURL != "" { + return h.baseURL, nil + } + + // Otherwise, we need to infer it from the redirect URI + if h.config.RedirectURI == "" { + return "", fmt.Errorf("no base URL available and no redirect URI provided") + } + + // Parse the redirect URI to extract the authority + parsedURL, err := url.Parse(h.config.RedirectURI) + if err != nil { + return "", fmt.Errorf("failed to parse redirect URI: %w", err) + } + + // Use the scheme and host from the redirect URI + baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host) + return baseURL, nil +} + +// GetServerMetadata is a public wrapper for getServerMetadata +func (h *OAuthHandler) GetServerMetadata(ctx context.Context) (*AuthServerMetadata, error) { + return h.getServerMetadata(ctx) +} + +// getDefaultEndpoints returns default OAuth endpoints based on the base URL +func (h *OAuthHandler) getDefaultEndpoints(baseURL string) (*AuthServerMetadata, error) { + // Parse the base URL to extract the authority + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse base URL: %w", err) + } + + // Discard any path component to get the authorization base URL + parsedURL.Path = "" + authBaseURL := parsedURL.String() + + // Validate that the URL has a scheme and host + if parsedURL.Scheme == "" || parsedURL.Host == "" { + return nil, fmt.Errorf("invalid base URL: missing scheme or host in %q", baseURL) + } + + return &AuthServerMetadata{ + Issuer: authBaseURL, + AuthorizationEndpoint: authBaseURL + "/authorize", + TokenEndpoint: authBaseURL + "/token", + RegistrationEndpoint: authBaseURL + "/register", + }, nil +} + +// RegisterClient performs dynamic client registration +func (h *OAuthHandler) RegisterClient(ctx context.Context, clientName string) error { + metadata, err := h.getServerMetadata(ctx) + if err != nil { + return fmt.Errorf("failed to get server metadata: %w", err) + } + + if metadata.RegistrationEndpoint == "" { + return errors.New("server does not support dynamic client registration") + } + + // Prepare registration request + regRequest := map[string]any{ + "client_name": clientName, + "redirect_uris": []string{h.config.RedirectURI}, + "token_endpoint_auth_method": "none", // For public clients + "grant_types": []string{"authorization_code", "refresh_token"}, + "response_types": []string{"code"}, + "scope": strings.Join(h.config.Scopes, " "), + } + + // Add client_secret if this is a confidential client + if h.config.ClientSecret != "" { + regRequest["token_endpoint_auth_method"] = "client_secret_basic" + } + + reqBody, err := json.Marshal(regRequest) + if err != nil { + return fmt.Errorf("failed to marshal registration request: %w", err) + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + metadata.RegistrationEndpoint, + bytes.NewReader(reqBody), + ) + if err != nil { + return fmt.Errorf("failed to create registration request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := h.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send registration request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return extractOAuthError(body, resp.StatusCode, "registration request failed") + } + + var regResponse struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` + } + + if err := json.NewDecoder(resp.Body).Decode(®Response); err != nil { + return fmt.Errorf("failed to decode registration response: %w", err) + } + + // Update the client configuration + h.config.ClientID = regResponse.ClientID + if regResponse.ClientSecret != "" { + h.config.ClientSecret = regResponse.ClientSecret + } + + return nil +} + +// ErrInvalidState is returned when the state parameter doesn't match the expected value +var ErrInvalidState = errors.New("invalid state parameter, possible CSRF attack") + +// ProcessAuthorizationResponse processes the authorization response and exchanges the code for a token +func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, state, codeVerifier string) error { + // Validate the state parameter to prevent CSRF attacks + if h.expectedState == "" { + return errors.New("no expected state found, authorization flow may not have been initiated properly") + } + + if state != h.expectedState { + return ErrInvalidState + } + + // Clear the expected state after validation + defer func() { + h.expectedState = "" + }() + + metadata, err := h.getServerMetadata(ctx) + if err != nil { + return fmt.Errorf("failed to get server metadata: %w", err) + } + + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("client_id", h.config.ClientID) + data.Set("redirect_uri", h.config.RedirectURI) + + if h.config.ClientSecret != "" { + data.Set("client_secret", h.config.ClientSecret) + } + + if h.config.PKCEEnabled && codeVerifier != "" { + data.Set("code_verifier", codeVerifier) + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + metadata.TokenEndpoint, + strings.NewReader(data.Encode()), + ) + if err != nil { + return fmt.Errorf("failed to create token request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := h.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send token request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return extractOAuthError(body, resp.StatusCode, "token request failed") + } + + var tokenResp Token + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return fmt.Errorf("failed to decode token response: %w", err) + } + + // Set expiration time + if tokenResp.ExpiresIn > 0 { + tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + // Save the token + if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil { + return fmt.Errorf("failed to save token: %w", err) + } + + return nil +} + +// GetAuthorizationURL returns the URL for the authorization endpoint +func (h *OAuthHandler) GetAuthorizationURL(ctx context.Context, state, codeChallenge string) (string, error) { + metadata, err := h.getServerMetadata(ctx) + if err != nil { + return "", fmt.Errorf("failed to get server metadata: %w", err) + } + + // Store the state for later validation + h.expectedState = state + + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", h.config.ClientID) + params.Set("redirect_uri", h.config.RedirectURI) + params.Set("state", state) + + if len(h.config.Scopes) > 0 { + params.Set("scope", strings.Join(h.config.Scopes, " ")) + } + + if h.config.PKCEEnabled && codeChallenge != "" { + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + } + + return metadata.AuthorizationEndpoint + "?" + params.Encode(), nil +} \ No newline at end of file diff --git a/client/transport/oauth_test.go b/client/transport/oauth_test.go new file mode 100644 index 000000000..c2660675b --- /dev/null +++ b/client/transport/oauth_test.go @@ -0,0 +1,302 @@ +package transport + +import ( + "context" + "errors" + "strings" + "testing" + "time" +) + +func TestToken_IsExpired(t *testing.T) { + // Test cases + testCases := []struct { + name string + token Token + expected bool + }{ + { + name: "Valid token", + token: Token{ + AccessToken: "valid-token", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, + expected: false, + }, + { + name: "Expired token", + token: Token{ + AccessToken: "expired-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), + }, + expected: true, + }, + { + name: "Token with no expiration", + token: Token{ + AccessToken: "no-expiration-token", + }, + expected: false, + }, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.token.IsExpired() + if result != tc.expected { + t.Errorf("Expected IsExpired() to return %v, got %v", tc.expected, result) + } + }) + } +} + +func TestMemoryTokenStore(t *testing.T) { + // Create a token store + store := NewMemoryTokenStore() + + // Test getting token from empty store + _, err := store.GetToken() + if err == nil { + t.Errorf("Expected error when getting token from empty store") + } + + // Create a test token + token := &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + // Save the token + err = store.SaveToken(token) + if err != nil { + t.Fatalf("Failed to save token: %v", err) + } + + // Get the token + retrievedToken, err := store.GetToken() + if err != nil { + t.Fatalf("Failed to get token: %v", err) + } + + // Verify the token + if retrievedToken.AccessToken != token.AccessToken { + t.Errorf("Expected access token to be %s, got %s", token.AccessToken, retrievedToken.AccessToken) + } + if retrievedToken.TokenType != token.TokenType { + t.Errorf("Expected token type to be %s, got %s", token.TokenType, retrievedToken.TokenType) + } + if retrievedToken.RefreshToken != token.RefreshToken { + t.Errorf("Expected refresh token to be %s, got %s", token.RefreshToken, retrievedToken.RefreshToken) + } +} + +func TestValidateRedirectURI(t *testing.T) { + // Test cases + testCases := []struct { + name string + redirectURI string + expectError bool + }{ + { + name: "Valid HTTPS URI", + redirectURI: "https://example.com/callback", + expectError: false, + }, + { + name: "Valid localhost URI", + redirectURI: "http://localhost:8085/callback", + expectError: false, + }, + { + name: "Valid localhost URI with 127.0.0.1", + redirectURI: "http://127.0.0.1:8085/callback", + expectError: false, + }, + { + name: "Invalid HTTP URI (non-localhost)", + redirectURI: "http://example.com/callback", + expectError: true, + }, + { + name: "Invalid HTTP URI with 'local' in domain", + redirectURI: "http://localdomain.com/callback", + expectError: true, + }, + { + name: "Empty URI", + redirectURI: "", + expectError: true, + }, + { + name: "Invalid scheme", + redirectURI: "ftp://example.com/callback", + expectError: true, + }, + { + name: "IPv6 localhost", + redirectURI: "http://[::1]:8080/callback", + expectError: false, // IPv6 localhost is valid + }, + } + + // Run test cases + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateRedirectURI(tc.redirectURI) + if tc.expectError && err == nil { + t.Errorf("Expected error for redirect URI %s, got nil", tc.redirectURI) + } else if !tc.expectError && err != nil { + t.Errorf("Expected no error for redirect URI %s, got %v", tc.redirectURI, err) + } + }) + } +} + +func TestOAuthHandler_GetAuthorizationHeader_EmptyAccessToken(t *testing.T) { + // Create a token store with a token that has an empty access token + tokenStore := NewMemoryTokenStore() + invalidToken := &Token{ + AccessToken: "", // Empty access token + TokenType: "Bearer", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour + } + if err := tokenStore.SaveToken(invalidToken); err != nil { + t.Fatalf("Failed to save token: %v", err) + } + + // Create an OAuth handler + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, + } + + handler := NewOAuthHandler(config) + + // Test getting authorization header with empty access token + _, err := handler.GetAuthorizationHeader(context.Background()) + if err == nil { + t.Fatalf("Expected error when getting authorization header with empty access token") + } + + // Verify the error message + if !errors.Is(err, ErrOAuthAuthorizationRequired) { + t.Errorf("Expected error to be ErrOAuthAuthorizationRequired, got %v", err) + } +} + +func TestOAuthHandler_GetServerMetadata_EmptyURL(t *testing.T) { + // Create an OAuth handler with an empty AuthServerMetadataURL + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read"}, + TokenStore: NewMemoryTokenStore(), + AuthServerMetadataURL: "", // Empty URL + PKCEEnabled: true, + } + + handler := NewOAuthHandler(config) + + // Test getting server metadata with empty URL + _, err := handler.GetServerMetadata(context.Background()) + if err == nil { + t.Fatalf("Expected error when getting server metadata with empty URL") + } + + // Verify the error message contains something about a connection error + // since we're now trying to connect to the well-known endpoint + if !strings.Contains(err.Error(), "connection refused") && + !strings.Contains(err.Error(), "failed to send protected resource request") { + t.Errorf("Expected error message to contain connection error, got %s", err.Error()) + } +} + +func TestOAuthError(t *testing.T) { + testCases := []struct { + name string + errorCode string + description string + uri string + expected string + }{ + { + name: "Error with description", + errorCode: "invalid_request", + description: "The request is missing a required parameter", + uri: "https://example.com/errors/invalid_request", + expected: "OAuth error: invalid_request - The request is missing a required parameter", + }, + { + name: "Error without description", + errorCode: "unauthorized_client", + description: "", + uri: "", + expected: "OAuth error: unauthorized_client", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + oauthErr := OAuthError{ + ErrorCode: tc.errorCode, + ErrorDescription: tc.description, + ErrorURI: tc.uri, + } + + if oauthErr.Error() != tc.expected { + t.Errorf("Expected error message %q, got %q", tc.expected, oauthErr.Error()) + } + }) + } +} + +func TestOAuthHandler_ProcessAuthorizationResponse_StateValidation(t *testing.T) { + // Create an OAuth handler + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: NewMemoryTokenStore(), + AuthServerMetadataURL: "http://example.com/.well-known/oauth-authorization-server", + PKCEEnabled: true, + } + + handler := NewOAuthHandler(config) + + // Mock the server metadata to avoid nil pointer dereference + handler.serverMetadata = &AuthServerMetadata{ + Issuer: "http://example.com", + AuthorizationEndpoint: "http://example.com/authorize", + TokenEndpoint: "http://example.com/token", + } + + // Set the expected state + expectedState := "test-state-123" + handler.expectedState = expectedState + + // Test with non-matching state - this should fail immediately with ErrInvalidState + // before trying to connect to any server + err := handler.ProcessAuthorizationResponse(context.Background(), "test-code", "wrong-state", "test-code-verifier") + if !errors.Is(err, ErrInvalidState) { + t.Errorf("Expected ErrInvalidState, got %v", err) + } + + // Test with empty expected state + handler.expectedState = "" + err = handler.ProcessAuthorizationResponse(context.Background(), "test-code", expectedState, "test-code-verifier") + if err == nil { + t.Errorf("Expected error with empty expected state, got nil") + } + if errors.Is(err, ErrInvalidState) { + t.Errorf("Got ErrInvalidState when expected a different error for empty expected state") + } +} \ No newline at end of file diff --git a/client/transport/oauth_utils.go b/client/transport/oauth_utils.go new file mode 100644 index 000000000..e2104307a --- /dev/null +++ b/client/transport/oauth_utils.go @@ -0,0 +1,68 @@ +package transport + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/url" +) + +// GenerateRandomString generates a random string of the specified length +func GenerateRandomString(length int) (string, error) { + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(bytes)[:length], nil +} + +// GenerateCodeVerifier generates a code verifier for PKCE +func GenerateCodeVerifier() (string, error) { + // According to RFC 7636, the code verifier should be between 43 and 128 characters + return GenerateRandomString(64) +} + +// GenerateCodeChallenge generates a code challenge from a code verifier +func GenerateCodeChallenge(codeVerifier string) string { + // SHA256 hash the code verifier + hash := sha256.Sum256([]byte(codeVerifier)) + // Base64url encode the hash + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +// GenerateState generates a state parameter for OAuth +func GenerateState() (string, error) { + return GenerateRandomString(32) +} + +// ValidateRedirectURI validates that a redirect URI is secure +func ValidateRedirectURI(redirectURI string) error { + // According to the spec, redirect URIs must be either localhost URLs or HTTPS URLs + if redirectURI == "" { + return fmt.Errorf("redirect URI cannot be empty") + } + + // Parse the URL + parsedURL, err := url.Parse(redirectURI) + if err != nil { + return fmt.Errorf("invalid redirect URI: %w", err) + } + + // Check if it's a localhost URL + if parsedURL.Scheme == "http" { + hostname := parsedURL.Hostname() + // Check for various forms of localhost + if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == "[::1]" { + return nil + } + return fmt.Errorf("HTTP redirect URI must use localhost or 127.0.0.1") + } + + // Check if it's an HTTPS URL + if parsedURL.Scheme == "https" { + return nil + } + + return fmt.Errorf("redirect URI must use either HTTP with localhost or HTTPS") +} \ No newline at end of file diff --git a/client/transport/oauth_utils_test.go b/client/transport/oauth_utils_test.go new file mode 100644 index 000000000..da6eacaee --- /dev/null +++ b/client/transport/oauth_utils_test.go @@ -0,0 +1,88 @@ +package transport + +import ( + "fmt" + "testing" +) + +func TestGenerateRandomString(t *testing.T) { + // Test generating strings of different lengths + lengths := []int{10, 32, 64, 128} + for _, length := range lengths { + t.Run(fmt.Sprintf("Length_%d", length), func(t *testing.T) { + str, err := GenerateRandomString(length) + if err != nil { + t.Fatalf("Failed to generate random string: %v", err) + } + if len(str) != length { + t.Errorf("Expected string of length %d, got %d", length, len(str)) + } + + // Generate another string to ensure they're different + str2, err := GenerateRandomString(length) + if err != nil { + t.Fatalf("Failed to generate second random string: %v", err) + } + if str == str2 { + t.Errorf("Generated identical random strings: %s", str) + } + }) + } +} + +func TestGenerateCodeVerifierAndChallenge(t *testing.T) { + // Generate a code verifier + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("Failed to generate code verifier: %v", err) + } + + // Verify the length (should be 64 characters) + if len(verifier) != 64 { + t.Errorf("Expected code verifier of length 64, got %d", len(verifier)) + } + + // Generate a code challenge + challenge := GenerateCodeChallenge(verifier) + + // Verify the challenge is not empty + if challenge == "" { + t.Errorf("Generated empty code challenge") + } + + // Generate another verifier and challenge to ensure they're different + verifier2, _ := GenerateCodeVerifier() + challenge2 := GenerateCodeChallenge(verifier2) + + if verifier == verifier2 { + t.Errorf("Generated identical code verifiers: %s", verifier) + } + if challenge == challenge2 { + t.Errorf("Generated identical code challenges: %s", challenge) + } + + // Verify the same verifier always produces the same challenge + challenge3 := GenerateCodeChallenge(verifier) + if challenge != challenge3 { + t.Errorf("Same verifier produced different challenges: %s and %s", challenge, challenge3) + } +} + +func TestGenerateState(t *testing.T) { + // Generate a state parameter + state, err := GenerateState() + if err != nil { + t.Fatalf("Failed to generate state: %v", err) + } + + // Verify the length (should be 32 characters) + if len(state) != 32 { + t.Errorf("Expected state of length 32, got %d", len(state)) + } + + // Generate another state to ensure they're different + state2, _ := GenerateState() + if state == state2 { + t.Errorf("Generated identical states: %s", state) + } +} \ No newline at end of file diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index afeea7e9a..8727bb190 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "mime" @@ -39,6 +40,13 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { } } +// WithOAuth enables OAuth authentication for the client. +func WithOAuth(config OAuthConfig) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.oauthHandler = NewOAuthHandler(config) + } +} + // StreamableHTTP implements Streamable HTTP transport. // // It transmits JSON-RPC messages over individual HTTP requests. One message per request. @@ -55,7 +63,7 @@ func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption { // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) // - server -> client request type StreamableHTTP struct { - baseURL *url.URL + serverURL *url.URL httpClient *http.Client headers map[string]string headerFunc HTTPHeaderFunc @@ -66,18 +74,21 @@ type StreamableHTTP struct { notifyMu sync.RWMutex closed chan struct{} + + // OAuth support + oauthHandler *OAuthHandler } -// NewStreamableHTTP creates a new Streamable HTTP transport with the given base URL. +// NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL. // Returns an error if the URL is invalid. -func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) { - parsedURL, err := url.Parse(baseURL) +func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) { + parsedURL, err := url.Parse(serverURL) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } smc := &StreamableHTTP{ - baseURL: parsedURL, + serverURL: parsedURL, httpClient: &http.Client{}, headers: make(map[string]string), closed: make(chan struct{}), @@ -88,6 +99,13 @@ func NewStreamableHTTP(baseURL string, options ...StreamableHTTPCOption) (*Strea opt(smc) } + // If OAuth is configured, set the base URL for metadata discovery + if smc.oauthHandler != nil { + // Extract base URL from server URL for metadata discovery + baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host) + smc.oauthHandler.SetBaseURL(baseURL) + } + return smc, nil } @@ -115,7 +133,7 @@ func (c *StreamableHTTP) Close() error { go func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.baseURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil) if err != nil { fmt.Printf("failed to create close request\n: %v", err) return @@ -137,6 +155,22 @@ const ( headerKeySessionID = "Mcp-Session-Id" ) +// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required +var ErrOAuthAuthorizationRequired = errors.New("no valid token available, authorization required") + +// OAuthAuthorizationRequiredError is returned when OAuth authorization is required +type OAuthAuthorizationRequiredError struct { + Handler *OAuthHandler +} + +func (e *OAuthAuthorizationRequiredError) Error() string { + return ErrOAuthAuthorizationRequired.Error() +} + +func (e *OAuthAuthorizationRequiredError) Unwrap() error { + return ErrOAuthAuthorizationRequired +} + // SendRequest sends a JSON-RPC request to the server and waits for a response. // Returns the raw JSON response message or an error if the request fails. func (c *StreamableHTTP) SendRequest( @@ -164,7 +198,7 @@ func (c *StreamableHTTP) SendRequest( } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -179,6 +213,22 @@ func (c *StreamableHTTP) SendRequest( for k, v := range c.headers { req.Header.Set(k, v) } + + // Add OAuth authorization if configured + if c.oauthHandler != nil { + authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) + if err != nil { + // If we get an authorization error, return a specific error that can be handled by the client + if err.Error() == "no valid token available, authorization required" { + return nil, &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + return nil, fmt.Errorf("failed to get authorization header: %w", err) + } + req.Header.Set("Authorization", authHeader) + } + if c.headerFunc != nil { for k, v := range c.headerFunc(ctx) { req.Header.Set(k, v) @@ -199,6 +249,13 @@ func (c *StreamableHTTP) SendRequest( c.sessionID.CompareAndSwap(sessionID, "") return nil, fmt.Errorf("session terminated (404). need to re-initialize") } + + // Handle OAuth unauthorized error + if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { + return nil, &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } // handle error response var errResponse JSONRPCResponse @@ -360,7 +417,7 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL.String(), bytes.NewReader(requestBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) if err != nil { return fmt.Errorf("failed to create request: %w", err) } @@ -374,6 +431,22 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. for k, v := range c.headers { req.Header.Set(k, v) } + + // Add OAuth authorization if configured + if c.oauthHandler != nil { + authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) + if err != nil { + // If we get an authorization error, return a specific error that can be handled by the client + if errors.Is(err, ErrOAuthAuthorizationRequired) { + return &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + return fmt.Errorf("failed to get authorization header: %w", err) + } + req.Header.Set("Authorization", authHeader) + } + if c.headerFunc != nil { for k, v := range c.headerFunc(ctx) { req.Header.Set(k, v) @@ -388,6 +461,13 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + // Handle OAuth unauthorized error + if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { + return &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + body, _ := io.ReadAll(resp.Body) return fmt.Errorf( "notification failed with status %d: %s", @@ -408,3 +488,13 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica func (c *StreamableHTTP) GetSessionId() string { return c.sessionID.Load().(string) } + +// GetOAuthHandler returns the OAuth handler if configured +func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler { + return c.oauthHandler +} + +// IsOAuthEnabled returns true if OAuth is enabled +func (c *StreamableHTTP) IsOAuthEnabled() bool { + return c.oauthHandler != nil +} diff --git a/client/transport/streamable_http_oauth_test.go b/client/transport/streamable_http_oauth_test.go new file mode 100644 index 000000000..3b38c2cbb --- /dev/null +++ b/client/transport/streamable_http_oauth_test.go @@ -0,0 +1,218 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestStreamableHTTP_WithOAuth(t *testing.T) { + // Track request count to simulate 401 on first request, then success + requestCount := 0 + authHeaderReceived := "" + + // Create a test server that requires OAuth + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture the Authorization header + authHeaderReceived = r.Header.Get("Authorization") + + // Check for Authorization header + if requestCount == 0 { + // First request - simulate 401 to test error handling + requestCount++ + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Subsequent requests - verify the Authorization header + if authHeaderReceived != "Bearer test-token" { + t.Errorf("Expected Authorization header 'Bearer test-token', got '%s'", authHeaderReceived) + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Return a successful response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "result": "success", + }); err != nil { + t.Errorf("Failed to encode JSON response: %v", err) + } + })) + defer server.Close() + + // Create a token store with a valid token + tokenStore := NewMemoryTokenStore() + validToken := &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + RefreshToken: "refresh-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour + } + if err := tokenStore.SaveToken(validToken); err != nil { + t.Fatalf("Failed to save token: %v", err) + } + + // Create OAuth config + oauthConfig := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, + } + + // Create StreamableHTTP with OAuth + transport, err := NewStreamableHTTP(server.URL, WithOAuth(oauthConfig)) + if err != nil { + t.Fatalf("Failed to create StreamableHTTP: %v", err) + } + + // Verify that OAuth is enabled + if !transport.IsOAuthEnabled() { + t.Errorf("Expected IsOAuthEnabled() to return true") + } + + // Verify the OAuth handler is set + if transport.GetOAuthHandler() == nil { + t.Errorf("Expected GetOAuthHandler() to return a handler") + } + + // First request should fail with OAuthAuthorizationRequiredError + _, err = transport.SendRequest(context.Background(), JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: "test", + }) + + // Verify the error is an OAuthAuthorizationRequiredError + if err == nil { + t.Fatalf("Expected error on first request, got nil") + } + + var oauthErr *OAuthAuthorizationRequiredError + if !errors.As(err, &oauthErr) { + t.Fatalf("Expected OAuthAuthorizationRequiredError, got %T: %v", err, err) + } + + // Verify the error has the handler + if oauthErr.Handler == nil { + t.Errorf("Expected OAuthAuthorizationRequiredError to have a handler") + } + + // Verify the server received the first request + if requestCount != 1 { + t.Errorf("Expected server to receive 1 request, got %d", requestCount) + } + + // Second request should succeed + response, err := transport.SendRequest(context.Background(), JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(2), + Method: "test", + }) + + if err != nil { + t.Fatalf("Failed to send second request: %v", err) + } + + // Verify the response + var resultStr string + if err := json.Unmarshal(response.Result, &resultStr); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + if resultStr != "success" { + t.Errorf("Expected result to be 'success', got %v", resultStr) + } + + // Verify the server received the Authorization header + if authHeaderReceived != "Bearer test-token" { + t.Errorf("Expected server to receive Authorization header 'Bearer test-token', got '%s'", authHeaderReceived) + } +} + +func TestStreamableHTTP_WithOAuth_Unauthorized(t *testing.T) { + // Create a test server that requires OAuth + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Always return unauthorized + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + // Create an empty token store + tokenStore := NewMemoryTokenStore() + + // Create OAuth config + oauthConfig := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, + } + + // Create StreamableHTTP with OAuth + transport, err := NewStreamableHTTP(server.URL, WithOAuth(oauthConfig)) + if err != nil { + t.Fatalf("Failed to create StreamableHTTP: %v", err) + } + + // Send a request + _, err = transport.SendRequest(context.Background(), JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: "test", + }) + + // Verify the error is an OAuthAuthorizationRequiredError + if err == nil { + t.Fatalf("Expected error, got nil") + } + + var oauthErr *OAuthAuthorizationRequiredError + if !errors.As(err, &oauthErr) { + t.Fatalf("Expected OAuthAuthorizationRequiredError, got %T: %v", err, err) + } + + // Verify the error has the handler + if oauthErr.Handler == nil { + t.Errorf("Expected OAuthAuthorizationRequiredError to have a handler") + } +} + +func TestStreamableHTTP_IsOAuthEnabled(t *testing.T) { + // Create StreamableHTTP without OAuth + transport1, err := NewStreamableHTTP("http://example.com") + if err != nil { + t.Fatalf("Failed to create StreamableHTTP: %v", err) + } + + // Verify OAuth is not enabled + if transport1.IsOAuthEnabled() { + t.Errorf("Expected IsOAuthEnabled() to return false") + } + + // Create StreamableHTTP with OAuth + transport2, err := NewStreamableHTTP("http://example.com", WithOAuth(OAuthConfig{ + ClientID: "test-client", + })) + if err != nil { + t.Fatalf("Failed to create StreamableHTTP: %v", err) + } + + // Verify OAuth is enabled + if !transport2.IsOAuthEnabled() { + t.Errorf("Expected IsOAuthEnabled() to return true") + } +} \ No newline at end of file diff --git a/examples/oauth_client/README.md b/examples/oauth_client/README.md new file mode 100644 index 000000000..a60bb7c5f --- /dev/null +++ b/examples/oauth_client/README.md @@ -0,0 +1,59 @@ +# OAuth Client Example + +This example demonstrates how to use the OAuth capabilities of the MCP Go client to authenticate with an MCP server that requires OAuth authentication. + +## Features + +- OAuth 2.1 authentication with PKCE support +- Dynamic client registration +- Authorization code flow +- Token refresh +- Local callback server for handling OAuth redirects + +## Usage + +```bash +# Set environment variables (optional) +export MCP_CLIENT_ID=your_client_id +export MCP_CLIENT_SECRET=your_client_secret + +# Run the example +go run main.go +``` + +## How it Works + +1. The client attempts to initialize a connection to the MCP server +2. If the server requires OAuth authentication, it will return a 401 Unauthorized response +3. The client detects this and starts the OAuth flow: + - Generates PKCE code verifier and challenge + - Generates a state parameter for security + - Opens a browser to the authorization URL + - Starts a local server to handle the callback +4. The user authorizes the application in their browser +5. The authorization server redirects back to the local callback server +6. The client exchanges the authorization code for an access token +7. The client retries the initialization with the access token +8. The client can now make authenticated requests to the MCP server + +## Configuration + +Edit the following constants in `main.go` to match your environment: + +```go +const ( + // Replace with your MCP server URL + serverURL = "https://api.example.com/v1/mcp" + // Use a localhost redirect URI for this example + redirectURI = "http://localhost:8085/oauth/callback" +) +``` + +## OAuth Scopes + +The example requests the following scopes: + +- `mcp.read` - Read access to MCP resources +- `mcp.write` - Write access to MCP resources + +You can modify the scopes in the `oauthConfig` to match the requirements of your MCP server. \ No newline at end of file diff --git a/examples/oauth_client/main.go b/examples/oauth_client/main.go new file mode 100644 index 000000000..3fc653324 --- /dev/null +++ b/examples/oauth_client/main.go @@ -0,0 +1,223 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/exec" + "runtime" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" +) + +const ( + // Replace with your MCP server URL + serverURL = "https://api.example.com/v1/mcp" + // Use a localhost redirect URI for this example + redirectURI = "http://localhost:8085/oauth/callback" +) + +func main() { + // Create a token store to persist tokens + tokenStore := client.NewMemoryTokenStore() + + // Create OAuth configuration + oauthConfig := client.OAuthConfig{ + // Client ID can be empty if using dynamic registration + ClientID: os.Getenv("MCP_CLIENT_ID"), + ClientSecret: os.Getenv("MCP_CLIENT_SECRET"), + RedirectURI: redirectURI, + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: tokenStore, + PKCEEnabled: true, // Enable PKCE for public clients + } + + // Create the client with OAuth support + c, err := client.NewOAuthStreamableHttpClient(serverURL, oauthConfig) + if err != nil { + log.Fatalf("Failed to create client: %v", err) + } + + // Start the client + if err := c.Start(context.Background()); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + defer c.Close() + + // Try to initialize the client + result, err := c.Initialize(context.Background(), mcp.InitializeRequest{ + Params: struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + ClientInfo mcp.Implementation `json:"clientInfo"` + }{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "mcp-go-oauth-example", + Version: "0.1.0", + }, + }, + }) + + // Check if we need OAuth authorization + if client.IsOAuthAuthorizationRequiredError(err) { + fmt.Println("OAuth authorization required. Starting authorization flow...") + + // Get the OAuth handler from the error + oauthHandler := client.GetOAuthHandler(err) + + // Start a local server to handle the OAuth callback + callbackChan := make(chan map[string]string) + server := startCallbackServer(callbackChan) + defer server.Close() + + // Generate PKCE code verifier and challenge + codeVerifier, err := client.GenerateCodeVerifier() + if err != nil { + log.Fatalf("Failed to generate code verifier: %v", err) + } + codeChallenge := client.GenerateCodeChallenge(codeVerifier) + + // Generate state parameter + state, err := client.GenerateState() + if err != nil { + log.Fatalf("Failed to generate state: %v", err) + } + + // Get the authorization URL + authURL, err := oauthHandler.GetAuthorizationURL(context.Background(), state, codeChallenge) + if err != nil { + log.Fatalf("Failed to get authorization URL: %v", err) + } + + // Open the browser to the authorization URL + fmt.Printf("Opening browser to: %s\n", authURL) + openBrowser(authURL) + + // Wait for the callback + fmt.Println("Waiting for authorization callback...") + params := <-callbackChan + + // Verify state parameter + if params["state"] != state { + log.Fatalf("State mismatch: expected %s, got %s", state, params["state"]) + } + + // Exchange the authorization code for a token + code := params["code"] + if code == "" { + log.Fatalf("No authorization code received") + } + + fmt.Println("Exchanging authorization code for token...") + err = oauthHandler.ProcessAuthorizationResponse(context.Background(), code, state, codeVerifier) + if err != nil { + log.Fatalf("Failed to process authorization response: %v", err) + } + + fmt.Println("Authorization successful!") + + // Try to initialize again with the token + result, err = c.Initialize(context.Background(), mcp.InitializeRequest{ + Params: struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + ClientInfo mcp.Implementation `json:"clientInfo"` + }{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "mcp-go-oauth-example", + Version: "0.1.0", + }, + }, + }) + if err != nil { + log.Fatalf("Failed to initialize client after authorization: %v", err) + } + } else if err != nil { + log.Fatalf("Failed to initialize client: %v", err) + } + + fmt.Printf("Client initialized successfully! Server: %s %s\n", + result.ServerInfo.Name, + result.ServerInfo.Version) + + // Now you can use the client as usual + // For example, list resources + resources, err := c.ListResources(context.Background(), mcp.ListResourcesRequest{}) + if err != nil { + log.Fatalf("Failed to list resources: %v", err) + } + + fmt.Println("Available resources:") + for _, resource := range resources.Resources { + fmt.Printf("- %s\n", resource.URI) + } +} + +// startCallbackServer starts a local HTTP server to handle the OAuth callback +func startCallbackServer(callbackChan chan<- map[string]string) *http.Server { + server := &http.Server{ + Addr: ":8085", + } + + http.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { + // Extract query parameters + params := make(map[string]string) + for key, values := range r.URL.Query() { + if len(values) > 0 { + params[key] = values[0] + } + } + + // Send parameters to the channel + callbackChan <- params + + // Respond to the user + w.Header().Set("Content-Type", "text/html") + _, err := w.Write([]byte(` + + +

Authorization Successful

+

You can now close this window and return to the application.

+ + + + `)) + if err != nil { + log.Printf("Error writing response: %v", err) + } + }) + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Printf("HTTP server error: %v", err) + } + }() + + return server +} + +// openBrowser opens the default browser to the specified URL +func openBrowser(url string) { + var err error + + switch runtime.GOOS { + case "linux": + err = exec.Command("xdg-open", url).Start() + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + default: + err = fmt.Errorf("unsupported platform") + } + + if err != nil { + log.Printf("Failed to open browser: %v", err) + fmt.Printf("Please open the following URL in your browser: %s\n", url) + } +} \ No newline at end of file