diff --git a/docs/release_notes.md b/docs/release_notes.md index 27ed1ca29..a71ae2db9 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -43,8 +43,10 @@ ### ✨ New Functionality -- - +- [Orchestration] Added embedding generation support with new `OrchestrationClient#embed()` methods. + - Added `OrchestrationEmbeddingModel` with `TEXT_EMBEDDING_3_SMALL`, `TEXT_EMBEDDING_3_LARGE`, `AMAZON_TITAN_EMBED_TEXT` and `NVIDIA_LLAMA_32_NV_EMBEDQA_1B` embedding models. + - Introduced `OrchestrationEmbeddingRequest` for building requests fluently and `OrchestrationEmbeddingResponse#getEmbeddingVectors()` to retrieve embeddings. + ### 📈 Improvements - diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index 36479a061..b7ba4568a 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -226,15 +226,34 @@ public Stream streamChatCompletionDeltas( } /** - * Generate embeddings for the given request. + * Generate embeddings for a {@code OrchestrationEmbeddingRequest} request. * * @param request the request containing the input text and other parameters. * @return the response containing the embeddings. * @throws OrchestrationClientException if the request fails - * @since 1.9.0 + * @since 1.12.0 */ @Nonnull - EmbeddingsPostResponse embed(@Nonnull final EmbeddingsPostRequest request) + public OrchestrationEmbeddingResponse embed(@Nonnull final OrchestrationEmbeddingRequest request) + throws OrchestrationClientException { + final var response = embed(request.createEmbeddingsPostRequest()); + return new OrchestrationEmbeddingResponse(response); + } + + /** + * Generates embeddings using the low-level API request. + * + *

