Skip to content

Commit 62c729d

Browse files
authoredJun 3, 2025
Merge pull request #3 from sacha-development-stuff/codex/review-implementation-of-client-credentials-flow
Refactor auth helper methods
2 parents dbbc6ce + 3f2a351 commit 62c729d

File tree

1 file changed

+48
-85
lines changed

1 file changed

+48
-85
lines changed
 

‎src/mcp/client/auth.py

Lines changed: 48 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,44 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
4848
...
4949

5050

51+
def _get_authorization_base_url(server_url: str) -> str:
52+
"""Return the authorization base URL for ``server_url``.
53+
54+
Per MCP spec 2.3.2, the path component must be discarded so that
55+
``https://api.example.com/v1/mcp`` becomes ``https://api.example.com``.
56+
"""
57+
from urllib.parse import urlparse, urlunparse
58+
59+
parsed = urlparse(server_url)
60+
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))
61+
62+
63+
async def _discover_oauth_metadata(server_url: str) -> OAuthMetadata | None:
64+
"""Discover OAuth metadata from the server's well-known endpoint."""
65+
66+
auth_base_url = _get_authorization_base_url(server_url)
67+
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
68+
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
69+
70+
async with httpx.AsyncClient() as client:
71+
try:
72+
response = await client.get(url, headers=headers)
73+
if response.status_code == 404:
74+
return None
75+
response.raise_for_status()
76+
return OAuthMetadata.model_validate(response.json())
77+
except Exception:
78+
try:
79+
response = await client.get(url)
80+
if response.status_code == 404:
81+
return None
82+
response.raise_for_status()
83+
return OAuthMetadata.model_validate(response.json())
84+
except Exception:
85+
logger.exception("Failed to discover OAuth metadata")
86+
return None
87+
88+
5189
class OAuthClientProvider(httpx.Auth):
5290
"""
5391
Authentication for httpx using anyio.
@@ -110,52 +148,6 @@ def _generate_code_challenge(self, code_verifier: str) -> str:
110148
digest = hashlib.sha256(code_verifier.encode()).digest()
111149
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
112150

113-
def _get_authorization_base_url(self, server_url: str) -> str:
114-
"""
115-
Extract base URL by removing path component.
116-
117-
Per MCP spec 2.3.2: https://api.example.com/v1/mcp -> https://api.example.com
118-
"""
119-
from urllib.parse import urlparse, urlunparse
120-
121-
parsed = urlparse(server_url)
122-
# Remove path component
123-
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))
124-
125-
async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
126-
"""
127-
Discover OAuth metadata from server's well-known endpoint.
128-
"""
129-
# Extract base URL per MCP spec
130-
auth_base_url = self._get_authorization_base_url(server_url)
131-
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
132-
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
133-
134-
async with httpx.AsyncClient() as client:
135-
try:
136-
response = await client.get(url, headers=headers)
137-
if response.status_code == 404:
138-
return None
139-
response.raise_for_status()
140-
metadata_json = response.json()
141-
logger.debug(f"OAuth metadata discovered: {metadata_json}")
142-
return OAuthMetadata.model_validate(metadata_json)
143-
except Exception:
144-
# Retry without MCP header for CORS compatibility
145-
try:
146-
response = await client.get(url)
147-
if response.status_code == 404:
148-
return None
149-
response.raise_for_status()
150-
metadata_json = response.json()
151-
logger.debug(
152-
f"OAuth metadata discovered (no MCP header): {metadata_json}"
153-
)
154-
return OAuthMetadata.model_validate(metadata_json)
155-
except Exception:
156-
logger.exception("Failed to discover OAuth metadata")
157-
return None
158-
159151
async def _register_oauth_client(
160152
self,
161153
server_url: str,
@@ -166,13 +158,13 @@ async def _register_oauth_client(
166158
Register OAuth client with server.
167159
"""
168160
if not metadata:
169-
metadata = await self._discover_oauth_metadata(server_url)
161+
metadata = await _discover_oauth_metadata(server_url)
170162

171163
if metadata and metadata.registration_endpoint:
172164
registration_url = str(metadata.registration_endpoint)
173165
else:
174166
# Use fallback registration endpoint
175-
auth_base_url = self._get_authorization_base_url(server_url)
167+
auth_base_url = _get_authorization_base_url(server_url)
176168
registration_url = urljoin(auth_base_url, "/register")
177169

