Skip to content

chore: moved truncation logic to conversation manager and added should_truncate_results #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ class SlidingWindowConversationManager(ConversationManager):
invalid window states.
"""

def __init__(self, window_size: int = 40):
def __init__(self, window_size: int = 40, should_truncate_results: bool = True):
"""Initialize the sliding window conversation manager.

Args:
window_size: Maximum number of messages to keep in the agent's history.
Defaults to 40 messages.
should_truncate_results: Truncate tool results when a message is too large for the model's context window
"""
self.window_size = window_size
self.should_truncate_results = should_truncate_results

def apply_management(self, agent: "Agent") -> None:
"""Apply the sliding window to the agent's messages array to maintain a manageable history size.
Expand Down Expand Up @@ -127,6 +129,19 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:
converted.
"""
messages = agent.messages

# Try to truncate the tool result first
last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages)
if last_message_idx_with_tool_results is not None and self.should_truncate_results:
logger.debug(
"message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results
)
results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results)
if results_truncated:
logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results)
return

# Try to trim index id when tool result cannot be truncated anymore
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size

Expand All @@ -151,3 +166,69 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None:

# Overwrite message history
messages[:] = messages[trim_index:]

def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
"""Truncate tool results in a message to reduce context size.

When a message contains tool results that are too large for the model's context window, this function
replaces the content of those tool results with a simple error message.

Args:
messages: The conversation message history.
msg_idx: Index of the message containing tool results to truncate.

Returns:
True if any changes were made to the message, False otherwise.
"""
if msg_idx >= len(messages) or msg_idx < 0:
return False

message = messages[msg_idx]
changes_made = False
tool_result_too_large_message = "The tool result was too large!"
for i, content in enumerate(message.get("content", [])):
if isinstance(content, dict) and "toolResult" in content:
tool_result_content_text = next(
(item["text"] for item in content["toolResult"]["content"] if "text" in item),
"",
)
# make the overwriting logic togglable
if (
message["content"][i]["toolResult"]["status"] == "error"
and tool_result_content_text == tool_result_too_large_message
):
logger.info("ToolResult has already been updated, skipping overwrite")
return False
# Update status to error with informative message
message["content"][i]["toolResult"]["status"] = "error"
message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}]
changes_made = True

return changes_made

def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]:
"""Find the index of the last message containing tool results.

This is useful for identifying messages that might need to be truncated to reduce context size.

Args:
messages: The conversation message history.

Returns:
Index of the last message with tool results, or None if no such message exists.
"""
# Iterate backwards through all messages (from newest to oldest)
for idx in range(len(messages) - 1, -1, -1):
# Check if this message has any content with toolResult
current_message = messages[idx]
has_tool_result = False

for content in current_message.get("content", []):
if isinstance(content, dict) and "toolResult" in content:
has_tool_result = True
break

if has_tool_result:
return idx

return None
69 changes: 2 additions & 67 deletions src/strands/event_loop/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,9 @@

import logging
import time
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Tuple

from ..telemetry.metrics import EventLoopMetrics
from ..types.content import Message, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.models import Model
from ..types.streaming import StopReason
from .message_processor import find_last_message_with_tool_results, truncate_tool_results
from ..types.exceptions import ModelThrottledException

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,63 +54,3 @@ def handle_throttling_error(

callback_handler(force_stop=True, force_stop_reason=str(e))
return False, current_delay


def handle_input_too_long_error(
e: ContextWindowOverflowException,
messages: Messages,
model: Model,
system_prompt: Optional[str],
tool_config: Any,
callback_handler: Any,
tool_handler: Any,
kwargs: Dict[str, Any],
) -> Tuple[StopReason, Message, EventLoopMetrics, Any]:
"""Handle 'Input is too long' errors by truncating tool results.

When a context window overflow exception occurs (input too long for the model), this function attempts to recover
by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the
function will make a call to the event loop.

Args:
e: The ContextWindowOverflowException that occurred.
messages: The conversation message history.
model: Model provider for running inference.
system_prompt: System prompt for the model.
tool_config: Tool configuration for the conversation.
callback_handler: Callback for processing events as they happen.
tool_handler: Handler for tool execution.
kwargs: Additional arguments for the event loop.

Returns:
The results from the event loop call if successful.

