Skip to content

Feat messages builder #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void roleTest() {
You should reply to the user's request with your name and also in the style of a {voice}.
""").createMessage(Map.of("name", "Bob", "voice", "pirate"));

UserMessage userMessage = new UserMessage("Generate the names of 5 famous pirates.");
UserMessage userMessage = UserMessage.builder().withContent("Generate the names of 5 famous pirates.").build();

Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = chatClient.call(prompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class AzureOpenAiChatClientFunctionCallIT {
@Test
void functionCallTest() {

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, in Tokyo, and in Paris?");
UserMessage userMessage = UserMessage.builder()
.withContent("What's the weather like in San Francisco, in Tokyo, and in Paris?")
.build();

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ class BedrockAnthropicChatClientIT {

@Test
void roleTest() {
UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
UserMessage userMessage = UserMessage.builder()
.withContent("Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.")
.build();
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void roleTest() {
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";
String name = "Bob";
String voice = "pirate";
UserMessage userMessage = new UserMessage(request);
UserMessage userMessage = UserMessage.builder().withContent(request).build();
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class BedrockLlama2ChatClientIT {

@Test
void roleTest() {
UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
UserMessage userMessage = UserMessage.builder()
.withContent("Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.")
.build();
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void roleTest() {
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";
String name = "Bob";
String voice = "pirate";
UserMessage userMessage = new UserMessage(request);
UserMessage userMessage = UserMessage.builder().withContent(request).build();
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ class MistralAiChatClientIT {

@Test
void roleTest() {
UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
UserMessage userMessage = UserMessage.builder()
.withContent("Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.")
.build();
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
// NOTE: Mistral expects the system message to be before the user message or will
Expand Down Expand Up @@ -188,7 +189,9 @@ void beanStreamOutputParserRecords() {
@Test
void functionCallTest() {

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco?");
UserMessage userMessage = UserMessage.builder()
.withContent("What's the weather like in San Francisco?")
.build();

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand All @@ -211,7 +214,7 @@ void functionCallTest() {
@Test
void streamFunctionCallTest() {

UserMessage userMessage = new UserMessage("What's the weather like in Tokyo, Japan?");
UserMessage userMessage = UserMessage.builder().withContent("What's the weather like in Tokyo, Japan?").build();

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ void roleTest() {
You should reply to the user's request with your name and also in the style of a {voice}.
""").createMessage(Map.of("name", "Bob", "voice", "pirate"));

UserMessage userMessage = new UserMessage("Tell me about 5 famous pirates from the Golden Age of Piracy.");
UserMessage userMessage = UserMessage.builder()
.withContent("Tell me about 5 famous pirates from the Golden Age of Piracy.")
.build();

