diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index 2ab1a9517e9..1b74bfa9208 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -34,6 +34,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; import org.springframework.ai.image.ImageMessage; +import org.springframework.ai.image.ImageOptionsBuilder; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.openai.OpenAiAudioTranscriptionModel; import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions; @@ -250,7 +251,8 @@ public void openAiImageTransientError() { .willThrow(new TransientAiException("Transient Error 2")) .willReturn(ResponseEntity.of(Optional.of(expectedResponse))); - var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); + var result = this.imageModel + .call(new ImagePrompt(List.of(new ImageMessage("Image Message")), ImageOptionsBuilder.builder().build())); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678"); @@ -262,8 +264,8 @@ public void openAiImageTransientError() { public void openAiImageNonTransientError() { given(this.openAiImageApi.createImage(isA(OpenAiImageRequest.class))) .willThrow(new RuntimeException("Transient Error 1")); - assertThrows(RuntimeException.class, - () -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + assertThrows(RuntimeException.class, () -> this.imageModel + .call(new ImagePrompt(List.of(new ImageMessage("Image Message")), ImageOptionsBuilder.builder().build()))); } private static class TestRetryListener implements RetryListener { diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 6fec96a8ba3..1ecf46c78c8 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -110,8 +110,7 @@ void whenPromptWithMessagesThenReturn() { assertThat(spec.getMessages()).hasSize(2); assertThat(spec.getMessages().get(0).getText()).isEqualTo("instructions"); assertThat(spec.getMessages().get(1).getText()).isEqualTo("my question"); - assertThat(spec.getChatOptions()).isNotNull(); - assertThat(spec.getChatOptions()).isInstanceOf(ChatOptions.class); + assertThat(spec.getChatOptions()).isNull(); } @Test diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 8d3d4f83182..26a1413833c 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -47,7 +47,8 @@ public class Prompt implements ModelRequest> { private final List messages; - private final ChatOptions chatOptions; + @Nullable + private ChatOptions chatOptions; public Prompt(String contents) { this(new UserMessage(contents)); @@ -58,26 +59,26 @@ public Prompt(Message message) { } public Prompt(List messages) { - this(messages, ChatOptions.builder().build()); + this(messages, null); } public Prompt(Message... messages) { - this(Arrays.asList(messages), ChatOptions.builder().build()); + this(Arrays.asList(messages), null); } - public Prompt(String contents, ChatOptions chatOptions) { + public Prompt(String contents, @Nullable ChatOptions chatOptions) { this(new UserMessage(contents), chatOptions); } - public Prompt(Message message, ChatOptions chatOptions) { + public Prompt(Message message, @Nullable ChatOptions chatOptions) { this(Collections.singletonList(message), chatOptions); } - public Prompt(List messages, ChatOptions chatOptions) { + public Prompt(List messages, @Nullable ChatOptions chatOptions) { Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); this.messages = messages; - this.chatOptions = (chatOptions != null) ? chatOptions : ChatOptions.builder().build(); + this.chatOptions = chatOptions; } public String getContents() { @@ -89,6 +90,7 @@ public String getContents() { } @Override + @Nullable public ChatOptions getOptions() { return this.chatOptions; } @@ -134,7 +136,7 @@ public int hashCode() { } public Prompt copy() { - return new Prompt(instructionsCopy(), this.chatOptions.copy()); + return new Prompt(instructionsCopy(), null == this.chatOptions ? null : this.chatOptions.copy()); } private List instructionsCopy() { @@ -196,7 +198,9 @@ public Prompt augmentUserMessage(String newUserText) { public Builder mutate() { Builder builder = new Builder().messages(instructionsCopy()); - builder.chatOptions(this.chatOptions.copy()); + if (this.chatOptions != null) { + builder.chatOptions(this.chatOptions.copy()); + } return builder; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/image/ImagePrompt.java b/spring-ai-model/src/main/java/org/springframework/ai/image/ImagePrompt.java index 6cc6efd8d92..a212c2cf4f5 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/image/ImagePrompt.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/image/ImagePrompt.java @@ -29,7 +29,12 @@ public class ImagePrompt implements ModelRequest> { private ImageOptions imageModelOptions; public ImagePrompt(List messages) { - this(messages, ImageOptionsBuilder.builder().build()); + this.messages = messages; + } + + public ImagePrompt(List messages, ImageOptions imageModelOptions) { + this.messages = messages; + this.imageModelOptions = imageModelOptions; } public ImagePrompt(ImageMessage imageMessage, ImageOptions imageOptions) { @@ -44,11 +49,6 @@ public ImagePrompt(String instructions) { this(new ImageMessage(instructions), ImageOptionsBuilder.builder().build()); } - public ImagePrompt(List messages, ImageOptions imageModelOptions) { - this.messages = messages; - this.imageModelOptions = imageModelOptions != null ? imageModelOptions : ImageOptionsBuilder.builder().build(); - } - @Override public List getInstructions() { return this.messages;