diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index b012a0968..24fd09bcb 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -57,7 +57,10 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): + def __init__( + self, cache_tools_list: bool, client_session_timeout_seconds: float | None, + allowed_tools: list[str] | None = None + ): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -68,6 +71,8 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float (by avoiding a round-trip to the server every time). client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + + allowed_tools: the names of the tools from the server that can be accessed. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() @@ -80,6 +85,7 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True self._tools_list: list[MCPTool] | None = None + self.allowed_tools = allowed_tools @abc.abstractmethod def create_streams( @@ -145,6 +151,13 @@ async def list_tools(self) -> list[MCPTool]: # Fetch the tools from the server self._tools_list = (await self.session.list_tools()).tools + # Filter out tools that should not be available + if self.allowed_tools: + self._tools_list = [ + tool for tool in self._tools_list + if tool.name in self.allowed_tools + ] + return self._tools_list async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult: @@ -206,6 +219,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + allowed_tools: list[str] | None = None ): """Create a new MCP server based on the stdio transport. @@ -214,17 +228,23 @@ def __init__( start the server, the args to pass to the command, the environment variables to set for the server, the working directory to use when spawning the process, and the text encoding used when sending/receiving messages to the server. + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be cached and only fetched from the server once. If `False`, the tools list will be fetched from the server on each call to `list_tools()`. The cache can be invalidated by calling `invalidate_tools_cache()`. You should set this to `True` if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + name: A readable name for the server. If not provided, we'll create one from the command. + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + + allowed_tools: A list of tool names provided by the server that the client is + permitted to access. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__(cache_tools_list, client_session_timeout_seconds, allowed_tools) self.params = StdioServerParameters( command=params["command"], @@ -283,6 +303,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + allowed_tools: list[str] | None = None ): """Create a new MCP server based on the HTTP with SSE transport. @@ -302,8 +323,11 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + + allowed_tools: A list of tool names provided by the server that the client is + permitted to access. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__(cache_tools_list, client_session_timeout_seconds, allowed_tools) self.params = params self._name = name or f"sse: {self.params['url']}" @@ -362,6 +386,7 @@ def __init__( cache_tools_list: bool = False, name: str | None = None, client_session_timeout_seconds: float | None = 5, + allowed_tools: list[str] | None = None ): """Create a new MCP server based on the Streamable HTTP transport. @@ -382,8 +407,11 @@ def __init__( URL. client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + + allowed_tools: A list of tool names provided by the server that the client is + permitted to access. """ - super().__init__(cache_tools_list, client_session_timeout_seconds) + super().__init__(cache_tools_list, client_session_timeout_seconds, allowed_tools) self.params = params self._name = name or f"streamable_http: {self.params['url']}" diff --git a/tests/mcp/test_allowed_tools.py b/tests/mcp/test_allowed_tools.py new file mode 100644 index 000000000..1c0bb3b63 --- /dev/null +++ b/tests/mcp/test_allowed_tools.py @@ -0,0 +1,56 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from mcp.types import ListToolsResult, Tool as MCPTool + +from agents.mcp import MCPServerStdio + +from .helpers import DummyStreamsContextManager, tee + + +@pytest.mark.asyncio +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("mcp.client.session.ClientSession.list_tools") +async def test_server_allowed_tools( + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client +): + """Test that if we specified allowed tools, the list of tools is reduced and contains only + the allowed ones on each call to `list_tools()`. + """ + allowed_tools = ["tool1", "tool3"] + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + allowed_tools=allowed_tools + ) + + all_tools = [ + MCPTool(name="tool1", inputSchema={}), + MCPTool(name="tool2", inputSchema={}), + MCPTool(name="tool3", inputSchema={}), + MCPTool(name="tool4", inputSchema={}), + ] + + mock_list_tools.return_value = ListToolsResult(tools=all_tools) + + async with server: + tools = await server.list_tools() + + # Check it returns only the number of allowed tools + assert len(tools) == len(allowed_tools) + # Check it returns exactly only the allowed tools + assert {tool.name for tool in tools} == set(allowed_tools) + + # Call list_tools() again, should use cached filtered results + tools = await server.list_tools() + assert len(tools) == len(allowed_tools) + assert {tool.name for tool in tools} == set(allowed_tools) + + # Invalidate cache and verify filtering still works + server.invalidate_tools_cache() + tools = await server.list_tools() + assert len(tools) == len(allowed_tools) + assert {tool.name for tool in tools} == set(allowed_tools)