178170
# Handle default scope
@@ -321,7 +313,7 @@ async def _perform_oauth_flow(self) -> None:
321313

322314
# Discover OAuth metadata
323315
if not self._metadata:
324-
self._metadata = await self._discover_oauth_metadata(self.server_url)
316+
self._metadata = await _discover_oauth_metadata(self.server_url)
325317

326318
# Ensure client registration
327319
client_info = await self._get_or_register_client()
@@ -335,7 +327,7 @@ async def _perform_oauth_flow(self) -> None:
335327
auth_url_base = str(self._metadata.authorization_endpoint)
336328
else:
337329
# Use fallback authorization endpoint
338-
auth_base_url = self._get_authorization_base_url(self.server_url)
330+
auth_base_url = _get_authorization_base_url(self.server_url)
339331
auth_url_base = urljoin(auth_base_url, "/authorize")
340332

341333
# Build authorization URL
@@ -386,7 +378,7 @@ async def _exchange_code_for_token(
386378
token_url = str(self._metadata.token_endpoint)
387379
else:
388380
# Use fallback token endpoint
389-
auth_base_url = self._get_authorization_base_url(self.server_url)
381+
auth_base_url = _get_authorization_base_url(self.server_url)
390382
token_url = urljoin(auth_base_url, "/token")
391383

392384
token_data = {
@@ -453,7 +445,7 @@ async def _refresh_access_token(self) -> bool:
453445
token_url = str(self._metadata.token_endpoint)
454446
else:
455447
# Use fallback token endpoint
456-
auth_base_url = self._get_authorization_base_url(self.server_url)
448+
auth_base_url = _get_authorization_base_url(self.server_url)
457449
token_url = urljoin(auth_base_url, "/token")
458450

459451
refresh_data = {
@@ -523,48 +515,19 @@ def __init__(
523515

524516
self._token_lock = anyio.Lock()
525517

526-
def _get_authorization_base_url(self, server_url: str) -> str:
527-
from urllib.parse import urlparse, urlunparse
528-
529-
parsed = urlparse(server_url)
530-
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))
531-
532-
async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
533-
auth_base_url = self._get_authorization_base_url(server_url)
534-
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
535-
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
536-
537-
async with httpx.AsyncClient() as client:
538-
try:
539-
response = await client.get(url, headers=headers)
540-
if response.status_code == 404:
541-
return None
542-
response.raise_for_status()
543-
return OAuthMetadata.model_validate(response.json())
544-
except Exception:
545-
try:
546-
response = await client.get(url)
547-
if response.status_code == 404:
548-
return None
549-
response.raise_for_status()
550-
return OAuthMetadata.model_validate(response.json())
551-
except Exception:
552-
logger.exception("Failed to discover OAuth metadata")
553-
return None
554-
555518
async def _register_oauth_client(
556519
self,
557520
server_url: str,
558521
client_metadata: OAuthClientMetadata,
559522
metadata: OAuthMetadata | None = None,
560523
) -> OAuthClientInformationFull:
561524
if not metadata:
562-
metadata = await self._discover_oauth_metadata(server_url)
525+
metadata = await _discover_oauth_metadata(server_url)
563526

564527
if metadata and metadata.registration_endpoint:
565528
registration_url = str(metadata.registration_endpoint)
566529
else:
567-
auth_base_url = self._get_authorization_base_url(server_url)
530+
auth_base_url = _get_authorization_base_url(server_url)
568531
registration_url = urljoin(auth_base_url, "/register")
569532

570533
if (
@@ -636,14 +599,14 @@ async def _get_or_register_client(self) -> OAuthClientInformationFull:
636599

637600
async def _request_token(self) -> None:
638601
if not self._metadata:
639-
self._metadata = await self._discover_oauth_metadata(self.server_url)
602+
self._metadata = await _discover_oauth_metadata(self.server_url)
640603

641604
client_info = await self._get_or_register_client()
642605

643606
if self._metadata and self._metadata.token_endpoint:
644607
token_url = str(self._metadata.token_endpoint)
645608
else:
646-
auth_base_url = self._get_authorization_base_url(self.server_url)
609+
auth_base_url = _get_authorization_base_url(self.server_url)
647610
token_url = urljoin(auth_base_url, "/token")
648611

649612
token_data = {

0 commit comments

Comments
 (0)