diff --git a/foundation-models/openai/pom.xml b/foundation-models/openai/pom.xml index 9ab13fbab..998b89dc1 100644 --- a/foundation-models/openai/pom.xml +++ b/foundation-models/openai/pom.xml @@ -38,11 +38,11 @@ ${project.basedir}/../../ - 72% + 70% 80% 76% 70% - 83% + 75% 84% diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java index d2edc094e..26fe6d8b9 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiAssistantMessage.java @@ -8,6 +8,7 @@ import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessage; import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionRequestAssistantMessageContent; import com.sap.ai.sdk.foundationmodels.openai.generated.model.ToolCallType; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import javax.annotation.Nonnull; @@ -52,6 +53,21 @@ public class OpenAiAssistantMessage implements OpenAiMessage { @Nonnull List toolCalls; + /** + * Creates a new assistant message with the given content and additional tool calls. + * + * @param toolCalls the additional tool calls to associate with the message. + * @return a new assistant message with the given content and additional tool calls. + * @since 1.10.0 + */ + @Nonnull + public OpenAiAssistantMessage withToolCalls( + @Nonnull final List toolCalls) { + final List newToolCalls = new ArrayList<>(this.toolCalls); + newToolCalls.addAll(toolCalls); + return new OpenAiAssistantMessage(content, newToolCalls); + } + /** * Creates a new assistant message with the given single message as text content. * diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionConfig.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionConfig.java new file mode 100644 index 000000000..d53e37fef --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionConfig.java @@ -0,0 +1,130 @@ +package com.sap.ai.sdk.foundationmodels.openai; + +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfResponseFormat; +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.With; + +/** Configuration for OpenAI chat completion requests. */ +@With +@NoArgsConstructor +@AllArgsConstructor +@Getter +public class OpenAiChatCompletionConfig { + + /** Upto 4 Stop sequences to interrupts token generation and returns a response without them. */ + @Nullable List stop; + + /** + * Controls the randomness of the completion. + * + *

Lower values (e.g. 0.0) make the model more deterministic and repetitive, while higher + * values (e.g. 1.0) make the model more random and creative. + */ + @Nullable BigDecimal temperature; + + /** + * Controls the cumulative probability threshold used for nucleus sampling. Alternative to {@link + * #temperature}. + * + *

Lower values (e.g. 0.1) limit the model to consider only the smallest set of tokens whose + * combined probabilities add up to at least 10% of the total. + */ + @Nullable BigDecimal topP; + + /** + * Controls the number of top tokens to consider for sampling. + * + *

Higher values (e.g. 50) allow the model to consider more tokens, while lower values (e.g. 1) + * restrict it to the most probable token. + */ + @Nullable Integer topK; + + /** Maximum number of tokens that can be generated for the completion. */ + @Nullable Integer maxTokens; + + /** + * Maximum number of tokens that can be generated for the completion, including consumed reasoning + * tokens. This field supersedes {@link #maxTokens} and should be used with newer models. + */ + @Nullable Integer maxCompletionTokens; + + /** + * Encourage new topic by penalising token based on their presence in the completion. + * + *

Value should be in range [-2, 2]. + */ + @Nullable BigDecimal presencePenalty; + + /** + * Encourage new topic by penalising tokens based on their frequency in the completion. + * + *

Value should be in range [-2, 2]. + */ + @Nullable BigDecimal frequencyPenalty; + + /** + * A map that adjusts the likelihood of specified tokens by adding a bias value (between -100 and + * 100) to the logits before sampling. Extreme values can effectively ban or enforce the selection + * of tokens. + */ + @Nullable Map logitBias; + + /** + * Unique identifier for the end-user making the request. This can help with monitoring and abuse + * detection. + */ + @Nullable String user; + + /** Whether to include log probabilities in the response. */ + @Nullable Boolean logprobs; + + /** + * Number of top log probabilities to return for each token. An integer between 0 and 20. This is + * only relevant if {@code logprobs} is enabled. + */ + @Nullable Integer topLogprobs; + + /** Number of completions to generate. */ + @Nullable Integer n; + + /** Whether to allow parallel tool calls. */ + @Nullable Boolean parallelToolCalls; + + /** Seed for random number generation. */ + @Nullable Integer seed; + + /** Options for streaming the completion response. */ + @Nullable ChatCompletionStreamOptions streamOptions; + + /** Response format for the completion. */ + @Nullable CreateChatCompletionRequestAllOfResponseFormat responseFormat; + + /** + * Tools the model may invoke during chat completion (metadata only). + * + *

Use {@link #withToolsExecutable} for registering executable tools. + */ + @Nullable List tools; + + /** + * Tools the model may invoke during chat completion that are also executable at application + * runtime. + * + * @since 1.8.0 + */ + @Getter(value = AccessLevel.PACKAGE) + @Nullable + List toolsExecutable; + + /** Option to control which tool is invoked by the model. */ + @Nullable OpenAiToolChoice toolChoice; +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java index 4dfef7f39..cd413d1f2 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java @@ -4,7 +4,6 @@ import com.google.common.collect.Lists; import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions; import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; -import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption; import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest; import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfResponseFormat; import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop; @@ -12,12 +11,12 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Objects; import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; +import lombok.Setter; import lombok.Value; import lombok.With; import lombok.experimental.Tolerate; @@ -36,166 +35,227 @@ @AllArgsConstructor(access = AccessLevel.PRIVATE) @Getter(value = AccessLevel.NONE) public class OpenAiChatCompletionRequest { + /** List of messages from the conversation. */ @Nonnull List messages; - /** Upto 4 Stop sequences to interrupts token generation and returns a response without them. */ - @Nullable List stop; + @Setter(AccessLevel.NONE) + @Getter(AccessLevel.PACKAGE) + OpenAiChatCompletionConfig config; /** - * Controls the randomness of the completion. + * Creates an OpenAiChatCompletionPrompt with string as user message. * - *

Lower values (e.g. 0.0) make the model more deterministic and repetitive, while higher - * values (e.g. 1.0) make the model more random and creative. + * @param message the message to be added to the prompt */ - @Nullable BigDecimal temperature; + @Tolerate + public OpenAiChatCompletionRequest(@Nonnull final String message) { + this(OpenAiMessage.user(message)); + } /** - * Controls the cumulative probability threshold used for nucleus sampling. Alternative to {@link - * #temperature}. + * Creates an OpenAiChatCompletionPrompt with a multiple unpacked messages. * - *

Lower values (e.g. 0.1) limit the model to consider only the smallest set of tokens whose - * combined probabilities add up to at least 10% of the total. + * @param message the primary message to be added to the prompt + * @param messages additional messages to be added to the prompt */ - @Nullable BigDecimal topP; - - /** Maximum number of tokens that can be generated for the completion. */ - @Nullable Integer maxTokens; + @Tolerate + public OpenAiChatCompletionRequest( + @Nonnull final OpenAiMessage message, @Nonnull final OpenAiMessage... messages) { + this(Lists.asList(message, messages)); + } /** - * Maximum number of tokens that can be generated for the completion, including consumed reasoning - * tokens. This field supersedes {@link #maxTokens} and should be used with newer models. + * Creates an OpenAiChatCompletionPrompt with a list of messages. + * + * @param messages the list of messages to be added to the prompt + * @since 1.6.0 */ - @Nullable Integer maxCompletionTokens; + @Tolerate + public OpenAiChatCompletionRequest(@Nonnull final List messages) { + this(List.copyOf(messages), new OpenAiChatCompletionConfig()); + } /** - * Encourage new topic by penalising token based on their presence in the completion. + * Creates a new OpenAiChatCompletionRequest with the specified messages and configuration. * - *

Value should be in range [-2, 2]. + * @param stop the stop sequences to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified stop sequences */ - @Nullable BigDecimal presencePenalty; + @Nonnull + public OpenAiChatCompletionRequest withStop(@Nonnull final List stop) { + return this.withConfig(config.withStop(stop)); + } /** - * Encourage new topic by penalising tokens based on their frequency in the completion. + * Sets the temperature for the request. * - *

Value should be in range [-2, 2]. + * @param temperature the temperature value to be used in the request. + * @return a new OpenAiChatCompletionRequest instance with the specified temperature */ - @Nullable BigDecimal frequencyPenalty; + @Nonnull + public OpenAiChatCompletionRequest withTemperature(@Nonnull final BigDecimal temperature) { + return this.withConfig(config.withTemperature(temperature)); + } /** - * A map that adjusts the likelihood of specified tokens by adding a bias value (between -100 and - * 100) to the logits before sampling. Extreme values can effectively ban or enforce the selection - * of tokens. + * Sets the top-p sampling parameter for the request. + * + * @param topP the top-p value to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified top-p value */ - @Nullable Map logitBias; + @Nonnull + public OpenAiChatCompletionRequest withTopP(@Nonnull final BigDecimal topP) { + return this.withConfig(config.withTopP(topP)); + } /** - * Unique identifier for the end-user making the request. This can help with monitoring and abuse - * detection. + * Sets the maximum number of tokens for the request. + * + * @param maxTokens the maximum number of tokens to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified maximum tokens */ - @Nullable String user; - - /** Whether to include log probabilities in the response. */ - @With(AccessLevel.NONE) - @Nullable - Boolean logprobs; + @Nonnull + public OpenAiChatCompletionRequest withMaxTokens(@Nonnull final Integer maxTokens) { + return this.withConfig(config.withMaxTokens(maxTokens)); + } /** - * Number of top log probabilities to return for each token. An integer between 0 and 20. This is - * only relevant if {@code logprobs} is enabled. + * Sets the maximum number of completion tokens for the request. + * + * @param maxCompletionTokens the maximum number of completion tokens to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified maximum completion tokens */ - @Nullable Integer topLogprobs; + @Nonnull + public OpenAiChatCompletionRequest withMaxCompletionTokens( + @Nonnull final Integer maxCompletionTokens) { + return this.withConfig(config.withMaxCompletionTokens(maxCompletionTokens)); + } - /** Number of completions to generate. */ - @Nullable Integer n; + /** + * Sets the presence penalty for the request. + * + * @param presencePenalty the presence penalty value to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified presence penalty + */ + @Nonnull + public OpenAiChatCompletionRequest withPresencePenalty( + @Nonnull final BigDecimal presencePenalty) { + return this.withConfig(config.withPresencePenalty(presencePenalty)); + } - /** Whether to allow parallel tool calls. */ - @With(AccessLevel.NONE) - @Nullable - Boolean parallelToolCalls; + /** + * Sets the frequency penalty for the request. + * + * @param frequencyPenalty the frequency penalty value to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified frequency penalty + */ + @Nonnull + public OpenAiChatCompletionRequest withFrequencyPenalty( + @Nonnull final BigDecimal frequencyPenalty) { + return this.withConfig(config.withFrequencyPenalty(frequencyPenalty)); + } - /** Seed for random number generation. */ - @Nullable Integer seed; + /** + * Sets the top log probabilities for the request. + * + * @param topLogprobs the number of top log probabilities to be included in the response + * @return a new OpenAiChatCompletionRequest instance with the specified top log probabilities + */ + @Nonnull + public OpenAiChatCompletionRequest withTopLogprobs(@Nonnull final Integer topLogprobs) { + return this.withConfig(config.withTopLogprobs(topLogprobs)); + } - /** Options for streaming the completion response. */ - @Nullable ChatCompletionStreamOptions streamOptions; + /** + * Sets the user identifier for the request. + * + * @param user the user identifier to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified user identifier + */ + @Nonnull + public OpenAiChatCompletionRequest withUser(@Nonnull final String user) { + return this.withConfig(config.withUser(user)); + } - /** Response format for the completion. */ - @Nullable CreateChatCompletionRequestAllOfResponseFormat responseFormat; + /** + * Sets the logit bias for the request. + * + * @param logitBias the logit bias map to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified logit bias + */ + @Nonnull + public OpenAiChatCompletionRequest withLogitBias(@Nonnull final Map logitBias) { + return this.withConfig(config.withLogitBias(logitBias)); + } /** - * Tools the model may invoke during chat completion (metadata only). + * Sets the number of completions to generate for the request. * - *

Use {@link #withToolsExecutable} for registering executable tools. + * @param n the number of completions to generate + * @return a new OpenAiChatCompletionRequest instance with the specified number of completions */ - @Nullable List tools; + @Nonnull + public OpenAiChatCompletionRequest withN(@Nonnull final Integer n) { + return this.withConfig(config.withN(n)); + } /** - * Tools the model may invoke during chat completion that are also executable at application - * runtime. + * Sets the random seed for the request. * - * @since 1.8.0 + * @param seed the random seed to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified random seed */ - @Getter(value = AccessLevel.PACKAGE) - @Nullable - List toolsExecutable; + @Nonnull + public OpenAiChatCompletionRequest withSeed(@Nonnull final Integer seed) { + return this.withConfig(config.withSeed(seed)); + } - /** Option to control which tool is invoked by the model. */ - @With(AccessLevel.PRIVATE) - @Nullable - ChatCompletionToolChoiceOption toolChoice; + /** + * Sets the stream options for the request. + * + * @param streamOptions the stream options to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified stream options + */ + @Nonnull + public OpenAiChatCompletionRequest withStreamOptions( + @Nonnull final ChatCompletionStreamOptions streamOptions) { + return this.withConfig(config.withStreamOptions(streamOptions)); + } /** - * Creates an OpenAiChatCompletionPrompt with string as user message. + * Sets the response format for the request. * - * @param message the message to be added to the prompt + * @param responseFormat the response format to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified response format */ - @Tolerate - public OpenAiChatCompletionRequest(@Nonnull final String message) { - this(OpenAiMessage.user(message)); + @Nonnull + public OpenAiChatCompletionRequest withResponseFormat( + @Nonnull final CreateChatCompletionRequestAllOfResponseFormat responseFormat) { + return this.withConfig(config.withResponseFormat(responseFormat)); } /** - * Creates an OpenAiChatCompletionPrompt with a multiple unpacked messages. + * Sets the tools for the request. * - * @param message the primary message to be added to the prompt - * @param messages additional messages to be added to the prompt + * @param tools the list of tools to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified tools */ - @Tolerate - public OpenAiChatCompletionRequest( - @Nonnull final OpenAiMessage message, @Nonnull final OpenAiMessage... messages) { - this(Lists.asList(message, messages)); + @Nonnull + public OpenAiChatCompletionRequest withTools(@Nonnull final List tools) { + return this.withConfig(config.withTools(tools)); } /** - * Creates an OpenAiChatCompletionPrompt with a list of messages. + * Sets the executable tools for the request. * - * @param messages the list of messages to be added to the prompt - * @since 1.6.0 + * @param toolsExecutable the list of executable tools to be used in the request + * @return a new OpenAiChatCompletionRequest instance with the specified executable tools */ - @Tolerate - public OpenAiChatCompletionRequest(@Nonnull final List messages) { - this( - List.copyOf(messages), - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null); + @Nonnull + public OpenAiChatCompletionRequest withToolsExecutable( + @Nonnull final List toolsExecutable) { + return this.withConfig(config.withToolsExecutable(toolsExecutable)); } /** @@ -209,7 +269,7 @@ public OpenAiChatCompletionRequest(@Nonnull final List messages) @Nonnull public OpenAiChatCompletionRequest withStop( @Nonnull final String sequence, @Nonnull final String... sequences) { - return this.withStop(Lists.asList(sequence, sequences)); + return withStop(Lists.asList(sequence, sequences)); } /** @@ -221,29 +281,7 @@ public OpenAiChatCompletionRequest withStop( @Nonnull public OpenAiChatCompletionRequest withParallelToolCalls( @Nonnull final Boolean parallelToolCalls) { - return Objects.equals(this.parallelToolCalls, parallelToolCalls) - ? this - : new OpenAiChatCompletionRequest( - this.messages, - this.stop, - this.temperature, - this.topP, - this.maxTokens, - this.maxCompletionTokens, - this.presencePenalty, - this.frequencyPenalty, - this.logitBias, - this.user, - this.logprobs, - this.topLogprobs, - this.n, - parallelToolCalls, - this.seed, - this.streamOptions, - this.responseFormat, - this.tools, - this.toolsExecutable, - this.toolChoice); + return this.withConfig(config.withParallelToolCalls(parallelToolCalls)); } /** @@ -254,29 +292,7 @@ public OpenAiChatCompletionRequest withParallelToolCalls( */ @Nonnull public OpenAiChatCompletionRequest withLogprobs(@Nonnull final Boolean logprobs) { - return Objects.equals(this.logprobs, logprobs) - ? this - : new OpenAiChatCompletionRequest( - this.messages, - this.stop, - this.temperature, - this.topP, - this.maxTokens, - this.maxCompletionTokens, - this.presencePenalty, - this.frequencyPenalty, - this.logitBias, - this.user, - logprobs, - this.topLogprobs, - this.n, - this.parallelToolCalls, - this.seed, - this.streamOptions, - this.responseFormat, - this.tools, - this.toolsExecutable, - this.toolChoice); + return this.withConfig(config.withLogprobs(logprobs)); } /** @@ -295,9 +311,8 @@ public OpenAiChatCompletionRequest withLogprobs(@Nonnull final Boolean logprobs) * @return the current OpenAiChatCompletionRequest instance. */ @Nonnull - @Tolerate public OpenAiChatCompletionRequest withToolChoice(@Nonnull final OpenAiToolChoice choice) { - return this.withToolChoice(choice.toolChoice); + return this.withConfig(config.withToolChoice(choice)); } /** @@ -311,27 +326,28 @@ CreateChatCompletionRequest createCreateChatCompletionRequest() { message -> request.addMessagesItem(OpenAiUtils.createChatCompletionRequestMessage(message))); - request.stop(this.stop != null ? CreateChatCompletionRequestAllOfStop.create(this.stop) : null); + request.stop( + config.stop != null ? CreateChatCompletionRequestAllOfStop.create(config.stop) : null); - request.temperature(this.temperature); - request.topP(this.topP); + request.temperature(config.temperature); + request.topP(config.topP); request.stream(null); - request.maxTokens(this.maxTokens); - request.maxCompletionTokens(this.maxCompletionTokens); - request.presencePenalty(this.presencePenalty); - request.frequencyPenalty(this.frequencyPenalty); - request.logitBias(this.logitBias); - request.user(this.user); - request.logprobs(this.logprobs); - request.topLogprobs(this.topLogprobs); - request.n(this.n); - request.parallelToolCalls(this.parallelToolCalls); - request.seed(this.seed); - request.streamOptions(this.streamOptions); - request.responseFormat(this.responseFormat); + request.maxTokens(config.maxTokens); + request.maxCompletionTokens(config.maxCompletionTokens); + request.presencePenalty(config.presencePenalty); + request.frequencyPenalty(config.frequencyPenalty); + request.logitBias(config.logitBias); + request.user(config.user); + request.logprobs(config.logprobs); + request.topLogprobs(config.topLogprobs); + request.n(config.n); + request.parallelToolCalls(config.parallelToolCalls); + request.seed(config.seed); + request.streamOptions(config.streamOptions); + request.responseFormat(config.responseFormat); request.tools(getChatCompletionTools()); - request.toolChoice(this.toolChoice); + request.toolChoice(config.toolChoice != null ? config.toolChoice.toolChoice : null); request.functionCall(null); request.functions(null); return request; @@ -340,11 +356,11 @@ CreateChatCompletionRequest createCreateChatCompletionRequest() { @Nullable private List getChatCompletionTools() { final var toolsCombined = new ArrayList(); - if (this.tools != null) { - toolsCombined.addAll(this.tools); + if (config.tools != null) { + toolsCombined.addAll(config.tools); } - if (this.toolsExecutable != null) { - for (final OpenAiTool tool : this.toolsExecutable) { + if (config.getToolsExecutable() != null) { + for (final OpenAiTool tool : config.getToolsExecutable()) { toolsCombined.add(tool.createChatCompletionTool()); } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java index 32131d7b1..ff411e7d0 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java @@ -112,7 +112,7 @@ public OpenAiAssistantMessage getMessage() { */ @Nonnull public List executeTools() { - final var tools = originalRequest.getToolsExecutable(); + final var tools = originalRequest.getConfig().getToolsExecutable(); return OpenAiTool.execute(tools != null ? tools : List.of(), getMessage()); } } diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java index 9c400c4f4..9a4d3ff27 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai; import com.google.common.annotations.Beta; +import javax.annotation.Nonnull; /** * Represents a tool called by an OpenAI model. @@ -8,4 +9,19 @@ * @since 1.6.0 */ @Beta -public sealed interface OpenAiToolCall permits OpenAiFunctionCall {} +public sealed interface OpenAiToolCall permits OpenAiFunctionCall { + /** + * Creates a new instance of {@link OpenAiToolCall}. + * + * @param id The unique identifier for the tool call. + * @param name The name of the tool to be called. + * @param arguments The arguments for the tool call, encoded as a JSON string. + * @return A new instance of {@link OpenAiToolCall}. + * @since 1.10.0 + */ + @Nonnull + static OpenAiToolCall function( + @Nonnull final String id, @Nonnull final String name, @Nonnull final String arguments) { + return new OpenAiFunctionCall(id, name, arguments); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java new file mode 100644 index 000000000..2d7dee093 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java @@ -0,0 +1,115 @@ +package com.sap.ai.sdk.foundationmodels.openai.spring; + +import static org.springframework.ai.model.tool.ToolCallingChatOptions.isInternalToolExecutionEnabled; + +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionMessageToolCall; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessage; +import io.vavr.control.Option; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; +import lombok.val; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.DefaultToolCallingManager; + +/** + * OpenAI Chat Model implementation that interacts with the OpenAI API to generate chat completions. + */ +@RequiredArgsConstructor +public class OpenAiChatModel implements ChatModel { + + private final OpenAiClient client; + + @Nonnull + private final DefaultToolCallingManager toolCallingManager = + DefaultToolCallingManager.builder().build(); + + @Override + @Nonnull + public ChatResponse call(@Nonnull final Prompt prompt) { + if (!(prompt.getOptions() instanceof OpenAiChatOptions options)) { + throw new IllegalArgumentException( + "Please add OpenAiChatOptions to the Prompt: new Prompt(\"message\", new OpenAiChatOptions(config))"); + } + val openAiRequest = toOpenAiRequest(prompt); + val request = new OpenAiChatCompletionRequest(openAiRequest).withTools(options.getTools()); + val result = client.chatCompletion(request); + val response = new ChatResponse(toGenerations(result)); + + if (isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) { + val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response); + // Send the tool execution result back to the model. + return call(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions())); + } + return response; + } + + private List toOpenAiRequest(final Prompt prompt) { + final List result = new ArrayList<>(); + for (final Message message : prompt.getInstructions()) { + switch (message.getMessageType()) { + case USER -> Option.of(message.getText()).peek(t -> result.add(OpenAiMessage.user(t))); + case SYSTEM -> Option.of(message.getText()).peek(t -> result.add(OpenAiMessage.system(t))); + case ASSISTANT -> addAssistantMessage(result, (AssistantMessage) message); + case TOOL -> addToolMessages(result, (ToolResponseMessage) message); + } + } + return result; + } + + private static void addAssistantMessage( + final List result, final AssistantMessage message) { + if (message.getText() == null) { + return; + } + if (!message.hasToolCalls()) { + result.add(OpenAiMessage.assistant(message.getText())); + return; + } + final Function callTranslate = + toolCall -> OpenAiToolCall.function(toolCall.id(), toolCall.name(), toolCall.arguments()); + val calls = message.getToolCalls().stream().map(callTranslate).toList(); + result.add(OpenAiMessage.assistant(message.getText()).withToolCalls(calls)); + } + + private static void addToolMessages( + final List result, final ToolResponseMessage message) { + for (final ToolResponseMessage.ToolResponse response : message.getResponses()) { + result.add(OpenAiMessage.tool(response.responseData(), response.id())); + } + } + + @Nonnull + static List toGenerations(@Nonnull final OpenAiChatCompletionResponse result) { + return result.getOriginalResponse().getChoices().stream() + .map(message -> toGeneration(message.getMessage())) + .toList(); + } + + @Nonnull + static Generation toGeneration(@Nonnull final ChatCompletionResponseMessage choice) { + // no metadata for now + val calls = new ArrayList(); + for (final ChatCompletionMessageToolCall c : choice.getToolCalls()) { + val fnc = c.getFunction(); + calls.add(new ToolCall(c.getId(), c.getType().getValue(), fnc.getName(), fnc.getArguments())); + } + val message = new AssistantMessage(choice.getContent(), Map.of(), calls); + return new Generation(message); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatOptions.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatOptions.java new file mode 100644 index 000000000..d724c89d8 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatOptions.java @@ -0,0 +1,131 @@ +package com.sap.ai.sdk.foundationmodels.openai.spring; + +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionConfig; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; +import io.vavr.control.Option; +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import java.util.Set; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.Data; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.val; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; + +/** OpenAI Chat Options for configuring tool callbacks and execution settings. */ +@Data +@NoArgsConstructor +public class OpenAiChatOptions implements ToolCallingChatOptions { + + @Nonnull private OpenAiChatCompletionConfig config; + + @Nonnull private List toolCallbacks = List.of(); + + @Nonnull private List tools = List.of(); + + @Getter(AccessLevel.NONE) + @Nullable + private Boolean internalToolExecutionEnabled; + + @Nonnull private Set toolNames = Set.of(); + + @Nonnull private Map toolContext = Map.of(); + + @Override + public void setToolCallbacks(@Nonnull final List toolCallbacks) { + this.toolCallbacks = toolCallbacks; + tools = toolCallbacks.stream().map(OpenAiChatOptions::toOpenAiTool).toList(); + } + + @Nullable + @Override + public Boolean getInternalToolExecutionEnabled() { + return this.internalToolExecutionEnabled; + } + + private static ChatCompletionTool toOpenAiTool(final ToolCallback toolCallback) { + val toolDef = toolCallback.getToolDefinition(); + val functionobject = + new FunctionObject() + .name(toolDef.name()) + .description(toolDef.description()) + .parameters(ModelOptionsUtils.jsonToMap(toolDef.inputSchema())); + return new ChatCompletionTool().type(TypeEnum.FUNCTION).function(functionobject); + } + + @Override + public void setInternalToolExecutionEnabled( + @Nullable final Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + @Override + @Nonnull + public String getModel() { + throw new UnsupportedOperationException( + "Model declaration not supported in OpenAI integration."); + } + + @Override + @Nullable + public Double getFrequencyPenalty() { + return Option.of(config.getFrequencyPenalty()).map(BigDecimal::doubleValue).getOrNull(); + } + + @Override + @Nullable + public Integer getMaxTokens() { + return config.getMaxTokens(); + } + + @Override + @Nullable + public Double getPresencePenalty() { + return Option.of(config.getPresencePenalty()).map(BigDecimal::doubleValue).getOrNull(); + } + + @Override + @Nullable + public List getStopSequences() { + return config.getStop(); + } + + @Override + @Nullable + public Double getTemperature() { + return Option.of(config.getTemperature()).map(BigDecimal::doubleValue).getOrNull(); + } + + @Override + @Nullable // this is available here but not in OpenAiChatCompletionConfig so added it there ? + public Integer getTopK() { + return config.getTopK(); + } + + @Override + @Nullable + public Double getTopP() { + return Option.of(config.getTopP()).map(BigDecimal::doubleValue).getOrNull(); + } + + @Override + @Nonnull + public T copy() { + final OpenAiChatOptions copy = new OpenAiChatOptions(); + copy.setToolCallbacks(this.toolCallbacks); + copy.setInternalToolExecutionEnabled(this.internalToolExecutionEnabled); + copy.setTools(this.tools); + copy.setToolNames(this.toolNames); + copy.setToolContext(this.toolContext); + return (T) copy; + } +} diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java index e1ff3b343..8f5d77bad 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java @@ -2,20 +2,35 @@ import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel; +import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiChatModel; +import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiChatOptions; import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiSpringEmbeddingModel; import java.util.List; +import java.util.Objects; import javax.annotation.Nonnull; +import lombok.val; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.support.ToolCallbacks; import org.springframework.stereotype.Service; /** Service class for Spring AI integration with OpenAI */ @Service public class SpringAiOpenAiService { - private final OpenAiClient client = OpenAiClient.forModel(OpenAiModel.TEXT_EMBEDDING_3_SMALL); + private final OpenAiSpringEmbeddingModel embeddingClient = + new OpenAiSpringEmbeddingModel(OpenAiClient.forModel(OpenAiModel.TEXT_EMBEDDING_3_SMALL)); + private final ChatModel chatClient = + new OpenAiChatModel(OpenAiClient.forModel(OpenAiModel.GPT_4O_MINI)); /** * Embeds a list of strings using the OpenAI embedding model. @@ -28,7 +43,69 @@ public EmbeddingResponse embedStrings() { final var springAiRequest = new EmbeddingRequest(List.of("The quick brown fox jumps over the lazy dog."), options); - return new OpenAiSpringEmbeddingModel(client).call(springAiRequest); + return embeddingClient.call(springAiRequest); + } + + /** + * Chat request to OpenAI through the OpenAI service with a simple prompt. + * + * @return the assistant response object + */ + @Nonnull + public ChatResponse completion() { + val options = new OpenAiChatOptions(); + val prompt = new Prompt("What is the capital of France?", options); + return chatClient.call(prompt); + } + + /** + * Asynchronous stream of an OpenAI chat request + * + * @return a stream of assistant message responses + */ + @Nonnull + public ChatResponse streamChatCompletion() { + val options = new OpenAiChatOptions(); + val prompt = + new Prompt("Can you give me the first 100 numbers of the Fibonacci sequence?", options); + return chatClient.call(prompt); + } + + /** + * Turn a method into a tool by annotating it with @Tool. Spring AI + * Tool Method Declarative Specification + * + * @param internalToolExecutionEnabled whether the internal tool execution is enabled + * @return the assistant response object + */ + @Nonnull + public ChatResponse toolCalling(final boolean internalToolExecutionEnabled) { + val options = new OpenAiChatOptions(); + options.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod()))); + options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + + val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", options); + return chatClient.call(prompt); + } + + /** + * Chat request to OpenAI through the OpenAI service using chat memory. + * + * @return the assistant response object + */ + @Nonnull + public ChatResponse chatMemory() { + val repository = new InMemoryChatMemoryRepository(); + val memory = MessageWindowChatMemory.builder().chatMemoryRepository(repository).build(); + val advisor = MessageChatMemoryAdvisor.builder(memory).build(); + val cl = ChatClient.builder(chatClient).defaultAdvisors(advisor).build(); + val prompt1 = new Prompt("What is the capital of France?", new OpenAiChatOptions()); + val prompt2 = new Prompt("And what is the typical food there?", new OpenAiChatOptions()); + + cl.prompt(prompt1).call().content(); + return Objects.requireNonNull( + cl.prompt(prompt2).call().chatResponse(), "Chat response is null"); } /** @@ -39,6 +116,6 @@ public EmbeddingResponse embedStrings() { @Nonnull public float[] embedDocument() { final var document = new Document("The quick brown fox jumps over the lazy dog."); - return new OpenAiSpringEmbeddingModel(client).embed(document); + return embeddingClient.embed(document); } } diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java index 7d3ea42fd..a1f56d57b 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java @@ -4,11 +4,16 @@ import com.sap.ai.sdk.app.services.SpringAiOpenAiService; import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel; +import java.util.List; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; class SpringAiOpenAiTest { private final SpringAiOpenAiService service = new SpringAiOpenAiService(); + private static final org.slf4j.Logger log = + org.slf4j.LoggerFactory.getLogger(SpringAiOrchestrationTest.class); @Test void testEmbedStrings() { @@ -23,4 +28,52 @@ void testEmbedStrings() { assertThat(response.getMetadata().getModel()) .isEqualTo(OpenAiModel.TEXT_EMBEDDING_3_SMALL.name()); } + + @Test + void testCompletion() { + ChatResponse response = service.completion(); + assertThat(response).isNotNull(); + assertThat(response.getResult().getOutput().getText()).contains("Paris"); + } + + @Test + void testStreamChatCompletion() { + ChatResponse response = service.streamChatCompletion(); + assertThat(response).isNotNull(); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + } + + @Test + void testToolCallingWithExecution() { + ChatResponse response = service.toolCalling(true); + assertThat(response.getResult().getOutput().getText()).contains("Potsdam", "Toulouse", "°C"); + } + + @Test + void testToolCallingWithoutExecution() { + ChatResponse response = service.toolCalling(false); + List toolCalls = response.getResult().getOutput().getToolCalls(); + assertThat(toolCalls).hasSize(2); + AssistantMessage.ToolCall toolCall1 = toolCalls.get(0); + AssistantMessage.ToolCall toolCall2 = toolCalls.get(1); + assertThat(toolCall1.type()).isEqualTo("function"); + assertThat(toolCall2.type()).isEqualTo("function"); + assertThat(toolCall1.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall2.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall1.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}"); + assertThat(toolCall2.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}"); + } + + @Test + void testChatMemory() { + ChatResponse response = service.chatMemory(); + assertThat(response).isNotNull(); + String text = response.getResult().getOutput().getText(); + log.info(text); + assertThat(text) + .containsAnyOf( + "French", "onion", "pastries", "cheese", "baguette", "coq au vin", "foie gras"); + } }