Skip to content

Feat: Improve mem0 memory tool to support storing short-term memories #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions src/strands_tools/mem0_memory.py
Original file line number Diff line number Diff line change
@@ -73,7 +73,6 @@
import boto3
from mem0 import Memory as Mem0Memory
from mem0 import MemoryClient
from opensearchpy import AWSV4SignerAuth, RequestsHttpConnection
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
@@ -91,20 +90,20 @@
"description": (
"Memory management tool for storing, retrieving, and managing memories in Mem0.\n\n"
"Features:\n"
"1. Store memories with metadata (requires user_id or agent_id)\n"
"2. Retrieve memories by ID or semantic search (requires user_id or agent_id)\n"
"3. List all memories for a user/agent (requires user_id or agent_id)\n"
"1. Store memories with metadata (requires user_id or agent_id or run_id)\n"
"2. Retrieve memories by ID or semantic search (requires user_id or agent_id or run_id)\n"
"3. List all memories for a user/agent (requires user_id or agent_id or run_id)\n"
"4. Delete memories\n"
"5. Get memory history\n\n"
"Actions:\n"
"- store: Store new memory (requires user_id or agent_id)\n"
"- store: Store new memory (requires user_id or agent_id or run_id)\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Any reason to write this as (requires user_id or agent_id or run_id) instead of requires user_id, agent_id, or run_id? Not sure if there is some LLM preference here? The second option seems more readable

"- get: Get memory by ID\n"
"- list: List all memories (requires user_id or agent_id)\n"
"- retrieve: Semantic search (requires user_id or agent_id)\n"
"- list: List all memories (requires user_id or agent_id or run_id)\n"
"- retrieve: Semantic search (requires user_id or agent_id or run_id)\n"
"- delete: Delete memory\n"
"- history: Get memory history\n\n"
"Note: Most operations require either user_id or agent_id to be specified. The tool will automatically "
"attempt to retrieve relevant memories when user_id or agent_id is available."
"Note: Most operations require either user_id or agent_id or run_id to be specified. The tool will "
"automatically attempt to retrieve relevant memories when user_id or agent_id or run_id is available."
),
"inputSchema": {
"json": {
@@ -135,12 +134,16 @@
"type": "string",
"description": "Agent ID for the memory operations (required for store, list, retrieve actions)",
},
"run_id": {
"type": "string",
"description": "Run/Session ID for memory operations (required for store, list, retrieve actions)",
},
"metadata": {
"type": "object",
"description": "Optional metadata to store with the memory",
},
},
"required": ["action"]
"required": ["action"],
}
},
}
@@ -166,7 +169,6 @@ class Mem0ServiceClient:
"collection_name": "mem0_memories",
"host": os.environ.get("OPENSEARCH_HOST"),
"embedding_model_dims": 1024,
"connection_class": RequestsHttpConnection,
"pool_maxsize": 20,
"use_ssl": True,
"verify_certs": True,
@@ -216,7 +218,18 @@ def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Me
Returns:
An initialized Mem0Memory instance configured for OpenSearch.
Raises:
ImportError: If opensearchpy package is not installed.
"""
try:
from opensearchpy import AWSV4SignerAuth, RequestsHttpConnection
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had trouble installing faiss before, so I could understand that we didnt include it as a top level import, but I dont see anything wrong with including this at the top level. We have it specified in the mem_0 memory optional dependencies here: https://github.com/strands-agents/tools/blob/main/pyproject.toml#L77

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was thinking that we can make mem0 as a default dependency in the tools repo and make faiss and opensearch as optional so that anyone who only wants to use the mem0 platform, then they don't have to install faiss or opensearch. Let me know what you think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone wants to use the mem0 tool, they should be able to use this optional dependency, and have the tool just work. I had trouble installing faiss, so im guessing others might as well, which is why im fine leaving that out with a runtime check. But for the other two, we should just include them in the optional dependency on mem0.

Since we already have this optional dependency for mem0, i'm hesitant to remove it, and I dont think it makes sense to move the mem0 dependency to a default if we already have this optional dependency.

except ImportError as err:
raise ImportError(
"The opensearchpy package is required for using OpenSearch as the vector store backend for Mem0. "
"Please install it using: pip install opensearchpy"
) from err

# Set up AWS region
self.region = os.environ.get("AWS_REGION", "us-west-2")
if not os.environ.get("AWS_REGION"):
@@ -230,6 +243,8 @@ def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Me
# Prepare configuration
merged_config = self._merge_config(config)
merged_config["vector_store"]["config"].update({"http_auth": auth, "host": os.environ["OPENSEARCH_HOST"]})
# Set the connection_class dynamically
merged_config["vector_store"]["config"]["connection_class"] = RequestsHttpConnection

return Mem0Memory.from_config(config_dict=merged_config)

@@ -291,32 +306,37 @@ def store_memory(
content: str,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
run_id: Optional[str] = None,
metadata: Optional[Dict] = None,
):
"""Store a memory in Mem0."""
if not user_id and not agent_id:
raise ValueError("Either user_id or agent_id must be provided")
if not user_id and not agent_id and not run_id:
raise ValueError("Either user_id or agent_id or run_id must be provided")

messages = [{"role": "user", "content": content}]
return self.mem0.add(messages, user_id=user_id, agent_id=agent_id, metadata=metadata)
return self.mem0.add(messages, user_id=user_id, agent_id=agent_id, run_id=run_id, metadata=metadata)

def get_memory(self, memory_id: str):
"""Get a memory by ID."""
return self.mem0.get(memory_id)

def list_memories(self, user_id: Optional[str] = None, agent_id: Optional[str] = None):
def list_memories(
self, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None
):
"""List all memories for a user or agent."""
if not user_id and not agent_id:
raise ValueError("Either user_id or agent_id must be provided")
if not user_id and not agent_id and not run_id:
raise ValueError("Either user_id or agent_id or run_id must be provided")

return self.mem0.get_all(user_id=user_id, agent_id=agent_id)
return self.mem0.get_all(user_id=user_id, agent_id=agent_id, run_id=run_id)

def search_memories(self, query: str, user_id: Optional[str] = None, agent_id: Optional[str] = None):
def search_memories(
self, query: str, user_id: Optional[str] = None, agent_id: Optional[str] = None, run_id: Optional[str] = None
):
"""Search memories using semantic search."""
if not user_id and not agent_id:
raise ValueError("Either user_id or agent_id must be provided")
if not user_id and not agent_id and not run_id:
raise ValueError("Either user_id or agent_id or run_id must be provided")

return self.mem0.search(query=query, user_id=user_id, agent_id=agent_id)
return self.mem0.search(query=query, user_id=user_id, agent_id=agent_id, run_id=run_id)

def delete_memory(self, memory_id: str):
"""Delete a memory by ID."""
@@ -583,6 +603,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
tool_input["content"],
tool_input.get("user_id"),
tool_input.get("agent_id"),
tool_input.get("run_id"),
tool_input.get("metadata"),
)

@@ -609,7 +630,9 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
)

elif action == "list":
memories = client.list_memories(tool_input.get("user_id"), tool_input.get("agent_id"))
memories = client.list_memories(
tool_input.get("user_id"), tool_input.get("agent_id"), tool_input.get("run_id")
)
# Normalize to list
results_list = memories if isinstance(memories, list) else memories.get("results", [])
panel = format_list_response(results_list)
@@ -628,6 +651,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
tool_input["query"],
tool_input.get("user_id"),
tool_input.get("agent_id"),
tool_input.get("run_id"),
)
# Normalize to list
results_list = memories if isinstance(memories, list) else memories.get("results", [])