stop;
+
+ /**
+ * Controls the randomness of the completion.
+ *
+ * Lower values (e.g. 0.0) make the model more deterministic and repetitive, while higher
+ * values (e.g. 1.0) make the model more random and creative.
+ */
+ @Nullable BigDecimal temperature;
+
+ /**
+ * Controls the cumulative probability threshold used for nucleus sampling. Alternative to {@link
+ * #temperature}.
+ *
+ *
Lower values (e.g. 0.1) limit the model to consider only the smallest set of tokens whose
+ * combined probabilities add up to at least 10% of the total.
+ */
+ @Nullable BigDecimal topP;
+
+ /**
+ * Controls the number of top tokens to consider for sampling.
+ *
+ *
Higher values (e.g. 50) allow the model to consider more tokens, while lower values (e.g. 1)
+ * restrict it to the most probable token.
+ */
+ @Nullable Integer topK;
+
+ /** Maximum number of tokens that can be generated for the completion. */
+ @Nullable Integer maxTokens;
+
+ /**
+ * Maximum number of tokens that can be generated for the completion, including consumed reasoning
+ * tokens. This field supersedes {@link #maxTokens} and should be used with newer models.
+ */
+ @Nullable Integer maxCompletionTokens;
+
+ /**
+ * Encourage new topic by penalising token based on their presence in the completion.
+ *
+ *
Value should be in range [-2, 2].
+ */
+ @Nullable BigDecimal presencePenalty;
+
+ /**
+ * Encourage new topic by penalising tokens based on their frequency in the completion.
+ *
+ *
Value should be in range [-2, 2].
+ */
+ @Nullable BigDecimal frequencyPenalty;
+
+ /**
+ * A map that adjusts the likelihood of specified tokens by adding a bias value (between -100 and
+ * 100) to the logits before sampling. Extreme values can effectively ban or enforce the selection
+ * of tokens.
+ */
+ @Nullable Map logitBias;
+
+ /**
+ * Unique identifier for the end-user making the request. This can help with monitoring and abuse
+ * detection.
+ */
+ @Nullable String user;
+
+ /** Whether to include log probabilities in the response. */
+ @Nullable Boolean logprobs;
+
+ /**
+ * Number of top log probabilities to return for each token. An integer between 0 and 20. This is
+ * only relevant if {@code logprobs} is enabled.
+ */
+ @Nullable Integer topLogprobs;
+
+ /** Number of completions to generate. */
+ @Nullable Integer n;
+
+ /** Whether to allow parallel tool calls. */
+ @Nullable Boolean parallelToolCalls;
+
+ /** Seed for random number generation. */
+ @Nullable Integer seed;
+
+ /** Options for streaming the completion response. */
+ @Nullable ChatCompletionStreamOptions streamOptions;
+
+ /** Response format for the completion. */
+ @Nullable CreateChatCompletionRequestAllOfResponseFormat responseFormat;
+
+ /**
+ * Tools the model may invoke during chat completion (metadata only).
+ *
+ * Use {@link #withToolsExecutable} for registering executable tools.
+ */
+ @Nullable List tools;
+
+ /**
+ * Tools the model may invoke during chat completion that are also executable at application
+ * runtime.
+ *
+ * @since 1.8.0
+ */
+ @Getter(value = AccessLevel.PACKAGE)
+ @Nullable
+ List toolsExecutable;
+
+ /** Option to control which tool is invoked by the model. */
+ @Nullable OpenAiToolChoice toolChoice;
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java
index 4dfef7f39..cd413d1f2 100644
--- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionRequest.java
@@ -4,7 +4,6 @@
import com.google.common.collect.Lists;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionStreamOptions;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
-import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionToolChoiceOption;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequest;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfResponseFormat;
import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionRequestAllOfStop;
@@ -12,12 +11,12 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
+import lombok.Setter;
import lombok.Value;
import lombok.With;
import lombok.experimental.Tolerate;
@@ -36,166 +35,227 @@
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@Getter(value = AccessLevel.NONE)
public class OpenAiChatCompletionRequest {
+
/** List of messages from the conversation. */
@Nonnull List messages;
- /** Upto 4 Stop sequences to interrupts token generation and returns a response without them. */
- @Nullable List stop;
+ @Setter(AccessLevel.NONE)
+ @Getter(AccessLevel.PACKAGE)
+ OpenAiChatCompletionConfig config;
/**
- * Controls the randomness of the completion.
+ * Creates an OpenAiChatCompletionPrompt with string as user message.
*
- * Lower values (e.g. 0.0) make the model more deterministic and repetitive, while higher
- * values (e.g. 1.0) make the model more random and creative.
+ * @param message the message to be added to the prompt
*/
- @Nullable BigDecimal temperature;
+ @Tolerate
+ public OpenAiChatCompletionRequest(@Nonnull final String message) {
+ this(OpenAiMessage.user(message));
+ }
/**
- * Controls the cumulative probability threshold used for nucleus sampling. Alternative to {@link
- * #temperature}.
+ * Creates an OpenAiChatCompletionPrompt with a multiple unpacked messages.
*
- *
Lower values (e.g. 0.1) limit the model to consider only the smallest set of tokens whose
- * combined probabilities add up to at least 10% of the total.
+ * @param message the primary message to be added to the prompt
+ * @param messages additional messages to be added to the prompt
*/
- @Nullable BigDecimal topP;
-
- /** Maximum number of tokens that can be generated for the completion. */
- @Nullable Integer maxTokens;
+ @Tolerate
+ public OpenAiChatCompletionRequest(
+ @Nonnull final OpenAiMessage message, @Nonnull final OpenAiMessage... messages) {
+ this(Lists.asList(message, messages));
+ }
/**
- * Maximum number of tokens that can be generated for the completion, including consumed reasoning
- * tokens. This field supersedes {@link #maxTokens} and should be used with newer models.
+ * Creates an OpenAiChatCompletionPrompt with a list of messages.
+ *
+ * @param messages the list of messages to be added to the prompt
+ * @since 1.6.0
*/
- @Nullable Integer maxCompletionTokens;
+ @Tolerate
+ public OpenAiChatCompletionRequest(@Nonnull final List messages) {
+ this(List.copyOf(messages), new OpenAiChatCompletionConfig());
+ }
/**
- * Encourage new topic by penalising token based on their presence in the completion.
+ * Creates a new OpenAiChatCompletionRequest with the specified messages and configuration.
*
- * Value should be in range [-2, 2].
+ * @param stop the stop sequences to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified stop sequences
*/
- @Nullable BigDecimal presencePenalty;
+ @Nonnull
+ public OpenAiChatCompletionRequest withStop(@Nonnull final List stop) {
+ return this.withConfig(config.withStop(stop));
+ }
/**
- * Encourage new topic by penalising tokens based on their frequency in the completion.
+ * Sets the temperature for the request.
*
- * Value should be in range [-2, 2].
+ * @param temperature the temperature value to be used in the request.
+ * @return a new OpenAiChatCompletionRequest instance with the specified temperature
*/
- @Nullable BigDecimal frequencyPenalty;
+ @Nonnull
+ public OpenAiChatCompletionRequest withTemperature(@Nonnull final BigDecimal temperature) {
+ return this.withConfig(config.withTemperature(temperature));
+ }
/**
- * A map that adjusts the likelihood of specified tokens by adding a bias value (between -100 and
- * 100) to the logits before sampling. Extreme values can effectively ban or enforce the selection
- * of tokens.
+ * Sets the top-p sampling parameter for the request.
+ *
+ * @param topP the top-p value to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified top-p value
*/
- @Nullable Map logitBias;
+ @Nonnull
+ public OpenAiChatCompletionRequest withTopP(@Nonnull final BigDecimal topP) {
+ return this.withConfig(config.withTopP(topP));
+ }
/**
- * Unique identifier for the end-user making the request. This can help with monitoring and abuse
- * detection.
+ * Sets the maximum number of tokens for the request.
+ *
+ * @param maxTokens the maximum number of tokens to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified maximum tokens
*/
- @Nullable String user;
-
- /** Whether to include log probabilities in the response. */
- @With(AccessLevel.NONE)
- @Nullable
- Boolean logprobs;
+ @Nonnull
+ public OpenAiChatCompletionRequest withMaxTokens(@Nonnull final Integer maxTokens) {
+ return this.withConfig(config.withMaxTokens(maxTokens));
+ }
/**
- * Number of top log probabilities to return for each token. An integer between 0 and 20. This is
- * only relevant if {@code logprobs} is enabled.
+ * Sets the maximum number of completion tokens for the request.
+ *
+ * @param maxCompletionTokens the maximum number of completion tokens to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified maximum completion tokens
*/
- @Nullable Integer topLogprobs;
+ @Nonnull
+ public OpenAiChatCompletionRequest withMaxCompletionTokens(
+ @Nonnull final Integer maxCompletionTokens) {
+ return this.withConfig(config.withMaxCompletionTokens(maxCompletionTokens));
+ }
- /** Number of completions to generate. */
- @Nullable Integer n;
+ /**
+ * Sets the presence penalty for the request.
+ *
+ * @param presencePenalty the presence penalty value to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified presence penalty
+ */
+ @Nonnull
+ public OpenAiChatCompletionRequest withPresencePenalty(
+ @Nonnull final BigDecimal presencePenalty) {
+ return this.withConfig(config.withPresencePenalty(presencePenalty));
+ }
- /** Whether to allow parallel tool calls. */
- @With(AccessLevel.NONE)
- @Nullable
- Boolean parallelToolCalls;
+ /**
+ * Sets the frequency penalty for the request.
+ *
+ * @param frequencyPenalty the frequency penalty value to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified frequency penalty
+ */
+ @Nonnull
+ public OpenAiChatCompletionRequest withFrequencyPenalty(
+ @Nonnull final BigDecimal frequencyPenalty) {
+ return this.withConfig(config.withFrequencyPenalty(frequencyPenalty));
+ }
- /** Seed for random number generation. */
- @Nullable Integer seed;
+ /**
+ * Sets the top log probabilities for the request.
+ *
+ * @param topLogprobs the number of top log probabilities to be included in the response
+ * @return a new OpenAiChatCompletionRequest instance with the specified top log probabilities
+ */
+ @Nonnull
+ public OpenAiChatCompletionRequest withTopLogprobs(@Nonnull final Integer topLogprobs) {
+ return this.withConfig(config.withTopLogprobs(topLogprobs));
+ }
- /** Options for streaming the completion response. */
- @Nullable ChatCompletionStreamOptions streamOptions;
+ /**
+ * Sets the user identifier for the request.
+ *
+ * @param user the user identifier to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified user identifier
+ */
+ @Nonnull
+ public OpenAiChatCompletionRequest withUser(@Nonnull final String user) {
+ return this.withConfig(config.withUser(user));
+ }
- /** Response format for the completion. */
- @Nullable CreateChatCompletionRequestAllOfResponseFormat responseFormat;
+ /**
+ * Sets the logit bias for the request.
+ *
+ * @param logitBias the logit bias map to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified logit bias
+ */
+ @Nonnull
+ public OpenAiChatCompletionRequest withLogitBias(@Nonnull final Map logitBias) {
+ return this.withConfig(config.withLogitBias(logitBias));
+ }
/**
- * Tools the model may invoke during chat completion (metadata only).
+ * Sets the number of completions to generate for the request.
*
- * Use {@link #withToolsExecutable} for registering executable tools.
+ * @param n the number of completions to generate
+ * @return a new OpenAiChatCompletionRequest instance with the specified number of completions
*/
- @Nullable List tools;
+ @Nonnull
+ public OpenAiChatCompletionRequest withN(@Nonnull final Integer n) {
+ return this.withConfig(config.withN(n));
+ }
/**
- * Tools the model may invoke during chat completion that are also executable at application
- * runtime.
+ * Sets the random seed for the request.
*
- * @since 1.8.0
+ * @param seed the random seed to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified random seed
*/
- @Getter(value = AccessLevel.PACKAGE)
- @Nullable
- List toolsExecutable;
+ @Nonnull
+ public OpenAiChatCompletionRequest withSeed(@Nonnull final Integer seed) {
+ return this.withConfig(config.withSeed(seed));
+ }
- /** Option to control which tool is invoked by the model. */
- @With(AccessLevel.PRIVATE)
- @Nullable
- ChatCompletionToolChoiceOption toolChoice;
+ /**
+ * Sets the stream options for the request.
+ *
+ * @param streamOptions the stream options to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified stream options
+ */
+ @Nonnull
+ public OpenAiChatCompletionRequest withStreamOptions(
+ @Nonnull final ChatCompletionStreamOptions streamOptions) {
+ return this.withConfig(config.withStreamOptions(streamOptions));
+ }
/**
- * Creates an OpenAiChatCompletionPrompt with string as user message.
+ * Sets the response format for the request.
*
- * @param message the message to be added to the prompt
+ * @param responseFormat the response format to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified response format
*/
- @Tolerate
- public OpenAiChatCompletionRequest(@Nonnull final String message) {
- this(OpenAiMessage.user(message));
+ @Nonnull
+ public OpenAiChatCompletionRequest withResponseFormat(
+ @Nonnull final CreateChatCompletionRequestAllOfResponseFormat responseFormat) {
+ return this.withConfig(config.withResponseFormat(responseFormat));
}
/**
- * Creates an OpenAiChatCompletionPrompt with a multiple unpacked messages.
+ * Sets the tools for the request.
*
- * @param message the primary message to be added to the prompt
- * @param messages additional messages to be added to the prompt
+ * @param tools the list of tools to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified tools
*/
- @Tolerate
- public OpenAiChatCompletionRequest(
- @Nonnull final OpenAiMessage message, @Nonnull final OpenAiMessage... messages) {
- this(Lists.asList(message, messages));
+ @Nonnull
+ public OpenAiChatCompletionRequest withTools(@Nonnull final List tools) {
+ return this.withConfig(config.withTools(tools));
}
/**
- * Creates an OpenAiChatCompletionPrompt with a list of messages.
+ * Sets the executable tools for the request.
*
- * @param messages the list of messages to be added to the prompt
- * @since 1.6.0
+ * @param toolsExecutable the list of executable tools to be used in the request
+ * @return a new OpenAiChatCompletionRequest instance with the specified executable tools
*/
- @Tolerate
- public OpenAiChatCompletionRequest(@Nonnull final List messages) {
- this(
- List.copyOf(messages),
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null);
+ @Nonnull
+ public OpenAiChatCompletionRequest withToolsExecutable(
+ @Nonnull final List toolsExecutable) {
+ return this.withConfig(config.withToolsExecutable(toolsExecutable));
}
/**
@@ -209,7 +269,7 @@ public OpenAiChatCompletionRequest(@Nonnull final List messages)
@Nonnull
public OpenAiChatCompletionRequest withStop(
@Nonnull final String sequence, @Nonnull final String... sequences) {
- return this.withStop(Lists.asList(sequence, sequences));
+ return withStop(Lists.asList(sequence, sequences));
}
/**
@@ -221,29 +281,7 @@ public OpenAiChatCompletionRequest withStop(
@Nonnull
public OpenAiChatCompletionRequest withParallelToolCalls(
@Nonnull final Boolean parallelToolCalls) {
- return Objects.equals(this.parallelToolCalls, parallelToolCalls)
- ? this
- : new OpenAiChatCompletionRequest(
- this.messages,
- this.stop,
- this.temperature,
- this.topP,
- this.maxTokens,
- this.maxCompletionTokens,
- this.presencePenalty,
- this.frequencyPenalty,
- this.logitBias,
- this.user,
- this.logprobs,
- this.topLogprobs,
- this.n,
- parallelToolCalls,
- this.seed,
- this.streamOptions,
- this.responseFormat,
- this.tools,
- this.toolsExecutable,
- this.toolChoice);
+ return this.withConfig(config.withParallelToolCalls(parallelToolCalls));
}
/**
@@ -254,29 +292,7 @@ public OpenAiChatCompletionRequest withParallelToolCalls(
*/
@Nonnull
public OpenAiChatCompletionRequest withLogprobs(@Nonnull final Boolean logprobs) {
- return Objects.equals(this.logprobs, logprobs)
- ? this
- : new OpenAiChatCompletionRequest(
- this.messages,
- this.stop,
- this.temperature,
- this.topP,
- this.maxTokens,
- this.maxCompletionTokens,
- this.presencePenalty,
- this.frequencyPenalty,
- this.logitBias,
- this.user,
- logprobs,
- this.topLogprobs,
- this.n,
- this.parallelToolCalls,
- this.seed,
- this.streamOptions,
- this.responseFormat,
- this.tools,
- this.toolsExecutable,
- this.toolChoice);
+ return this.withConfig(config.withLogprobs(logprobs));
}
/**
@@ -295,9 +311,8 @@ public OpenAiChatCompletionRequest withLogprobs(@Nonnull final Boolean logprobs)
* @return the current OpenAiChatCompletionRequest instance.
*/
@Nonnull
- @Tolerate
public OpenAiChatCompletionRequest withToolChoice(@Nonnull final OpenAiToolChoice choice) {
- return this.withToolChoice(choice.toolChoice);
+ return this.withConfig(config.withToolChoice(choice));
}
/**
@@ -311,27 +326,28 @@ CreateChatCompletionRequest createCreateChatCompletionRequest() {
message ->
request.addMessagesItem(OpenAiUtils.createChatCompletionRequestMessage(message)));
- request.stop(this.stop != null ? CreateChatCompletionRequestAllOfStop.create(this.stop) : null);
+ request.stop(
+ config.stop != null ? CreateChatCompletionRequestAllOfStop.create(config.stop) : null);
- request.temperature(this.temperature);
- request.topP(this.topP);
+ request.temperature(config.temperature);
+ request.topP(config.topP);
request.stream(null);
- request.maxTokens(this.maxTokens);
- request.maxCompletionTokens(this.maxCompletionTokens);
- request.presencePenalty(this.presencePenalty);
- request.frequencyPenalty(this.frequencyPenalty);
- request.logitBias(this.logitBias);
- request.user(this.user);
- request.logprobs(this.logprobs);
- request.topLogprobs(this.topLogprobs);
- request.n(this.n);
- request.parallelToolCalls(this.parallelToolCalls);
- request.seed(this.seed);
- request.streamOptions(this.streamOptions);
- request.responseFormat(this.responseFormat);
+ request.maxTokens(config.maxTokens);
+ request.maxCompletionTokens(config.maxCompletionTokens);
+ request.presencePenalty(config.presencePenalty);
+ request.frequencyPenalty(config.frequencyPenalty);
+ request.logitBias(config.logitBias);
+ request.user(config.user);
+ request.logprobs(config.logprobs);
+ request.topLogprobs(config.topLogprobs);
+ request.n(config.n);
+ request.parallelToolCalls(config.parallelToolCalls);
+ request.seed(config.seed);
+ request.streamOptions(config.streamOptions);
+ request.responseFormat(config.responseFormat);
request.tools(getChatCompletionTools());
- request.toolChoice(this.toolChoice);
+ request.toolChoice(config.toolChoice != null ? config.toolChoice.toolChoice : null);
request.functionCall(null);
request.functions(null);
return request;
@@ -340,11 +356,11 @@ CreateChatCompletionRequest createCreateChatCompletionRequest() {
@Nullable
private List getChatCompletionTools() {
final var toolsCombined = new ArrayList();
- if (this.tools != null) {
- toolsCombined.addAll(this.tools);
+ if (config.tools != null) {
+ toolsCombined.addAll(config.tools);
}
- if (this.toolsExecutable != null) {
- for (final OpenAiTool tool : this.toolsExecutable) {
+ if (config.getToolsExecutable() != null) {
+ for (final OpenAiTool tool : config.getToolsExecutable()) {
toolsCombined.add(tool.createChatCompletionTool());
}
}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java
index 32131d7b1..ff411e7d0 100644
--- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiChatCompletionResponse.java
@@ -112,7 +112,7 @@ public OpenAiAssistantMessage getMessage() {
*/
@Nonnull
public List executeTools() {
- final var tools = originalRequest.getToolsExecutable();
+ final var tools = originalRequest.getConfig().getToolsExecutable();
return OpenAiTool.execute(tools != null ? tools : List.of(), getMessage());
}
}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java
index 9c400c4f4..9a4d3ff27 100644
--- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java
@@ -1,6 +1,7 @@
package com.sap.ai.sdk.foundationmodels.openai;
import com.google.common.annotations.Beta;
+import javax.annotation.Nonnull;
/**
* Represents a tool called by an OpenAI model.
@@ -8,4 +9,19 @@
* @since 1.6.0
*/
@Beta
-public sealed interface OpenAiToolCall permits OpenAiFunctionCall {}
+public sealed interface OpenAiToolCall permits OpenAiFunctionCall {
+ /**
+ * Creates a new instance of {@link OpenAiToolCall}.
+ *
+ * @param id The unique identifier for the tool call.
+ * @param name The name of the tool to be called.
+ * @param arguments The arguments for the tool call, encoded as a JSON string.
+ * @return A new instance of {@link OpenAiToolCall}.
+ * @since 1.10.0
+ */
+ @Nonnull
+ static OpenAiToolCall function(
+ @Nonnull final String id, @Nonnull final String name, @Nonnull final String arguments) {
+ return new OpenAiFunctionCall(id, name, arguments);
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java
new file mode 100644
index 000000000..2d7dee093
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java
@@ -0,0 +1,115 @@
+package com.sap.ai.sdk.foundationmodels.openai.spring;
+
+import static org.springframework.ai.model.tool.ToolCallingChatOptions.isInternalToolExecutionEnabled;
+
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage;
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionMessageToolCall;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionResponseMessage;
+import io.vavr.control.Option;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import javax.annotation.Nonnull;
+import lombok.RequiredArgsConstructor;
+import lombok.val;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.messages.AssistantMessage.ToolCall;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.ToolResponseMessage;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.model.Generation;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.model.tool.DefaultToolCallingManager;
+
+/**
+ * OpenAI Chat Model implementation that interacts with the OpenAI API to generate chat completions.
+ */
+@RequiredArgsConstructor
+public class OpenAiChatModel implements ChatModel {
+
+ private final OpenAiClient client;
+
+ @Nonnull
+ private final DefaultToolCallingManager toolCallingManager =
+ DefaultToolCallingManager.builder().build();
+
+ @Override
+ @Nonnull
+ public ChatResponse call(@Nonnull final Prompt prompt) {
+ if (!(prompt.getOptions() instanceof OpenAiChatOptions options)) {
+ throw new IllegalArgumentException(
+ "Please add OpenAiChatOptions to the Prompt: new Prompt(\"message\", new OpenAiChatOptions(config))");
+ }
+ val openAiRequest = toOpenAiRequest(prompt);
+ val request = new OpenAiChatCompletionRequest(openAiRequest).withTools(options.getTools());
+ val result = client.chatCompletion(request);
+ val response = new ChatResponse(toGenerations(result));
+
+ if (isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) {
+ val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response);
+ // Send the tool execution result back to the model.
+ return call(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()));
+ }
+ return response;
+ }
+
+ private List toOpenAiRequest(final Prompt prompt) {
+ final List result = new ArrayList<>();
+ for (final Message message : prompt.getInstructions()) {
+ switch (message.getMessageType()) {
+ case USER -> Option.of(message.getText()).peek(t -> result.add(OpenAiMessage.user(t)));
+ case SYSTEM -> Option.of(message.getText()).peek(t -> result.add(OpenAiMessage.system(t)));
+ case ASSISTANT -> addAssistantMessage(result, (AssistantMessage) message);
+ case TOOL -> addToolMessages(result, (ToolResponseMessage) message);
+ }
+ }
+ return result;
+ }
+
+ private static void addAssistantMessage(
+ final List result, final AssistantMessage message) {
+ if (message.getText() == null) {
+ return;
+ }
+ if (!message.hasToolCalls()) {
+ result.add(OpenAiMessage.assistant(message.getText()));
+ return;
+ }
+ final Function callTranslate =
+ toolCall -> OpenAiToolCall.function(toolCall.id(), toolCall.name(), toolCall.arguments());
+ val calls = message.getToolCalls().stream().map(callTranslate).toList();
+ result.add(OpenAiMessage.assistant(message.getText()).withToolCalls(calls));
+ }
+
+ private static void addToolMessages(
+ final List result, final ToolResponseMessage message) {
+ for (final ToolResponseMessage.ToolResponse response : message.getResponses()) {
+ result.add(OpenAiMessage.tool(response.responseData(), response.id()));
+ }
+ }
+
+ @Nonnull
+ static List toGenerations(@Nonnull final OpenAiChatCompletionResponse result) {
+ return result.getOriginalResponse().getChoices().stream()
+ .map(message -> toGeneration(message.getMessage()))
+ .toList();
+ }
+
+ @Nonnull
+ static Generation toGeneration(@Nonnull final ChatCompletionResponseMessage choice) {
+ // no metadata for now
+ val calls = new ArrayList();
+ for (final ChatCompletionMessageToolCall c : choice.getToolCalls()) {
+ val fnc = c.getFunction();
+ calls.add(new ToolCall(c.getId(), c.getType().getValue(), fnc.getName(), fnc.getArguments()));
+ }
+ val message = new AssistantMessage(choice.getContent(), Map.of(), calls);
+ return new Generation(message);
+ }
+}
diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatOptions.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatOptions.java
new file mode 100644
index 000000000..d724c89d8
--- /dev/null
+++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatOptions.java
@@ -0,0 +1,131 @@
+package com.sap.ai.sdk.foundationmodels.openai.spring;
+
+import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionConfig;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool.TypeEnum;
+import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject;
+import io.vavr.control.Option;
+import java.math.BigDecimal;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import lombok.AccessLevel;
+import lombok.Data;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.val;
+import org.springframework.ai.chat.prompt.ChatOptions;
+import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.ai.model.tool.ToolCallingChatOptions;
+import org.springframework.ai.tool.ToolCallback;
+
+/** OpenAI Chat Options for configuring tool callbacks and execution settings. */
+@Data
+@NoArgsConstructor
+public class OpenAiChatOptions implements ToolCallingChatOptions {
+
+ @Nonnull private OpenAiChatCompletionConfig config;
+
+ @Nonnull private List toolCallbacks = List.of();
+
+ @Nonnull private List tools = List.of();
+
+ @Getter(AccessLevel.NONE)
+ @Nullable
+ private Boolean internalToolExecutionEnabled;
+
+ @Nonnull private Set toolNames = Set.of();
+
+ @Nonnull private Map toolContext = Map.of();
+
+ @Override
+ public void setToolCallbacks(@Nonnull final List toolCallbacks) {
+ this.toolCallbacks = toolCallbacks;
+ tools = toolCallbacks.stream().map(OpenAiChatOptions::toOpenAiTool).toList();
+ }
+
+ @Nullable
+ @Override
+ public Boolean getInternalToolExecutionEnabled() {
+ return this.internalToolExecutionEnabled;
+ }
+
+ private static ChatCompletionTool toOpenAiTool(final ToolCallback toolCallback) {
+ val toolDef = toolCallback.getToolDefinition();
+ val functionobject =
+ new FunctionObject()
+ .name(toolDef.name())
+ .description(toolDef.description())
+ .parameters(ModelOptionsUtils.jsonToMap(toolDef.inputSchema()));
+ return new ChatCompletionTool().type(TypeEnum.FUNCTION).function(functionobject);
+ }
+
+ @Override
+ public void setInternalToolExecutionEnabled(
+ @Nullable final Boolean internalToolExecutionEnabled) {
+ this.internalToolExecutionEnabled = internalToolExecutionEnabled;
+ }
+
+ @Override
+ @Nonnull
+ public String getModel() {
+ throw new UnsupportedOperationException(
+ "Model declaration not supported in OpenAI integration.");
+ }
+
+ @Override
+ @Nullable
+ public Double getFrequencyPenalty() {
+ return Option.of(config.getFrequencyPenalty()).map(BigDecimal::doubleValue).getOrNull();
+ }
+
+ @Override
+ @Nullable
+ public Integer getMaxTokens() {
+ return config.getMaxTokens();
+ }
+
+ @Override
+ @Nullable
+ public Double getPresencePenalty() {
+ return Option.of(config.getPresencePenalty()).map(BigDecimal::doubleValue).getOrNull();
+ }
+
+ @Override
+ @Nullable
+ public List getStopSequences() {
+ return config.getStop();
+ }
+
+ @Override
+ @Nullable
+ public Double getTemperature() {
+ return Option.of(config.getTemperature()).map(BigDecimal::doubleValue).getOrNull();
+ }
+
+ @Override
+ @Nullable // this is available here but not in OpenAiChatCompletionConfig so added it there ?
+ public Integer getTopK() {
+ return config.getTopK();
+ }
+
+ @Override
+ @Nullable
+ public Double getTopP() {
+ return Option.of(config.getTopP()).map(BigDecimal::doubleValue).getOrNull();
+ }
+
+ @Override
+ @Nonnull
+ public T copy() {
+ final OpenAiChatOptions copy = new OpenAiChatOptions();
+ copy.setToolCallbacks(this.toolCallbacks);
+ copy.setInternalToolExecutionEnabled(this.internalToolExecutionEnabled);
+ copy.setTools(this.tools);
+ copy.setToolNames(this.toolNames);
+ copy.setToolContext(this.toolContext);
+ return (T) copy;
+ }
+}
diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java
index e1ff3b343..8f5d77bad 100644
--- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java
+++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java
@@ -2,20 +2,35 @@
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel;
+import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiChatModel;
+import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiChatOptions;
import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiSpringEmbeddingModel;
import java.util.List;
+import java.util.Objects;
import javax.annotation.Nonnull;
+import lombok.val;
+import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
+import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
+import org.springframework.ai.chat.memory.MessageWindowChatMemory;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.support.ToolCallbacks;
import org.springframework.stereotype.Service;
/** Service class for Spring AI integration with OpenAI */
@Service
public class SpringAiOpenAiService {
- private final OpenAiClient client = OpenAiClient.forModel(OpenAiModel.TEXT_EMBEDDING_3_SMALL);
+ private final OpenAiSpringEmbeddingModel embeddingClient =
+ new OpenAiSpringEmbeddingModel(OpenAiClient.forModel(OpenAiModel.TEXT_EMBEDDING_3_SMALL));
+ private final ChatModel chatClient =
+ new OpenAiChatModel(OpenAiClient.forModel(OpenAiModel.GPT_4O_MINI));
/**
* Embeds a list of strings using the OpenAI embedding model.
@@ -28,7 +43,69 @@ public EmbeddingResponse embedStrings() {
final var springAiRequest =
new EmbeddingRequest(List.of("The quick brown fox jumps over the lazy dog."), options);
- return new OpenAiSpringEmbeddingModel(client).call(springAiRequest);
+ return embeddingClient.call(springAiRequest);
+ }
+
+ /**
+ * Chat request to OpenAI through the OpenAI service with a simple prompt.
+ *
+ * @return the assistant response object
+ */
+ @Nonnull
+ public ChatResponse completion() {
+ val options = new OpenAiChatOptions();
+ val prompt = new Prompt("What is the capital of France?", options);
+ return chatClient.call(prompt);
+ }
+
+ /**
+ * Asynchronous stream of an OpenAI chat request
+ *
+ * @return a stream of assistant message responses
+ */
+ @Nonnull
+ public ChatResponse streamChatCompletion() {
+ val options = new OpenAiChatOptions();
+ val prompt =
+ new Prompt("Can you give me the first 100 numbers of the Fibonacci sequence?", options);
+ return chatClient.call(prompt);
+ }
+
+ /**
+ * Turn a method into a tool by annotating it with @Tool. Spring AI
+ * Tool Method Declarative Specification
+ *
+ * @param internalToolExecutionEnabled whether the internal tool execution is enabled
+ * @return the assistant response object
+ */
+ @Nonnull
+ public ChatResponse toolCalling(final boolean internalToolExecutionEnabled) {
+ val options = new OpenAiChatOptions();
+ options.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod())));
+ options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
+
+ val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", options);
+ return chatClient.call(prompt);
+ }
+
+ /**
+ * Chat request to OpenAI through the OpenAI service using chat memory.
+ *
+ * @return the assistant response object
+ */
+ @Nonnull
+ public ChatResponse chatMemory() {
+ val repository = new InMemoryChatMemoryRepository();
+ val memory = MessageWindowChatMemory.builder().chatMemoryRepository(repository).build();
+ val advisor = MessageChatMemoryAdvisor.builder(memory).build();
+ val cl = ChatClient.builder(chatClient).defaultAdvisors(advisor).build();
+ val prompt1 = new Prompt("What is the capital of France?", new OpenAiChatOptions());
+ val prompt2 = new Prompt("And what is the typical food there?", new OpenAiChatOptions());
+
+ cl.prompt(prompt1).call().content();
+ return Objects.requireNonNull(
+ cl.prompt(prompt2).call().chatResponse(), "Chat response is null");
}
/**
@@ -39,6 +116,6 @@ public EmbeddingResponse embedStrings() {
@Nonnull
public float[] embedDocument() {
final var document = new Document("The quick brown fox jumps over the lazy dog.");
- return new OpenAiSpringEmbeddingModel(client).embed(document);
+ return embeddingClient.embed(document);
}
}
diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java
index 7d3ea42fd..a1f56d57b 100644
--- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java
+++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java
@@ -4,11 +4,16 @@
import com.sap.ai.sdk.app.services.SpringAiOpenAiService;
import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel;
+import java.util.List;
import org.junit.jupiter.api.Test;
+import org.springframework.ai.chat.messages.AssistantMessage;
+import org.springframework.ai.chat.model.ChatResponse;
class SpringAiOpenAiTest {
private final SpringAiOpenAiService service = new SpringAiOpenAiService();
+ private static final org.slf4j.Logger log =
+ org.slf4j.LoggerFactory.getLogger(SpringAiOrchestrationTest.class);
@Test
void testEmbedStrings() {
@@ -23,4 +28,52 @@ void testEmbedStrings() {
assertThat(response.getMetadata().getModel())
.isEqualTo(OpenAiModel.TEXT_EMBEDDING_3_SMALL.name());
}
+
+ @Test
+ void testCompletion() {
+ ChatResponse response = service.completion();
+ assertThat(response).isNotNull();
+ assertThat(response.getResult().getOutput().getText()).contains("Paris");
+ }
+
+ @Test
+ void testStreamChatCompletion() {
+ ChatResponse response = service.streamChatCompletion();
+ assertThat(response).isNotNull();
+ assertThat(response.getResult().getOutput().getText()).isNotEmpty();
+ }
+
+ @Test
+ void testToolCallingWithExecution() {
+ ChatResponse response = service.toolCalling(true);
+ assertThat(response.getResult().getOutput().getText()).contains("Potsdam", "Toulouse", "°C");
+ }
+
+ @Test
+ void testToolCallingWithoutExecution() {
+ ChatResponse response = service.toolCalling(false);
+ List toolCalls = response.getResult().getOutput().getToolCalls();
+ assertThat(toolCalls).hasSize(2);
+ AssistantMessage.ToolCall toolCall1 = toolCalls.get(0);
+ AssistantMessage.ToolCall toolCall2 = toolCalls.get(1);
+ assertThat(toolCall1.type()).isEqualTo("function");
+ assertThat(toolCall2.type()).isEqualTo("function");
+ assertThat(toolCall1.name()).isEqualTo("getCurrentWeather");
+ assertThat(toolCall2.name()).isEqualTo("getCurrentWeather");
+ assertThat(toolCall1.arguments())
+ .isEqualTo("{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}");
+ assertThat(toolCall2.arguments())
+ .isEqualTo("{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}");
+ }
+
+ @Test
+ void testChatMemory() {
+ ChatResponse response = service.chatMemory();
+ assertThat(response).isNotNull();
+ String text = response.getResult().getOutput().getText();
+ log.info(text);
+ assertThat(text)
+ .containsAnyOf(
+ "French", "onion", "pastries", "cheese", "baguette", "coq au vin", "foie gras");
+ }
}