// portable/generic options
var portableOptions = ChatOptionsBuilder.builder().withTemperature(0.7f).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void acmeChain() {
// be relevant.

Message systemMessage = getSystemMessage(similarDocuments);
UserMessage userMessage = new UserMessage(userQuery);
UserMessage userMessage = UserMessage.builder().withContent(userQuery).build();

// Create the prompt ad-hoc for now, need to put in system message and user
// message via ChatPromptTemplate or some other message building mechanic;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ class OpenAiChatClientIT extends AbstractIT {

@Test
void roleTest() {
UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
UserMessage userMessage = UserMessage.builder()
.withContent("Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.")
.build();
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
Expand Down Expand Up @@ -184,7 +185,9 @@ void beanStreamOutputParserRecords() {
@Test
void functionCallTest() {

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
UserMessage userMessage = UserMessage.builder()
.withContent("What's the weather like in San Francisco, Tokyo, and Paris?")
.build();

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand All @@ -209,7 +212,9 @@ void functionCallTest() {
@Test
void streamFunctionCallTest() {

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
UserMessage userMessage = UserMessage.builder()
.withContent("What's the weather like in San Francisco, Tokyo, and Paris?")
.build();

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,19 @@ protected void evaluateQuestionAndAnswer(String question, ChatResponse response,
Map.of("question", question, "answer", answer));
SystemMessage systemMessage;
if (factBased) {
systemMessage = new SystemMessage(qaEvalutaorFactBasedAnswerResource);
systemMessage = SystemMessage.builder().withResource(qaEvalutaorFactBasedAnswerResource).build();
}
else {
systemMessage = new SystemMessage(qaEvaluatorAccurateAnswerResource);
systemMessage = SystemMessage.builder().withResource(qaEvaluatorAccurateAnswerResource).build();
}
Message userMessage = userPromptTemplate.createMessage();
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
String yesOrNo = openAiChatClient.call(prompt).getResult().getOutput().getContent();
logger.info("Is Answer related to question: " + yesOrNo);
if (yesOrNo.equalsIgnoreCase("no")) {
SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource);
SystemMessage notRelatedSystemMessage = SystemMessage.builder()
.withResource(qaEvaluatorNotRelatedResource)
.build();
prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage));
String reasonForFailure = openAiChatClient.call(prompt).getResult().getOutput().getContent();
fail(reasonForFailure);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void roleTest() {
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";
String name = "Bob";
String voice = "pirate";
UserMessage userMessage = new UserMessage(request);
UserMessage userMessage = UserMessage.builder().withContent(request).build();
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
Expand Down Expand Up @@ -188,8 +188,10 @@ void multiModalityTest() throws IOException {

byte[] data = new ClassPathResource("/vertex.test.png").getContentAsByteArray();

var userMessage = new UserMessage("Explain what do you see o this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, data)));
var userMessage = UserMessage.builder()
.withContent("Explain what do you see o this picture?")
.withMedia(List.of(new Media(MimeTypeUtils.IMAGE_PNG, data)))
.build();

ChatResponse response = client.call(new Prompt(List.of(userMessage)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ public void afterEach() {
@Test
public void functionCallExplicitOpenApiSchema() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations.");
UserMessage userMessage = UserMessage.builder()
.withContent(
"What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations.")
.build();

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand Down Expand Up @@ -116,7 +118,7 @@ public void functionCallTestInferredOpenApiSchema() {

// UserMessage userMessage = new UserMessage("What's the weather like in San
// Francisco, Paris and Tokyo?");
UserMessage userMessage = new UserMessage("What's the weather like in Paris?");
UserMessage userMessage = UserMessage.builder().withContent("What's the weather like in Paris?").build();

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand Down Expand Up @@ -145,8 +147,10 @@ public void functionCallTestInferredOpenApiSchema() {
@Test
public void functionCallTestInferredOpenApiSchemaStream() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, in Paris and in Tokyo? Use Multi-turn function calling.");
UserMessage userMessage = UserMessage.builder()
.withContent(
"What's the weather like in San Francisco, in Paris and in Tokyo? Use Multi-turn function calling.")
.build();

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void roleTest() {
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";
String name = "Bob";
String voice = "pirate";
UserMessage userMessage = new UserMessage(request);
UserMessage userMessage = UserMessage.builder().withContent(request).build();
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
public interface ChatClient extends ModelClient<Prompt, ChatResponse> {

default String call(String message) {
Prompt prompt = new Prompt(new UserMessage(message));
Prompt prompt = new Prompt(UserMessage.builder().withContent(message).build());
Generation generation = call(prompt).getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ public class Generation implements ModelResult<AssistantMessage> {
private ChatGenerationMetadata chatGenerationMetadata;

public Generation(String text) {
this.assistantMessage = new AssistantMessage(text);
this.assistantMessage = AssistantMessage.builder().withContent(text).build();
}

public Generation(String text, Map<String, Object> properties) {
this.assistantMessage = new AssistantMessage(text, properties);
this.assistantMessage = AssistantMessage.builder().withContent(text).withProperties(properties).build();
}

@Override
Expand Down
Loading