diff --git a/src/mcpm/core/router/app.py b/src/mcpm/core/router/app.py new file mode 100644 index 0000000..8be7ac4 --- /dev/null +++ b/src/mcpm/core/router/app.py @@ -0,0 +1,63 @@ +import asyncio +import logging +from contextlib import asynccontextmanager + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route +from starlette.types import Receive, Scope, Send + +from mcpm.monitor.event import monitor +from mcpm.router.router import MCPRouter + +from .middleware import SessionMiddleware +from .session import SessionManager +from .transport import SseTransport + +logger = logging.getLogger("mcpm.router") + +session_manager = SessionManager() +transport = SseTransport(endpoint="/messages/", session_manager=session_manager) + +router = MCPRouter(reload_server=False) + +class NoOpsResponse(Response): + async def __call__(self, scope: Scope, receive: Receive, send: Send): + # To comply with Starlette's ASGI application design, this method must return a response. + # Since no further client interaction is needed after server shutdown, we provide a no-operation response + # that allows the application to exit gracefully when cancelled by Uvicorn. + # No content is sent back to the client as EventSourceResponse has already returned a 200 status code. + pass + + +async def handle_sse(request: Request): + try: + async with transport.connect_sse(request.scope, request.receive, request._send) as (read, write): + await router.aggregated_server.run(read, write, router.aggregated_server.initialization_options) # type: ignore + except asyncio.CancelledError: + return NoOpsResponse() + + +@asynccontextmanager +async def lifespan(app): + logger.info("Starting MCPRouter...") + await router.initialize_router() + await monitor.initialize_storage() + + yield + + logger.info("Shutting down MCPRouter...") + await router.shutdown() + await monitor.close() + +app = Starlette( + debug=True, + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=transport.handle_post_message) + ], + middleware=[Middleware(SessionMiddleware, session_manager=session_manager)], + lifespan=lifespan +) diff --git a/src/mcpm/core/router/extra.py b/src/mcpm/core/router/extra.py new file mode 100644 index 0000000..4368a23 --- /dev/null +++ b/src/mcpm/core/router/extra.py @@ -0,0 +1,49 @@ + +from typing import Any, Protocol + +from mcp.types import ServerResult +from starlette.requests import Request + +from mcpm.core.router.session import Session + + +class MetaRequestProcessor(Protocol): + + def process(self, request: Request, session: Session): + ... + +class MetaResponseProcessor(Protocol): + + def process(self, response: ServerResult, request_context: dict[str, Any], response_context: dict[str, Any]) -> ServerResult: + ... + +class ProfileMetaRequestProcessor: + + def process(self, request: Request, session: Session): + profile = request.query_params.get("profile") + if not profile: + # fallback to headers + profile = request.headers.get("profile") + + if profile: + session["meta"]["profile"] = profile + +class ClientMetaRequestProcessor: + + def process(self, request: Request, session: Session): + client = request.query_params.get("client") + if client: + session["meta"]["client_id"] = client + + +class MCPResponseProcessor: + + def process(self, response: ServerResult, request_context: dict[str, Any], response_context: dict[str, Any]) -> ServerResult: + if not response.root.meta: + response.root.meta = {} + + response.root.meta.update({ + "request_context": request_context, + "response_context": response_context, + }) + return response diff --git a/src/mcpm/core/router/middleware.py b/src/mcpm/core/router/middleware.py new file mode 100644 index 0000000..6b7e08e --- /dev/null +++ b/src/mcpm/core/router/middleware.py @@ -0,0 +1,77 @@ +import logging +from uuid import UUID + +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send + +from .extra import ClientMetaRequestProcessor, MetaRequestProcessor, ProfileMetaRequestProcessor +from .session import SessionManager + +logger = logging.getLogger(__name__) + +class SessionMiddleware: + + def __init__( + self, + app: ASGIApp, + session_manager: SessionManager, + meta_request_processors: list[MetaRequestProcessor] = [ + ProfileMetaRequestProcessor(), + ClientMetaRequestProcessor(), + ] + ) -> None: + self.app = app + self.session_manager = session_manager + # patch meta data from request to session + self.meta_request_processors = meta_request_processors + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + # we related metadata with session through this middleware, so that in the transport layer we only need to handle + # session_id and dispatch message to the correct memory stream + + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = Request(scope) + + if scope["path"] == "/sse": + # retrieve metadata from query params or header + session = await self.session_manager.create_session() + # session.meta will identically copied to JSONRPCMessage + if self.meta_request_processors: + for processor in self.meta_request_processors: + processor.process(request, session) + + logger.debug(f"Created new session with ID: {session['id']}") + + scope["session_id"] = session["id"].hex + + if scope["path"] == "/messages/": + session_id_param = request.query_params.get("session_id") + if not session_id_param: + logger.debug("Missing session_id") + response = Response("session_id is required", status_code=400) + await response(scope, receive, send) + return + + # validate session_id + try: + session_id = UUID(hex=session_id_param) + except ValueError: + logger.warning(f"Received invalid session ID: {session_id_param}") + response = Response("invalid session ID", status_code=400) + await response(scope, receive, send) + return + + # if session_id is not in session manager, return 404 + if not self.session_manager.exist(session_id): + logger.debug(f"session {session_id} not found") + response = Response("session not found", status_code=404) + await response(scope, receive, send) + return + + scope["session_id"] = session_id.hex + + await self.app(scope, receive, send) diff --git a/src/mcpm/core/router/session.py b/src/mcpm/core/router/session.py new file mode 100644 index 0000000..d76b6f4 --- /dev/null +++ b/src/mcpm/core/router/session.py @@ -0,0 +1,95 @@ +from typing import Any, Protocol, TypedDict +from uuid import UUID, uuid4 + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import types + + +class Session(TypedDict): + id: UUID + # some read,write streams related with session + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + # any meta data is saved here + meta: dict[str, Any] + + +class SessionStore(Protocol): + + def exist(self, session_id: UUID) -> bool: + ... + + async def put(self, session: Session) -> None: + ... + + async def get(self, session_id: UUID) -> Session: + ... + + async def remove(self, session_id: UUID): + ... + + async def cleanup(self): + ... + + +class LocalSessionStore: + + def __init__(self): + self._store: dict[UUID, Session] = {} + + def exist(self, session_id: UUID) -> bool: + return session_id in self._store + + async def put(self, session: Session) -> None: + self._store[session["id"]] = session + + async def get(self, session_id: UUID) -> Session: + return self._store[session_id] + + async def remove(self, session_id: UUID): + session = self._store.pop(session_id, None) + if session: + await session["read_stream_writer"].aclose() + await session["write_stream"].aclose() + + async def cleanup(self): + keys = list(self._store.keys()) + for session_id in keys: + await self.remove(session_id) + + +class SessionManager: + + def __init__(self): + self.session_store: SessionStore = LocalSessionStore() + + def exist(self, session_id: UUID) -> bool: + return self.session_store.exist(session_id) + + async def create_session(self, meta: dict[str, Any] = {}) -> Session: + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + session_id = uuid4() + session = Session( + id=session_id, + read_stream=read_stream, + read_stream_writer=read_stream_writer, + write_stream=write_stream, + write_stream_reader=write_stream_reader, + meta=meta + ) + await self.session_store.put(session) + return session + + async def get_session(self, session_id: UUID) -> Session: + return await self.session_store.get(session_id) + + async def close_session(self, session_id: UUID): + await self.session_store.remove(session_id) + + async def cleanup_resources(self): + await self.session_store.cleanup() diff --git a/src/mcpm/core/router/transport.py b/src/mcpm/core/router/transport.py new file mode 100644 index 0000000..6fa57b1 --- /dev/null +++ b/src/mcpm/core/router/transport.py @@ -0,0 +1,123 @@ +import asyncio +import logging +from contextlib import asynccontextmanager +from typing import Any +from urllib.parse import quote +from uuid import UUID + +import anyio +from mcp import types +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from starlette.background import BackgroundTask +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from .session import Session, SessionManager + +logger = logging.getLogger(__name__) + + +class SseTransport: + + def __init__(self, endpoint: str, session_manager: SessionManager) -> None: + self.session_manager = session_manager + self._endpoint = endpoint + + @asynccontextmanager + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): + session_id_hex = scope["session_id"] + session_id: UUID = UUID(hex=session_id_hex) + session = await self.session_manager.get_session(session_id) + + session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" + + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) + + async def sse_writer(): + logger.debug("Starting SSE writer") + async with sse_stream_writer, session["write_stream_reader"]: + await sse_stream_writer.send({"event": "endpoint", "data": session_uri}) + logger.debug(f"Sent endpoint event: {session_uri}") + + async for message in session["write_stream_reader"]: + logger.debug(f"Sending message via SSE: {message}") + # we should pop the meta field from message + if isinstance(message.root, types.JSONRPCResponse): + message.root.result.pop("_meta", None) + await sse_stream_writer.send( + { + "event": "message", + "data": message.model_dump_json(by_alias=True, exclude_none=True), + } + ) + + async with anyio.create_task_group() as tg: + async def on_client_disconnect(): + await self.session_manager.close_session(session_id) + + try: + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + background=BackgroundTask(on_client_disconnect), + ) + logger.debug("Starting SSE response task") + tg.start_soon(response, scope, receive, send) + + logger.debug("Yielding read and write streams") + # Due to limitations with interrupting the MCP server run operation, + # this will always block here regardless of client disconnection status + yield (session["read_stream"], session["write_stream"]) + except asyncio.CancelledError as exc: + logger.warning(f"SSE connection for session {session_id} was cancelled") + tg.cancel_scope.cancel() + # raise the exception again so that to interrupt mcp server run operation + raise exc + finally: + # for server shutdown + await self.session_manager.cleanup_resources() + + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send): + + session_id = scope["session_id"] + session: Session = await self.session_manager.get_session(UUID(hex=session_id)) + + request = Request(scope, receive) + body = await request.body() + + # send message to writer + writer = session["read_stream_writer"] + try: + message = types.JSONRPCMessage.model_validate_json(body) + logger.debug(f"Validated client message: {message}") + except ValidationError as err: + logger.error(f"Failed to parse message: {err}") + response = Response("Could not parse message", status_code=400) + await response(scope, receive, send) + try: + await writer.send(err) + except (BrokenPipeError, ConnectionError, OSError) as pipe_err: + logger.warning(f"Failed to send error due to pipe issue: {pipe_err}") + return + + logger.debug(f"Sending message to writer: {message}") + response = Response("Accepted", status_code=202) + await response(scope, receive, send) + + if session["meta"]: + if isinstance(message.root, types.JSONRPCRequest): + message.root.params = message.root.params or {} + message.root.params.setdefault("_meta", {}).update(session["meta"]) + + # add error handling, catch possible pipe errors + try: + await writer.send(message) + except (BrokenPipeError, ConnectionError, OSError) as e: + # if it's EPIPE error or other connection error, log it but don't throw an exception + if isinstance(e, OSError) and e.errno == 32: # EPIPE + logger.warning(f"EPIPE error when sending message to session {session_id}, connection may be closing") + else: + logger.warning(f"Connection error when sending message to session {session_id}: {e}") + await self.session_manager.close_session(session_id) diff --git a/src/mcpm/router/router.py b/src/mcpm/router/router.py index 64df52d..5b4e48f 100644 --- a/src/mcpm/router/router.py +++ b/src/mcpm/router/router.py @@ -19,6 +19,7 @@ from starlette.routing import Mount, Route from starlette.types import AppType, Lifespan +from mcpm.core.router.extra import MCPResponseProcessor, MetaResponseProcessor from mcpm.monitor.base import AccessEventType from mcpm.monitor.event import trace_event from mcpm.profile.profile_config import ProfileConfigManager @@ -62,6 +63,8 @@ def __init__(self, reload_server: bool = False, profile_path: str | None = None, self.watcher = ConfigWatcher(self.profile_manager.profile_path) self.strict: bool = strict self.error_log_manager = ServerErrorLogManager() + # we can just inject the processor here + self.meta_response_processor: MetaResponseProcessor = MCPResponseProcessor() def get_unique_servers(self) -> list[ServerConfig]: profiles = self.profile_manager.list_profiles() @@ -263,7 +266,11 @@ async def list_prompts(req: types.ListPromptsRequest) -> types.ServerResult: server_id = get_capability_server_id("prompts", server_prompt_id) if server_id in active_servers: prompts.append(prompt.model_copy(update={"name": server_prompt_id})) - return types.ServerResult(types.ListPromptsResult(prompts=prompts)) + return self.meta_response_processor.process( + types.ServerResult(types.ListPromptsResult(prompts=prompts)), + request_context=req.params.meta.model_dump(), # type: ignore + response_context={} + ) @trace_event(AccessEventType.PROMPT_EXECUTION) async def get_prompt(req: types.GetPromptRequest) -> types.ServerResult: @@ -279,7 +286,11 @@ async def get_prompt(req: types.GetPromptRequest) -> types.ServerResult: if prompt is None: return empty_result() result = await self.server_sessions[server_id].session.get_prompt(prompt.name, req.params.arguments) - return types.ServerResult(result) + return self.meta_response_processor.process( + types.ServerResult(result), + request_context=req.params.meta.model_dump(), # type: ignore + response_context={"server_id": server_id, "prompt": prompt.name} + ) async def list_resources(req: types.ListResourcesRequest) -> types.ServerResult: resources: list[types.Resource] = [] @@ -290,7 +301,11 @@ async def list_resources(req: types.ListResourcesRequest) -> types.ServerResult: continue if server_id in active_servers: resources.append(resource.model_copy(update={"uri": AnyUrl(server_resource_id)})) - return types.ServerResult(types.ListResourcesResult(resources=resources)) + return self.meta_response_processor.process( + types.ServerResult(types.ListResourcesResult(resources=resources)), + request_context=req.params.meta.model_dump(), # type: ignore + response_context={} + ) async def list_resource_templates(req: types.ListResourceTemplatesRequest) -> types.ServerResult: resource_templates: list[types.ResourceTemplate] = [] @@ -303,7 +318,11 @@ async def list_resource_templates(req: types.ListResourceTemplatesRequest) -> ty resource_templates.append( resource_template.model_copy(update={"uriTemplate": server_resource_template_id}) ) - return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=resource_templates)) + return self.meta_response_processor.process( + types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=resource_templates)), + request_context=req.params.meta.model_dump(), # type: ignore + response_context={} + ) @trace_event(AccessEventType.RESOURCE_ACCESS) async def read_resource(req: types.ReadResourceRequest) -> types.ServerResult: @@ -319,7 +338,11 @@ async def read_resource(req: types.ReadResourceRequest) -> types.ServerResult: return empty_result() result = await self.server_sessions[server_id].session.read_resource(resource.uri) - return types.ServerResult(result) + return self.meta_response_processor.process( + types.ServerResult(result), + request_context=req.params.meta.model_dump(), # type: ignore + response_context={"server_id": server_id, "resource": str(resource.uri)} + ) async def list_tools(req: types.ListToolsRequest) -> types.ServerResult: tools: list[types.Tool] = [] @@ -334,7 +357,11 @@ async def list_tools(req: types.ListToolsRequest) -> types.ServerResult: if not tools: return empty_result() - return types.ServerResult(types.ListToolsResult(tools=tools)) + return self.meta_response_processor.process( + types.ServerResult(types.ListToolsResult(tools=tools)), + request_context=req.params.meta.model_dump(), # type: ignore + response_context={} + ) @trace_event(AccessEventType.TOOL_INVOCATION) async def call_tool(req: types.CallToolRequest) -> types.ServerResult: @@ -357,7 +384,11 @@ async def call_tool(req: types.CallToolRequest) -> types.ServerResult: try: result = await self.server_sessions[server_id].session.call_tool(tool.name, req.params.arguments or {}) - return types.ServerResult(result) + return self.meta_response_processor.process( + types.ServerResult(result), + request_context=req.params.meta.model_dump(), # type: ignore + response_context={"server_id": server_id, "tool": tool.name} + ) except Exception as e: logger.error(f"Error calling tool {tool_name} on server {server_id}: {e}") return types.ServerResult( @@ -392,8 +423,12 @@ async def complete(req: types.CompleteRequest) -> types.ServerResult: if server_id not in active_servers: return empty_result() - result = await self.server_sessions[server_id].session.complete(ref, req.params.arguments or {}) - return types.ServerResult(result) + result = await self.server_sessions[server_id].session.complete(ref, req.params.argument.model_dump() or {}) + return self.meta_response_processor.process( + types.ServerResult(result), + request_context=req.params.meta.model_dump(), # type: ignore + response_context={"server_id": server_id} + ) app.request_handlers[types.ListPromptsRequest] = list_prompts app.request_handlers[types.GetPromptRequest] = get_prompt diff --git a/src/mcpm/utils/errlog_manager.py b/src/mcpm/utils/errlog_manager.py index f5d198f..536bb05 100644 --- a/src/mcpm/utils/errlog_manager.py +++ b/src/mcpm/utils/errlog_manager.py @@ -30,5 +30,6 @@ def close_errlog_file(self, server_id: str) -> None: del self._log_files[server_id] def close_all(self) -> None: - for server_id in self._log_files: + keys = list(self._log_files.keys()) + for server_id in keys: self.close_errlog_file(server_id)