Skip to content

Added mcp tool filtering and unit test #854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""

def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None):
def __init__(
self, cache_tools_list: bool, client_session_timeout_seconds: float | None,
allowed_tools: list[str] | None = None
):
"""
Args:
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
Expand All @@ -68,6 +71,8 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float
(by avoiding a round-trip to the server every time).

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.

allowed_tools: the names of the tools from the server that can be accessed.
"""
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
Expand All @@ -80,6 +85,7 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float
# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
self._tools_list: list[MCPTool] | None = None
self.allowed_tools = allowed_tools

@abc.abstractmethod
def create_streams(
Expand Down Expand Up @@ -145,6 +151,13 @@ async def list_tools(self) -> list[MCPTool]:

# Fetch the tools from the server
self._tools_list = (await self.session.list_tools()).tools
# Filter out tools that should not be available
if self.allowed_tools:
self._tools_list = [
tool for tool in self._tools_list
if tool.name in self.allowed_tools
]

return self._tools_list

async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> CallToolResult:
Expand Down Expand Up @@ -206,6 +219,7 @@ def __init__(
cache_tools_list: bool = False,
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
allowed_tools: list[str] | None = None
):
"""Create a new MCP server based on the stdio transport.

Expand All @@ -214,17 +228,23 @@ def __init__(
start the server, the args to pass to the command, the environment variables to
set for the server, the working directory to use when spawning the process, and
the text encoding used when sending/receiving messages to the server.

cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
cached and only fetched from the server once. If `False`, the tools list will be
fetched from the server on each call to `list_tools()`. The cache can be
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
if you know the server will not change its tools list, because it can drastically
improve latency (by avoiding a round-trip to the server every time).

name: A readable name for the server. If not provided, we'll create one from the
command.

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.

allowed_tools: A list of tool names provided by the server that the client is
permitted to access.
"""
super().__init__(cache_tools_list, client_session_timeout_seconds)
super().__init__(cache_tools_list, client_session_timeout_seconds, allowed_tools)

self.params = StdioServerParameters(
command=params["command"],
Expand Down Expand Up @@ -283,6 +303,7 @@ def __init__(
cache_tools_list: bool = False,
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
allowed_tools: list[str] | None = None
):
"""Create a new MCP server based on the HTTP with SSE transport.

Expand All @@ -302,8 +323,11 @@ def __init__(
URL.

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.

allowed_tools: A list of tool names provided by the server that the client is
permitted to access.
"""
super().__init__(cache_tools_list, client_session_timeout_seconds)
super().__init__(cache_tools_list, client_session_timeout_seconds, allowed_tools)

self.params = params
self._name = name or f"sse: {self.params['url']}"
Expand Down Expand Up @@ -362,6 +386,7 @@ def __init__(
cache_tools_list: bool = False,
name: str | None = None,
client_session_timeout_seconds: float | None = 5,
allowed_tools: list[str] | None = None
):
"""Create a new MCP server based on the Streamable HTTP transport.

Expand All @@ -382,8 +407,11 @@ def __init__(
URL.

client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.

allowed_tools: A list of tool names provided by the server that the client is
permitted to access.
"""
super().__init__(cache_tools_list, client_session_timeout_seconds)
super().__init__(cache_tools_list, client_session_timeout_seconds, allowed_tools)

self.params = params
self._name = name or f"streamable_http: {self.params['url']}"
Expand Down
56 changes: 56 additions & 0 deletions tests/mcp/test_allowed_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from unittest.mock import AsyncMock, patch

import pytest
from mcp.types import ListToolsResult, Tool as MCPTool

from agents.mcp import MCPServerStdio

from .helpers import DummyStreamsContextManager, tee


@pytest.mark.asyncio
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
@patch("mcp.client.session.ClientSession.list_tools")
async def test_server_allowed_tools(
mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client
):
"""Test that if we specified allowed tools, the list of tools is reduced and contains only
the allowed ones on each call to `list_tools()`.
"""
allowed_tools = ["tool1", "tool3"]
server = MCPServerStdio(
params={
"command": tee,
},
cache_tools_list=True,
allowed_tools=allowed_tools
)

all_tools = [
MCPTool(name="tool1", inputSchema={}),
MCPTool(name="tool2", inputSchema={}),
MCPTool(name="tool3", inputSchema={}),
MCPTool(name="tool4", inputSchema={}),
]

mock_list_tools.return_value = ListToolsResult(tools=all_tools)

async with server:
tools = await server.list_tools()

# Check it returns only the number of allowed tools
assert len(tools) == len(allowed_tools)
# Check it returns exactly only the allowed tools
assert {tool.name for tool in tools} == set(allowed_tools)

# Call list_tools() again, should use cached filtered results
tools = await server.list_tools()
assert len(tools) == len(allowed_tools)
assert {tool.name for tool in tools} == set(allowed_tools)

# Invalidate cache and verify filtering still works
server.invalidate_tools_cache()
tools = await server.list_tools()
assert len(tools) == len(allowed_tools)
assert {tool.name for tool in tools} == set(allowed_tools)