Skip to content

Commit 7d132c9

Browse files
committed
more error handling
1 parent e1818a4 commit 7d132c9

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

client/transport/oauth.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*
195195

196196
if resp.StatusCode != http.StatusOK {
197197
body, _ := io.ReadAll(resp.Body)
198-
return nil, fmt.Errorf("refresh token request failed with status %d: %s", resp.StatusCode, body)
198+
return nil, extractOAuthError(body, resp.StatusCode, "refresh token request failed")
199199
}
200200

201201
var tokenResp Token
@@ -232,6 +232,18 @@ func (h *OAuthHandler) GetClientID() string {
232232
return h.config.ClientID
233233
}
234234

235+
// extractOAuthError attempts to parse an OAuth error response from the response body
236+
func extractOAuthError(body []byte, statusCode int, context string) error {
237+
// Try to parse the error as an OAuth error response
238+
var oauthErr OAuthError
239+
if err := json.Unmarshal(body, &oauthErr); err == nil && oauthErr.ErrorCode != "" {
240+
return fmt.Errorf("%s: %w", context, oauthErr)
241+
}
242+
243+
// If not a valid OAuth error, return the raw response
244+
return fmt.Errorf("%s with status %d: %s", context, statusCode, body)
245+
}
246+
235247
// GetClientSecret returns the client secret
236248
func (h *OAuthHandler) GetClientSecret() string {
237249
return h.config.ClientSecret
@@ -247,6 +259,21 @@ func (h *OAuthHandler) GetExpectedState() string {
247259
return h.expectedState
248260
}
249261

262+
// OAuthError represents a standard OAuth 2.0 error response
263+
type OAuthError struct {
264+
ErrorCode string `json:"error"`
265+
ErrorDescription string `json:"error_description,omitempty"`
266+
ErrorURI string `json:"error_uri,omitempty"`
267+
}
268+
269+
// Error implements the error interface
270+
func (e OAuthError) Error() string {
271+
if e.ErrorDescription != "" {
272+
return fmt.Sprintf("OAuth error: %s - %s", e.ErrorCode, e.ErrorDescription)
273+
}
274+
return fmt.Sprintf("OAuth error: %s", e.ErrorCode)
275+
}
276+
250277
// OAuthProtectedResource represents the response from /.well-known/oauth-protected-resource
251278
type OAuthProtectedResource struct {
252279
AuthorizationServers []string `json:"authorization_servers"`
@@ -486,7 +513,7 @@ func (h *OAuthHandler) RegisterClient(ctx context.Context, clientName string) er
486513

487514
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
488515
body, _ := io.ReadAll(resp.Body)
489-
return fmt.Errorf("registration request failed with status %d: %s", resp.StatusCode, body)
516+
return extractOAuthError(body, resp.StatusCode, "registration request failed")
490517
}
491518

492519
var regResponse struct {
@@ -566,7 +593,7 @@ func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, s
566593

567594
if resp.StatusCode != http.StatusOK {
568595
body, _ := io.ReadAll(resp.Body)
569-
return fmt.Errorf("token request failed with status %d: %s", resp.StatusCode, body)
596+
return extractOAuthError(body, resp.StatusCode, "token request failed")
570597
}
571598

572599
var tokenResp Token

client/transport/oauth_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,45 @@ func TestOAuthHandler_GetServerMetadata_EmptyURL(t *testing.T) {
220220
}
221221
}
222222

223+
func TestOAuthError(t *testing.T) {
224+
testCases := []struct {
225+
name string
226+
errorCode string
227+
description string
228+
uri string
229+
expected string
230+
}{
231+
{
232+
name: "Error with description",
233+
errorCode: "invalid_request",
234+
description: "The request is missing a required parameter",
235+
uri: "https://example.com/errors/invalid_request",
236+
expected: "OAuth error: invalid_request - The request is missing a required parameter",
237+
},
238+
{
239+
name: "Error without description",
240+
errorCode: "unauthorized_client",
241+
description: "",
242+
uri: "",
243+
expected: "OAuth error: unauthorized_client",
244+
},
245+
}
246+
247+
for _, tc := range testCases {
248+
t.Run(tc.name, func(t *testing.T) {
249+
oauthErr := OAuthError{
250+
ErrorCode: tc.errorCode,
251+
ErrorDescription: tc.description,
252+
ErrorURI: tc.uri,
253+
}
254+
255+
if oauthErr.Error() != tc.expected {
256+
t.Errorf("Expected error message %q, got %q", tc.expected, oauthErr.Error())
257+
}
258+
})
259+
}
260+
}
261+
223262
func TestOAuthHandler_ProcessAuthorizationResponse_StateValidation(t *testing.T) {
224263
// Create an OAuth handler
225264
config := OAuthConfig{

0 commit comments

Comments
 (0)