Raises:
ContextWindowOverflowException: If messages cannot be truncated.
"""
from .event_loop import recurse_event_loop # Import here to avoid circular imports

# Find the last message with tool results
last_message_with_tool_results = find_last_message_with_tool_results(messages)

# If we found a message with toolResult
if last_message_with_tool_results is not None:
logger.debug("message_index=<%s> | found message with tool results at index", last_message_with_tool_results)

# Truncate the tool results in this message
truncate_tool_results(messages, last_message_with_tool_results)

return recurse_event_loop(
model=model,
system_prompt=system_prompt,
messages=messages,
tool_config=tool_config,
callback_handler=callback_handler,
tool_handler=tool_handler,
**kwargs,
)

# If we can't handle this error, pass it up
callback_handler(force_stop=True, force_stop_reason=str(e))
logger.error("an exception occurred in event_loop_cycle | %s", e)
raise ContextWindowOverflowException() from e
17 changes: 6 additions & 11 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..types.models import Model
from ..types.streaming import Metrics, StopReason
from ..types.tools import ToolConfig, ToolHandler, ToolResult, ToolUse
from .error_handler import handle_input_too_long_error, handle_throttling_error
from .error_handler import handle_throttling_error
from .message_processor import clean_orphaned_empty_tool_uses
from .streaming import stream_messages

Expand Down Expand Up @@ -160,16 +160,7 @@ def event_loop_cycle(
except ContextWindowOverflowException as e:
if model_invoke_span:
tracer.end_span_with_error(model_invoke_span, str(e), e)
return handle_input_too_long_error(
e,
messages,
model,
system_prompt,
tool_config,
callback_handler,
tool_handler,
kwargs,
)
raise e

except ModelThrottledException as e:
if model_invoke_span:
Expand Down Expand Up @@ -248,6 +239,10 @@ def event_loop_cycle(
# Don't invoke the callback_handler or log the exception - we already did it when we
# raised the exception and we don't need that duplication.
raise
except ContextWindowOverflowException as e:
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)
raise e
except Exception as e:
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)
Expand Down
59 changes: 1 addition & 58 deletions src/strands/event_loop/message_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import logging
from typing import Dict, Optional, Set, Tuple
from typing import Dict, Set, Tuple

from ..types.content import Messages

Expand Down Expand Up @@ -103,60 +103,3 @@ def clean_orphaned_empty_tool_uses(messages: Messages) -> bool:
logger.warning("failed to fix orphaned tool use | %s", e)

return True


def find_last_message_with_tool_results(messages: Messages) -> Optional[int]:
"""Find the index of the last message containing tool results.

This is useful for identifying messages that might need to be truncated to reduce context size.

Args:
messages: The conversation message history.

Returns:
Index of the last message with tool results, or None if no such message exists.
"""
# Iterate backwards through all messages (from newest to oldest)
for idx in range(len(messages) - 1, -1, -1):
# Check if this message has any content with toolResult
current_message = messages[idx]
has_tool_result = False

for content in current_message.get("content", []):
if isinstance(content, dict) and "toolResult" in content:
has_tool_result = True
break

if has_tool_result:
return idx

return None


def truncate_tool_results(messages: Messages, msg_idx: int) -> bool:
"""Truncate tool results in a message to reduce context size.

When a message contains tool results that are too large for the model's context window, this function replaces the
content of those tool results with a simple error message.

Args:
messages: The conversation message history.
msg_idx: Index of the message containing tool results to truncate.

Returns:
True if any changes were made to the message, False otherwise.
"""
if msg_idx >= len(messages) or msg_idx < 0:
return False

message = messages[msg_idx]
changes_made = False

for i, content in enumerate(message.get("content", [])):
if isinstance(content, dict) and "toolResult" in content:
# Update status to error with informative message
message["content"][i]["toolResult"]["status"] = "error"
message["content"][i]["toolResult"]["content"] = [{"text": "The tool result was too large!"}]
changes_made = True

return changes_made
40 changes: 38 additions & 2 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool):


def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite_loop(mock_model, agent, tool):
conversation_manager = SlidingWindowConversationManager(window_size=500)
conversation_manager = SlidingWindowConversationManager(window_size=500, should_truncate_results=False)
conversation_manager_spy = unittest.mock.Mock(wraps=conversation_manager)
agent.conversation_manager = conversation_manager_spy

Expand Down Expand Up @@ -484,10 +484,43 @@ def test_agent__call__null_conversation_window_manager__doesnt_infinite_loop(moc
agent("Test!")


def test_agent__call__tool_truncation_doesnt_infinite_loop(mock_model, agent):
messages: Messages = [
{"role": "user", "content": [{"text": "Hello!"}]},
{
"role": "assistant",
"content": [{"toolUse": {"toolUseId": "123", "input": {"hello": "world"}, "name": "test"}}],
},
{
"role": "user",
"content": [
{"toolResult": {"toolUseId": "123", "content": [{"text": "Some large input!"}], "status": "success"}}
],
},
]
agent.messages = messages

mock_model.mock_converse.side_effect = ContextWindowOverflowException(
RuntimeError("Input is too long for requested model")
)

with pytest.raises(ContextWindowOverflowException):
agent("Test!")


def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool):
conversation_manager_spy = unittest.mock.Mock(wraps=agent.conversation_manager)
agent.conversation_manager = conversation_manager_spy

messages: Messages = [
{"role": "user", "content": [{"text": "Hello!"}]},
{
"role": "assistant",
"content": [{"text": "Hi!"}],
},
]
agent.messages = messages

mock_model.mock_converse.side_effect = [
[
{
Expand All @@ -504,6 +537,9 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool):
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "tool_use"}},
],
# Will truncate the tool result
ContextWindowOverflowException(RuntimeError("Input is too long for requested model")),
# Will reduce the context
ContextWindowOverflowException(RuntimeError("Input is too long for requested model")),
[],
]
Expand Down Expand Up @@ -538,7 +574,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool):
unittest.mock.ANY,
)

conversation_manager_spy.reduce_context.assert_not_called()
assert conversation_manager_spy.reduce_context.call_count == 2
assert conversation_manager_spy.apply_management.call_count == 1


Expand Down
Loading
Loading