Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ def __init__(self, key: str = "langchain_messages"):

if key not in st.session_state:
st.session_state[key] = []
self._messages = st.session_state[key]
self._key = key

@property
def messages(self) -> List[BaseMessage]:
"""Retrieve the current list of messages"""
import streamlit as st

return st.session_state[self._key]
return self._messages

@messages.setter
def messages(self, value: List[BaseMessage]) -> None:
"""Set the messages list with a new value"""
import streamlit as st

st.session_state[self._key] = value
self._messages = st.session_state[self._key]

def add_message(self, message: BaseMessage) -> None:
"""Add a message to the session memory"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import streamlit as st
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_core.messages import message_to_dict
from langchain_core.messages import message_to_dict, BaseMessage

message_history = StreamlitChatMessageHistory()
memory = ConversationBufferMemory(chat_memory=message_history, return_messages=True)
Expand All @@ -23,6 +23,15 @@
st.markdown("Cleared!")
memory.chat_memory.clear()

# Use message setter
if st.checkbox("Override messages"):
memory.chat_memory.messages = [
BaseMessage(content="A basic message", type="basic")
]
st.session_state["langchain_messages"].append(
BaseMessage(content="extra cool message", type="basic")
)

# Write the output to st.code as a json blob for inspection
messages = memory.chat_memory.messages
messages_json = json.dumps([message_to_dict(msg) for msg in messages])
Expand All @@ -33,32 +42,33 @@
@pytest.mark.requires("streamlit")
def test_memory_with_message_store() -> None:
try:
from streamlit.testing.script_interactions import InteractiveScriptTests
from streamlit.testing.v1 import AppTest
except ModuleNotFoundError:
pytest.skip("Incorrect version of Streamlit installed")

test_handler = InteractiveScriptTests()
test_handler.setUp()
try:
sr = test_handler.script_from_string(test_script).run()
except TypeError:
# Earlier version expected 2 arguments
sr = test_handler.script_from_string("memory_test.py", test_script).run()
at = AppTest.from_string(test_script).run(timeout=10)

# Initial run should write two messages
messages_json = sr.get("text")[-1].value
messages_json = at.get("text")[-1].value
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json

# Uncheck the initial write, they should persist in session_state
sr = sr.get("checkbox")[0].uncheck().run()
assert sr.get("markdown")[0].value == "Skipped add"
messages_json = sr.get("text")[-1].value
at.get("checkbox")[0].uncheck().run()
assert at.get("markdown")[0].value == "Skipped add"
messages_json = at.get("text")[-1].value
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json

# Clear the message history
sr = sr.get("checkbox")[1].check().run()
assert sr.get("markdown")[1].value == "Cleared!"
messages_json = sr.get("text")[-1].value
at.get("checkbox")[1].check().run()
assert at.get("markdown")[1].value == "Cleared!"
messages_json = at.get("text")[-1].value
assert messages_json == "[]"

# Use message setter
at.get("checkbox")[1].uncheck()
at.get("checkbox")[2].check().run()
messages_json = at.get("text")[-1].value
assert "A basic message" in messages_json
assert "extra cool message" in messages_json