@@ -48,6 +48,44 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
48
48
...
49
49
50
50
51
+ def _get_authorization_base_url (server_url : str ) -> str :
52
+ """Return the authorization base URL for ``server_url``.
53
+
54
+ Per MCP spec 2.3.2, the path component must be discarded so that
55
+ ``https://api.example.com/v1/mcp`` becomes ``https://api.example.com``.
56
+ """
57
+ from urllib .parse import urlparse , urlunparse
58
+
59
+ parsed = urlparse (server_url )
60
+ return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
61
+
62
+
63
+ async def _discover_oauth_metadata (server_url : str ) -> OAuthMetadata | None :
64
+ """Discover OAuth metadata from the server's well-known endpoint."""
65
+
66
+ auth_base_url = _get_authorization_base_url (server_url )
67
+ url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
68
+ headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
69
+
70
+ async with httpx .AsyncClient () as client :
71
+ try :
72
+ response = await client .get (url , headers = headers )
73
+ if response .status_code == 404 :
74
+ return None
75
+ response .raise_for_status ()
76
+ return OAuthMetadata .model_validate (response .json ())
77
+ except Exception :
78
+ try :
79
+ response = await client .get (url )
80
+ if response .status_code == 404 :
81
+ return None
82
+ response .raise_for_status ()
83
+ return OAuthMetadata .model_validate (response .json ())
84
+ except Exception :
85
+ logger .exception ("Failed to discover OAuth metadata" )
86
+ return None
87
+
88
+
51
89
class OAuthClientProvider (httpx .Auth ):
52
90
"""
53
91
Authentication for httpx using anyio.
@@ -110,52 +148,6 @@ def _generate_code_challenge(self, code_verifier: str) -> str:
110
148
digest = hashlib .sha256 (code_verifier .encode ()).digest ()
111
149
return base64 .urlsafe_b64encode (digest ).decode ().rstrip ("=" )
112
150
113
- def _get_authorization_base_url (self , server_url : str ) -> str :
114
- """
115
- Extract base URL by removing path component.
116
-
117
- Per MCP spec 2.3.2: https://api.example.com/v1/mcp -> https://api.example.com
118
- """
119
- from urllib .parse import urlparse , urlunparse
120
-
121
- parsed = urlparse (server_url )
122
- # Remove path component
123
- return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
124
-
125
- async def _discover_oauth_metadata (self , server_url : str ) -> OAuthMetadata | None :
126
- """
127
- Discover OAuth metadata from server's well-known endpoint.
128
- """
129
- # Extract base URL per MCP spec
130
- auth_base_url = self ._get_authorization_base_url (server_url )
131
- url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
132
- headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
133
-
134
- async with httpx .AsyncClient () as client :
135
- try :
136
- response = await client .get (url , headers = headers )
137
- if response .status_code == 404 :
138
- return None
139
- response .raise_for_status ()
140
- metadata_json = response .json ()
141
- logger .debug (f"OAuth metadata discovered: { metadata_json } " )
142
- return OAuthMetadata .model_validate (metadata_json )
143
- except Exception :
144
- # Retry without MCP header for CORS compatibility
145
- try :
146
- response = await client .get (url )
147
- if response .status_code == 404 :
148
- return None
149
- response .raise_for_status ()
150
- metadata_json = response .json ()
151
- logger .debug (
152
- f"OAuth metadata discovered (no MCP header): { metadata_json } "
153
- )
154
- return OAuthMetadata .model_validate (metadata_json )
155
- except Exception :
156
- logger .exception ("Failed to discover OAuth metadata" )
157
- return None
158
-
159
151
async def _register_oauth_client (
160
152
self ,
161
153
server_url : str ,
@@ -166,13 +158,13 @@ async def _register_oauth_client(
166
158
Register OAuth client with server.
167
159
"""
168
160
if not metadata :
169
- metadata = await self . _discover_oauth_metadata (server_url )
161
+ metadata = await _discover_oauth_metadata (server_url )
170
162
171
163
if metadata and metadata .registration_endpoint :
172
164
registration_url = str (metadata .registration_endpoint )
173
165
else :
174
166
# Use fallback registration endpoint
175
- auth_base_url = self . _get_authorization_base_url (server_url )
167
+ auth_base_url = _get_authorization_base_url (server_url )
176
168
registration_url = urljoin (auth_base_url , "/register" )
177
169
178
170
# Handle default scope
@@ -321,7 +313,7 @@ async def _perform_oauth_flow(self) -> None:
321
313
322
314
# Discover OAuth metadata
323
315
if not self ._metadata :
324
- self ._metadata = await self . _discover_oauth_metadata (self .server_url )
316
+ self ._metadata = await _discover_oauth_metadata (self .server_url )
325
317
326
318
# Ensure client registration
327
319
client_info = await self ._get_or_register_client ()
@@ -335,7 +327,7 @@ async def _perform_oauth_flow(self) -> None:
335
327
auth_url_base = str (self ._metadata .authorization_endpoint )
336
328
else :
337
329
# Use fallback authorization endpoint
338
- auth_base_url = self . _get_authorization_base_url (self .server_url )
330
+ auth_base_url = _get_authorization_base_url (self .server_url )
339
331
auth_url_base = urljoin (auth_base_url , "/authorize" )
340
332
341
333
# Build authorization URL
@@ -386,7 +378,7 @@ async def _exchange_code_for_token(
386
378
token_url = str (self ._metadata .token_endpoint )
387
379
else :
388
380
# Use fallback token endpoint
389
- auth_base_url = self . _get_authorization_base_url (self .server_url )
381
+ auth_base_url = _get_authorization_base_url (self .server_url )
390
382
token_url = urljoin (auth_base_url , "/token" )
391
383
392
384
token_data = {
@@ -453,7 +445,7 @@ async def _refresh_access_token(self) -> bool:
453
445
token_url = str (self ._metadata .token_endpoint )
454
446
else :
455
447
# Use fallback token endpoint
456
- auth_base_url = self . _get_authorization_base_url (self .server_url )
448
+ auth_base_url = _get_authorization_base_url (self .server_url )
457
449
token_url = urljoin (auth_base_url , "/token" )
458
450
459
451
refresh_data = {
@@ -523,48 +515,19 @@ def __init__(
523
515
524
516
self ._token_lock = anyio .Lock ()
525
517
526
- def _get_authorization_base_url (self , server_url : str ) -> str :
527
- from urllib .parse import urlparse , urlunparse
528
-
529
- parsed = urlparse (server_url )
530
- return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
531
-
532
- async def _discover_oauth_metadata (self , server_url : str ) -> OAuthMetadata | None :
533
- auth_base_url = self ._get_authorization_base_url (server_url )
534
- url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
535
- headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
536
-
537
- async with httpx .AsyncClient () as client :
538
- try :
539
- response = await client .get (url , headers = headers )
540
- if response .status_code == 404 :
541
- return None
542
- response .raise_for_status ()
543
- return OAuthMetadata .model_validate (response .json ())
544
- except Exception :
545
- try :
546
- response = await client .get (url )
547
- if response .status_code == 404 :
548
- return None
549
- response .raise_for_status ()
550
- return OAuthMetadata .model_validate (response .json ())
551
- except Exception :
552
- logger .exception ("Failed to discover OAuth metadata" )
553
- return None
554
-
555
518
async def _register_oauth_client (
556
519
self ,
557
520
server_url : str ,
558
521
client_metadata : OAuthClientMetadata ,
559
522
metadata : OAuthMetadata | None = None ,
560
523
) -> OAuthClientInformationFull :
561
524
if not metadata :
562
- metadata = await self . _discover_oauth_metadata (server_url )
525
+ metadata = await _discover_oauth_metadata (server_url )
563
526
564
527
if metadata and metadata .registration_endpoint :
565
528
registration_url = str (metadata .registration_endpoint )
566
529
else :
567
- auth_base_url = self . _get_authorization_base_url (server_url )
530
+ auth_base_url = _get_authorization_base_url (server_url )
568
531
registration_url = urljoin (auth_base_url , "/register" )
569
532
570
533
if (
@@ -636,14 +599,14 @@ async def _get_or_register_client(self) -> OAuthClientInformationFull:
636
599
637
600
async def _request_token (self ) -> None :
638
601
if not self ._metadata :
639
- self ._metadata = await self . _discover_oauth_metadata (self .server_url )
602
+ self ._metadata = await _discover_oauth_metadata (self .server_url )
640
603
641
604
client_info = await self ._get_or_register_client ()
642
605
643
606
if self ._metadata and self ._metadata .token_endpoint :
644
607
token_url = str (self ._metadata .token_endpoint )
645
608
else :
646
- auth_base_url = self . _get_authorization_base_url (self .server_url )
609
+ auth_base_url = _get_authorization_base_url (self .server_url )
647
610
token_url = urljoin (auth_base_url , "/token" )
648
611
649
612
token_data = {
0 commit comments