Skip to content

Fixed message order for JDBC Chat Memory #2781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ void addGetAndClear_shouldAllExecute() {
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1);
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage));

var assistantMessage = new AssistantMessage("Message from the assistant");

chatMemory.add(conversationId, List.of(assistantMessage));

assertThat(chatMemory.get(conversationId)).hasSize(2);
assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage, assistantMessage));

chatMemory.clear(conversationId);

assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty();
Expand Down Expand Up @@ -142,6 +149,13 @@ void useAutoConfiguredChatMemoryWithJdbc() {
assertThat(chatMemory.get(conversationId)).hasSize(1);
assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage));

var assistantMessage = new AssistantMessage("Message from the assistant");

chatMemory.add(conversationId, List.of(assistantMessage));

assertThat(chatMemory.get(conversationId)).hasSize(2);
assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage, assistantMessage));

chatMemory.clear(conversationId);

assertThat(chatMemory.get(conversationId)).isEmpty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;

import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
Expand All @@ -46,10 +50,10 @@
public class JdbcChatMemory implements ChatMemory {

private static final String QUERY_ADD = """
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""";
INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?)""";

private static final String QUERY_GET = """
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" LIMIT ?""";
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""";

private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";

Expand All @@ -70,23 +74,31 @@ public void add(String conversationId, List<Message> messages) {

@Override
public List<Message> get(String conversationId, int lastN) {
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
List<Message> messages = this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
Collections.reverse(messages);
return messages;
}

@Override
public void clear(String conversationId) {
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
}

private record AddBatchPreparedStatement(String conversationId,
List<Message> messages) implements BatchPreparedStatementSetter {
private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
AtomicLong instantSeq) implements BatchPreparedStatementSetter {

private AddBatchPreparedStatement(String conversationId, List<Message> messages) {
this(conversationId, messages, new AtomicLong(Instant.now().toEpochMilli()));
}

@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
var message = this.messages.get(i);

ps.setString(1, this.conversationId);
ps.setString(2, message.getText());
ps.setString(3, message.getMessageType().name());
ps.setTimestamp(4, new Timestamp(instantSeq.getAndIncrement()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;

/**
* An implementation of {@link ChatMemoryRepository} for JDBC.
Expand All @@ -44,7 +47,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
""";

private static final String QUERY_ADD = """
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)
INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?)
""";

private static final String QUERY_GET = """
Expand Down Expand Up @@ -93,15 +96,21 @@ public void deleteByConversationId(String conversationId) {
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
}

private record AddBatchPreparedStatement(String conversationId,
List<Message> messages) implements BatchPreparedStatementSetter {
private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
AtomicLong instantSeq) implements BatchPreparedStatementSetter {

private AddBatchPreparedStatement(String conversationId, List<Message> messages) {
this(conversationId, messages, new AtomicLong(Instant.now().toEpochMilli()));
}

@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
var message = this.messages.get(i);

ps.setString(1, this.conversationId);
ps.setString(2, message.getText());
ps.setString(3, message.getMessageType().name());
ps.setTimestamp(4, new Timestamp(instantSeq.getAndIncrement()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory (
conversation_id VARCHAR(36) NOT NULL,
content TEXT NOT NULL,
type VARCHAR(10) NOT NULL,
`timestamp` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
`timestamp` TIMESTAMP NOT NULL,
CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL'))
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory (
conversation_id VARCHAR(36) NOT NULL,
content TEXT NOT NULL,
type VARCHAR(10) NOT NULL CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL')),
"timestamp" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
"timestamp" TIMESTAMP NOT NULL
);

CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,12 @@

package org.springframework.ai.chat.memory.jdbc;

import java.sql.Timestamp;
import java.util.List;
import java.util.UUID;

import javax.sql.DataSource;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.MountableFile;

import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.*;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
Expand All @@ -46,6 +31,15 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Primary;
import org.springframework.jdbc.core.JdbcTemplate;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.MountableFile;

import javax.sql.DataSource;
import java.sql.Timestamp;
import java.util.List;
import java.util.UUID;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -147,10 +141,11 @@ void get_shouldReturnMessages() {
this.contextRunner.run(context -> {
var chatMemory = context.getBean(ChatMemory.class);
var conversationId = UUID.randomUUID().toString();
var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + conversationId),
new AssistantMessage("Message from assistant 2 - " + conversationId),
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
var messages = List.<Message>of(new SystemMessage("Message from system - " + conversationId),
new UserMessage("Message from user 1 - " + conversationId),
new AssistantMessage("Message from assistant 1 - " + conversationId),
new UserMessage("Message from user 2 - " + conversationId),
new AssistantMessage("Message from assistant 2 - " + conversationId));

chatMemory.add(conversationId, messages);

Expand All @@ -161,6 +156,24 @@ void get_shouldReturnMessages() {
});
}

@Test
void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() {
this.contextRunner.run(context -> {
var chatMemory = context.getBean(ChatMemory.class);
var conversationId = UUID.randomUUID().toString();
var userMessage = new UserMessage("Message from user - " + conversationId);
var assistantMessage = new AssistantMessage("Message from assistant - " + conversationId);

chatMemory.add(conversationId, userMessage);
chatMemory.add(conversationId, assistantMessage);

var results = chatMemory.get(conversationId, Integer.MAX_VALUE);

assertThat(results.size()).isEqualTo(2);
assertThat(results).isEqualTo(List.of(userMessage, assistantMessage));
});
}

@Test
void clear_shouldDeleteMessages() {
this.contextRunner.run(context -> {
Expand Down