diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPostgresqlAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPostgresqlAutoConfigurationIT.java index 30caa5a23e..086688f6e4 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPostgresqlAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPostgresqlAutoConfigurationIT.java @@ -87,6 +87,13 @@ void addGetAndClear_shouldAllExecute() { assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1); assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage)); + var assistantMessage = new AssistantMessage("Message from the assistant"); + + chatMemory.add(conversationId, List.of(assistantMessage)); + + assertThat(chatMemory.get(conversationId)).hasSize(2); + assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage, assistantMessage)); + chatMemory.clear(conversationId); assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty(); @@ -142,6 +149,13 @@ void useAutoConfiguredChatMemoryWithJdbc() { assertThat(chatMemory.get(conversationId)).hasSize(1); assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage)); + var assistantMessage = new AssistantMessage("Message from the assistant"); + + chatMemory.add(conversationId, List.of(assistantMessage)); + + assertThat(chatMemory.get(conversationId)).hasSize(2); + assertThat(chatMemory.get(conversationId)).isEqualTo(List.of(userMessage, assistantMessage)); + chatMemory.clear(conversationId); assertThat(chatMemory.get(conversationId)).isEmpty(); diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java index 51827ed0f6..86174c496f 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java @@ -19,7 +19,11 @@ import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicLong; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.MessageWindowChatMemory; @@ -46,10 +50,10 @@ public class JdbcChatMemory implements ChatMemory { private static final String QUERY_ADD = """ - INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)"""; + INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?)"""; private static final String QUERY_GET = """ - SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" LIMIT ?"""; + SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?"""; private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?"; @@ -70,7 +74,9 @@ public void add(String conversationId, List messages) { @Override public List get(String conversationId, int lastN) { - return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN); + List messages = this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN); + Collections.reverse(messages); + return messages; } @Override @@ -78,8 +84,13 @@ public void clear(String conversationId) { this.jdbcTemplate.update(QUERY_CLEAR, conversationId); } - private record AddBatchPreparedStatement(String conversationId, - List messages) implements BatchPreparedStatementSetter { + private record AddBatchPreparedStatement(String conversationId, List messages, + AtomicLong instantSeq) implements BatchPreparedStatementSetter { + + private AddBatchPreparedStatement(String conversationId, List messages) { + this(conversationId, messages, new AtomicLong(Instant.now().toEpochMilli())); + } + @Override public void setValues(PreparedStatement ps, int i) throws SQLException { var message = this.messages.get(i); @@ -87,6 +98,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException { ps.setString(1, this.conversationId); ps.setString(2, message.getText()); ps.setString(3, message.getMessageType().name()); + ps.setTimestamp(4, new Timestamp(instantSeq.getAndIncrement())); } @Override diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java index 09ba5f1653..01b94b6e06 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java @@ -27,8 +27,11 @@ import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicLong; /** * An implementation of {@link ChatMemoryRepository} for JDBC. @@ -44,7 +47,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository { """; private static final String QUERY_ADD = """ - INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?) + INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?) """; private static final String QUERY_GET = """ @@ -93,8 +96,13 @@ public void deleteByConversationId(String conversationId) { this.jdbcTemplate.update(QUERY_CLEAR, conversationId); } - private record AddBatchPreparedStatement(String conversationId, - List messages) implements BatchPreparedStatementSetter { + private record AddBatchPreparedStatement(String conversationId, List messages, + AtomicLong instantSeq) implements BatchPreparedStatementSetter { + + private AddBatchPreparedStatement(String conversationId, List messages) { + this(conversationId, messages, new AtomicLong(Instant.now().toEpochMilli())); + } + @Override public void setValues(PreparedStatement ps, int i) throws SQLException { var message = this.messages.get(i); @@ -102,6 +110,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException { ps.setString(1, this.conversationId); ps.setString(2, message.getText()); ps.setString(3, message.getMessageType().name()); + ps.setTimestamp(4, new Timestamp(instantSeq.getAndIncrement())); } @Override diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql index 174c3b545f..976c920964 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql @@ -2,7 +2,7 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory ( conversation_id VARCHAR(36) NOT NULL, content TEXT NOT NULL, type VARCHAR(10) NOT NULL, - `timestamp` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `timestamp` TIMESTAMP NOT NULL, CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL')) ); diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql index 31a9f301e0..bc12680145 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql @@ -2,7 +2,7 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory ( conversation_id VARCHAR(36) NOT NULL, content TEXT NOT NULL, type VARCHAR(10) NOT NULL CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL')), - "timestamp" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + "timestamp" TIMESTAMP NOT NULL ); CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java index 96b0e7ca5f..990a088c82 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java @@ -16,27 +16,12 @@ package org.springframework.ai.chat.memory.jdbc; -import java.sql.Timestamp; -import java.util.List; -import java.util.UUID; - -import javax.sql.DataSource; - import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.MountableFile; - import org.springframework.ai.chat.memory.ChatMemory; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.messages.*; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; @@ -46,6 +31,15 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.MountableFile; + +import javax.sql.DataSource; +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; @@ -147,10 +141,11 @@ void get_shouldReturnMessages() { this.contextRunner.run(context -> { var chatMemory = context.getBean(ChatMemory.class); var conversationId = UUID.randomUUID().toString(); - var messages = List.of(new AssistantMessage("Message from assistant 1 - " + conversationId), - new AssistantMessage("Message from assistant 2 - " + conversationId), - new UserMessage("Message from user - " + conversationId), - new SystemMessage("Message from system - " + conversationId)); + var messages = List.of(new SystemMessage("Message from system - " + conversationId), + new UserMessage("Message from user 1 - " + conversationId), + new AssistantMessage("Message from assistant 1 - " + conversationId), + new UserMessage("Message from user 2 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId)); chatMemory.add(conversationId, messages); @@ -161,6 +156,24 @@ void get_shouldReturnMessages() { }); } + @Test + void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from user - " + conversationId); + var assistantMessage = new AssistantMessage("Message from assistant - " + conversationId); + + chatMemory.add(conversationId, userMessage); + chatMemory.add(conversationId, assistantMessage); + + var results = chatMemory.get(conversationId, Integer.MAX_VALUE); + + assertThat(results.size()).isEqualTo(2); + assertThat(results).isEqualTo(List.of(userMessage, assistantMessage)); + }); + } + @Test void clear_shouldDeleteMessages() { this.contextRunner.run(context -> {