Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.image.ImageMessage;
import org.springframework.ai.image.ImageOptionsBuilder;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.openai.OpenAiAudioTranscriptionModel;
import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions;
Expand Down Expand Up @@ -250,7 +251,8 @@ public void openAiImageTransientError() {
.willThrow(new TransientAiException("Transient Error 2"))
.willReturn(ResponseEntity.of(Optional.of(expectedResponse)));

var result = this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))));
var result = this.imageModel
.call(new ImagePrompt(List.of(new ImageMessage("Image Message")), ImageOptionsBuilder.builder().build()));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678");
Expand All @@ -262,8 +264,8 @@ public void openAiImageTransientError() {
public void openAiImageNonTransientError() {
given(this.openAiImageApi.createImage(isA(OpenAiImageRequest.class)))
.willThrow(new RuntimeException("Transient Error 1"));
assertThrows(RuntimeException.class,
() -> this.imageModel.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))));
assertThrows(RuntimeException.class, () -> this.imageModel
.call(new ImagePrompt(List.of(new ImageMessage("Image Message")), ImageOptionsBuilder.builder().build())));
}

private static class TestRetryListener implements RetryListener {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ void whenPromptWithMessagesThenReturn() {
assertThat(spec.getMessages()).hasSize(2);
assertThat(spec.getMessages().get(0).getText()).isEqualTo("instructions");
assertThat(spec.getMessages().get(1).getText()).isEqualTo("my question");
assertThat(spec.getChatOptions()).isNotNull();
assertThat(spec.getChatOptions()).isInstanceOf(ChatOptions.class);
assertThat(spec.getChatOptions()).isNull();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public class Prompt implements ModelRequest<List<Message>> {

private final List<Message> messages;

private final ChatOptions chatOptions;
@Nullable
private ChatOptions chatOptions;

public Prompt(String contents) {
this(new UserMessage(contents));
Expand All @@ -58,26 +59,26 @@ public Prompt(Message message) {
}

public Prompt(List<Message> messages) {
this(messages, ChatOptions.builder().build());
this(messages, null);
}

public Prompt(Message... messages) {
this(Arrays.asList(messages), ChatOptions.builder().build());
this(Arrays.asList(messages), null);
}

public Prompt(String contents, ChatOptions chatOptions) {
public Prompt(String contents, @Nullable ChatOptions chatOptions) {
this(new UserMessage(contents), chatOptions);
}

public Prompt(Message message, ChatOptions chatOptions) {
public Prompt(Message message, @Nullable ChatOptions chatOptions) {
this(Collections.singletonList(message), chatOptions);
}

public Prompt(List<Message> messages, ChatOptions chatOptions) {
public Prompt(List<Message> messages, @Nullable ChatOptions chatOptions) {
Assert.notNull(messages, "messages cannot be null");
Assert.noNullElements(messages, "messages cannot contain null elements");
this.messages = messages;
this.chatOptions = (chatOptions != null) ? chatOptions : ChatOptions.builder().build();
this.chatOptions = chatOptions;
}

public String getContents() {
Expand All @@ -89,6 +90,7 @@ public String getContents() {
}

@Override
@Nullable
public ChatOptions getOptions() {
return this.chatOptions;
}
Expand Down Expand Up @@ -134,7 +136,7 @@ public int hashCode() {
}

public Prompt copy() {
return new Prompt(instructionsCopy(), this.chatOptions.copy());
return new Prompt(instructionsCopy(), null == this.chatOptions ? null : this.chatOptions.copy());
}

private List<Message> instructionsCopy() {
Expand Down Expand Up @@ -196,7 +198,9 @@ public Prompt augmentUserMessage(String newUserText) {

public Builder mutate() {
Builder builder = new Builder().messages(instructionsCopy());
builder.chatOptions(this.chatOptions.copy());
if (this.chatOptions != null) {
builder.chatOptions(this.chatOptions.copy());
}
return builder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ public class ImagePrompt implements ModelRequest<List<ImageMessage>> {
private ImageOptions imageModelOptions;

public ImagePrompt(List<ImageMessage> messages) {
this(messages, ImageOptionsBuilder.builder().build());
this.messages = messages;
}

public ImagePrompt(List<ImageMessage> messages, ImageOptions imageModelOptions) {
this.messages = messages;
this.imageModelOptions = imageModelOptions;
}

public ImagePrompt(ImageMessage imageMessage, ImageOptions imageOptions) {
Expand All @@ -44,11 +49,6 @@ public ImagePrompt(String instructions) {
this(new ImageMessage(instructions), ImageOptionsBuilder.builder().build());
}

public ImagePrompt(List<ImageMessage> messages, ImageOptions imageModelOptions) {
this.messages = messages;
this.imageModelOptions = imageModelOptions != null ? imageModelOptions : ImageOptionsBuilder.builder().build();
}

@Override
public List<ImageMessage> getInstructions() {
return this.messages;
Expand Down