This method provides direct access to the underlying API for advanced use cases. For most + * scenarios, prefer {@link #embed(OrchestrationEmbeddingRequest)}. + * + * @param request the low-level API request + * @return the low level response object + * @throws OrchestrationClientException if the request fails + * @since 1.12.0 + * @see #embed(OrchestrationEmbeddingRequest) + */ + @Nonnull + public EmbeddingsPostResponse embed(@Nonnull final EmbeddingsPostRequest request) throws OrchestrationClientException { return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class, customHeaders); } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingModel.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingModel.java new file mode 100644 index 000000000..d9d8b8b46 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingModel.java @@ -0,0 +1,73 @@ +package com.sap.ai.sdk.orchestration; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.core.AiModel; +import com.sap.ai.sdk.orchestration.model.EmbeddingsModelDetails; +import com.sap.ai.sdk.orchestration.model.EmbeddingsModelParams; +import com.sap.ai.sdk.orchestration.model.EmbeddingsModelParams.EncodingFormatEnum; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Value; +import lombok.With; +import lombok.experimental.Accessors; + +/** + * Configuration for embedding models in the Orchestration service. + * + * @since 1.12.0 + */ +@Beta +@With +@Value +@Accessors(fluent = true) +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class OrchestrationEmbeddingModel implements AiModel { + /** The name of the embedding model. */ + @Nonnull String name; + + /** The version of the model, defaults to latest if not specified. */ + @Nullable String version; + + /** The number of dimensions for the output embeddings. */ + @Nullable Integer dimensions; + + /** Whether to normalize the embedding vectors. */ + @Nullable Boolean normalize; + + /** Azure OpenAI Text Embedding 3 Small model */ + public static final OrchestrationEmbeddingModel TEXT_EMBEDDING_3_SMALL = + new OrchestrationEmbeddingModel("text-embedding-3-small"); + + /** Azure OpenAI Text Embedding 3 Large model */ + public static final OrchestrationEmbeddingModel TEXT_EMBEDDING_3_LARGE = + new OrchestrationEmbeddingModel("text-embedding-3-large"); + + /** Amazon Titan Embed Text model */ + public static final OrchestrationEmbeddingModel AMAZON_TITAN_EMBED_TEXT = + new OrchestrationEmbeddingModel("amazon--titan-embed-text"); + + /** NVIDIA LLaMA 3.2 7B NV EmbedQA model */ + public static final OrchestrationEmbeddingModel NVIDIA_LLAMA_32_NV_EMBEDQA_1B = + new OrchestrationEmbeddingModel("nvidia--llama-3.2-nv-embedqa-1b"); + + /** + * Creates a new embedding model configuration with the specified name. + * + * @param name the model name + */ + public OrchestrationEmbeddingModel(@Nonnull final String name) { + this(name, null, null, null); + } + + @Nonnull + EmbeddingsModelDetails createEmbeddingsModelDetails() { + final var params = + EmbeddingsModelParams.create() + .dimensions(dimensions) + .normalize(normalize) + .encodingFormat(EncodingFormatEnum.FLOAT); + return EmbeddingsModelDetails.create().name(name).version(version).params(params); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingRequest.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingRequest.java new file mode 100644 index 000000000..d1ebfe91e --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingRequest.java @@ -0,0 +1,156 @@ +package com.sap.ai.sdk.orchestration; + +import static lombok.AccessLevel.NONE; +import static lombok.AccessLevel.PRIVATE; + +import com.google.common.annotations.Beta; +import com.google.common.collect.Lists; +import com.sap.ai.sdk.orchestration.model.EmbeddingsInput; +import com.sap.ai.sdk.orchestration.model.EmbeddingsInputText; +import com.sap.ai.sdk.orchestration.model.EmbeddingsModelConfig; +import com.sap.ai.sdk.orchestration.model.EmbeddingsModuleConfigs; +import com.sap.ai.sdk.orchestration.model.EmbeddingsOrchestrationConfig; +import com.sap.ai.sdk.orchestration.model.EmbeddingsPostRequest; +import com.sap.ai.sdk.orchestration.model.MaskingModuleConfigProviders; +import java.util.List; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Value; +import lombok.With; +import lombok.experimental.Tolerate; + +/** + * Represents a request for generating embeddings through the SAP AI Core Orchestration service. + * + * @since 1.12.0 + */ +@Beta +@Value +@AllArgsConstructor(access = PRIVATE) +public class OrchestrationEmbeddingRequest { + + /** The embedding model to use for generating vector representations. */ + @Nonnull OrchestrationEmbeddingModel model; + + /** The list of text inputs to be converted into embeddings. */ + @Nonnull List inputs; + + /** Optional masking providers for data privacy and security. */ + @With(value = PRIVATE) + @Nullable + List masking; + + /** Optional embedding input type classification to optimize embedding generation. */ + @With(value = PRIVATE) + @Getter(NONE) + @Nullable + EmbeddingsInput.TypeEnum inputType; + + /** + * Create an embedding request using fluent API starting with model selection. + * + *

{@code
+   * OrchestrationEmbeddingRequest.forModel(myModel).forInputs("text to embed");
+   * }
+ * + * @param model the embedding model to use + * @return a step for specifying inputs + */ + @Nonnull + public static InputStep forModel(@Nonnull final OrchestrationEmbeddingModel model) { + return inputs -> new OrchestrationEmbeddingRequest(model, List.copyOf(inputs), null, null); + } + + /** Builder step for specifying text inputs to embed. */ + @FunctionalInterface + public interface InputStep { + + /** + * Specifies text inputs to be embedded. + * + * @param inputs the text strings to embed + * @return a new embedding request instance + */ + @Nonnull + OrchestrationEmbeddingRequest forInputs(@Nonnull final List inputs); + + /** + * Specifies multiple text inputs using variable arguments. + * + * @param firstInput string to embed + * @param inputs optional additional strings to embed + * @return a new embedding request instance + */ + @Nonnull + default OrchestrationEmbeddingRequest forInputs( + @Nonnull final String firstInput, @Nonnull final String... inputs) { + return forInputs(Lists.asList(firstInput, inputs)); + } + } + + /** + * Adds data masking providers to enable detection and masking of sensitive information. + * + * @param maskingProvider the primary masking provider + * @param maskingProviders additional masking providers + * @return a new request instance with the specified masking providers + * @see MaskingProvider + */ + @Tolerate + @Nonnull + public OrchestrationEmbeddingRequest withMasking( + @Nonnull final MaskingProvider maskingProvider, + @Nonnull final MaskingProvider... maskingProviders) { + return withMasking(Lists.asList(maskingProvider, maskingProviders)); + } + + /** + * Configures this request to optimize embeddings for document content. + * + * @return a new request instance configured for document embedding + */ + @Nonnull + public OrchestrationEmbeddingRequest asDocument() { + return withInputType(EmbeddingsInput.TypeEnum.DOCUMENT); + } + + /** + * Configures this request to optimize embeddings for general text content. + * + * @return a new request instance configured for text embedding + */ + @Nonnull + public OrchestrationEmbeddingRequest asText() { + return withInputType(EmbeddingsInput.TypeEnum.TEXT); + } + + /** + * Configures this request to optimize embeddings for query content. + * + * @return a new request instance configured for query embedding + */ + @Nonnull + public OrchestrationEmbeddingRequest asQuery() { + return withInputType(EmbeddingsInput.TypeEnum.QUERY); + } + + @Nonnull + EmbeddingsPostRequest createEmbeddingsPostRequest() { + + final var input = + EmbeddingsInput.create().text(EmbeddingsInputText.create(inputs)).type(inputType); + final var embeddingsModelConfig = + EmbeddingsModelConfig.create().model(model.createEmbeddingsModelDetails()); + final var modules = + EmbeddingsOrchestrationConfig.create() + .modules(EmbeddingsModuleConfigs.create().embeddings(embeddingsModelConfig)); + + if (masking != null) { + final var dpiConfigs = masking.stream().map(MaskingProvider::createConfig).toList(); + modules.getModules().setMasking(MaskingModuleConfigProviders.create().providers(dpiConfigs)); + } + return EmbeddingsPostRequest.create().config(modules).input(input); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingResponse.java new file mode 100644 index 000000000..109825775 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingResponse.java @@ -0,0 +1,48 @@ +package com.sap.ai.sdk.orchestration; + +import static lombok.AccessLevel.PACKAGE; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.orchestration.model.Embedding; +import com.sap.ai.sdk.orchestration.model.EmbeddingsPostResponse; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nonnull; +import lombok.AllArgsConstructor; +import lombok.Value; + +/** + * Response wrapper for orchestration embedding operations. + * + *

Wraps {@link EmbeddingsPostResponse} and provides convenient access to embedding vectors. + * + * @since 1.12.0 + */ +@Beta +@Value +@AllArgsConstructor(access = PACKAGE) +public class OrchestrationEmbeddingResponse { + + /** The original embedding response from the orchestration API. */ + @Nonnull EmbeddingsPostResponse originalResponse; + + /** + * Extracts embedding vectors as float arrays. + * + * @return list of embedding vectors, never {@code null} + */ + @Nonnull + public List getEmbeddingVectors() { + final var embeddings = new ArrayList(); + for (final var container : originalResponse.getFinalResult().getData()) { + final var bigDecimals = (Embedding.InnerBigDecimals) container.getEmbedding(); + final var values = bigDecimals.values(); + final float[] arr = new float[values.size()]; + for (int i = 0; i < values.size(); i++) { + arr[i] = values.get(i).floatValue(); + } + embeddings.add(arr); + } + return embeddings; + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingTest.java new file mode 100644 index 000000000..0f88cd492 --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationEmbeddingTest.java @@ -0,0 +1,157 @@ +package com.sap.ai.sdk.orchestration; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static com.sap.ai.sdk.orchestration.OrchestrationEmbeddingModel.TEXT_EMBEDDING_3_SMALL; +import static com.sap.ai.sdk.orchestration.model.DPIEntities.PERSON; +import static com.sap.ai.sdk.orchestration.model.EmbeddingResult.ObjectEnum.EMBEDDING; +import static com.sap.ai.sdk.orchestration.model.EmbeddingsInput.TypeEnum.DOCUMENT; +import static com.sap.ai.sdk.orchestration.model.EmbeddingsInput.TypeEnum.QUERY; +import static com.sap.ai.sdk.orchestration.model.EmbeddingsInput.TypeEnum.TEXT; +import static com.sap.ai.sdk.orchestration.model.EmbeddingsResponse.ObjectEnum.LIST; +import static org.assertj.core.api.Assertions.assertThat; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.sap.ai.sdk.orchestration.model.EmbeddingsModelParams.EncodingFormatEnum; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import java.io.InputStream; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@WireMockTest +class OrchestrationEmbeddingTest { + private static OrchestrationClient client; + + private final Function fileLoader = + filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename)); + + @BeforeEach + void setup(WireMockRuntimeInfo server) { + final DefaultHttpDestination destination = + DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); + client = new OrchestrationClient(destination); + ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED); + } + + @AfterEach + void reset() { + ApacheHttpClient5Accessor.setHttpClientCache(null); + ApacheHttpClient5Accessor.setHttpClientFactory(null); + } + + @Test + void testEmbeddingModel() { + final var model = + TEXT_EMBEDDING_3_SMALL.withVersion("v1").withDimensions(1536).withNormalize(true); + + final var lowLevelModel1 = model.createEmbeddingsModelDetails(); + assertThat(lowLevelModel1.getName()).isEqualTo("text-embedding-3-small"); + assertThat(lowLevelModel1.getVersion()).isEqualTo("v1"); + assertThat(lowLevelModel1.getParams().getDimensions()).isEqualTo(1536); + assertThat(lowLevelModel1.getParams().isNormalize()).isTrue(); + assertThat(lowLevelModel1.getParams().getEncodingFormat()).isEqualTo(EncodingFormatEnum.FLOAT); + + final var model2 = new OrchestrationEmbeddingModel("custom-model"); + final var lowLevelModel2 = model2.createEmbeddingsModelDetails(); + assertThat(lowLevelModel2.getName()).isEqualTo("custom-model"); + } + + @Test + void testEmbeddingRequestTokenTypes() { + var request = + OrchestrationEmbeddingRequest.forModel(TEXT_EMBEDDING_3_SMALL).forInputs("Hello World"); + + request = request.asText(); + var lowLevelRequest = request.createEmbeddingsPostRequest(); + assertThat(lowLevelRequest.getInput().getType()).isEqualTo(TEXT); + + request = request.asDocument(); + lowLevelRequest = request.createEmbeddingsPostRequest(); + assertThat(lowLevelRequest.getInput().getType()).isEqualTo(DOCUMENT); + + request = request.asQuery(); + lowLevelRequest = request.createEmbeddingsPostRequest(); + assertThat(lowLevelRequest.getInput().getType()).isEqualTo(QUERY); + } + + @SneakyThrows + @Test + void testEmbeddingRequest() { + stubFor( + post(urlPathEqualTo("/v2/embeddings")) + .withHeader("Content-Type", equalTo("application/json; charset=UTF-8")) + .willReturn( + aResponse() + .withBodyFile("embeddingResponse.json") + .withHeader("Content-Type", "application/json"))); + + final var masking = + DpiMasking.anonymization().withEntities(PERSON).withAllowList(List.of("SAP", "Joule")); + final var request = + OrchestrationEmbeddingRequest.forModel(TEXT_EMBEDDING_3_SMALL) + .forInputs(List.of("Hi SAP Orchestration Service", "I am John Doe")) + .withMasking(masking); + + final var response = client.embed(request); + + assertThat(response.getEmbeddingVectors()) + .isNotNull() + .hasSize(2) + .isInstanceOf(List.class) + .contains( + new float[] {-0.09068491f, -0.3462946f, 0.88297224f, -0.29537824f, 0.0704844f}, + new float[] {0.18703203f, -0.10362422f, -0.65176725f, 0.6386932f, -0.34864223f}); + + assertThat(response.getOriginalResponse().getRequestId()) + .isEqualTo("62935941-7c2d-4c16-8a35-0b9dce7c8c9e"); + assertThat(response.getOriginalResponse().getIntermediateResults().getInputMasking().getData()) + .isEqualTo( + Map.of("masked_input", List.of("Hi SAP Orchestration Service", "I am MASKED_PERSON"))); + + final var finalResult = response.getOriginalResponse().getFinalResult(); + assertThat(finalResult.getModel()).isEqualTo("text-embedding-3-small"); + assertThat(finalResult.getUsage().getPromptTokens()).isEqualTo(11); + assertThat(finalResult.getUsage().getTotalTokens()).isEqualTo(11); + assertThat(finalResult.getObject()).isEqualTo(LIST); + assertThat(finalResult.getData()) + .hasSize(2) + .allSatisfy( + embeddingData -> { + assertThat(embeddingData.getObject()).isEqualTo(EMBEDDING); + assertThat(embeddingData.getEmbedding()).isNotNull(); + assertThat(embeddingData.getIndex()).isIn(0, 1); + }); + + final var moduleResults = response.getOriginalResponse().getIntermediateResults(); + assertThat(moduleResults).isNotNull(); + assertThat(moduleResults.getInputMasking()).isNotNull(); + assertThat(moduleResults.getInputMasking().getMessage()) + .isEqualTo("Embedding input is masked successfully."); + assertThat(moduleResults.getInputMasking().getData()).isNotNull(); + assertThat(moduleResults.getInputMasking().getData()) + .isEqualTo( + Map.of("masked_input", List.of("Hi SAP Orchestration Service", "I am MASKED_PERSON"))); + + try (var inputStream = fileLoader.apply("embeddingRequest.json")) { + var requestJson = new String(inputStream.readAllBytes()); + verify( + postRequestedFor(urlEqualTo("/v2/embeddings")).withRequestBody(equalToJson(requestJson))); + } + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index bd3a934f5..092403158 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -13,7 +13,6 @@ import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.serverError; import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; -import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; import static com.github.tomakehurst.wiremock.client.WireMock.verify; import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE; @@ -43,22 +42,9 @@ import com.github.tomakehurst.wiremock.junit5.WireMockTest; import com.github.tomakehurst.wiremock.stubbing.Scenario; import com.sap.ai.sdk.orchestration.model.ChatDelta; -import com.sap.ai.sdk.orchestration.model.DPIConfig; import com.sap.ai.sdk.orchestration.model.DPIEntities; -import com.sap.ai.sdk.orchestration.model.DPIStandardEntity; import com.sap.ai.sdk.orchestration.model.DataRepositoryType; import com.sap.ai.sdk.orchestration.model.DocumentGroundingFilter; -import com.sap.ai.sdk.orchestration.model.Embedding; -import com.sap.ai.sdk.orchestration.model.EmbeddingsInput; -import com.sap.ai.sdk.orchestration.model.EmbeddingsInputText; -import com.sap.ai.sdk.orchestration.model.EmbeddingsModelConfig; -import com.sap.ai.sdk.orchestration.model.EmbeddingsModelDetails; -import com.sap.ai.sdk.orchestration.model.EmbeddingsModelParams; -import com.sap.ai.sdk.orchestration.model.EmbeddingsModuleConfigs; -import com.sap.ai.sdk.orchestration.model.EmbeddingsOrchestrationConfig; -import com.sap.ai.sdk.orchestration.model.EmbeddingsPostRequest; -import com.sap.ai.sdk.orchestration.model.EmbeddingsPostResponse; -import com.sap.ai.sdk.orchestration.model.EmbeddingsResponse; import com.sap.ai.sdk.orchestration.model.ErrorResponse; import com.sap.ai.sdk.orchestration.model.GenericModuleResult; import com.sap.ai.sdk.orchestration.model.GroundingFilterSearchConfiguration; @@ -67,7 +53,6 @@ import com.sap.ai.sdk.orchestration.model.GroundingModuleConfigConfigPlaceholders; import com.sap.ai.sdk.orchestration.model.KeyValueListPair; import com.sap.ai.sdk.orchestration.model.LlamaGuard38b; -import com.sap.ai.sdk.orchestration.model.MaskingModuleConfigProviders; import com.sap.ai.sdk.orchestration.model.ModuleResultsStreaming; import com.sap.ai.sdk.orchestration.model.ResponseFormatText; import com.sap.ai.sdk.orchestration.model.SearchDocumentKeyValueListPair; @@ -80,7 +65,6 @@ import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; import java.io.IOException; import java.io.InputStream; -import java.math.BigDecimal; import java.nio.file.Files; import java.nio.file.Path; import java.util.List; @@ -1336,156 +1320,4 @@ void testGetAllMessages() { assertThat(messageListTools.get(1)).isInstanceOf(AssistantMessage.class); assertThat(messageListTools.get(2)).isInstanceOf(ToolMessage.class); } - - @Test - void testEmbeddingCallWithMasking() { - - stubFor( - post(urlEqualTo("/v2/embeddings")) - .willReturn( - aResponse() - .withStatus(200) - .withBody( - """ - { - "request_id": "2ee98443-e1ee-9503-b800-e38b5b80fe45", - "intermediate_results": { - "input_masking": { - "message": "Embedding input is masked successfully.", - "data": { - "masked_input": "['Hello', 'MASKED_PERSON', '!']" - } - } - }, - "final_result": { - "object": "list", - "data": [ - { - "object": "embedding", - "embedding": [ - 0.43988228, - -0.82985526, - -0.15936942, - 0.041005015, - 0.30127057 - ], - "index": 0 - } - ], - "model": "text-embedding-3-large", - "usage": { - "prompt_tokens": 10, - "total_tokens": 10 - } - } - } - """))); - - val dpiConfig = - DPIConfig.create() - .type(DPIConfig.TypeEnum.SAP_DATA_PRIVACY_INTEGRATION) - .method(DPIConfig.MethodEnum.ANONYMIZATION) - .entities(List.of(DPIStandardEntity.create().type(DPIEntities.PERSON))); - val maskingConfig = MaskingModuleConfigProviders.create().providers(List.of(dpiConfig)); - - val modelParams = - EmbeddingsModelParams.create() - .encodingFormat(EmbeddingsModelParams.EncodingFormatEnum.FLOAT) - .dimensions(5) - .normalize(false); - val modelConfig = - EmbeddingsModelConfig.create() - .model( - EmbeddingsModelDetails.create().name("text-embedding-3-large").params(modelParams)); - - val orchestrationConfig = - EmbeddingsOrchestrationConfig.create() - .modules( - EmbeddingsModuleConfigs.create().embeddings(modelConfig).masking(maskingConfig)); - - val inputText = - EmbeddingsInput.create().text(EmbeddingsInputText.create("['Hello', 'Müller', '!']")); - - val request = EmbeddingsPostRequest.create().config(orchestrationConfig).input(inputText); - - EmbeddingsPostResponse response = client.embed(request); - - assertThat(response).isNotNull(); - assertThat(response.getRequestId()).isEqualTo("2ee98443-e1ee-9503-b800-e38b5b80fe45"); - - val orchestrationResult = response.getFinalResult(); - assertThat(orchestrationResult).isNotNull(); - assertThat(orchestrationResult.getObject()).isEqualTo(EmbeddingsResponse.ObjectEnum.LIST); - assertThat(orchestrationResult.getModel()).isEqualTo("text-embedding-3-large"); - - val data = orchestrationResult.getData(); - assertThat(data).isNotEmpty(); - assertThat(data.get(0).getEmbedding()) - .isEqualTo( - Embedding.create( - List.of( - BigDecimal.valueOf(0.43988228), - BigDecimal.valueOf(-0.82985526), - BigDecimal.valueOf(-0.15936942), - BigDecimal.valueOf(0.041005015), - BigDecimal.valueOf(0.30127057)))); - assertThat(data.get(0).getIndex()).isZero(); - - val usage = orchestrationResult.getUsage(); - assertThat(usage).isNotNull(); - assertThat(usage.getPromptTokens()).isEqualTo(10); - assertThat(usage.getTotalTokens()).isEqualTo(10); - - val moduleResults = response.getIntermediateResults(); - assertThat(moduleResults).isNotNull(); - assertThat(moduleResults.getInputMasking()).isNotNull(); - assertThat(moduleResults.getInputMasking().getMessage()) - .isEqualTo("Embedding input is masked successfully."); - assertThat(moduleResults.getInputMasking().getData()).isNotNull(); - assertThat(moduleResults.getInputMasking().getData()) - .isEqualTo(Map.of("masked_input", "['Hello', 'MASKED_PERSON', '!']")); - - verify( - postRequestedFor(urlEqualTo("/v2/embeddings")) - .withRequestBody( - equalToJson( - """ - { - "config": { - "modules": { - "embeddings": { - "model": { - "name": "text-embedding-3-large", - "version": "latest", - "timeout" : 600, - "max_retries" : 2, - "params": { - "encoding_format": "float", - "dimensions": 5, - "normalize": false - } - } - }, - "masking": { - "providers": [ - { - "type": "sap_data_privacy_integration", - "method": "anonymization", - "entities": [ - { - "type": "profile-person" - } - ], - "allowlist" : [ ] - } - ] - } - } - }, - "input": { - "text": "['Hello', 'Müller', '!']" - } - } - """))); - } } diff --git a/orchestration/src/test/resources/__files/embeddingResponse.json b/orchestration/src/test/resources/__files/embeddingResponse.json new file mode 100644 index 000000000..8b6ca278c --- /dev/null +++ b/orchestration/src/test/resources/__files/embeddingResponse.json @@ -0,0 +1,46 @@ +{ + "request_id": "62935941-7c2d-4c16-8a35-0b9dce7c8c9e", + "intermediate_results": { + "input_masking": { + "message": "Embedding input is masked successfully.", + "data": { + "masked_input": [ + "Hi SAP Orchestration Service", + "I am MASKED_PERSON" + ] + } + } + }, + "final_result": { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ + -0.09068491, + -0.3462946, + 0.88297224, + -0.29537824, + 0.0704844 + ], + "index": 0 + }, + { + "object": "embedding", + "embedding": [ + 0.18703203, + -0.10362422, + -0.65176725, + 0.6386932, + -0.34864223 + ], + "index": 1 + } + ], + "model": "text-embedding-3-small", + "usage": { + "prompt_tokens": 11, + "total_tokens": 11 + } + } +} diff --git a/orchestration/src/test/resources/embeddingRequest.json b/orchestration/src/test/resources/embeddingRequest.json new file mode 100644 index 000000000..9c201b87e --- /dev/null +++ b/orchestration/src/test/resources/embeddingRequest.json @@ -0,0 +1,42 @@ +{ + "config": { + "modules": { + "embeddings": { + "model": { + "name": "text-embedding-3-small", + "params": { + "encoding_format": "float" + }, + "timeout": 600, + "max_retries": 2 + } + }, + "masking": { + "providers": [ + { + "type": "sap_data_privacy_integration", + "method": "anonymization", + "entities": [ + { + "type": "profile-person" + } + ], + "allowlist": [ + "SAP", + "Joule" + ], + "mask_grounding_input": { + "enabled": false + } + } + ] + } + } + }, + "input": { + "text": [ + "Hi SAP Orchestration Service", + "I am John Doe" + ] + } +} \ No newline at end of file diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index 6f1d48b1f..43281d59d 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -340,4 +340,14 @@ Object translation(@RequestParam(value = "format", required = false) final Strin } return response.getContent(); } + + @GetMapping("/embedding") + @Nonnull + Object embedding(@RequestParam(value = "format", required = false) final String format) { + final var response = service.embed(List.of("Hi SAP Orchestration Service", "I am John Doe")); + if ("json".equals(format)) { + return response; + } + return response.getEmbeddingVectors(); + } } diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java index a77b2d169..d74a55bba 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java @@ -3,6 +3,7 @@ import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GEMINI_2_5_FLASH; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_4O_MINI; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.TEMPERATURE; +import static com.sap.ai.sdk.orchestration.OrchestrationEmbeddingModel.TEXT_EMBEDDING_3_SMALL; import com.fasterxml.jackson.annotation.JsonProperty; import com.sap.ai.sdk.core.AiCoreService; @@ -16,6 +17,8 @@ import com.sap.ai.sdk.orchestration.OrchestrationChatResponse; import com.sap.ai.sdk.orchestration.OrchestrationClient; import com.sap.ai.sdk.orchestration.OrchestrationClientException; +import com.sap.ai.sdk.orchestration.OrchestrationEmbeddingRequest; +import com.sap.ai.sdk.orchestration.OrchestrationEmbeddingResponse; import com.sap.ai.sdk.orchestration.OrchestrationModuleConfig; import com.sap.ai.sdk.orchestration.OrchestrationPrompt; import com.sap.ai.sdk.orchestration.ResponseJsonSchema; @@ -599,4 +602,26 @@ public OrchestrationChatResponse translation() { return client.chatCompletion(prompt, configWithTranslation); } + + /** + * Create text embeddings using the Orchestration service. + * + * @link AI + * Core: Orchestration - Embedding + * @param texts the list of texts to embed + * @return the embedding response object + */ + @Nonnull + public OrchestrationEmbeddingResponse embed(@Nonnull final List texts) { + final var masking = + DpiMasking.anonymization() + .withEntities(DPIEntities.PERSON) + .withAllowList(List.of("SAP", "Joule")); + + final var request = + OrchestrationEmbeddingRequest.forModel(TEXT_EMBEDDING_3_SMALL) + .forInputs(texts) + .withMasking(masking); + return client.embed(request); + } } diff --git a/sample-code/spring-app/src/main/resources/static/index.html b/sample-code/spring-app/src/main/resources/static/index.html index 62a5f9052..7474dc579 100644 --- a/sample-code/spring-app/src/main/resources/static/index.html +++ b/sample-code/spring-app/src/main/resources/static/index.html @@ -557,6 +557,26 @@

Orchestration

+
+
+
Embedding
+
+
    +
  • +
    + +
    + Get the embedding of a text after masking using the Orchestration + service. +
    +
    +
  • +
+
diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index 7aa1730e2..1a51b82f3 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -465,6 +465,18 @@ void testTranslation() { assertThat(outputTranslation.getMessage()).isEqualTo("Output Translation successful"); } + @Test + void testEmbedding() { + val result = service.embed(List.of("Hi SAP Orchestration Service", "I am John Doe")); + val embeddingVectors = result.getEmbeddingVectors(); + + assertThat(embeddingVectors) + .isNotNull() + .hasSize(2) + .isInstanceOf(List.class) + .allSatisfy(vector -> assertThat(vector).isInstanceOf(float[].class)); + } + @Test void wrongModelVersion() { val filterConfig =