Skip to content

Commit 1c9c101

Browse files
committed
[WIP] Tool guardrails
1 parent a81601a commit 1c9c101

File tree

5 files changed

+402
-17
lines changed

5 files changed

+402
-17
lines changed

src/agents/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@
3333
output_guardrail,
3434
)
3535
from .handoffs import Handoff, HandoffInputData, HandoffInputFilter, handoff
36+
from .tool_guardrails import (
37+
ToolGuardrailFunctionOutput,
38+
ToolInputGuardrail,
39+
ToolInputGuardrailData,
40+
ToolOutputGuardrail,
41+
ToolOutputGuardrailData,
42+
tool_input_guardrail,
43+
tool_output_guardrail,
44+
)
3645
from .items import (
3746
HandoffCallItem,
3847
HandoffOutputItem,
@@ -204,6 +213,13 @@ def enable_verbose_stdout_logging():
204213
"GuardrailFunctionOutput",
205214
"input_guardrail",
206215
"output_guardrail",
216+
"ToolInputGuardrail",
217+
"ToolOutputGuardrail",
218+
"ToolGuardrailFunctionOutput",
219+
"ToolInputGuardrailData",
220+
"ToolOutputGuardrailData",
221+
"tool_input_guardrail",
222+
"tool_output_guardrail",
207223
"handoff",
208224
"Handoff",
209225
"HandoffInputData",

src/agents/_run_impl.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@
8080
Tool,
8181
)
8282
from .tool_context import ToolContext
83+
from .tool_guardrails import (
84+
ToolInputGuardrailData,
85+
ToolOutputGuardrailData,
86+
)
8387
from .tracing import (
8488
SpanError,
8589
Trace,
@@ -556,24 +560,64 @@ async def run_single_tool(
556560
if config.trace_include_sensitive_data:
557561
span_fn.span_data.input = tool_call.arguments
558562
try:
559-
_, _, result = await asyncio.gather(
560-
hooks.on_tool_start(tool_context, agent, func_tool),
561-
(
562-
agent.hooks.on_tool_start(tool_context, agent, func_tool)
563-
if agent.hooks
564-
else _coro.noop_coroutine()
565-
),
566-
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
567-
)
563+
# 1) Run input tool guardrails, if any
564+
final_result: Any | None = None
565+
if func_tool.tool_input_guardrails:
566+
for guardrail in func_tool.tool_input_guardrails:
567+
gr_out = await guardrail.run(
568+
ToolInputGuardrailData(
569+
context=tool_context,
570+
agent=agent,
571+
tool_call=tool_call,
572+
)
573+
)
574+
if gr_out.tripwire_triggered:
575+
# Use the provided model message as the tool output
576+
final_result = str(gr_out.model_message or "")
577+
break
578+
579+
if final_result is None:
580+
# 2) Actually run the tool
581+
await asyncio.gather(
582+
hooks.on_tool_start(tool_context, agent, func_tool),
583+
(
584+
agent.hooks.on_tool_start(tool_context, agent, func_tool)
585+
if agent.hooks
586+
else _coro.noop_coroutine()
587+
),
588+
)
589+
real_result = await func_tool.on_invoke_tool(
590+
tool_context, tool_call.arguments
591+
)
568592

569-
await asyncio.gather(
570-
hooks.on_tool_end(tool_context, agent, func_tool, result),
571-
(
572-
agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
573-
if agent.hooks
574-
else _coro.noop_coroutine()
575-
),
576-
)
593+
# 3) Run output tool guardrails, if any
594+
final_result = real_result
595+
if func_tool.tool_output_guardrails:
596+
for guardrail in func_tool.tool_output_guardrails:
597+
gr_out = await guardrail.run(
598+
ToolOutputGuardrailData(
599+
context=tool_context,
600+
agent=agent,
601+
tool_call=tool_call,
602+
output=real_result,
603+
)
604+
)
605+
if gr_out.tripwire_triggered:
606+
final_result = str(gr_out.model_message or "")
607+
break
608+
609+
# 4) Tool end hooks (with final result, which may have been overridden)
610+
await asyncio.gather(
611+
hooks.on_tool_end(tool_context, agent, func_tool, final_result),
612+
(
613+
agent.hooks.on_tool_end(
614+
tool_context, agent, func_tool, final_result
615+
)
616+
if agent.hooks
617+
else _coro.noop_coroutine()
618+
),
619+
)
620+
result = final_result
577621
except Exception as e:
578622
_error_tracing.attach_error_to_current_span(
579623
SpanError(

src/agents/tool.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ class FunctionTool:
9393
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
9494
based on your context/state."""
9595

96+
# Tool-specific guardrails
97+
tool_input_guardrails: list["ToolInputGuardrail[Any]"] | None = None
98+
"""Optional list of input guardrails to run before invoking this tool."""
99+
100+
tool_output_guardrails: list["ToolOutputGuardrail[Any]"] | None = None
101+
"""Optional list of output guardrails to run after invoking this tool."""
102+
96103
def __post_init__(self):
97104
if self.strict_json_schema:
98105
self.params_json_schema = ensure_strict_json_schema(self.params_json_schema)

src/agents/tool_guardrails.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
from collections.abc import Awaitable
5+
from dataclasses import dataclass
6+
from typing import Any, Callable, Generic, Optional, overload
7+
8+
from typing_extensions import TypeVar
9+
10+
from .agent import Agent
11+
from .tool_context import ToolContext
12+
from .util._types import MaybeAwaitable
13+
from openai.types.responses import ResponseFunctionToolCall
14+
15+
16+
@dataclass
17+
class ToolGuardrailFunctionOutput:
18+
"""The output of a tool guardrail function.
19+
20+
- `output_info`: Optional data about checks performed.
21+
- `tripwire_triggered`: Whether the guardrail was tripped.
22+
- `model_message`: Message to send back to the model as the tool output if tripped.
23+
"""
24+
25+
output_info: Any
26+
tripwire_triggered: bool
27+
model_message: Optional[str] = None
28+
29+
30+
@dataclass
31+
class ToolInputGuardrailData:
32+
"""Input data passed to a tool input guardrail function."""
33+
34+
context: ToolContext[Any]
35+
agent: Agent[Any]
36+
tool_call: ResponseFunctionToolCall
37+
38+
39+
@dataclass
40+
class ToolOutputGuardrailData(ToolInputGuardrailData):
41+
"""Input data passed to a tool output guardrail function.
42+
43+
Extends input data with the tool's output.
44+
"""
45+
46+
output: Any
47+
48+
49+
TContext_co = TypeVar("TContext_co", bound=Any, covariant=True)
50+
51+
52+
@dataclass
53+
class ToolInputGuardrail(Generic[TContext_co]):
54+
"""A guardrail that runs before a function tool is invoked."""
55+
56+
guardrail_function: Callable[[ToolInputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput]]
57+
name: str | None = None
58+
59+
def get_name(self) -> str:
60+
return self.name or self.guardrail_function.__name__
61+
62+
async def run(
63+
self, data: ToolInputGuardrailData
64+
) -> ToolGuardrailFunctionOutput:
65+
result = self.guardrail_function(data)
66+
if inspect.isawaitable(result):
67+
return await result # type: ignore[return-value]
68+
return result # type: ignore[return-value]
69+
70+
71+
@dataclass
72+
class ToolOutputGuardrail(Generic[TContext_co]):
73+
"""A guardrail that runs after a function tool is invoked."""
74+
75+
guardrail_function: Callable[[ToolOutputGuardrailData], MaybeAwaitable[ToolGuardrailFunctionOutput]]
76+
name: str | None = None
77+
78+
def get_name(self) -> str:
79+
return self.name or self.guardrail_function.__name__
80+
81+
async def run(
82+
self, data: ToolOutputGuardrailData
83+
) -> ToolGuardrailFunctionOutput:
84+
result = self.guardrail_function(data)
85+
if inspect.isawaitable(result):
86+
return await result # type: ignore[return-value]
87+
return result # type: ignore[return-value]
88+
89+
90+
# Decorators
91+
_ToolInputFuncSync = Callable[[ToolInputGuardrailData], ToolGuardrailFunctionOutput]
92+
_ToolInputFuncAsync = Callable[[ToolInputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]]
93+
94+
95+
@overload
96+
def tool_input_guardrail(func: _ToolInputFuncSync): # type: ignore[overload-overlap]
97+
...
98+
99+
100+
@overload
101+
def tool_input_guardrail(func: _ToolInputFuncAsync): # type: ignore[overload-overlap]
102+
...
103+
104+
105+
@overload
106+
def tool_input_guardrail(*, name: str | None = None) -> Callable[[
107+
_ToolInputFuncSync | _ToolInputFuncAsync
108+
], ToolInputGuardrail[Any]]: ...
109+
110+
111+
def tool_input_guardrail(
112+
func: _ToolInputFuncSync | _ToolInputFuncAsync | None = None,
113+
*,
114+
name: str | None = None,
115+
) -> ToolInputGuardrail[Any] | Callable[[
116+
_ToolInputFuncSync | _ToolInputFuncAsync
117+
], ToolInputGuardrail[Any]]:
118+
"""Decorator to create a ToolInputGuardrail from a function."""
119+
120+
def decorator(f: _ToolInputFuncSync | _ToolInputFuncAsync) -> ToolInputGuardrail[Any]:
121+
return ToolInputGuardrail(guardrail_function=f, name=name or f.__name__)
122+
123+
if func is not None:
124+
return decorator(func)
125+
return decorator
126+
127+
128+
_ToolOutputFuncSync = Callable[[ToolOutputGuardrailData], ToolGuardrailFunctionOutput]
129+
_ToolOutputFuncAsync = Callable[[ToolOutputGuardrailData], Awaitable[ToolGuardrailFunctionOutput]]
130+
131+
132+
@overload
133+
def tool_output_guardrail(func: _ToolOutputFuncSync): # type: ignore[overload-overlap]
134+
...
135+
136+
137+
@overload
138+
def tool_output_guardrail(func: _ToolOutputFuncAsync): # type: ignore[overload-overlap]
139+
...
140+
141+
142+
@overload
143+
def tool_output_guardrail(*, name: str | None = None) -> Callable[[
144+
_ToolOutputFuncSync | _ToolOutputFuncAsync
145+
], ToolOutputGuardrail[Any]]: ...
146+
147+
148+
def tool_output_guardrail(
149+
func: _ToolOutputFuncSync | _ToolOutputFuncAsync | None = None,
150+
*,
151+
name: str | None = None,
152+
) -> ToolOutputGuardrail[Any] | Callable[[
153+
_ToolOutputFuncSync | _ToolOutputFuncAsync
154+
], ToolOutputGuardrail[Any]]:
155+
"""Decorator to create a ToolOutputGuardrail from a function."""
156+
157+
def decorator(f: _ToolOutputFuncSync | _ToolOutputFuncAsync) -> ToolOutputGuardrail[Any]:
158+
return ToolOutputGuardrail(guardrail_function=f, name=name or f.__name__)
159+
160+
if func is not None:
161+
return decorator(func)
162+
return decorator
163+

0 commit comments

Comments
 (0)