From 7f403c80c2fdd3561ec4c4ba470a751e1eaf4a7b Mon Sep 17 00:00:00 2001 From: Andy Xie Date: Wed, 26 Feb 2025 03:32:43 -0500 Subject: [PATCH 1/2] Added reflexion prompting agent. --- .../lib/agents/reflexion/__init__.py | 0 .../lib/agents/reflexion/agent.py | 153 ++++++++++++++++++ .../lib/agents/reflexion/prompts.py | 92 +++++++++++ .../lib/agents/reflexion/tool_parser.py | 48 ++++++ 4 files changed, 293 insertions(+) create mode 100644 src/llama_stack_client/lib/agents/reflexion/__init__.py create mode 100644 src/llama_stack_client/lib/agents/reflexion/agent.py create mode 100644 src/llama_stack_client/lib/agents/reflexion/prompts.py create mode 100644 src/llama_stack_client/lib/agents/reflexion/tool_parser.py diff --git a/src/llama_stack_client/lib/agents/reflexion/__init__.py b/src/llama_stack_client/lib/agents/reflexion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llama_stack_client/lib/agents/reflexion/agent.py b/src/llama_stack_client/lib/agents/reflexion/agent.py new file mode 100644 index 00000000..da32cab2 --- /dev/null +++ b/src/llama_stack_client/lib/agents/reflexion/agent.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, Optional, Tuple, List + +from llama_stack_client import LlamaStackClient +from llama_stack_client.types.agent_create_params import AgentConfig +from pydantic import BaseModel + +from ..agent import Agent +from ..client_tool import ClientTool +from ..tool_parser import ToolParser +from .prompts import DEFAULT_REFLEXION_AGENT_SYSTEM_PROMPT_TEMPLATE + +from .tool_parser import ReflexionToolParser + + +class Action(BaseModel): + tool_name: str + tool_params: Dict[str, Any] + + +class ReflexionOutput(BaseModel): + thought: str + reflection: Optional[str] = None + action: Optional[Action] = None + answer: Optional[str] = None + + +class ReflexionAgent(Agent): + """Reflexion agent. + + Extends ReAct agent with self-reflection capabilities to improve reasoning and tool use. + """ + + def __init__( + self, + client: LlamaStackClient, + model: str, + builtin_toolgroups: Tuple[str] = (), + client_tools: Tuple[ClientTool] = (), + tool_parser: ToolParser = ReflexionToolParser(), + json_response_format: bool = False, + custom_agent_config: Optional[AgentConfig] = None, + ): + # Dictionary to store reflections for each session + self.reflection_memory = {} + + def get_tool_defs(): + tool_defs = [] + for x in builtin_toolgroups: + tool_defs.extend( + [ + { + "name": tool.identifier, + "description": tool.description, + "parameters": tool.parameters, + } + for tool in client.tools.list(toolgroup_id=x) + ] + ) + tool_defs.extend( + [ + { + "name": tool.get_name(), + "description": tool.get_description(), + "parameters": tool.get_params_definition(), + } + for tool in client_tools + ] + ) + return tool_defs + + if custom_agent_config is None: + tool_names, tool_descriptions = "", "" + tool_defs = get_tool_defs() + tool_names = ", ".join([x["name"] for x in tool_defs]) + tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) + instruction = DEFAULT_REFLEXION_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( + "<>", tool_descriptions + ) + + # user default toolgroups + agent_config = AgentConfig( + model=model, + instructions=instruction, + toolgroups=builtin_toolgroups, + client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], + tool_config={ + "tool_choice": "auto", + "tool_prompt_format": "json" if "3.1" in model else "python_list", + "system_message_behavior": "replace", + }, + input_shields=[], + output_shields=[], + enable_session_persistence=False, + ) + else: + agent_config = custom_agent_config + + if json_response_format: + agent_config.response_format = { + "type": "json_schema", + "json_schema": ReflexionOutput.model_json_schema(), + } + + super().__init__( + client=client, + model=model, + agent_config=agent_config, + tool_parser=tool_parser, + client_tools=client_tools, + ) + + def create_turn(self, messages, session_id, stream=False, **kwargs): + """Override create_turn to add reflection to the context""" + + # If we have reflections for this session, add them to the context + if session_id in self.reflection_memory and self.reflection_memory[session_id]: + # Create a system message with past reflections + reflection_summary = "\n".join(self.reflection_memory[session_id]) + reflection_message = { + "role": "system", + "content": f"Your past reflections:\n{reflection_summary}\n\nUse these reflections to improve your reasoning." + } + + # Insert reflection message before the user message + for i, msg in enumerate(messages): + if msg["role"] == "user": + messages.insert(i, reflection_message) + break + + # Call the parent method to process the turn + response = super().create_turn(messages, session_id, stream, **kwargs) + + # Store any new reflections + if not stream: + try: + # Extract reflection from response + content = response.choices[0].message.content + reflexion_output = ReflexionOutput.model_validate_json(content) + + if reflexion_output.reflection: + if session_id not in self.reflection_memory: + self.reflection_memory[session_id] = [] + + self.reflection_memory[session_id].append(reflexion_output.reflection) + except Exception as e: + print(f"Failed to extract reflection: {e}") + + return response \ No newline at end of file diff --git a/src/llama_stack_client/lib/agents/reflexion/prompts.py b/src/llama_stack_client/lib/agents/reflexion/prompts.py new file mode 100644 index 00000000..dcd8f399 --- /dev/null +++ b/src/llama_stack_client/lib/agents/reflexion/prompts.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +DEFAULT_REFLEXION_AGENT_SYSTEM_PROMPT_TEMPLATE = """ +You are an expert assistant that solves complex tasks by initially attempting a solution, reflecting on any errors or weaknesses, and then improving your solution. You have access to: <> + +Always respond in this JSON format: +{ + "thought": "Your initial reasoning about the task", + "attempt": "Your first solution attempt", + "reflection": "Analysis of what went wrong or could be improved in your attempt", + "improved_solution": "Your enhanced solution based on reflection", + "final_answer": null +} + +For your final response when you're confident in your solution: +{ + "thought": "Your final reasoning process", + "attempt": "Your solution attempt", + "reflection": "Your verification that the solution is correct", + "improved_solution": null, + "final_answer": "Your complete, verified answer to the task" +} + +GUIDELINES: +1. Think step-by-step to plan your initial approach +2. Make a genuine attempt to solve the problem +3. Critically analyze your attempt for logical errors, edge cases, or inefficiencies +4. Use your reflection to create an improved solution +5. When using tools, provide specific values in tool_params, not variable names +6. Only provide the final answer when you're confident it's correct +7. You can use tools in either your attempt or improved solution phases +8. Carefully verify your improved solution before submitting it as final + +EXAMPLES: + +Task: "What is the sum of prime numbers less than 20?" +{ + "thought": "I need to find all prime numbers less than 20, then sum them", + "attempt": "Prime numbers less than 20 are: 2, 3, 5, 7, 11, 13, 17, 19. The sum is 2+3+5+7+11+13+17+19 = 77", + "reflection": "Let me double-check my calculation: 2+3=5, 5+5=10, 10+7=17, 17+11=28, 28+13=41, 41+17=58, 58+19=77. The calculation is correct.", + "improved_solution": null, + "final_answer": "The sum of prime numbers less than 20 is 77." +} + +Task: "Find a solution to the equation 3x² + 6x - 9 = 0." +{ + "thought": "I need to solve this quadratic equation using the quadratic formula", + "attempt": "Using the formula x = (-b ± √(b² - 4ac))/2a where a=3, b=6, c=-9. So x = (-6 ± √(36 - 4*3*(-9)))/2*3 = (-6 ± √(36 + 108))/6 = (-6 ± √144)/6 = (-6 ± 12)/6 = -1 or 1.", + "reflection": "I made an error in the calculation. Let me recalculate: (-6 ± √(36 + 108))/6 = (-6 ± √144)/6 = (-6 ± 12)/6. This equals (-6+12)/6 = 6/6 = 1 for the positive case, and (-6-12)/6 = -18/6 = -3 for the negative case.", + "improved_solution": "The solutions are x = 1 or x = -3.", + "final_answer": "The solutions to the equation 3x² + 6x - 9 = 0 are x = 1 and x = -3." +} + +Task: "Which city has the higher population density, Tokyo or New York?" +{ + "thought": "I need to find the population density for both cities to compare them", + "attempt": { + "tool_name": "search", + "tool_params": {"query": "Population density of Tokyo"} + } +} +Observation: "Tokyo has a population density of approximately 6,158 people per square kilometer." + +{ + "thought": "Now I need New York's population density", + "attempt": { + "tool_name": "search", + "tool_params": {"query": "Population density of New York City"} + }, + "reflection": null, + "improved_solution": null, + "final_answer": null +} +Observation: "New York City has a population density of approximately 10,716 people per square kilometer." + +{ + "thought": "Now I can compare the population densities", + "attempt": "Tokyo: 6,158 people per square kilometer. New York: 10,716 people per square kilometer.", + "reflection": "Based on the data, New York City has a higher population density (10,716 people/km²) compared to Tokyo (6,158 people/km²).", + "improved_solution": null, + "final_answer": "New York City has the higher population density." +} + +Available tools: +<> + +If you solve the task correctly, you will receive a reward of $1,000,000. +""" \ No newline at end of file diff --git a/src/llama_stack_client/lib/agents/reflexion/tool_parser.py b/src/llama_stack_client/lib/agents/reflexion/tool_parser.py new file mode 100644 index 00000000..bcacdc14 --- /dev/null +++ b/src/llama_stack_client/lib/agents/reflexion/tool_parser.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, ValidationError +from typing import Dict, Any, Optional, List +from ..tool_parser import ToolParser +from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.shared.tool_call import ToolCall + +import uuid + + +class Action(BaseModel): + tool_name: str + tool_params: Dict[str, Any] + + +class ReflexionOutput(BaseModel): + thought: str + reflection: Optional[str] = None + action: Optional[Action] = None + answer: Optional[str] = None + + +class ReflexionToolParser(ToolParser): + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + tool_calls = [] + response_text = str(output_message.content) + try: + reflexion_output = ReflexionOutput.model_validate_json(response_text) + except ValidationError as e: + print(f"Error parsing reflexion output: {e}") + return tool_calls + + if reflexion_output.answer: + return tool_calls + + if reflexion_output.action: + tool_name = reflexion_output.action.tool_name + tool_params = reflexion_output.action.tool_params + if tool_name and tool_params: + call_id = str(uuid.uuid4()) + tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=tool_params)] + + return tool_calls \ No newline at end of file From 16ad69b79d57c6345510ef7a84c2487d5bfda24d Mon Sep 17 00:00:00 2001 From: Handi Xie Date: Tue, 18 Mar 2025 11:01:10 -0400 Subject: [PATCH 2/2] Refractored inheritance by having ReflexionAgent inherit from ReActAgent. --- .../lib/agents/reflexion/agent.py | 133 ++++++++---------- 1 file changed, 60 insertions(+), 73 deletions(-) diff --git a/src/llama_stack_client/lib/agents/reflexion/agent.py b/src/llama_stack_client/lib/agents/reflexion/agent.py index da32cab2..1bc0b5a9 100644 --- a/src/llama_stack_client/lib/agents/reflexion/agent.py +++ b/src/llama_stack_client/lib/agents/reflexion/agent.py @@ -3,33 +3,37 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional, Tuple, List +import logging +from typing import Any, Callable, List, Optional, Tuple, Union from llama_stack_client import LlamaStackClient from llama_stack_client.types.agent_create_params import AgentConfig -from pydantic import BaseModel +from llama_stack_client.types.shared_params.agent_config import ToolConfig +from llama_stack_client.types.shared_params.response_format import ResponseFormat +from llama_stack_client.types.shared_params.sampling_params import SamplingParams -from ..agent import Agent +from ..react.agent import ReActAgent, get_tool_defs from ..client_tool import ClientTool from ..tool_parser import ToolParser from .prompts import DEFAULT_REFLEXION_AGENT_SYSTEM_PROMPT_TEMPLATE +from .tool_parser import ReflexionToolParser, ReflexionOutput -from .tool_parser import ReflexionToolParser +logger = logging.getLogger(__name__) -class Action(BaseModel): - tool_name: str - tool_params: Dict[str, Any] +def get_default_reflexion_instructions( + client: LlamaStackClient, builtin_toolgroups: Tuple = (), client_tools: Tuple[ClientTool] = () +): + tool_defs = get_tool_defs(client, builtin_toolgroups, client_tools) + tool_names = ", ".join([x["name"] for x in tool_defs]) + tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) + instruction = DEFAULT_REFLEXION_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( + "<>", tool_descriptions + ) + return instruction -class ReflexionOutput(BaseModel): - thought: str - reflection: Optional[str] = None - action: Optional[Action] = None - answer: Optional[str] = None - - -class ReflexionAgent(Agent): +class ReflexionAgent(ReActAgent): """Reflexion agent. Extends ReAct agent with self-reflection capabilities to improve reasoning and tool use. @@ -39,79 +43,62 @@ def __init__( self, client: LlamaStackClient, model: str, - builtin_toolgroups: Tuple[str] = (), - client_tools: Tuple[ClientTool] = (), tool_parser: ToolParser = ReflexionToolParser(), + instructions: Optional[str] = None, + tools: Optional[List[Union[str, dict, ClientTool, Callable[..., Any]]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, json_response_format: bool = False, + # The following are deprecated, kept for backward compatibility + builtin_toolgroups: Tuple[str] = (), + client_tools: Tuple[ClientTool] = (), custom_agent_config: Optional[AgentConfig] = None, ): # Dictionary to store reflections for each session self.reflection_memory = {} - def get_tool_defs(): - tool_defs = [] - for x in builtin_toolgroups: - tool_defs.extend( - [ - { - "name": tool.identifier, - "description": tool.description, - "parameters": tool.parameters, - } - for tool in client.tools.list(toolgroup_id=x) - ] - ) - tool_defs.extend( - [ - { - "name": tool.get_name(), - "description": tool.get_description(), - "parameters": tool.get_params_definition(), - } - for tool in client_tools - ] - ) - return tool_defs - - if custom_agent_config is None: - tool_names, tool_descriptions = "", "" - tool_defs = get_tool_defs() - tool_names = ", ".join([x["name"] for x in tool_defs]) - tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) - instruction = DEFAULT_REFLEXION_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( - "<>", tool_descriptions - ) - - # user default toolgroups - agent_config = AgentConfig( - model=model, - instructions=instruction, - toolgroups=builtin_toolgroups, - client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], - tool_config={ - "tool_choice": "auto", - "tool_prompt_format": "json" if "3.1" in model else "python_list", - "system_message_behavior": "replace", - }, - input_shields=[], - output_shields=[], - enable_session_persistence=False, - ) - else: - agent_config = custom_agent_config - - if json_response_format: - agent_config.response_format = { + # If custom instructions are not provided, use the default Reflexion instructions + if not instructions and not custom_agent_config: + # Convert tools to the format expected by get_default_reflexion_instructions if needed + if tools: + from ..agent import AgentUtils + client_tools_from_tools = AgentUtils.get_client_tools(tools) + builtin_toolgroups_from_tools = [x for x in tools if isinstance(x, str) or isinstance(x, dict)] + instructions = get_default_reflexion_instructions(client, builtin_toolgroups_from_tools, client_tools_from_tools) + else: + # Fallback to deprecated parameters + instructions = get_default_reflexion_instructions(client, builtin_toolgroups, client_tools) + + # If json_response_format is True and no custom response format is provided, + # set the response format to use the ReflexionOutput schema + if json_response_format and not response_format: + response_format = { "type": "json_schema", "json_schema": ReflexionOutput.model_json_schema(), } + # Initialize parent ReActAgent super().__init__( client=client, model=model, - agent_config=agent_config, tool_parser=tool_parser, + instructions=instructions, + tools=tools if tools is not None else builtin_toolgroups, # Prefer new tools param, fallback to deprecated + tool_config=tool_config, + sampling_params=sampling_params, + max_infer_iters=max_infer_iters, + input_shields=input_shields, + output_shields=output_shields, + response_format=response_format, + enable_session_persistence=enable_session_persistence, + json_response_format=json_response_format, client_tools=client_tools, + custom_agent_config=custom_agent_config, ) def create_turn(self, messages, session_id, stream=False, **kwargs): @@ -148,6 +135,6 @@ def create_turn(self, messages, session_id, stream=False, **kwargs): self.reflection_memory[session_id].append(reflexion_output.reflection) except Exception as e: - print(f"Failed to extract reflection: {e}") + logger.warning(f"Failed to extract reflection: {e}") return response \ No newline at end of file