Skip to content

Commit 2d6c062

Browse files
committed
merge with recent branch
1 parent d04d17c commit 2d6c062

File tree

15 files changed

+55
-115
lines changed

15 files changed

+55
-115
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -859,14 +859,14 @@ async def main():
859859
# instead of OAuthClientProvider.
860860

861861
# If you already have a user token from another provider, you can
862-
# exchange it for an MCP token using the token-exchange grant
862+
# exchange it for an MCP token using the token_exchange grant
863863
# implemented by TokenExchangeProvider.
864864
token_exchange_auth = TokenExchangeProvider(
865865
server_url="https://api.example.com",
866866
client_metadata=OAuthClientMetadata(
867867
client_name="My Client",
868868
redirect_uris=["http://localhost:3000/callback"],
869-
grant_types=["token-exchange"],
869+
grant_types=["token_exchange"],
870870
response_types=["code"],
871871
),
872872
storage=CustomTokenStorage(),

docs/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
The Python SDK exposes the entire `mcp` package for use in your own projects.
22
It includes an OAuth server implementation with support for the RFC 8693
3-
`token-exchange` grant type.
3+
`token_exchange` grant type.
44

55
::: mcp

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ This is the MCP Server implementation in Python.
44

55
It only contains the [API Reference](api.md) for the time being.
66

7-
The built-in OAuth server supports the RFC 8693 `token-exchange` grant type,
7+
The built-in OAuth server supports the RFC 8693 `token_exchange` grant type,
88
allowing clients to exchange user tokens from external providers for MCP
99
access tokens.

examples/servers/simple-auth/mcp_simple_auth/server.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,7 @@ async def exchange_token(
252252
"""Exchange an external token for an MCP access token."""
253253
raise NotImplementedError("Token exchange is not supported")
254254

255-
async def exchange_client_credentials(
256-
self, client: OAuthClientInformationFull, scopes: list[str]
257-
) -> OAuthToken:
255+
async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken:
258256
"""Exchange client credentials for an access token."""
259257
token = f"mcp_{secrets.token_hex(32)}"
260258
self.tokens[token] = AccessToken(

src/mcp/client/auth.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import anyio
1818
import httpx
1919

20-
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
2120
from mcp.shared.auth import (
2221
OAuthClientInformationFull,
2322
OAuthClientMetadata,
@@ -90,9 +89,7 @@ async def _discover_oauth_metadata(server_url: str) -> OAuthMetadata | None:
9089
return None
9190
response.raise_for_status()
9291
metadata_json = response.json()
93-
logger.debug(
94-
f"OAuth metadata discovered (no MCP header): {metadata_json}"
95-
)
92+
logger.debug(f"OAuth metadata discovered (no MCP header): {metadata_json}")
9693
return OAuthMetadata.model_validate(metadata_json)
9794
except Exception:
9895
logger.exception("Failed to discover OAuth metadata")
@@ -513,16 +510,10 @@ async def _register_oauth_client(
513510
auth_base_url = _get_authorization_base_url(server_url)
514511
registration_url = urljoin(auth_base_url, "/register")
515512

516-
if (
517-
client_metadata.scope is None
518-
and metadata
519-
and metadata.scopes_supported is not None
520-
):
513+
if client_metadata.scope is None and metadata and metadata.scopes_supported is not None:
521514
client_metadata.scope = " ".join(metadata.scopes_supported)
522515

523-
registration_data = client_metadata.model_dump(
524-
by_alias=True, mode="json", exclude_none=True
525-
)
516+
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
526517

527518
async with httpx.AsyncClient() as client:
528519
response = await client.post(
@@ -558,9 +549,7 @@ async def _validate_token_scopes(self, token_response: OAuthToken) -> None:
558549
returned_scopes = set(token_response.scope.split())
559550
unauthorized_scopes = returned_scopes - requested_scopes
560551
if unauthorized_scopes:
561-
raise Exception(
562-
f"Server granted unauthorized scopes: {unauthorized_scopes}."
563-
)
552+
raise Exception(f"Server granted unauthorized scopes: {unauthorized_scopes}.")
564553
else:
565554
granted = set(token_response.scope.split())
566555
logger.debug(
@@ -574,9 +563,7 @@ async def initialize(self) -> None:
574563

575564
async def _get_or_register_client(self) -> OAuthClientInformationFull:
576565
if not self._client_info:
577-
self._client_info = await self._register_oauth_client(
578-
self.server_url, self.client_metadata, self._metadata
579-
)
566+
self._client_info = await self._register_oauth_client(self.server_url, self.client_metadata, self._metadata)
580567
await self.storage.set_client_info(self._client_info)
581568
return self._client_info
582569

@@ -612,9 +599,7 @@ async def _request_token(self) -> None:
612599
)
613600

614601
if response.status_code != 200:
615-
raise Exception(
616-
f"Token request failed: {response.status_code} {response.text}"
617-
)
602+
raise Exception(f"Token request failed: {response.status_code} {response.text}")
618603

619604
token_response = OAuthToken.model_validate(response.json())
620605
await self._validate_token_scopes(token_response)
@@ -633,17 +618,13 @@ async def ensure_token(self) -> None:
633618
return
634619
await self._request_token()
635620

636-
async def async_auth_flow(
637-
self, request: httpx.Request
638-
) -> AsyncGenerator[httpx.Request, httpx.Response]:
621+
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
639622
if not self._has_valid_token():
640623
await self.initialize()
641624
await self.ensure_token()
642625

643626
if self._current_tokens and self._current_tokens.access_token:
644-
request.headers["Authorization"] = (
645-
f"Bearer {self._current_tokens.access_token}"
646-
)
627+
request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}"
647628

648629
response = yield request
649630

@@ -688,12 +669,10 @@ async def _request_token(self) -> None:
688669
token_url = urljoin(auth_base_url, "/token")
689670

690671
subject_token = await self.subject_token_supplier()
691-
actor_token = (
692-
await self.actor_token_supplier() if self.actor_token_supplier else None
693-
)
672+
actor_token = await self.actor_token_supplier() if self.actor_token_supplier else None
694673

695674
token_data = {
696-
"grant_type": "token-exchange",
675+
"grant_type": "token_exchange",
697676
"client_id": client_info.client_id,
698677
"subject_token": subject_token,
699678
"subject_token_type": self.subject_token_type,
@@ -722,9 +701,7 @@ async def _request_token(self) -> None:
722701
)
723702

724703
if response.status_code != 200:
725-
raise Exception(
726-
f"Token request failed: {response.status_code} {response.text}"
727-
)
704+
raise Exception(f"Token request failed: {response.status_code} {response.text}")
728705

729706
token_response = OAuthToken.model_validate(response.json())
730707
await self._validate_token_scopes(token_response)

src/mcp/client/streamable_http.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,7 @@ async def _handle_sse_event(
176176
# Call resumption token callback if we have an ID. Only update
177177
# the resumption token on notifications to avoid overwriting it
178178
# with the token from the final response.
179-
if (
180-
sse.id
181-
and resumption_callback
182-
and not isinstance(message.root, JSONRPCResponse | JSONRPCError)
183-
):
179+
if sse.id and resumption_callback and not isinstance(message.root, JSONRPCResponse | JSONRPCError):
184180
await resumption_callback(sse.id.strip())
185181

186182
# If this is a response or error return True indicating completion

src/mcp/server/auth/handlers/register.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ async def handle(self, request: Request) -> Response:
7272
valid_sets = [
7373
{"authorization_code", "refresh_token"},
7474
{"client_credentials"},
75-
{"token-exchange"},
75+
{"token_exchange"},
7676
]
7777

7878
if grant_types_set not in valid_sets:

src/mcp/server/auth/handlers/token.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,11 @@ class ClientCredentialsRequest(BaseModel):
4747
class TokenExchangeRequest(BaseModel):
4848
"""RFC 8693 token exchange request."""
4949

50-
grant_type: Literal["token-exchange"]
50+
grant_type: Literal["token_exchange"]
5151
subject_token: str = Field(..., description="Token to exchange")
5252
subject_token_type: str = Field(..., description="Type of the subject token")
5353
actor_token: str | None = Field(None, description="Optional actor token")
54-
actor_token_type: str | None = Field(
55-
None, description="Type of the actor token if provided"
56-
)
54+
actor_token_type: str | None = Field(None, description="Type of the actor token if provided")
5755
resource: str | None = None
5856
audience: str | None = None
5957
scope: str | None = None
@@ -64,19 +62,13 @@ class TokenExchangeRequest(BaseModel):
6462
class TokenRequest(
6563
RootModel[
6664
Annotated[
67-
AuthorizationCodeRequest
68-
| RefreshTokenRequest
69-
| ClientCredentialsRequest
70-
| TokenExchangeRequest,
65+
AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest | TokenExchangeRequest,
7166
Field(discriminator="grant_type"),
7267
]
7368
]
7469
):
7570
root: Annotated[
76-
AuthorizationCodeRequest
77-
| RefreshTokenRequest
78-
| ClientCredentialsRequest
79-
| TokenExchangeRequest,
71+
AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest | TokenExchangeRequest,
8072
Field(discriminator="grant_type"),
8173
]
8274

@@ -223,9 +215,7 @@ async def handle(self, request: Request):
223215
else []
224216
)
225217
try:
226-
tokens = await self.provider.exchange_client_credentials(
227-
client_info, scopes
228-
)
218+
tokens = await self.provider.exchange_client_credentials(client_info, scopes)
229219
except TokenError as e:
230220
return self.response(
231221
TokenErrorResponse(

src/mcp/server/auth/provider.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,7 @@ async def exchange_refresh_token(
239239
"""
240240
...
241241

242-
async def exchange_client_credentials(
243-
self, client: OAuthClientInformationFull, scopes: list[str]
244-
) -> OAuthToken:
242+
async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken:
245243
"""Exchange client credentials for an access token."""
246244
...
247245

src/mcp/server/auth/routes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def build_metadata(
163163
"authorization_code",
164164
"refresh_token",
165165
"client_credentials",
166-
"token-exchange",
166+
"token_exchange",
167167
],
168168
token_endpoint_auth_methods_supported=["client_secret_post"],
169169
token_endpoint_auth_signing_alg_values_supported=None,

src/mcp/shared/auth.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ class OAuthClientMetadata(BaseModel):
4747
# client_secret_post;
4848
# ie: we do not support client_secret_basic
4949
token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post"
50-
# grant_types: this implementation supports authorization_code, refresh_token, client_credentials, & token-exchange
50+
# grant_types: this implementation supports authorization_code, refresh_token, client_credentials, & token_exchange
5151
grant_types: list[
5252
Literal[
5353
"authorization_code",
5454
"refresh_token",
5555
"client_credentials",
56-
"token-exchange",
56+
"token_exchange",
5757
]
5858
] = [
5959
"authorization_code",
@@ -129,14 +129,12 @@ class OAuthMetadata(BaseModel):
129129
"authorization_code",
130130
"refresh_token",
131131
"client_credentials",
132-
"token-exchange",
132+
"token_exchange",
133133
]
134134
]
135135
| None
136136
) = None
137-
token_endpoint_auth_methods_supported: (
138-
list[Literal["none", "client_secret_post"]] | None
139-
) = None
137+
token_endpoint_auth_methods_supported: list[Literal["none", "client_secret_post"]] | None = None
140138
token_endpoint_auth_signing_alg_values_supported: None = None
141139
service_documentation: AnyHttpUrl | None = None
142140
ui_locales_supported: list[str] | None = None

src/mcp/shared/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ async def _receive_loop(self) -> None:
370370
)
371371

372372
session_message = SessionMessage(message=JSONRPCMessage(error_response))
373-
373+
374374
await self._write_stream.send(session_message)
375375

376376
elif isinstance(message.message.root, JSONRPCNotification):

tests/client/test_auth.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def oauth_metadata():
9191
"authorization_code",
9292
"refresh_token",
9393
"client_credentials",
94-
"token-exchange",
94+
"token_exchange",
9595
],
9696
code_challenge_methods_supported=["S256"],
9797
)
@@ -205,13 +205,13 @@ async def test_generate_code_challenge(self, oauth_provider):
205205
async def test_get_authorization_base_url(self, oauth_provider):
206206
"""Test authorization base URL extraction."""
207207
# Test with path
208-
assert (_get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com")
208+
assert _get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com"
209209

210210
# Test with no path
211-
assert (_get_authorization_base_url("https://api.example.com") == "https://api.example.com")
211+
assert _get_authorization_base_url("https://api.example.com") == "https://api.example.com"
212212

213213
# Test with port
214-
assert (_get_authorization_base_url("https://api.example.com:8080/path/to/mcp") == "https://api.example.com:8080")
214+
assert _get_authorization_base_url("https://api.example.com:8080/path/to/mcp") == "https://api.example.com:8080"
215215

216216
@pytest.mark.anyio
217217
async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata):
@@ -930,7 +930,7 @@ def test_build_metadata(
930930
"authorization_code",
931931
"refresh_token",
932932
"client_credentials",
933-
"token-exchange",
933+
"token_exchange",
934934
],
935935
token_endpoint_auth_methods_supported=["client_secret_post"],
936936
service_documentation=AnyHttpUrl(service_documentation_url),
@@ -969,10 +969,7 @@ async def test_request_token_success(
969969
await client_credentials_provider.ensure_token()
970970

971971
mock_client.post.assert_called_once()
972-
assert (
973-
client_credentials_provider._current_tokens.access_token
974-
== oauth_token.access_token
975-
)
972+
assert client_credentials_provider._current_tokens.access_token == oauth_token.access_token
976973

977974
@pytest.mark.anyio
978975
async def test_async_auth_flow(self, client_credentials_provider, oauth_token):
@@ -985,10 +982,7 @@ async def test_async_auth_flow(self, client_credentials_provider, oauth_token):
985982

986983
auth_flow = client_credentials_provider.async_auth_flow(request)
987984
updated_request = await auth_flow.__anext__()
988-
assert (
989-
updated_request.headers["Authorization"]
990-
== f"Bearer {oauth_token.access_token}"
991-
)
985+
assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}"
992986
try:
993987
await auth_flow.asend(mock_response)
994988
except StopAsyncIteration:
@@ -1022,7 +1016,4 @@ async def test_request_token_success(
10221016
await token_exchange_provider.ensure_token()
10231017

10241018
mock_client.post.assert_called_once()
1025-
assert (
1026-
token_exchange_provider._current_tokens.access_token
1027-
== oauth_token.access_token
1028-
)
1019+
assert token_exchange_provider._current_tokens.access_token == oauth_token.access_token

0 commit comments

Comments
 (0)