diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 3381247c..683cb52f 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -44,14 +44,16 @@ class SlidingWindowConversationManager(ConversationManager): invalid window states. """ - def __init__(self, window_size: int = 40): + def __init__(self, window_size: int = 40, should_truncate_results: bool = True): """Initialize the sliding window conversation manager. Args: window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. + should_truncate_results: Truncate tool results when a message is too large for the model's context window """ self.window_size = window_size + self.should_truncate_results = should_truncate_results def apply_management(self, agent: "Agent") -> None: """Apply the sliding window to the agent's messages array to maintain a manageable history size. @@ -127,6 +129,19 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: converted. """ messages = agent.messages + + # Try to truncate the tool result first + last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages) + if last_message_idx_with_tool_results is not None and self.should_truncate_results: + logger.debug( + "message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results + ) + results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results) + if results_truncated: + logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) + return + + # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size @@ -151,3 +166,69 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: # Overwrite message history messages[:] = messages[trim_index:] + + def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: + """Truncate tool results in a message to reduce context size. + + When a message contains tool results that are too large for the model's context window, this function + replaces the content of those tool results with a simple error message. + + Args: + messages: The conversation message history. + msg_idx: Index of the message containing tool results to truncate. + + Returns: + True if any changes were made to the message, False otherwise. + """ + if msg_idx >= len(messages) or msg_idx < 0: + return False + + message = messages[msg_idx] + changes_made = False + tool_result_too_large_message = "The tool result was too large!" + for i, content in enumerate(message.get("content", [])): + if isinstance(content, dict) and "toolResult" in content: + tool_result_content_text = next( + (item["text"] for item in content["toolResult"]["content"] if "text" in item), + "", + ) + # make the overwriting logic togglable + if ( + message["content"][i]["toolResult"]["status"] == "error" + and tool_result_content_text == tool_result_too_large_message + ): + logger.info("ToolResult has already been updated, skipping overwrite") + return False + # Update status to error with informative message + message["content"][i]["toolResult"]["status"] = "error" + message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}] + changes_made = True + + return changes_made + + def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: + """Find the index of the last message containing tool results. + + This is useful for identifying messages that might need to be truncated to reduce context size. + + Args: + messages: The conversation message history. + + Returns: + Index of the last message with tool results, or None if no such message exists. + """ + # Iterate backwards through all messages (from newest to oldest) + for idx in range(len(messages) - 1, -1, -1): + # Check if this message has any content with toolResult + current_message = messages[idx] + has_tool_result = False + + for content in current_message.get("content", []): + if isinstance(content, dict) and "toolResult" in content: + has_tool_result = True + break + + if has_tool_result: + return idx + + return None diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py index a5c85668..6dc0d9ee 100644 --- a/src/strands/event_loop/error_handler.py +++ b/src/strands/event_loop/error_handler.py @@ -6,14 +6,9 @@ import logging import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple -from ..telemetry.metrics import EventLoopMetrics -from ..types.content import Message, Messages -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException -from ..types.models import Model -from ..types.streaming import StopReason -from .message_processor import find_last_message_with_tool_results, truncate_tool_results +from ..types.exceptions import ModelThrottledException logger = logging.getLogger(__name__) @@ -59,63 +54,3 @@ def handle_throttling_error( callback_handler(force_stop=True, force_stop_reason=str(e)) return False, current_delay - - -def handle_input_too_long_error( - e: ContextWindowOverflowException, - messages: Messages, - model: Model, - system_prompt: Optional[str], - tool_config: Any, - callback_handler: Any, - tool_handler: Any, - kwargs: Dict[str, Any], -) -> Tuple[StopReason, Message, EventLoopMetrics, Any]: - """Handle 'Input is too long' errors by truncating tool results. - - When a context window overflow exception occurs (input too long for the model), this function attempts to recover - by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the - function will make a call to the event loop. - - Args: - e: The ContextWindowOverflowException that occurred. - messages: The conversation message history. - model: Model provider for running inference. - system_prompt: System prompt for the model. - tool_config: Tool configuration for the conversation. - callback_handler: Callback for processing events as they happen. - tool_handler: Handler for tool execution. - kwargs: Additional arguments for the event loop. - - Returns: - The results from the event loop call if successful. - - Raises: - ContextWindowOverflowException: If messages cannot be truncated. - """ - from .event_loop import recurse_event_loop # Import here to avoid circular imports - - # Find the last message with tool results - last_message_with_tool_results = find_last_message_with_tool_results(messages) - - # If we found a message with toolResult - if last_message_with_tool_results is not None: - logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results) - - # Truncate the tool results in this message - truncate_tool_results(messages, last_message_with_tool_results) - - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) - - # If we can't handle this error, pass it up - callback_handler(force_stop=True, force_stop_reason=str(e)) - logger.error("an exception occurred in event_loop_cycle | %s", e) - raise ContextWindowOverflowException() from e diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 23d7bd0f..71165926 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -22,7 +22,7 @@ from ..types.models import Model from ..types.streaming import Metrics, StopReason from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse -from .error_handler import handle_input_too_long_error, handle_throttling_error +from .error_handler import handle_throttling_error from .message_processor import clean_orphaned_empty_tool_uses from .streaming import stream_messages @@ -160,16 +160,7 @@ def event_loop_cycle( except ContextWindowOverflowException as e: if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - return handle_input_too_long_error( - e, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) + raise e except ModelThrottledException as e: if model_invoke_span: @@ -248,6 +239,10 @@ def event_loop_cycle( # Don't invoke the callback_handler or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise + except ContextWindowOverflowException as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + raise e except Exception as e: if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) diff --git a/src/strands/event_loop/message_processor.py b/src/strands/event_loop/message_processor.py index 61e1c1d7..4e1a39dc 100644 --- a/src/strands/event_loop/message_processor.py +++ b/src/strands/event_loop/message_processor.py @@ -5,7 +5,7 @@ """ import logging -from typing import Dict, Optional, Set, Tuple +from typing import Dict, Set, Tuple from ..types.content import Messages @@ -103,60 +103,3 @@ def clean_orphaned_empty_tool_uses(messages: Messages) -> bool: logger.warning("failed to fix orphaned tool use | %s", e) return True - - -def find_last_message_with_tool_results(messages: Messages) -> Optional[int]: - """Find the index of the last message containing tool results. - - This is useful for identifying messages that might need to be truncated to reduce context size. - - Args: - messages: The conversation message history. - - Returns: - Index of the last message with tool results, or None if no such message exists. - """ - # Iterate backwards through all messages (from newest to oldest) - for idx in range(len(messages) - 1, -1, -1): - # Check if this message has any content with toolResult - current_message = messages[idx] - has_tool_result = False - - for content in current_message.get("content", []): - if isinstance(content, dict) and "toolResult" in content: - has_tool_result = True - break - - if has_tool_result: - return idx - - return None - - -def truncate_tool_results(messages: Messages, msg_idx: int) -> bool: - """Truncate tool results in a message to reduce context size. - - When a message contains tool results that are too large for the model's context window, this function replaces the - content of those tool results with a simple error message. - - Args: - messages: The conversation message history. - msg_idx: Index of the message containing tool results to truncate. - - Returns: - True if any changes were made to the message, False otherwise. - """ - if msg_idx >= len(messages) or msg_idx < 0: - return False - - message = messages[msg_idx] - changes_made = False - - for i, content in enumerate(message.get("content", [])): - if isinstance(content, dict) and "toolResult" in content: - # Update status to error with informative message - message["content"][i]["toolResult"]["status"] = "error" - message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}] - changes_made = True - - return changes_made diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 0ea20b64..5104e280 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -438,7 +438,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite_loop(mock_model, agent, tool): - conversation_manager = SlidingWindowConversationManager(window_size=500) + conversation_manager = SlidingWindowConversationManager(window_size=500, should_truncate_results=False) conversation_manager_spy = unittest.mock.Mock(wraps=conversation_manager) agent.conversation_manager = conversation_manager_spy @@ -484,10 +484,43 @@ def test_agent__call__null_conversation_window_manager__doesnt_infinite_loop(moc agent("Test!") +def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent): + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "input": {"hello": "world"}, "name": "test"}}], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Some large input!"}], "status": "success"}} + ], + }, + ] + agent.messages = messages + + mock_model.mock_converse.side_effect = ContextWindowOverflowException( + RuntimeError("Input is too long for requested model") + ) + + with pytest.raises(ContextWindowOverflowException): + agent("Test!") + + def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager) agent.conversation_manager = conversation_manager_spy + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + { + "role": "assistant", + "content": [{"text": "Hi!"}], + }, + ] + agent.messages = messages + mock_model.mock_converse.side_effect = [ [ { @@ -504,6 +537,9 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, ], + # Will truncate the tool result + ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), + # Will reduce the context ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), [], ] @@ -538,7 +574,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool): unittest.mock.ANY, ) - conversation_manager_spy.reduce_context.assert_not_called() + assert conversation_manager_spy.reduce_context.call_count == 2 assert conversation_manager_spy.apply_management.call_count == 1 diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index bbec3cd1..7d43199e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -25,6 +25,7 @@ def test_is_assistant_message(role, exp_result): def conversation_manager(request): params = { "window_size": 2, + "should_truncate_results": False, } if hasattr(request, "param"): params.update(request.param) @@ -168,7 +169,7 @@ def test_apply_management(conversation_manager, messages, expected_messages): def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): - manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1, False) messages = [ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, @@ -182,6 +183,40 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con assert messages == original_messages +def test_sliding_window_conversation_manager_with_tool_results_truncated(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "789", "content": [{"text": "large input"}], "status": "success"}} + ], + }, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent) + + expected_messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "789", + "content": [{"text": "The tool result was too large!"}], + "status": "error", + } + } + ], + }, + ] + + assert messages == expected_messages + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = strands.agent.conversation_manager.NullConversationManager() diff --git a/tests/strands/event_loop/test_error_handler.py b/tests/strands/event_loop/test_error_handler.py index 7249adef..fe1b3e9f 100644 --- a/tests/strands/event_loop/test_error_handler.py +++ b/tests/strands/event_loop/test_error_handler.py @@ -4,7 +4,7 @@ import pytest import strands -from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException +from strands.types.exceptions import ModelThrottledException @pytest.fixture @@ -95,126 +95,3 @@ def test_handle_throttling_error_does_not_exist(callback_handler, kwargs): assert tru_retry == exp_retry and tru_delay == exp_delay callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(exception)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - sdk_event_loop.return_value = "success" - - messages = [ - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "t1", "status": "success", "content": [{"text": "needs truncation"}]}} - ], - } - ] - - tru_result = strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - exp_result = "success" - - tru_messages = messages - exp_messages = [ - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - }, - }, - ], - }, - ] - - assert tru_result == exp_result and tru_messages == exp_messages - - sdk_event_loop.assert_called_once_with( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - request_state="value", - ) - - callback_handler.assert_not_called() - - -@pytest.mark.parametrize("event_stream_error", ["Other error"], indirect=True) -def test_handle_input_too_long_error_does_not_exist( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) - - -@pytest.mark.parametrize("event_stream_error", ["Input is too long for requested model"], indirect=True) -def test_handle_input_too_long_error_no_tool_result( - sdk_event_loop, - event_stream_error, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, -): - messages = [] - - with pytest.raises(ContextWindowOverflowException): - strands.event_loop.error_handler.handle_input_too_long_error( - event_stream_error, - messages, - model, - system_prompt, - tool_config, - callback_handler, - tool_handler, - kwargs, - ) - - sdk_event_loop.assert_not_called() - callback_handler.assert_called_with(force_stop=True, force_stop_reason=str(event_stream_error)) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 569462f1..8c46e009 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -157,55 +157,6 @@ def test_event_loop_cycle_text_response( assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state -def test_event_loop_cycle_text_response_input_too_long( - model, - model_id, - system_prompt, - messages, - tool_config, - callback_handler, - tool_handler, - tool_execution_handler, -): - model.converse.side_effect = [ - ContextWindowOverflowException(RuntimeError("Input is too long for requested model")), - [ - {"contentBlockDelta": {"delta": {"text": "test text"}}}, - {"contentBlockStop": {}}, - ], - ] - messages.append( - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "t1", - "status": "success", - "content": [{"text": "2025-04-01T00:00:00"}], - }, - }, - ], - } - ) - - tru_stop_reason, tru_message, _, tru_request_state = strands.event_loop.event_loop.event_loop_cycle( - model=model, - model_id=model_id, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - tool_execution_handler=tool_execution_handler, - ) - exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} - exp_request_state = {} - - assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state - - @unittest.mock.patch.object(strands.event_loop.error_handler, "time") def test_event_loop_cycle_text_response_throttling( model, diff --git a/tests/strands/event_loop/test_message_processor.py b/tests/strands/event_loop/test_message_processor.py index 395c71a1..fcf531df 100644 --- a/tests/strands/event_loop/test_message_processor.py +++ b/tests/strands/event_loop/test_message_processor.py @@ -45,84 +45,3 @@ def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages): result = message_processor.clean_orphaned_empty_tool_uses(test_messages) assert result == expected assert test_messages == expected_messages - - -@pytest.mark.parametrize( - "messages,expected_idx", - [ - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - 1, - ), - ( - [ - {"role": "user", "content": [{"text": "hi"}]}, - {"role": "assistant", "content": [{"text": "ok"}]}, - ], - None, - ), - ( - [], - None, - ), - ], -) -def test_find_last_message_with_tool_results(messages, expected_idx): - idx = message_processor.find_last_message_with_tool_results(messages) - assert idx == expected_idx - - -@pytest.mark.parametrize( - "messages,msg_idx,expected_changed,expected_content", - [ - ( - [ - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}], - } - ], - 0, - True, - [ - { - "toolResult": { - "toolUseId": "1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], - } - } - ], - ), - ( - [{"role": "user", "content": [{"text": "no tool result"}]}], - 0, - False, - [{"text": "no tool result"}], - ), - ( - [], - 0, - False, - [], - ), - ( - [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}], - 2, - False, - [{"toolResult": {"toolUseId": "1"}}], - ), - ], -) -def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content): - test_messages = copy.deepcopy(messages) - changed = message_processor.truncate_tool_results(test_messages, msg_idx) - assert changed == expected_changed - if 0 <= msg_idx < len(test_messages): - assert test_messages[msg_idx]["content"] == expected_content - else: - assert test_messages == messages