diff --git a/models/spring-ai-qwen/pom.xml b/models/spring-ai-qwen/pom.xml new file mode 100644 index 00000000000..21f4ba1cfc4 --- /dev/null +++ b/models/spring-ai-qwen/pom.xml @@ -0,0 +1,89 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-qwen + jar + Spring AI Model - Qwen + Qwen models support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + 2.18.4 + + + + + + + org.springframework.ai + spring-ai-client-chat + ${project.parent.version} + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework + spring-context-support + + + + org.slf4j + slf4j-api + + + + com.alibaba + dashscope-sdk-java + ${dashscope.version} + + + org.slf4j + slf4j-simple + + + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + diff --git a/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/QwenChatModel.java b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/QwenChatModel.java new file mode 100644 index 00000000000..541f486b624 --- /dev/null +++ b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/QwenChatModel.java @@ -0,0 +1,248 @@ +package org.springframework.ai.qwen; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.qwen.api.QwenApi; +import org.springframework.ai.qwen.api.QwenModel; +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +import static org.springframework.ai.qwen.api.QwenApiHelper.getOrDefault; + +public class QwenChatModel implements ChatModel { + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + + private final QwenApi qwenApi; + + private final QwenChatOptions defaultOptions; + + private final ObservationRegistry observationRegistry; + + private final ToolCallingManager toolCallingManager; + + private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; + + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + public QwenChatModel(QwenApi openAiApi, QwenChatOptions defaultOptions, ToolCallingManager toolCallingManager, + ObservationRegistry observationRegistry) { + this(openAiApi, defaultOptions, toolCallingManager, observationRegistry, + new DefaultToolExecutionEligibilityPredicate()); + } + + public QwenChatModel(QwenApi qwenApi, QwenChatOptions defaultOptions, ToolCallingManager toolCallingManager, + ObservationRegistry observationRegistry, + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + Assert.notNull(qwenApi, "qwenApi cannot be null"); + Assert.notNull(defaultOptions, "defaultOptions cannot be null"); + Assert.notNull(observationRegistry, "observationRegistry cannot be null"); + Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null"); + this.qwenApi = qwenApi; + this.defaultOptions = defaultOptions; + this.toolCallingManager = getOrDefault(toolCallingManager, DEFAULT_TOOL_CALLING_MANAGER); + this.observationRegistry = observationRegistry; + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public ChatResponse call(Prompt prompt) { + Prompt requestPrompt = buildRequestPrompt(prompt); + return internalCall(requestPrompt, null); + } + + @Override + public Flux stream(Prompt prompt) { + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalStream(requestPrompt, null); + } + + @Override + public ChatOptions getDefaultOptions() { + return QwenChatOptions.fromOptions(this.defaultOptions); + } + + /** + * Use the provided convention for reporting observation data + * @param observationConvention The provided convention + */ + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + + private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(AiProvider.ALIBABA.value()) + .requestOptions(prompt.getOptions()) + .build(); + + ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + ChatResponse chatResponse = qwenApi.call(prompt, previousChatResponse); + observationContext.setResponse(chatResponse); + return chatResponse; + }); + + if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // return tool execution result directly to the client + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // send the tool execution result back to the model + return internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } + } + + return response; + } + + private Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return Flux.deferContextual(contextView -> { + final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(AiProvider.ALIBABA.value()) + .requestOptions(prompt.getOptions()) + .build(); + + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry); + + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + + Flux chatResponse = this.qwenApi.streamCall(prompt, previousChatResponse) + .flatMap(response -> { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + return Flux.defer(() -> { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // return tool execution result directly to the client + return Flux.just(ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } + else { + // send the tool execution result back to the model. + return this.internalStream( + new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } + }).subscribeOn(Schedulers.boundedElastic()); + } + else { + return Flux.just(response); + } + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + + return new MessageAggregator().aggregate(chatResponse, observationContext::setResponse); + }); + } + + private Prompt buildRequestPrompt(Prompt prompt) { + // process runtime options + QwenChatOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + QwenChatOptions.class); + } + else { + runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + QwenChatOptions.class); + } + } + + QwenChatOptions requestOptions = QwenChatOptions.fromOptions(this.defaultOptions).overrideWith(runtimeOptions); + + ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + + return new Prompt(prompt.getInstructions(), requestOptions); + } + + public static final class Builder { + + private QwenApi qwenApi; + + private QwenChatOptions defaultOptions = QwenChatOptions.builder().model(QwenModel.QWEN_MAX.getName()).build(); + + private ToolCallingManager toolCallingManager; + + private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private Builder() { + } + + public Builder qwenApi(QwenApi qwenApi) { + this.qwenApi = qwenApi; + return this; + } + + public Builder defaultOptions(QwenChatOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + public Builder toolExecutionEligibilityPredicate( + ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { + this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public QwenChatModel build() { + return new QwenChatModel(this.qwenApi, this.defaultOptions, this.toolCallingManager, + this.observationRegistry, this.toolExecutionEligibilityPredicate); + } + + } + +} diff --git a/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/QwenChatOptions.java b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/QwenChatOptions.java new file mode 100644 index 00000000000..6af8a04ab65 --- /dev/null +++ b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/QwenChatOptions.java @@ -0,0 +1,776 @@ +package org.springframework.ai.qwen; + +import com.alibaba.dashscope.common.ResponseFormat; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.springframework.ai.qwen.api.QwenApiHelper.copyIfNotNull; +import static org.springframework.ai.qwen.api.QwenApiHelper.getOrDefault; + +/** + * Options for the OpenAI Chat API. + * + * @author Peng Jiang + * @since 1.0.0 + */ +@SuppressWarnings("LombokGetterMayBeUsed") +public class QwenChatOptions implements ToolCallingChatOptions { + + /** + * ID of the model to use. + */ + private String model; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their + * existing frequency in the text so far, decreasing the model's likelihood to repeat + * the same line verbatim. + */ + private Double frequencyPenalty; + + /** + * The maximum number of tokens to generate in the chat completion. The total length + * of input tokens and generated tokens is limited by the model's context length. + */ + private Integer maxTokens; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether + * they appear in the text so far, increasing the model's likelihood to talk about new + * topics. + */ + private Double presencePenalty; + + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates + * is valid JSON. + */ + private ResponseFormat responseFormat; + + /** + * If specified, our system will make a best effort to sample deterministically, such + * that repeated requests with the same seed and parameters should return the same + * result. + */ + private Integer seed; + + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + private List stopSequences; + + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make + * the output more random, while lower values like 0.2 will make it more focused and + * deterministic. We generally recommend altering this or top_p but not both. + */ + private Double temperature; + + /** + * An alternative to sampling with temperature, called nucleus sampling, where the + * model considers the results of the tokens with top_p probability mass. So 0.1 means + * only the tokens comprising the top 10% probability mass are considered. We + * generally recommend altering this or temperature but not both. + */ + private Double topP; + + /** + * The size of the candidate set for sampling during the generation process. For + * example, when the value is 50, only the 50 tokens with the highest scores in a + * single generation will form the candidate set for random sampling. The larger the + * value, the higher the randomness of the generation; the smaller the value,the + * higher the certainty of the generation. When the value is None or when top_k is + * greater than 100, it means that the top_k strategy is not enabled, and only the + * top_p strategy is effective. The value needs to be greater than or equal to 0. + */ + private Integer topK; + + /** + * Collection of {@link ToolCallback}s to be used for tool calling in the chat + * completion requests. + */ + private List toolCallbacks; + + /** + * Collection of tool names to be resolved at runtime and used for tool calling in the + * chat completion requests. + */ + private Set toolNames; + + /** + * Whether to enable the tool execution lifecycle internally in ChatModel. + */ + private Boolean internalToolExecutionEnabled; + + private Map toolContext; + + /** + * Controls which (if any) function is called by the model. none means the model will + * not call a function and instead generates a message. auto means the model can pick + * between generating a message or calling a function. Specifying a particular + * function via {"type: "function", "function": {"name": "my_function"}} forces the + * model to call that function. none is the default when no functions are present. + * auto is the default if functions are present. + */ + private Object toolChoice; + + /** + * Whether the model should use internet search results for reference when generating + * text. + */ + private Boolean enableSearch; + + /** + * The strategy for network search. Only takes effect when enableSearch is true. + */ + private SearchOptions searchOptions; + + /** + * The translation parameters you need to configure when you use the translation + * models. + */ + private TranslationOptions translationOptions; + + /** + * Whether to increase the default token limit for input images. The default token + * limit for input images is 1280. When configured to true, the token limit for input + * images is 16384. Default value is false. + */ + private Boolean vlHighResolutionImages; + + /** + * Whether the model is a multimodal model (whether it supports multimodal input). If + * not specified, it will be judged based on the model name when called, but these + * judgments may not keep up with the latest situation. + */ + private Boolean isMultimodalModel; + + /** + * Whether the model supports incremental output in the streaming output mode. This + * parameter is used to assist QwenChatModel in providing incremental output in stream + * mode. If not specified, it will be judged based on the model name when called, but + * these judgments may not keep up with the latest situation. + */ + private Boolean supportIncrementalOutput; + + /** + * User-defined parameters. They may have special effects on some special models. + */ + private Map custom; + + private QwenChatOptions(Builder builder) { + this.model = builder.model; + this.frequencyPenalty = builder.frequencyPenalty; + this.maxTokens = builder.maxTokens; + this.presencePenalty = builder.presencePenalty; + this.responseFormat = builder.responseFormat; + this.seed = builder.seed; + this.stopSequences = builder.stopSequences; + this.temperature = builder.temperature; + this.topP = builder.topP; + this.topK = builder.topK; + this.toolCallbacks = builder.toolCallbacks; + this.toolNames = builder.toolNames; + this.internalToolExecutionEnabled = builder.internalToolExecutionEnabled; + this.toolContext = builder.toolContext; + this.toolChoice = builder.toolChoice; + this.enableSearch = builder.enableSearch; + this.searchOptions = builder.searchOptions; + this.translationOptions = builder.translationOptions; + this.vlHighResolutionImages = builder.vlHighResolutionImages; + this.isMultimodalModel = builder.isMultimodalModel; + this.supportIncrementalOutput = builder.supportIncrementalOutput; + this.custom = builder.custom; + } + + @Override + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public ResponseFormat getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public List getStopSequences() { + return stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + @Override + public Double getTemperature() { + return temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Double getTopP() { + return topP; + } + + @Override + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + public Set getToolNames() { + return this.toolNames; + } + + @Override + public void setToolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements")); + this.toolNames = toolNames; + } + + @Override + @Nullable + public Boolean getInternalToolExecutionEnabled() { + return internalToolExecutionEnabled; + } + + @Override + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + @Override + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + @Override + public Map getToolContext() { + return toolContext; + } + + @Override + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + + public Object getToolChoice() { + return toolChoice; + } + + public void setToolChoice(Object toolChoice) { + this.toolChoice = toolChoice; + } + + public Boolean isEnableSearch() { + return enableSearch; + } + + public void setEnableSearch(Boolean enableSearch) { + this.enableSearch = enableSearch; + } + + public SearchOptions getSearchOptions() { + return searchOptions; + } + + public void setSearchOptions(SearchOptions searchOptions) { + this.searchOptions = searchOptions; + } + + public TranslationOptions getTranslationOptions() { + return translationOptions; + } + + public void setTranslationOptions(TranslationOptions translationOptions) { + this.translationOptions = translationOptions; + } + + public Boolean getVlHighResolutionImages() { + return vlHighResolutionImages; + } + + public void setVlHighResolutionImages(Boolean vlHighResolutionImages) { + this.vlHighResolutionImages = vlHighResolutionImages; + } + + public Boolean getIsMultimodalModel() { + return isMultimodalModel; + } + + public void setIsMultimodalModel(Boolean isMultimodalModel) { + this.isMultimodalModel = isMultimodalModel; + } + + public Boolean getSupportIncrementalOutput() { + return supportIncrementalOutput; + } + + public void setSupportIncrementalOutput(Boolean supportIncrementalOutput) { + this.supportIncrementalOutput = supportIncrementalOutput; + } + + public Map getCustom() { + return custom; + } + + public void setCustom(Map custom) { + this.custom = custom; + } + + @Override + public QwenChatOptions copy() { + return fromOptions(this); + } + + public static QwenChatOptions fromOptions(QwenChatOptions fromOptions) { + return QwenChatOptions.builder().overrideWith(fromOptions).build(); + } + + public QwenChatOptions overrideWith(QwenChatOptions that) { + return QwenChatOptions.builder().overrideWith(this).overrideWith(that).build(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String model; + + private Double frequencyPenalty; + + private Integer maxTokens; + + private Double presencePenalty; + + private ResponseFormat responseFormat; + + private Integer seed; + + private List stopSequences = new ArrayList<>(); + + private Double temperature; + + private Double topP; + + private Integer topK; + + private List toolCallbacks = new ArrayList<>(); + + private Set toolNames = new HashSet<>(); + + private Boolean internalToolExecutionEnabled; + + private Map toolContext = new HashMap<>(); + + private Object toolChoice; + + private Boolean enableSearch; + + private SearchOptions searchOptions; + + private TranslationOptions translationOptions; + + private Boolean vlHighResolutionImages; + + private Boolean isMultimodalModel; + + private Boolean supportIncrementalOutput; + + private Map custom; + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public Builder responseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public Builder seed(Integer seed) { + this.seed = seed; + return this; + } + + public Builder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder topK(Integer topK) { + this.topK = topK; + return this; + } + + public Builder toolCallbacks(List toolCallbacks) { + this.toolCallbacks = toolCallbacks; + return this; + } + + public Builder toolNames(Set toolNames) { + this.toolNames = toolNames; + return this; + } + + public Builder internalToolExecutionEnabled(Boolean enabled) { + this.internalToolExecutionEnabled = enabled; + return this; + } + + public Builder toolContext(Map toolContext) { + this.toolContext = toolContext; + return this; + } + + public Builder toolChoice(Object toolChoice) { + this.toolChoice = toolChoice; + return this; + } + + public Builder enableSearch(Boolean enableSearch) { + this.enableSearch = enableSearch; + return this; + } + + public Builder searchOptions(SearchOptions searchOptions) { + this.searchOptions = searchOptions; + return this; + } + + public Builder translationOptions(TranslationOptions translationOptions) { + this.translationOptions = translationOptions; + return this; + } + + public Builder vlHighResolutionImages(Boolean vlHighResolutionImages) { + this.vlHighResolutionImages = vlHighResolutionImages; + return this; + } + + public Builder isMultimodalModel(Boolean isMultimodalModel) { + this.isMultimodalModel = isMultimodalModel; + return this; + } + + public Builder supportIncrementalOutput(Boolean supportIncrementalOutput) { + this.supportIncrementalOutput = supportIncrementalOutput; + return this; + } + + public Builder custom(Map custom) { + this.custom = custom; + return this; + } + + public Builder overrideWith(QwenChatOptions fromOptions) { + if (fromOptions == null) { + return this; + } + + this.model(getOrDefault(fromOptions.getModel(), this.model)); + this.frequencyPenalty(getOrDefault(fromOptions.getFrequencyPenalty(), this.frequencyPenalty)); + this.maxTokens(getOrDefault(fromOptions.getMaxTokens(), this.maxTokens)); + this.presencePenalty(getOrDefault(fromOptions.getPresencePenalty(), this.presencePenalty)); + this.responseFormat(getOrDefault(fromOptions.getResponseFormat(), this.responseFormat)); + this.seed(getOrDefault(fromOptions.getSeed(), this.seed)); + this.stopSequences(copyIfNotNull(getOrDefault(fromOptions.getStopSequences(), this.stopSequences))); + this.temperature(getOrDefault(fromOptions.getTemperature(), this.temperature)); + this.topP(getOrDefault(fromOptions.getTopP(), this.topP)); + this.topK(getOrDefault(fromOptions.getTopK(), this.topK)); + this.toolCallbacks(copyIfNotNull(getOrDefault(fromOptions.getToolCallbacks(), this.toolCallbacks))); + this.toolNames(copyIfNotNull(getOrDefault(fromOptions.getToolNames(), this.toolNames))); + this.internalToolExecutionEnabled( + getOrDefault(fromOptions.isInternalToolExecutionEnabled(), this.internalToolExecutionEnabled)); + this.toolContext(getOrDefault(fromOptions.getToolContext(), this.toolContext)); + this.toolChoice(getOrDefault(fromOptions.getToolChoice(), this.toolChoice)); + this.enableSearch(getOrDefault(fromOptions.isEnableSearch(), this.enableSearch)); + this.searchOptions(getOrDefault(fromOptions.getSearchOptions(), this.searchOptions)); + this.translationOptions(getOrDefault(fromOptions.getTranslationOptions(), this.translationOptions)); + this.vlHighResolutionImages( + getOrDefault(fromOptions.getVlHighResolutionImages(), this.vlHighResolutionImages)); + this.isMultimodalModel(getOrDefault(fromOptions.getIsMultimodalModel(), this.isMultimodalModel)); + this.supportIncrementalOutput( + getOrDefault(fromOptions.getSupportIncrementalOutput(), this.supportIncrementalOutput)); + this.custom(copyIfNotNull(getOrDefault(fromOptions.getCustom(), this.custom))); + return this; + } + + public QwenChatOptions build() { + return new QwenChatOptions(this); + } + + } + + /** + * The strategy for network search. + * + * @param enableSource Whether to display the searched information in the returned + * results. Default value is false. + * @param enableCitation Whether to enable the [1] or [ref_1] style superscript + * annotation function. This function takes effect only when enable_source is true. + * Default value is false. + * @param citationFormat Subscript style. Only available when enable_citation is true. + * Supported styles: “[1]” and “[ref_1]”. Default value is “[1]”. + * @param forcedSearch Whether to force search to start. + * @param searchStrategy The amount of Internet information searched. Supported + * values: “standard” and “pro”. Default value is “standard”. + */ + public record SearchOptions(Boolean enableSource, Boolean enableCitation, String citationFormat, + Boolean forcedSearch, String searchStrategy) { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Boolean enableSource; + + private Boolean enableCitation; + + private String citationFormat; + + private Boolean forcedSearch; + + private String searchStrategy; + + public Builder enableSource(Boolean enableSource) { + this.enableSource = enableSource; + return this; + } + + public Builder enableCitation(Boolean enableCitation) { + this.enableCitation = enableCitation; + return this; + } + + public Builder citationFormat(String citationFormat) { + this.citationFormat = citationFormat; + return this; + } + + public Builder forcedSearch(Boolean forcedSearch) { + this.forcedSearch = forcedSearch; + return this; + } + + public Builder searchStrategy(String searchStrategy) { + this.searchStrategy = searchStrategy; + return this; + } + + public SearchOptions build() { + return new SearchOptions(enableSource, enableCitation, citationFormat, forcedSearch, searchStrategy); + } + + } + } + + /** + * The translation parameters you need to configure when you use the translation + * models. + * + * @param sourceLang The full English name of the source language.For more + * information, see Supported + * Languages. You can set source_lang to "auto" and the model will automatically + * determine the language of the input text. + * @param targetLang The full English name of the target language.For more + * information, see Supported + * Languages. + * @param terms An array of terms that needs to be set when using the + * term-intervention-translation feature. + * @param tmList The translation memory array that needs to be set when using the + * translation-memory feature. + * @param domains The domain prompt statement needs to be set when using the + * domain-prompt feature. + */ + public record TranslationOptions(String sourceLang, String targetLang, List terms, + List tmList, String domains) { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String sourceLang; + + private String targetLang; + + private List terms; + + private List tmLists; + + private String domains; + + public Builder sourceLang(String sourceLang) { + this.sourceLang = sourceLang; + return this; + } + + public Builder targetLang(String targetLang) { + this.targetLang = targetLang; + return this; + } + + public Builder terms(List terms) { + this.terms = terms; + return this; + } + + public Builder tmLists(List tmLists) { + this.tmLists = tmLists; + return this; + } + + public Builder domains(String domains) { + this.domains = domains; + return this; + } + + public TranslationOptions build() { + return new TranslationOptions(sourceLang, targetLang, terms, tmLists, domains); + } + + } + } + + /** + * The term. + * + * @param source The term in the source language. + * @param target The term in the target language. + */ + public record TranslationOptionTerm(String source, String target) { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String source; + + private String target; + + public Builder source(String source) { + this.source = source; + return this; + } + + public Builder target(String target) { + this.target = target; + return this; + } + + public TranslationOptionTerm build() { + return new TranslationOptionTerm(source, target); + } + + } + } + +} diff --git a/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/aot/QwenRuntimeHints.java b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/aot/QwenRuntimeHints.java new file mode 100644 index 00000000000..e51aeb39bdf --- /dev/null +++ b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/aot/QwenRuntimeHints.java @@ -0,0 +1,21 @@ +package org.springframework.ai.qwen.aot; + +import org.springframework.ai.aot.AiRuntimeHints; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +public class QwenRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) { + var mcs = MemberCategory.values(); + AiRuntimeHints + .findClassesInPackage(com.alibaba.dashscope.Version.class.getPackageName(), + (metadataReader, metadataReaderFactory) -> true) + .forEach(clazz -> hints.reflection().registerType(clazz, mcs)); + } + +} diff --git a/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenApi.java b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenApi.java new file mode 100644 index 00000000000..553b0c617fe --- /dev/null +++ b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenApi.java @@ -0,0 +1,329 @@ +package org.springframework.ai.qwen.api; + +import com.alibaba.dashscope.aigc.generation.GenerationOutput; +import com.alibaba.dashscope.aigc.generation.GenerationParam; +import com.alibaba.dashscope.aigc.generation.GenerationResult; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversation; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationOutput; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationParam; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult; +import com.alibaba.dashscope.common.MultiModalMessage; +import com.alibaba.dashscope.exception.InputRequiredException; +import com.alibaba.dashscope.exception.NoApiKeyException; +import com.alibaba.dashscope.exception.UploadFileException; +import com.alibaba.dashscope.protocol.Protocol; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.metadata.UsageUtils; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.ai.qwen.QwenChatOptions; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; + +import static org.springframework.ai.qwen.api.QwenApiHelper.toQwenSearchInfo; +import static org.springframework.ai.qwen.api.QwenApiHelper.defaultUsageFrom; +import static org.springframework.ai.qwen.api.QwenApiHelper.generationsFrom; +import static org.springframework.ai.qwen.api.QwenApiHelper.getOrDefault; +import static org.springframework.ai.qwen.api.QwenApiHelper.isMultimodalModelName; +import static org.springframework.ai.qwen.api.QwenApiHelper.isStreamingDone; +import static org.springframework.ai.qwen.api.QwenApiHelper.isStreamingToolCall; +import static org.springframework.ai.qwen.api.QwenApiHelper.isSupportingIncrementalOutputModelName; +import static org.springframework.ai.qwen.api.QwenApiHelper.newGenerationResult; +import static org.springframework.ai.qwen.api.QwenApiHelper.toGenerationParam; +import static org.springframework.ai.qwen.api.QwenApiHelper.toMultiModalConversationParam; +import static org.springframework.ai.qwen.api.QwenApiHelper.toQwenResultCallback; + +public class QwenApi { + + private final String apiKey; + + private final com.alibaba.dashscope.aigc.generation.Generation generation; + + private final com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversation conv; + + /** + * Some models support deeply customized parameters. Here is a way to intervene in the + * request parameters of the qwen models at runtime. + */ + private Consumer> generationParamCustomizer = p -> { + }; + + /** + * Some models support deeply customized parameters. Here is a way to intervene in the + * request parameters of the qwen multimodal-models at runtime. + */ + private Consumer> multimodalConversationParamCustomizer = p -> { + }; + + public QwenApi(String baseUrl, ApiKey apiKey) { + if (!StringUtils.hasText(baseUrl)) { + this.conv = new MultiModalConversation(); + this.generation = new com.alibaba.dashscope.aigc.generation.Generation(); + } + else if (baseUrl.startsWith("wss://")) { + this.conv = new MultiModalConversation(Protocol.WEBSOCKET.getValue(), baseUrl); + this.generation = new com.alibaba.dashscope.aigc.generation.Generation(Protocol.WEBSOCKET.getValue(), + baseUrl); + } + else { + this.conv = new MultiModalConversation(Protocol.HTTP.getValue(), baseUrl); + this.generation = new com.alibaba.dashscope.aigc.generation.Generation(Protocol.HTTP.getValue(), baseUrl); + } + + this.apiKey = apiKey.getValue(); + } + + public static Builder builder() { + return new Builder(); + } + + public ChatResponse call(Prompt prompt, ChatResponse previousChatResponse) { + return isMultimodalModel(prompt) ? callMultimodalModel(prompt, previousChatResponse) + : callNonMultimodalModel(prompt, previousChatResponse); + } + + private ChatResponse callNonMultimodalModel(Prompt prompt, ChatResponse previousChatResponse) { + GenerationParam param = toGenerationParam(apiKey, prompt, false, generationParamCustomizer); + + try { + GenerationResult result = generation.call(param); + List generations = generationsFrom(result); + Usage currentUsage = defaultUsageFrom(result.getUsage()); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponseMetadata.Builder metadataBuilder = ChatResponseMetadata.builder() + .id(result.getRequestId()) + .usage(accumulatedUsage) + .model(prompt.getOptions().getModel()); + if (result.getOutput().getSearchInfo() != null) { + metadataBuilder.keyValue("searchInfo", toQwenSearchInfo(result.getOutput().getSearchInfo())); + } + return new ChatResponse(generations, metadataBuilder.build()); + } + catch (NoApiKeyException | InputRequiredException e) { + throw new IllegalArgumentException(e); + } + } + + private ChatResponse callMultimodalModel(Prompt prompt, ChatResponse previousChatResponse) { + MultiModalConversationParam param = toMultiModalConversationParam(apiKey, prompt, false, + multimodalConversationParamCustomizer); + + try { + MultiModalConversationResult result = conv.call(param); + List generations = generationsFrom(result); + Usage currentUsage = defaultUsageFrom(result.getUsage()); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponseMetadata metadata = ChatResponseMetadata.builder() + .id(result.getRequestId()) + .usage(accumulatedUsage) + .model(prompt.getOptions().getModel()) + .build(); + return new ChatResponse(generations, metadata); + } + catch (NoApiKeyException e) { + throw new IllegalArgumentException(e); + } + catch (UploadFileException e) { + throw new IllegalStateException(e); + } + } + + public Flux streamCall(Prompt prompt, ChatResponse previousChatResponse) { + return isMultimodalModel(prompt) ? streamCallMultimodalModel(prompt, previousChatResponse) + : streamCallNonMultimodalModel(prompt, previousChatResponse); + } + + private Flux streamCallNonMultimodalModel(Prompt prompt, ChatResponse previousChatResponse) { + boolean incrementalOutput = supportIncrementalOutput(prompt); + GenerationParam param = toGenerationParam(apiKey, prompt, incrementalOutput, generationParamCustomizer); + StringBuilder generatedContent = new StringBuilder(); + Sinks.Many sink = Sinks.many().multicast().onBackpressureBuffer(); + AtomicBoolean isInsideTool = new AtomicBoolean(false); + + try { + generation.streamCall(param, toQwenResultCallback(sink)); + + return sink.asFlux().map(result -> { + if (isStreamingToolCall(result)) { + isInsideTool.set(true); + } + if (!incrementalOutput) { + // unified into incremental output mode + Optional.of(result) + .map(GenerationResult::getOutput) + .map(GenerationOutput::getChoices) + .filter(choices -> !choices.isEmpty()) + .map(choices -> choices.get(0)) + .map(GenerationOutput.Choice::getMessage) + .filter(message -> StringUtils.hasText(message.getContent())) + .ifPresent(message -> { + String partialContent = message.getContent().substring(generatedContent.length()); + generatedContent.append(partialContent); + message.setContent(partialContent); + }); + } + return result; + }).windowUntil(result -> { + if (isInsideTool.get() && isStreamingDone(result)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }).concatMapIterable(window -> { + Mono monoChunk = window.reduce(newGenerationResult(), QwenApiHelper::mergeResult); + return List.of(monoChunk); + }).flatMap(mono -> mono).map(result -> { + List generations = generationsFrom(result); + Usage currentUsage = defaultUsageFrom(result.getUsage()); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponseMetadata.Builder metadataBuilder = ChatResponseMetadata.builder() + .id(result.getRequestId()) + .usage(accumulatedUsage) + .model(prompt.getOptions().getModel()); + if (result.getOutput().getSearchInfo() != null) { + metadataBuilder.keyValue("searchInfo", toQwenSearchInfo(result.getOutput().getSearchInfo())); + } + return new ChatResponse(generations, metadataBuilder.build()); + }); + + } + catch (NoApiKeyException | InputRequiredException e) { + throw new IllegalArgumentException(e); + } + } + + private Flux streamCallMultimodalModel(Prompt prompt, ChatResponse previousChatResponse) { + boolean incrementalOutput = supportIncrementalOutput(prompt); + MultiModalConversationParam param = toMultiModalConversationParam(apiKey, prompt, incrementalOutput, + multimodalConversationParamCustomizer); + + StringBuilder generatedContent = new StringBuilder(); + Sinks.Many sink = Sinks.many().multicast().onBackpressureBuffer(); + + try { + // note: multimodal models do not support toolcalls + conv.streamCall(param, toQwenResultCallback(sink)); + + return sink.asFlux().map(result -> { + if (!incrementalOutput) { + // unified into incremental output mode + Optional.of(result) + .map(MultiModalConversationResult::getOutput) + .map(MultiModalConversationOutput::getChoices) + .filter(choices -> !choices.isEmpty()) + .map(choices -> choices.get(0)) + .map(MultiModalConversationOutput.Choice::getMessage) + .map(MultiModalMessage::getContent) + .filter(contents -> !contents.isEmpty()) + .map(contents -> contents.get(0)) + .filter(content -> StringUtils.hasText((String) content.get("text"))) + .ifPresent(content -> { + String textContent = (String) content.get("text"); + String partialContent = textContent.substring(generatedContent.length()); + generatedContent.append(partialContent); + content.put("text", partialContent); + }); + } + return result; + }).map(result -> { + List generations = generationsFrom(result); + Usage currentUsage = defaultUsageFrom(result.getUsage()); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponseMetadata metadata = ChatResponseMetadata.builder() + .id(result.getRequestId()) + .usage(accumulatedUsage) + .model(prompt.getOptions().getModel()) + .build(); + return new ChatResponse(generations, metadata); + }); + + } + catch (NoApiKeyException | InputRequiredException e) { + throw new IllegalArgumentException(e); + } + catch (UploadFileException e) { + throw new IllegalStateException(e); + } + } + + boolean isMultimodalModel(Prompt prompt) { + ChatOptions options = prompt.getOptions(); + if (!(options instanceof QwenChatOptions)) { + throw new IllegalArgumentException("options should be an instance of QwenChatOption"); + } + + String modelName = options.getModel(); + Boolean isMultimodalModel = ((QwenChatOptions) options).getIsMultimodalModel(); + isMultimodalModel = getOrDefault(isMultimodalModel, isMultimodalModelName(modelName)); + + return Boolean.TRUE.equals(isMultimodalModel); + } + + boolean supportIncrementalOutput(Prompt prompt) { + ChatOptions options = prompt.getOptions(); + if (!(options instanceof QwenChatOptions)) { + throw new IllegalArgumentException("options should be an instance of QwenChatOption"); + } + + String modelName = options.getModel(); + Boolean supportIncrementalOutput = ((QwenChatOptions) options).getSupportIncrementalOutput(); + supportIncrementalOutput = getOrDefault(supportIncrementalOutput, + isSupportingIncrementalOutputModelName(modelName)); + + return Boolean.TRUE.equals(supportIncrementalOutput); + } + + public void setGenerationParamCustomizer( + Consumer> generationParamCustomizer) { + this.generationParamCustomizer = generationParamCustomizer; + } + + public void setMultimodalConversationParamCustomizer( + Consumer> multimodalConversationParamCustomizer) { + this.multimodalConversationParamCustomizer = multimodalConversationParamCustomizer; + } + + public static class Builder { + + private String baseUrl; + + private ApiKey apiKey; + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder apiKey(ApiKey apiKey) { + Assert.notNull(apiKey, "apiKey cannot be null"); + this.apiKey = apiKey; + return this; + } + + public Builder apiKey(String simpleApiKey) { + Assert.notNull(simpleApiKey, "simpleApiKey cannot be null"); + this.apiKey = new SimpleApiKey(simpleApiKey); + return this; + } + + public QwenApi build() { + Assert.notNull(this.apiKey, "apiKey must be set"); + return new QwenApi(this.baseUrl, this.apiKey); + } + + } + +} diff --git a/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenApiHelper.java b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenApiHelper.java new file mode 100644 index 00000000000..18378a76c72 --- /dev/null +++ b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenApiHelper.java @@ -0,0 +1,988 @@ +package org.springframework.ai.qwen.api; + +import com.alibaba.dashscope.aigc.generation.GenerationOutput; +import com.alibaba.dashscope.aigc.generation.GenerationParam; +import com.alibaba.dashscope.aigc.generation.GenerationResult; +import com.alibaba.dashscope.aigc.generation.GenerationUsage; +import com.alibaba.dashscope.aigc.generation.SearchInfo; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationOutput; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationParam; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult; +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationUsage; +import com.alibaba.dashscope.common.DashScopeResult; +import com.alibaba.dashscope.common.MessageContentBase; +import com.alibaba.dashscope.common.MessageContentImageURL; +import com.alibaba.dashscope.common.MessageContentText; +import com.alibaba.dashscope.common.MultiModalMessage; +import com.alibaba.dashscope.common.ResultCallback; +import com.alibaba.dashscope.common.Role; +import com.alibaba.dashscope.tools.FunctionDefinition; +import com.alibaba.dashscope.tools.ToolBase; +import com.alibaba.dashscope.tools.ToolCallBase; +import com.alibaba.dashscope.tools.ToolCallFunction; +import com.alibaba.dashscope.tools.ToolFunction; +import com.alibaba.dashscope.tools.codeinterpretertool.ToolCallCodeInterpreter; +import com.alibaba.dashscope.tools.search.ToolCallQuarkSearch; +import com.alibaba.dashscope.utils.JsonUtils; +import com.google.gson.JsonObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.content.Media; +import org.springframework.ai.qwen.QwenChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; +import org.springframework.util.StringUtils; +import reactor.core.publisher.Sinks; + +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.BinaryOperator; +import java.util.function.Consumer; + +import static com.alibaba.dashscope.aigc.conversation.ConversationParam.ResultFormat.MESSAGE; +import static java.util.stream.Collectors.toList; + +public class QwenApiHelper { + + private static final Logger log = LoggerFactory.getLogger(QwenApiHelper.class); + + static boolean isMultimodalModelName(String modelName) { + // rough judgment + return modelName.contains("-vl-") || modelName.contains("-audio-"); + } + + static boolean isSupportingIncrementalOutputModelName(String modelName) { + // rough judgment + return !(modelName.contains("-vl-") || modelName.contains("-audio-") || modelName.contains("-mt-")); + } + + static List toQwenMessages(List messages) { + return sanitizeMessages(messages).stream().map(QwenApiHelper::toQwenMessage).toList(); + } + + static com.alibaba.dashscope.common.Message toQwenMessage(Message message) { + return com.alibaba.dashscope.common.Message.builder() + .role(roleFrom(message)) + .content(contentFrom(message)) + .name(nameFrom(message)) + .toolCallId(toolCallIdFrom(message)) + .toolCalls(toolCallsFrom(message)) + .build(); + } + + private static String roleFrom(Message message) { + if (message.getMessageType() == MessageType.ASSISTANT) { + return Role.ASSISTANT.getValue(); + } + else if (message.getMessageType() == MessageType.SYSTEM) { + return Role.SYSTEM.getValue(); + } + else if (message.getMessageType() == MessageType.TOOL) { + return Role.TOOL.getValue(); + } + else { + return Role.USER.getValue(); + } + } + + private static String nameFrom(Message message) { + if (message.getMessageType() == MessageType.TOOL) { + return ((ToolResponseMessage) message).getResponses().get(0).name(); + } + return null; + } + + private static String contentFrom(Message message) { + if (message.getMessageType() == MessageType.TOOL) { + return ((ToolResponseMessage) message).getResponses().get(0).responseData(); + } + return message.getText(); + } + + private static String toolCallIdFrom(Message message) { + if (message.getMessageType() == MessageType.TOOL) { + return ((ToolResponseMessage) message).getResponses().get(0).id(); + } + return null; + } + + private static List toolCallsFrom(Message message) { + if (message.getMessageType() == MessageType.ASSISTANT && ((AssistantMessage) message).hasToolCalls()) { + return toToolCalls(((AssistantMessage) message).getToolCalls()); + } + return null; + } + + private static List toToolCalls(Collection toolExecutionRequests) { + return toolExecutionRequests.stream().map(QwenApiHelper::toToolCall).toList(); + } + + private static ToolCallBase toToolCall(AssistantMessage.ToolCall toolExecutionRequest) { + ToolCallFunction toolCallFunction = new ToolCallFunction(); + toolCallFunction.setId(toolExecutionRequest.id()); + ToolCallFunction.CallFunction callFunction = toolCallFunction.new CallFunction(); + callFunction.setName(toolExecutionRequest.name()); + callFunction.setArguments(toolExecutionRequest.arguments()); + toolCallFunction.setFunction(callFunction); + return toolCallFunction; + } + + private static List toToolFunctions(Collection toolSpecifications) { + if (CollectionUtils.isEmpty(toolSpecifications)) { + return Collections.emptyList(); + } + + return toolSpecifications.stream().map(QwenApiHelper::toToolFunction).toList(); + } + + private static ToolBase toToolFunction(ToolCallback toolCallback) { + FunctionDefinition functionDefinition = FunctionDefinition.builder() + .name(toolCallback.getToolDefinition().name()) + .description(getOrDefault(toolCallback.getToolDefinition().description(), "")) + .parameters(toParameters(toolCallback)) + .build(); + return ToolFunction.builder().function(functionDefinition).build(); + } + + private static JsonObject toParameters(ToolCallback toolCallback) { + if (StringUtils.hasText(toolCallback.getToolDefinition().inputSchema())) { + return JsonUtils.parse(toolCallback.getToolDefinition().inputSchema()); + } + else { + return JsonUtils.toJsonObject(Collections.emptyMap()); + } + } + + static List toQwenMultiModalMessages(List messages) { + return messages.stream().map(QwenApiHelper::toQwenMultiModalMessage).collect(toList()); + } + + private static MultiModalMessage toQwenMultiModalMessage(Message message) { + return MultiModalMessage.builder().role(roleFrom(message)).content(toMultiModalContents(message)).build(); + } + + private static List> toMultiModalContents(Message message) { + List> contents = new LinkedList<>(); + if (StringUtils.hasText(message.getText())) { + contents.add(toMultiModalContent(message.getText())); + } + + List media = switch (message.getMessageType()) { + case USER -> ((UserMessage) message).getMedia(); + case ASSISTANT -> ((AssistantMessage) message).getMedia(); + default -> Collections.emptyList(); + }; + + media.stream().map(QwenApiHelper::toMultiModalContent).forEach(contents::add); + + if (message instanceof ToolResponseMessage toolMessage) { + List toolResponses = toolMessage.getResponses(); + if (!CollectionUtils.isEmpty(toolResponses)) { + for (ToolResponseMessage.ToolResponse toolResponse : toolResponses) { + contents.add(Map.of("content", toolResponse.responseData(), "tool_call_id", toolResponse.id())); + } + } + } + + return contents; + } + + static Map toMultiModalContent(Media media) { + MimeType mimeType = media.getMimeType(); + return switch (mimeType.getType()) { + case "image" -> Collections.singletonMap("image", fromMediaData(mimeType, media.getData())); + case "audio" -> Collections.singletonMap("audio", fromMediaData(mimeType, media.getData())); + case "video" -> Collections.singletonMap("video", fromMediaData(mimeType, media.getData())); + case "text" -> Collections.singletonMap("text", media.getData()); + default -> Collections.emptyMap(); + }; + } + + static Map toMultiModalContent(String text) { + return Collections.singletonMap("text", text); + } + + private static String fromMediaData(MimeType mimeType, Object mediaContentData) { + if (mediaContentData instanceof byte[] bytes) { + // Assume the bytes are an image. So, convert the bytes to a base64 encoded + // following the prefix pattern. + return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); + } + else if (mediaContentData instanceof String text) { + // Assume the text is a URLs or a base64 encoded image prefixed by the user. + return text; + } + else { + throw new IllegalArgumentException( + "Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); + } + } + + static List sanitizeMessages(List messages) { + LinkedList sanitizedMessages = messages.stream() + .reduce(new LinkedList<>(), messageAccumulator(), messageCombiner()); + + // Ensure the last message is a user/tool_execution_result message + while (!sanitizedMessages.isEmpty() && !isInputMessageType(sanitizedMessages.getLast())) { + Message removedMessage = sanitizedMessages.removeLast(); + log.warn("The last message should be a user/tool_execution_result message, but found: {}", removedMessage); + } + + return sanitizedMessages; + } + + private static BiFunction, Message, LinkedList> messageAccumulator() { + return (acc, message) -> { + MessageType type = message.getMessageType(); + if (acc.isEmpty()) { + // Ensure the first message is a system message or a user message. + if (type == MessageType.SYSTEM || type == MessageType.USER) { + acc.add(message); + } + else { + log.warn("The first message should be a system message or a user message, but found: {}", message); + } + return acc; + } + + if (type == MessageType.SYSTEM) { + if (acc.getFirst().getMessageType() == MessageType.SYSTEM) { + log.warn("Drop existed system message: {}", acc); + acc.removeFirst(); + } + acc.addFirst(message); + return acc; + } + + MessageType lastType = acc.getLast().getMessageType(); + if (lastType == MessageType.SYSTEM && type != MessageType.USER) { + log.warn("The first non-system message must be a user message, but found: {}", message); + return acc; + } + + if (type == MessageType.USER) { + while (acc.getLast().getMessageType() != MessageType.SYSTEM && !isNormalAiType(acc.getLast())) { + Message removedMessage = acc.removeLast(); + log.warn( + "Tool execution result should follow a tool execution request message. Drop duplicated message: {}", + removedMessage); + } + } + else if (type == MessageType.TOOL) { + while (!isToolCallAiType(acc.getLast())) { + Message removedMessage = acc.removeLast(); + log.warn( + "Tool execution result should follow a tool execution request message. Drop duplicated message: {}", + removedMessage); + } + } + else if (type == MessageType.ASSISTANT) { + while (!isInputMessageType(acc.getLast())) { + Message removedMessage = acc.removeLast(); + log.warn( + "AI message should follow a user/tool_execution_result message. Drop duplicated message: {}", + removedMessage); + } + } + + acc.add(message); + return acc; + }; + } + + private static BinaryOperator> messageCombiner() { + return (acc1, acc2) -> { + throw new UnsupportedOperationException("Parallel stream not supported"); + }; + } + + private static boolean isInputMessageType(Message message) { + MessageType type = message.getMessageType(); + return type == MessageType.USER || type == MessageType.TOOL; + } + + private static boolean isNormalAiType(Message message) { + return message.getMessageType() == MessageType.ASSISTANT && !((AssistantMessage) message).hasToolCalls(); + } + + private static boolean isToolCallAiType(Message message) { + return message.getMessageType() == MessageType.ASSISTANT && ((AssistantMessage) message).hasToolCalls(); + } + + static GenerationParam toGenerationParam(String apiKey, Prompt prompt, boolean incrementalOutput, + Consumer> generationParamCustomizer) { + QwenChatOptions options = (QwenChatOptions) prompt.getOptions(); + validateGenerationParameters(options); + + GenerationParam.GenerationParamBuilder builder = GenerationParam.builder() + .apiKey(apiKey) + .model(options.getModel()) + .topP(options.getTopP()) + .topK(options.getTopK()) + .enableSearch(getOrDefault(options.isEnableSearch(), false)) + .searchOptions(toQwenSearchOptions(options.getSearchOptions())) + .seed(options.getSeed()) + .repetitionPenalty(frequencyPenaltyToRepetitionPenalty(options.getFrequencyPenalty())) + .maxTokens(options.getMaxTokens()) + .messages(toQwenMessages(prompt.getInstructions())) + .responseFormat(options.getResponseFormat()) + .resultFormat(MESSAGE) + .incrementalOutput(incrementalOutput); + + if (options.getTemperature() != null) { + builder.temperature(options.getTemperature().floatValue()); + } + + if (options.getStopSequences() != null) { + builder.stopStrings(options.getStopSequences()); + } + + if (!CollectionUtils.isEmpty(options.getToolCallbacks())) { + builder.tools(toToolFunctions(options.getToolCallbacks())); + if (options.getToolChoice() != null) { + Object toolChoiceObject = options.getToolChoice(); + if (toolChoiceObject instanceof ToolCallback toolCallback) { + builder.toolChoice(toToolFunction(toolCallback)); + } + else { + builder.toolChoice(toolChoiceObject); + } + } + } + + if (options.getTranslationOptions() != null) { + // no java field is provided yet + builder.parameter("translation_options", toQwenTranslationOptions(options.getTranslationOptions())); + } + + if (options.getCustom() != null) { + // no java field is provided yet + builder.parameter("custom", options.getCustom()); + } + + if (generationParamCustomizer != null) { + generationParamCustomizer.accept(builder); + } + + return builder.build(); + } + + static void validateGenerationParameters(QwenChatOptions options) { + if (options.getVlHighResolutionImages() != null) { + throw new UnsupportedOperationException( + "'vlHighResolutionImages' parameter is not supported by " + options.getModel()); + } + } + + static MultiModalConversationParam toMultiModalConversationParam(String apiKey, Prompt prompt, + boolean incrementalOutput, + Consumer> multimodalConversationParamCustomizer) { + QwenChatOptions options = (QwenChatOptions) prompt.getOptions(); + validateMultimodalConversationParameters(options); + + MultiModalConversationParam.MultiModalConversationParamBuilder builder = MultiModalConversationParam + .builder() + .apiKey(apiKey) + .model(options.getModel()) + .topP(options.getTopP()) + .topK(options.getTopK()) + .enableSearch(getOrDefault(options.isEnableSearch(), false)) + .seed(options.getSeed()) + .maxTokens(options.getMaxTokens()) + .messages(toQwenMultiModalMessages(prompt.getInstructions())) + .incrementalOutput(incrementalOutput); + + if (options.getTemperature() != null) { + builder.temperature(options.getTemperature().floatValue()); + } + + if (options.getVlHighResolutionImages() != null) { + // no java field is provided yet + builder.parameter("vl_high_resolution_images", options.getVlHighResolutionImages()); + } + + if (options.getCustom() != null) { + // no java field is provided yet + builder.parameter("custom", options.getCustom()); + } + + if (multimodalConversationParamCustomizer != null) { + multimodalConversationParamCustomizer.accept(builder); + } + + return builder.build(); + } + + static void validateMultimodalConversationParameters(QwenChatOptions options) { + if (options.getSearchOptions() != null) { + throw new UnsupportedOperationException( + "'searchOptions' parameter is not supported by " + options.getModel()); + } + + if (options.getFrequencyPenalty() != null) { + throw new UnsupportedOperationException( + "'frequencyPenalty' parameter is not supported by " + options.getModel()); + } + + if (!CollectionUtils.isEmpty(options.getStopSequences())) { + throw new UnsupportedOperationException( + "'stopSequences' parameter is not supported by " + options.getModel()); + } + + if (!CollectionUtils.isEmpty(options.getToolCallbacks()) || !CollectionUtils.isEmpty(options.getToolNames()) + || !CollectionUtils.isEmpty(options.getToolContext()) || options.getToolChoice() != null) { + throw new UnsupportedOperationException("'tools' parameter is not supported by " + options.getModel()); + } + + if (options.getTranslationOptions() != null) { + throw new UnsupportedOperationException( + "'translationOptions' parameter is not supported by " + options.getModel()); + } + + if (options.getResponseFormat() != null) { + throw new UnsupportedOperationException( + "'responseFormat' parameter is not supported by " + options.getModel()); + } + } + + static com.alibaba.dashscope.aigc.generation.SearchOptions toQwenSearchOptions( + QwenChatOptions.SearchOptions searchOptions) { + if (searchOptions == null) { + return null; + } + + return com.alibaba.dashscope.aigc.generation.SearchOptions.builder() + .citationFormat(searchOptions.citationFormat()) + .enableCitation(searchOptions.enableCitation()) + .enableSource(searchOptions.enableSource()) + .forcedSearch(searchOptions.forcedSearch()) + .searchStrategy(searchOptions.searchStrategy()) + .build(); + } + + static Map toQwenTranslationOptions(QwenChatOptions.TranslationOptions translationOptions) { + if (translationOptions == null) { + return null; + } + + // no java class is provided yet + Map translationOptionsMap = new HashMap<>(5); + translationOptionsMap.put("source_lang", translationOptions.sourceLang()); + translationOptionsMap.put("target_lang", translationOptions.targetLang()); + translationOptionsMap.put("terms", toTermList(translationOptions.terms())); + translationOptionsMap.put("tm_list", toTermList(translationOptions.tmList())); + translationOptionsMap.put("domains", translationOptions.domains()); + return translationOptionsMap; + } + + static List> toTermList(List list) { + if (list == null) { + return null; + } + + return list.stream().map(term -> Map.of("source", term.source(), "target", term.target())).toList(); + } + + static boolean isStreamingToolCall(GenerationResult result) { + return Optional.of(result) + .map(GenerationResult::getOutput) + .map(GenerationOutput::getChoices) + .filter(choices -> !choices.isEmpty()) + .map(choices -> choices.get(0)) + .map(GenerationOutput.Choice::getMessage) + .map(com.alibaba.dashscope.common.Message::getToolCalls) + .map(toolCalls -> !toolCalls.isEmpty()) + .orElse(false); + } + + static boolean isStreamingDone(GenerationResult result) { + return getFinishReason(result) != null; + } + + private static String getFinishReason(GenerationResult result) { + return Optional.of(result) + .map(GenerationResult::getOutput) + .map(GenerationOutput::getChoices) + .filter(choices -> !choices.isEmpty()) + .map(choices -> choices.get(0)) + .map(QwenApiHelper::getFinishReason) + .orElse(null); + } + + private static String getFinishReason(GenerationOutput.Choice choice) { + String finishReason = choice.getFinishReason(); + return StringUtils.hasText(finishReason) && !"null".equals(finishReason) ? finishReason : null; + } + + private static String getFinishReason(MultiModalConversationOutput.Choice choice) { + String finishReason = choice.getFinishReason(); + return StringUtils.hasText(finishReason) && !"null".equals(finishReason) ? finishReason : null; + } + + static GenerationResult newGenerationResult() { + DashScopeResult emptyResult = new DashScopeResult(); + emptyResult.setOutput(new JsonObject()); + return GenerationResult.fromDashScopeResult(emptyResult); + } + + static GenerationResult mergeResult(GenerationResult previous, GenerationResult current) { + String requestId = getOrDefault(current.getRequestId(), previous.getRequestId()); + GenerationUsage usage = getOrDefault(current.getUsage(), previous.getUsage()); + GenerationOutput output = mergeOutput(previous.getOutput(), current.getOutput()); + + GenerationResult result = newGenerationResult(); + result.setRequestId(requestId); + result.setUsage(usage); + result.setOutput(output); + + return result; + } + + private static GenerationOutput mergeOutput(GenerationOutput previous, GenerationOutput current) { + GenerationOutput output = new GenerationOutput(); + + String finishReason = getOrDefault(current.getFinishReason(), previous.getFinishReason()); + String text = merge(current.getText(), previous.getText()); + List choices = mergeChoices(output, previous.getChoices(), current.getChoices()); + SearchInfo searchInfo = mergeSearchInfo(previous.getSearchInfo(), current.getSearchInfo()); + + output.setFinishReason(finishReason); + output.setText(text); + output.setChoices(choices); + output.setSearchInfo(searchInfo); + + return output; + } + + private static SearchInfo mergeSearchInfo(SearchInfo previous, SearchInfo current) { + if (previous == null) { + return current; + } + if (current == null) { + return previous; + } + List searchResults = merge(previous.getSearchResults(), current.getSearchResults()); + return SearchInfo.builder().searchResults(searchResults).build(); + } + + private static List mergeChoices(GenerationOutput output, + List previous, List current) { + // in most cases, there is only one + List choices = new ArrayList<>(1); + GenerationOutput.Choice lastPreviousChoice = null; + + if (!CollectionUtils.isEmpty(previous)) { + lastPreviousChoice = previous.get(previous.size() - 1); + if (previous.size() > 1) { + choices.addAll(previous.subList(0, previous.size() - 1)); + } + } + + if (!CollectionUtils.isEmpty(current)) { + var iterator = current.iterator(); + var firstChoice = iterator.next(); + // the first one should be merged with previous last one + choices.add(mergeChoice(output, lastPreviousChoice, firstChoice)); + while (iterator.hasNext()) { + choices.add(iterator.next()); + } + } + else { + if (lastPreviousChoice != null) { + choices.add(lastPreviousChoice); + } + } + + return choices; + } + + private static GenerationOutput.Choice mergeChoice(GenerationOutput output, GenerationOutput.Choice previous, + GenerationOutput.Choice current) { + if (previous == null) { + return current; + } + if (current == null) { + return previous; + } + + Integer index = getOrDefault(current.getIndex(), previous.getIndex()); + String finishReason = getOrDefault(current.getFinishReason(), previous.getFinishReason()); + com.alibaba.dashscope.common.Message message = mergeMessage(previous.getMessage(), current.getMessage()); + + GenerationOutput.Choice choice = output.new Choice(); + choice.setIndex(index); + choice.setFinishReason(finishReason); + choice.setMessage(message); + + return choice; + } + + private static com.alibaba.dashscope.common.Message mergeMessage(com.alibaba.dashscope.common.Message previous, + com.alibaba.dashscope.common.Message current) { + + if (previous == null) { + return current; + } + if (current == null) { + return previous; + } + + String content = merge(previous.getContent(), current.getContent()); + String reasoningContent = merge(previous.getReasoningContent(), current.getReasoningContent()); + String role = getOrDefault(current.getRole(), previous.getRole()); + role = getOrDefault(role, Role.ASSISTANT.getValue()); + String name = getOrDefault(current.getName(), previous.getName()); + List contents = merge(previous.getContents(), current.getContents()); + List toolCalls = mergeToolCalls(previous.getToolCalls(), current.getToolCalls()); + String toolCallId = getOrDefault(current.getToolCallId(), previous.getToolCallId()); + + return com.alibaba.dashscope.common.Message.builder() + .content(content) + .contents(contents) + .toolCalls(toolCalls) + .toolCallId(toolCallId) + .name(name) + .role(role) + .reasoningContent(reasoningContent) + .build(); + } + + private static List mergeToolCalls(List previous, List current) { + // in most cases, there is only one + List toolCalls = new ArrayList<>(1); + ToolCallBase lastPreviousTooCall = null; + + if (!CollectionUtils.isEmpty(previous)) { + lastPreviousTooCall = previous.get(previous.size() - 1); + if (previous.size() > 1) { + toolCalls.addAll(previous.subList(0, previous.size() - 1)); + } + } + + if (!CollectionUtils.isEmpty(current)) { + var iterator = current.iterator(); + var firstToolCall = iterator.next(); + // the first one should be merged with previous last one + if (StringUtils.hasText(firstToolCall.getId())) { + if (lastPreviousTooCall != null) { + toolCalls.add(lastPreviousTooCall); + } + toolCalls.add(firstToolCall); + } + else { + toolCalls.add(mergeToolCall(lastPreviousTooCall, firstToolCall)); + } + while (iterator.hasNext()) { + toolCalls.add(iterator.next()); + } + } + else { + if (lastPreviousTooCall != null) { + toolCalls.add(lastPreviousTooCall); + } + } + + return toolCalls; + } + + private static ToolCallBase mergeToolCall(ToolCallBase previous, ToolCallBase current) { + if (previous == null) { + return current; + } + + String id = (StringUtils.hasText(current.getId()) ? current.getId() : previous.getId()); + String type = getOrDefault(current.getType(), previous.getType()); + + if (previous instanceof ToolCallFunction previousToolCallFunction + && current instanceof ToolCallFunction currentToolCallFunction) { + ToolCallFunction newToolCall = new ToolCallFunction(); + ToolCallFunction.CallFunction callFunction = mergeToolCallFunction(newToolCall, + previousToolCallFunction.getFunction(), currentToolCallFunction.getFunction()); + newToolCall.setFunction(callFunction); + newToolCall.setId(id); + newToolCall.setType(type); + return newToolCall; + } + else if (current instanceof ToolCallCodeInterpreter) { + ToolCallCodeInterpreter newToolCall = new ToolCallCodeInterpreter(); + newToolCall.setId(id); + newToolCall.setType(type); + return newToolCall; + } + else if (previous instanceof ToolCallQuarkSearch previousQuarkToolCall + && current instanceof ToolCallQuarkSearch currentQuarkToolCall) { + Map quarkSearch = merge(previousQuarkToolCall.getQuarkSearch(), + currentQuarkToolCall.getQuarkSearch()); + ToolCallQuarkSearch newToolCall = new ToolCallQuarkSearch(); + newToolCall.setId(id); + newToolCall.setType(type); + newToolCall.setQuarkSearch(quarkSearch); + return newToolCall; + } + else { + return current; + } + } + + private static ToolCallFunction.CallFunction mergeToolCallFunction(ToolCallFunction toolCallFunction, + ToolCallFunction.CallFunction previous, ToolCallFunction.CallFunction current) { + if (previous == null) { + return current; + } + + String name = merge(previous.getName(), current.getName()); + String arguments = merge(previous.getArguments(), current.getArguments()); + String output = merge(previous.getOutput(), current.getOutput()); + + ToolCallFunction.CallFunction callFunction = toolCallFunction.new CallFunction(); + callFunction.setName(name); + callFunction.setArguments(arguments); + callFunction.setOutput(output); + return callFunction; + } + + private static Map merge(Map previous, Map current) { + if (previous == null) { + return current; + } + if (current == null) { + return previous; + } + Map merged = new HashMap<>(previous); + merged.putAll(current); + return merged; + } + + private static List merge(List previous, List current) { + if (previous == null) { + return current; + } + if (current == null) { + return previous; + } + List merged = new ArrayList<>(previous.size() + current.size()); + merged.addAll(previous); + merged.addAll(current); + return merged; + } + + private static String merge(String previous, String current) { + if (previous == null) { + return current; + } + if (current == null) { + return previous; + } + return previous + current; + } + + static Float frequencyPenaltyToRepetitionPenalty(Double frequencyPenalty) { + // repetitionPenalty: + // https://www.alibabacloud.com/help/en/model-studio/use-qwen-by-calling-api#2ed5ee7377fum + // frequencyPenalty: + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-frequency_penalty + // map: [-2, 2] -> (0, ∞), and 0 -> 1 + // use logit function (https://en.wikipedia.org/wiki/Logit) + + if (frequencyPenalty == null) { + return null; + } + else if (frequencyPenalty >= 2) { + return Float.POSITIVE_INFINITY; + } + else if (frequencyPenalty < -2) { + throw new IllegalArgumentException("Value of frequencyPenalty must be within [-2.0, 2.0]"); + } + + // limit the input to 0.5 to 1 (as the repetition penalty is a positive value) + double x = (frequencyPenalty + 6) / 8; + // make sure repetition penalty is 1 when frequency penalty is 0 + double denominator = logit(0.75d); + + return (float) (logit(x) / denominator); + } + + private static double logit(double x) { + return Math.log(x / (1 - x)); + } + + static List generationsFrom(GenerationResult result) { + return Optional.of(result) + .map(GenerationResult::getOutput) + .map(GenerationOutput::getChoices) + .orElse(Collections.emptyList()) + .stream() + .map(choice -> buildGeneration(result.getRequestId(), choice)) + .toList(); + } + + private static Generation buildGeneration(String id, GenerationOutput.Choice choice) { + com.alibaba.dashscope.common.Message message = choice.getMessage(); + List toolCalls = Optional.ofNullable(message.getToolCalls()) + .orElse(Collections.emptyList()) + .stream() + .filter(ToolCallFunction.class::isInstance) + .map(ToolCallFunction.class::cast) + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.getId(), toolCall.getType(), + toolCall.getFunction().getName(), toolCall.getFunction().getArguments())) + .toList(); + + String finishReason = getFinishReason(choice); + List media = new LinkedList<>(); + String text = message.getContent(); + List contents = message.getContents(); + if (!CollectionUtils.isEmpty(contents)) { + for (MessageContentBase content : contents) { + if (content instanceof MessageContentImageURL imageContent) { + media + .add(Media.builder().mimeType(Media.Format.IMAGE_PNG).data(imageContent.getImageURL()).build()); + } + else if (content instanceof MessageContentText textContent) { + media.add(Media.builder().mimeType(Media.Format.DOC_TXT).data(textContent.getText()).build()); + } + } + } + Map metadata = CollectionUtils.newHashMap(6); + putIfNotNull(metadata, "id", id); + putIfNotNull(metadata, "role", message.getRole()); + putIfNotNull(metadata, "name", message.getName()); + putIfNotNull(metadata, "index", choice.getIndex()); + putIfNotNull(metadata, "finishReason", finishReason); + putIfNotNull(metadata, "reasoningContent", message.getReasoningContent()); + + return new Generation(new AssistantMessage(text, metadata, toolCalls, media), + ChatGenerationMetadata.builder().finishReason(finishReason).build()); + } + + static Usage defaultUsageFrom(GenerationUsage qwenUsage) { + return qwenUsage == null ? new EmptyUsage() : new DefaultUsage(qwenUsage.getInputTokens(), + qwenUsage.getOutputTokens(), qwenUsage.getTotalTokens(), qwenUsage); + } + + static List generationsFrom(MultiModalConversationResult result) { + return Optional.of(result) + .map(MultiModalConversationResult::getOutput) + .map(MultiModalConversationOutput::getChoices) + .orElse(Collections.emptyList()) + .stream() + .map(choice -> buildGeneration(result.getRequestId(), choice)) + .toList(); + } + + private static Generation buildGeneration(String id, MultiModalConversationOutput.Choice choice) { + com.alibaba.dashscope.common.MultiModalMessage message = choice.getMessage(); + List toolCalls = Collections.emptyList(); + + String finishReason = getFinishReason(choice); + List media = new LinkedList<>(); + List textContents = new LinkedList<>(); + List> contents = message.getContent(); + if (!CollectionUtils.isEmpty(contents)) { + for (Map content : contents) { + if (content.containsKey("text")) { + textContents.add((String) content.get("text")); + } + + if (content.containsKey("image")) { + media.add(Media.builder().mimeType(Media.Format.IMAGE_PNG).data(content.get("image")).build()); + } + } + } + + String text = String.join("\n", textContents); + + Map metadata = CollectionUtils.newHashMap(3); + putIfNotNull(metadata, "id", id); + putIfNotNull(metadata, "role", message.getRole()); + putIfNotNull(metadata, "finishReason", finishReason); + + return new Generation(new AssistantMessage(text, metadata, toolCalls, media), + ChatGenerationMetadata.builder().finishReason(finishReason).build()); + } + + static Usage defaultUsageFrom(MultiModalConversationUsage qwenUsage) { + return qwenUsage == null ? new EmptyUsage() : new DefaultUsage(qwenUsage.getInputTokens(), + qwenUsage.getOutputTokens(), qwenUsage.getTotalTokens(), qwenUsage); + } + + static QwenSearchInfo toQwenSearchInfo(SearchInfo searchInfo) { + List searchResults = searchInfo == null + || CollectionUtils.isEmpty(searchInfo.getSearchResults()) ? Collections.emptyList() + : searchInfo.getSearchResults().stream().map(QwenApiHelper::toQwenSearchResult).toList(); + + return QwenSearchInfo.builder().searchResults(searchResults).build(); + } + + private static QwenSearchResult toQwenSearchResult(SearchInfo.SearchResult searchResult) { + return QwenSearchResult.builder() + .siteName(searchResult.getSiteName()) + .icon(searchResult.getIcon()) + .index(searchResult.getIndex()) + .title(searchResult.getTitle()) + .url(searchResult.getUrl()) + .build(); + } + + static ResultCallback toQwenResultCallback(Sinks.Many sink) { + return new ResultCallback<>() { + @Override + public void onEvent(T result) { + sink.tryEmitNext(result); + } + + @Override + public void onComplete() { + sink.tryEmitComplete(); + } + + @Override + public void onError(Exception e) { + sink.tryEmitError(e); + } + }; + } + + public static T getOrDefault(T value, T defaultValue) { + return value != null ? value : defaultValue; + } + + public static List copyIfNotNull(List list) { + return list == null ? null : Collections.unmodifiableList(list); + } + + public static Set copyIfNotNull(Set set) { + return set == null ? null : Collections.unmodifiableSet(set); + } + + public static Map copyIfNotNull(Map map) { + return map == null ? null : Collections.unmodifiableMap(map); + } + + public static void putIfNotNull(Map map, K key, V value) { + if (value != null) { + map.put(key, value); + } + } + +} diff --git a/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenModel.java b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenModel.java new file mode 100644 index 00000000000..ee635730ee6 --- /dev/null +++ b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenModel.java @@ -0,0 +1,71 @@ +package org.springframework.ai.qwen.api; + +import org.springframework.ai.model.ChatModelDescription; + +public enum QwenModel implements ChatModelDescription { + + QWEN_TURBO("qwen-turbo", "Qwen base model, stable version."), + QWEN_TURBO_LATEST("qwen-turbo-latest", "Qwen base model, latest version."), + QWEN_PLUS("qwen-plus", "Qwen plus model, stable version."), + QWEN_PLUS_LATEST("qwen-plus-latest", "Qwen plus model, latest version."), + QWEN_MAX("qwen-max", "Qwen max model, stable version."), + QWEN_MAX_LATEST("qwen-max-latest", "Qwen max model, latest version."), + QWEN_LONG("qwen-long", "Qwen long model, 10m context."), + QWEN_7B_CHAT("qwen-7b-chat", "Qwen open sourced 7-billion-parameters model."), + QWEN_14B_CHAT("qwen-14b-chat", "Qwen open sourced 14-billion-parameters model."), + QWEN_72B_CHAT("qwen-72b-chat", "Qwen open sourced 72-billion-parameters model."), + QWEN1_5_7B_CHAT("qwen1.5-7b-chat", "Qwen open sourced 7-billion-parameters model (v1.5)."), + QWEN1_5_14B_CHAT("qwen1.5-14b-chat", "Qwen open sourced 14-billion-parameters model (v1.5)."), + QWEN1_5_32B_CHAT("qwen1.5-32b-chat", "Qwen open sourced 32-billion-parameters model (v1.5)."), + QWEN1_5_72B_CHAT("qwen1.5-72b-chat", "Qwen open sourced 72-billion-parameters model (v1.5)."), + QWEN2_0_5B_INSTRUCT("qwen2-0.5b-instruct", "Qwen open sourced 0.5-billion-parameters model (v2)."), + QWEN2_1_5B_INSTRUCT("qwen2-1.5b-instruct", "Qwen open sourced 1.5-billion-parameters model (v2)."), + QWEN2_7B_INSTRUCT("qwen2-7b-instruct", "Qwen open sourced 7-billion-parameters model (v2)."), + QWEN2_72B_INSTRUCT("qwen2-72b-instruct", "Qwen open sourced 72-billion-parameters model (v2)."), + QWEN2_57B_A14B_INSTRUCT("qwen2-57b-a14b-instruct", + "Qwen open sourced 57-billion-parameters and 14-billion-activation-parameters MOE model (v2)."), + QWEN2_5_0_5B_INSTRUCT("qwen2.5-0.5b-instruct", "Qwen open sourced 0.5-billion-parameters model (v2.5)."), + QWEN2_5_1_5B_INSTRUCT("qwen2.5-1.5b-instruct", "Qwen open sourced 1.5-billion-parameters model (v2.5)."), + QWEN2_5_3B_INSTRUCT("qwen2.5-3b-instruct", "Qwen open sourced 3-billion-parameters model (v2.5)."), + QWEN2_5_7B_INSTRUCT("qwen2.5-7b-instruct", "Qwen open sourced 7-billion-parameters model (v2.5)."), + QWEN2_5_14B_INSTRUCT("qwen2.5-14b-instruct", "Qwen open sourced 14-billion-parameters model (v2.5)."), + QWEN2_5_32B_INSTRUCT("qwen2.5-32b-instruct", "Qwen open sourced 32-billion-parameters model (v2.5)."), + QWEN2_5_72B_INSTRUCT("qwen2.5-72b-instruct", "Qwen open sourced 72-billion-parameters model (v2.5)."), + QWEN_VL_PLUS("qwen-vl-plus", "Qwen multi-modal model, supports image and text information, stable version."), + QWEN_VL_PLUS_LATEST("qwen-vl-plus-latest", + "Qwen multi-modal model, supports image and text information, latest version."), + QWEN_VL_MAX("qwen-vl-max", + "Qwen multi-modal model, supports image and text information, offers optimal performance, stable version."), + QWEN_VL_MAX_LATEST("qwen-vl-max-latest", + "Qwen multi-modal model, supports image and text information, offers optimal performance, latest version."), + QWEN_AUDIO_TURBO("qwen-audio-turbo", "Qwen audio understanding model, stable version."), + QWEN_AUDIO_TURBO_LATEST("qwen-audio-turbo-latest", "Qwen audio understanding model, latest version."), + QWEN_MT_TURBO("qwen-mt-turbo", "Qwen turbo model for translation."), + QWEN_MT_PLUS("qwen-mt-plus", "Qwen plus model for translation."), + QWQ_PLUS("qwq-plus", "Qwen reasoning model, stable version."), + QWQ_PLUS_LATEST("qwq-plus-latest", "Qwen reasoning model, latest version."); + + private final String name; + + private final String description; + + QwenModel(String name) { + this(name, ""); + } + + QwenModel(String name, String description) { + this.name = name; + this.description = description; + } + + @Override + public String getName() { + return this.name; + } + + @Override + public String getDescription() { + return this.description; + } + +} diff --git a/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenSearchInfo.java b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenSearchInfo.java new file mode 100644 index 00000000000..dd5ab0cefac --- /dev/null +++ b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenSearchInfo.java @@ -0,0 +1,31 @@ +package org.springframework.ai.qwen.api; + +import java.util.List; + +/** + * The information searched on the Internet will be returned after the search_options + * parameter is set. + * + * @param searchResults a list of results from online searches + */ +public record QwenSearchInfo(List searchResults) { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private List searchResults; + + public Builder searchResults(List searchResults) { + this.searchResults = searchResults; + return this; + } + + public QwenSearchInfo build() { + return new QwenSearchInfo(searchResults); + } + + } +} diff --git a/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenSearchResult.java b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenSearchResult.java new file mode 100644 index 00000000000..20b9ddf1467 --- /dev/null +++ b/models/spring-ai-qwen/src/main/java/org/springframework/ai/qwen/api/QwenSearchResult.java @@ -0,0 +1,63 @@ +package org.springframework.ai.qwen.api; + +/** + * Results from online searches. + * + * @see QwenSearchInfo + * @param siteName the name of the website from which the search results came + * @param icon the URL of the icon from the source website, or an empty string if there is + * no icon + * @param index the sequence number of the search result, indicating the index of the + * search result in search_results + * @param title the title of the search result + * @param url the URL of the search result + */ +public record QwenSearchResult(String siteName, String icon, Integer index, String title, String url) { + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String siteName; + + private String icon; + + private Integer index; + + private String title; + + private String url; + + public Builder siteName(String siteName) { + this.siteName = siteName; + return this; + } + + public Builder icon(String icon) { + this.icon = icon; + return this; + } + + public Builder index(Integer index) { + this.index = index; + return this; + } + + public Builder title(String title) { + this.title = title; + return this; + } + + public Builder url(String url) { + this.url = url; + return this; + } + + public QwenSearchResult build() { + return new QwenSearchResult(siteName, icon, index, title, url); + } + + } +} diff --git a/models/spring-ai-qwen/src/main/resources/META-INF/spring/aot.factories b/models/spring-ai-qwen/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..b49c0accf16 --- /dev/null +++ b/models/spring-ai-qwen/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ + org.springframework.ai.qwen.aot.QwenRuntimeHints \ No newline at end of file diff --git a/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/MockWeatherService.java b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/MockWeatherService.java new file mode 100644 index 00000000000..fc682d1192c --- /dev/null +++ b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/MockWeatherService.java @@ -0,0 +1,76 @@ +package org.springframework.ai.qwen; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +import java.util.function.Function; + +public class MockWeatherService implements Function { + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, 15, temperature / 2, temperature * 2, 53, 45, request.unit); + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty("lat") @JsonPropertyDescription("The city latitude") double lat, + @JsonProperty("lon") @JsonPropertyDescription("The city longitude") double lon, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + + } + + /** + * Weather Function response. + */ + public record Response(double temperature, double feels_like, double temp_min, double temp_max, int pressure, + int humidity, Unit unit) { + + } + +} diff --git a/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/QwenChatModelIT.java b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/QwenChatModelIT.java new file mode 100644 index 00000000000..24d3d17a439 --- /dev/null +++ b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/QwenChatModelIT.java @@ -0,0 +1,295 @@ +package org.springframework.ai.qwen; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.qwen.api.QwenApi; +import org.springframework.ai.qwen.api.QwenModel; +import org.springframework.ai.qwen.api.QwenSearchInfo; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.util.MimeTypeUtils; + +import java.io.IOException; +import java.net.URL; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = QwenChatModelIT.TestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "DASHSCOPE_API_KEY", matches = ".+") +class QwenChatModelIT { + + private static final Logger logger = LoggerFactory.getLogger(QwenChatModelIT.class); + + @Autowired + private QwenChatModel chatModel; + + @Test + void roleTest() { + Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + UserMessage userMessage = new UserMessage("Generate the names of 5 famous pirates."); + + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = this.chatModel.call(prompt); + assertThat(response.getResult().getOutput().getText()).contains("Blackbeard"); + } + + @Test + void messageHistoryTest() { + + Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + + ChatResponse response = this.chatModel.call(prompt); + assertThat(response.getResult().getOutput().getText()).containsAnyOf("Blackbeard"); + + var promptWithMessageHistory = new Prompt(List.of(new UserMessage("Dummy"), response.getResult().getOutput(), + new UserMessage("Repeat the last assistant message."))); + response = this.chatModel.call(promptWithMessageHistory); + + System.out.println(response.getResult().getOutput().getText()); + assertThat(response.getResult().getOutput().getText()).containsAnyOf("Blackbeard"); + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getText()); + assertThat(list).hasSize(5); + + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getText()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + @Test + void beanOutputConverter() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography for a random actor. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText()); + assertThat(actorsFilms.actor()).isNotNull(); + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText()); + logger.info(actorsFilms.toString()); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter converter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = converter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = this.chatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = converter.convert(generationTextFromStream); + logger.info(actorsFilms.toString()); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void multiModalityImageUrl() throws IOException { + URL url = new URL("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"); + + String response = ChatClient.create(this.chatModel) + .prompt() + .options(QwenChatOptions.builder().model(QwenModel.QWEN_VL_MAX.getName()).build()) + .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) + .call() + .content(); + + logger.info(response); + assertThat(response).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); + } + + @Test + void multiModalityImageResource() { + Resource resource = new ClassPathResource("multimodal.test.png"); + + String response = ChatClient.create(this.chatModel) + .prompt() + .options(QwenChatOptions.builder().model(QwenModel.QWEN_VL_MAX.getName()).build()) + .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, resource)) + .call() + .content(); + + assertThat(response).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); + } + + @Test + void answerAfterSearch() { + QwenChatOptions options = QwenChatOptions.builder() + .enableSearch(true) + .searchOptions(QwenChatOptions.SearchOptions.builder() + .citationFormat("[]") + .enableCitation(true) + .enableSource(true) + .forcedSearch(true) + .searchStrategy("standard") + .build()) + .build(); + + Prompt prompt = new Prompt("What is the weather of Beijing?", options); + + ChatResponse response = chatModel.call(prompt); + System.out.println(response.getResult().getOutput().getText()); + ChatResponseMetadata metadata = response.getMetadata(); + QwenSearchInfo searchInfo = metadata.get("searchInfo"); + assertThat(searchInfo).isNotNull(); + assertThat(searchInfo.searchResults()).isNotEmpty(); + } + + @Test + void translateMessage() { + QwenChatOptions options = QwenChatOptions.builder() + .model(QwenModel.QWEN_MT_PLUS.getName()) + .translationOptions(QwenChatOptions.TranslationOptions.builder() + .sourceLang("English") + .targetLang("Chinese") + .terms(singletonList( + QwenChatOptions.TranslationOptionTerm.builder().source("memory").target("内存").build())) + .domains("Translate into this IT domain style.") + .build()) + .build(); + + Prompt prompt = new Prompt("my memory", options); + + ChatResponse response = chatModel.call(prompt); + String chineseContent = response.getResult().getOutput().getText().trim(); + System.out.println(chineseContent); + assertThat(chineseContent).isEqualTo("我的内存"); + } + + record ActorsFilms(String actor, List movies) { + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public QwenApi qwenApi() { + return QwenApi.builder().apiKey(System.getenv("DASHSCOPE_API_KEY")).build(); + } + + @Bean + public QwenChatModel qwenChatModel(QwenApi qwenApi) { + return QwenChatModel.builder().qwenApi(qwenApi).build(); + } + + } + +} diff --git a/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/QwenChatModelObservationIT.java b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/QwenChatModelObservationIT.java new file mode 100644 index 00000000000..18db586a189 --- /dev/null +++ b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/QwenChatModelObservationIT.java @@ -0,0 +1,177 @@ +package org.springframework.ai.qwen; + +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.platform.commons.logging.Logger; +import org.junit.platform.commons.logging.LoggerFactory; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.qwen.api.QwenApi; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = QwenChatModelObservationIT.TestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "DASHSCOPE_API_KEY", matches = ".+") +public class QwenChatModelObservationIT { + + private static final Logger logger = LoggerFactory.getLogger(QwenChatModelObservationIT.class); + + @Autowired + private TestObservationRegistry observationRegistry; + + @Autowired + private QwenChatModel chatModel; + + @BeforeEach + void beforeEach() { + this.observationRegistry.clear(); + } + + @Test + void observationForImperativeChatOperation() { + + var options = QwenChatOptions.builder() + .frequencyPenalty(0.0) + .maxTokens(2048) + .presencePenalty(0.0) + .stopSequences(List.of("this-is-the-end")) + .temperature(0.7) + .topP(1.0) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getText()).isNotEmpty(); + + ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata, true); + } + + @Test + void observationForStreamingChatOperation() { + + var options = QwenChatOptions.builder() + .frequencyPenalty(0.0) + .maxTokens(2048) + .presencePenalty(0.0) + .stopSequences(List.of("this-is-the-end")) + .temperature(0.7) + .topP(1.0) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + Flux chatResponseFlux = this.chatModel.stream(prompt); + List responses = chatResponseFlux.collectList().block(); + assertThat(responses).isNotEmpty(); + assertThat(responses).hasSizeGreaterThan(10); + + String aggregatedResponse = responses.subList(0, responses.size() - 1) + .stream() + .map(r -> r.getResult().getOutput().getText()) + .collect(Collectors.joining()); + assertThat(aggregatedResponse).isNotEmpty(); + + ChatResponse lastChatResponse = responses.get(responses.size() - 1); + + ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata, false); + } + + private void validate(ChatResponseMetadata responseMetadata, boolean checkModel) { + + TestObservationRegistryAssert.That that = TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME); + + if (checkModel) { + that.that() + .hasLowCardinalityKeyValue( + ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL.asString(), + responseMetadata.getModel()); + } + + that.that() + .hasLowCardinalityKeyValue( + ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.CHAT.value()) + .hasLowCardinalityKeyValue(ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER.asString(), + AiProvider.ALIBABA.value()) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), + "0.0") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), + "0.0") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), + "[\"this-is-the-end\"]") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7") + .doesNotHaveHighCardinalityKeyValueWithKey( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_K.asString()) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_ID.asString(), + responseMetadata.getId()) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(), + "[\"stop\"]") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getCompletionTokens())) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public QwenApi qwenApi() { + return QwenApi.builder().apiKey(System.getenv("DASHSCOPE_API_KEY")).build(); + } + + @Bean + public QwenChatModel qwenChatModel(QwenApi qwenApi, TestObservationRegistry observationRegistry) { + return QwenChatModel.builder().qwenApi(qwenApi).observationRegistry(observationRegistry).build(); + } + + } + +} diff --git a/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/QwenChatModelToolCallIT.java b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/QwenChatModelToolCallIT.java new file mode 100644 index 00000000000..17ab243c698 --- /dev/null +++ b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/QwenChatModelToolCallIT.java @@ -0,0 +1,192 @@ +package org.springframework.ai.qwen; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.qwen.api.QwenApi; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = QwenChatModelToolCallIT.TestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "DASHSCOPE_API_KEY", matches = ".+") +public class QwenChatModelToolCallIT { + + private static final Logger logger = LoggerFactory.getLogger(QwenChatModelIT.class); + + @Autowired + private QwenChatModel chatModel; + + private final MockWeatherService weatherService = new MockWeatherService(); + + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, in Tokyo, and in Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = QwenChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the current weather in a given location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getOutput()).isNotNull(); + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isGreaterThan(600); + } + + @Test + void functionCallSequentialTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = QwenChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the current weather in a given location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + } + + @Test + void streamFunctionCallTest() { + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = QwenChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the current weather in a given location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); + + final var counter = new AtomicInteger(); + String content = response.doOnEach(listSignal -> counter.getAndIncrement()) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(counter.get()).isGreaterThan(30).as("The response should be chunked in more than 30 messages"); + + assertThat(content).contains("30", "10", "15"); + + } + + @Test + void streamFunctionCallUsageTest() { + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = QwenChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the current weather in a given location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); + + ChatResponse chatResponse = response.last().block(); + logger.info("Response: {}", chatResponse); + + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(600); + + } + + @Test + void functionCallSequentialAndStreamTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = QwenChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the current weather in a given location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + var response = this.chatModel.stream(new Prompt(messages, promptOptions)); + + final var counter = new AtomicInteger(); + String content = response.doOnEach(listSignal -> counter.getAndIncrement()) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + + logger.info("Response: {}", response); + + assertThat(content).contains("30", "10", "15"); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public QwenApi qwenApi() { + return QwenApi.builder().apiKey(System.getenv("DASHSCOPE_API_KEY")).build(); + } + + @Bean + public QwenChatModel qwenChatModel(QwenApi qwenApi) { + return QwenChatModel.builder().qwenApi(qwenApi).build(); + } + + } + +} diff --git a/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/aot/QwenRuntimeHintsTests.java b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/aot/QwenRuntimeHintsTests.java new file mode 100644 index 00000000000..31c500d51bb --- /dev/null +++ b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/aot/QwenRuntimeHintsTests.java @@ -0,0 +1,29 @@ +package org.springframework.ai.qwen.aot; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.aot.AiRuntimeHints; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.TypeReference; + +import java.util.Set; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; + +public class QwenRuntimeHintsTests { + + @Test + void registerHints() { + RuntimeHints runtimeHints = new RuntimeHints(); + QwenRuntimeHints qwenRuntimeHints = new QwenRuntimeHints(); + qwenRuntimeHints.registerHints(runtimeHints, null); + + Set qwenModelTypes = AiRuntimeHints.findClassesInPackage( + com.alibaba.dashscope.Version.class.getPackageName(), (metadataReader, metadataReaderFactory) -> true); + assertThat(qwenModelTypes.size()).isGreaterThan(100); + for (TypeReference modelType : qwenModelTypes) { + assertThat(runtimeHints).matches(reflection().onType(modelType)); + } + } + +} diff --git a/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/api/MockImageContentFilter.java b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/api/MockImageContentFilter.java new file mode 100644 index 00000000000..ed16c9ccfed --- /dev/null +++ b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/api/MockImageContentFilter.java @@ -0,0 +1,50 @@ +package org.springframework.ai.qwen.api; + +import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationParam; +import com.alibaba.dashscope.common.MultiModalMessage; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +public class MockImageContentFilter { + + static void handle(MultiModalConversationParam.MultiModalConversationParamBuilder builder) { + List> filteredContents = new LinkedList<>(); + List filteredMessages = new LinkedList<>(); + boolean customized = false; + + for (Object message : builder.build().getMessages()) { + MultiModalMessage multiModalMessage = (MultiModalMessage) message; + for (Map content : multiModalMessage.getContent()) { + Map filteredContent = CollectionUtils.newHashMap(1); + for (String key : content.keySet()) { + Object value = content.get(key); + if ("image".equals(key)) { + String imageUrl = (String) content.get("image"); + if (StringUtils.hasText(imageUrl)) { + // Maybe an invalid image. Replace with a default one. + value = "https://avatars.githubusercontent.com/u/317776"; + customized = true; + } + } + filteredContent.put(key, value); + } + filteredContents.add(filteredContent); + } + MultiModalMessage filteredMessage = MultiModalMessage.builder() + .role(multiModalMessage.getRole()) + .content(filteredContents) + .build(); + filteredMessages.add(filteredMessage); + } + + if (customized) { + builder.clearMessages(); + builder.messages(filteredMessages); + } + } + +} diff --git a/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/api/QwenApiIT.java b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/api/QwenApiIT.java new file mode 100644 index 00000000000..c52cda508ba --- /dev/null +++ b/models/spring-ai-qwen/src/test/java/org/springframework/ai/qwen/api/QwenApiIT.java @@ -0,0 +1,152 @@ +package org.springframework.ai.qwen.api; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.content.Media; +import org.springframework.ai.qwen.QwenChatOptions; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.util.MimeTypeUtils; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "DASHSCOPE_API_KEY", matches = ".+") +public class QwenApiIT { + + private static final Logger logger = LoggerFactory.getLogger(QwenApiIT.class); + + private QwenApi qwenApi() { + return QwenApi.builder().apiKey(System.getenv("DASHSCOPE_API_KEY")).build(); + } + + private List history() { + SystemMessage systemMessage = new SystemMessage(""" + Your name is Jack. + You like to answer other people's questions briefly. + It's rainy today. + """); + + UserMessage query1 = new UserMessage("Hello. What's your name?"); + AssistantMessage answer = new AssistantMessage("Jack!"); + UserMessage query2 = new UserMessage("How about the weather today?"); + + return List.of(systemMessage, query1, answer, query2); + } + + @Test + public void callNonMultimodalModel() { + QwenChatOptions options = QwenChatOptions.builder().model(QwenModel.QWEN_MAX.getName()).build(); + Prompt prompt = new Prompt(history(), options); + + QwenApi api = qwenApi(); + + ChatResponse response = api.call(prompt, null); + logger.info(response.getResult().getOutput().getText()); + assertThat(response.getResult().getOutput().getText()).containsIgnoringCase("rain"); + } + + @Test + public void streamingCallNonMultimodalModel() { + QwenChatOptions options = QwenChatOptions.builder().model(QwenModel.QWEN_MAX.getName()).build(); + Prompt prompt = new Prompt(history(), options); + + QwenApi api = qwenApi(); + + String generationTextFromStream = api.streamCall(prompt, null) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + + logger.info(generationTextFromStream); + assertThat(generationTextFromStream).containsIgnoringCase("rain"); + } + + @Test + public void callMultimodalModel() { + QwenChatOptions options = QwenChatOptions.builder().model(QwenModel.QWEN_VL_MAX.getName()).build(); + Resource resource = new ClassPathResource("multimodal.test.png"); + UserMessage message = new UserMessage("Explain what do you see on this picture?", + Media.builder().mimeType(MimeTypeUtils.IMAGE_PNG).data(resource).build()); + Prompt prompt = new Prompt(message, options); + + QwenApi api = qwenApi(); + + ChatResponse response = api.call(prompt, null); + logger.info(response.getResult().getOutput().getText()); + assertThat(response.getResult().getOutput().getText()).containsAnyOf("bananas", "apple", "bowl", "basket", + "fruit stand"); + } + + @Test + public void callNonMultimodalModelWithCustomizedParameter() { + QwenChatOptions options = QwenChatOptions.builder().model(QwenModel.QWEN_MAX.getName()).build(); + Prompt prompt = new Prompt(history(), options); + + QwenApi api = qwenApi(); + api.setGenerationParamCustomizer(builder -> builder.stopString("rain")); + + ChatResponse response = api.call(prompt, null); + logger.info(response.getResult().getOutput().getText()); + assertThat(response.getResult().getOutput().getText()).doesNotContainIgnoringCase("rain"); + } + + @Test + public void streamingCallNonMultimodalModelWithCustomizedParameter() { + QwenChatOptions options = QwenChatOptions.builder().model(QwenModel.QWEN_MAX.getName()).build(); + Prompt prompt = new Prompt(history(), options); + + QwenApi api = qwenApi(); + api.setGenerationParamCustomizer(builder -> builder.stopString("rain")); + + String generationTextFromStream = api.streamCall(prompt, null) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + + logger.info(generationTextFromStream); + assertThat(generationTextFromStream).doesNotContainIgnoringCase("rain"); + } + + @Test + public void callMultimodalModelWithCustomizedParameter() { + QwenChatOptions options = QwenChatOptions.builder().model(QwenModel.QWEN_VL_MAX.getName()).build(); + Resource resource = new ClassPathResource("multimodal.test.png"); + UserMessage message = new UserMessage("Explain what do you see on this picture?", + Media.builder().mimeType(MimeTypeUtils.IMAGE_PNG).data(resource).build()); + Prompt prompt = new Prompt(message, options); + + QwenApi api = qwenApi(); + api.setMultimodalConversationParamCustomizer(MockImageContentFilter::handle); + + ChatResponse response = api.call(prompt, null); + logger.info(response.getResult().getOutput().getText()); + assertThat(response.getResult().getOutput().getText()).doesNotContainIgnoringCase("bananas", "apple", "bowl", + "basket", "fruit stand"); + } + +} diff --git a/models/spring-ai-qwen/src/test/resources/multimodal.test.png b/models/spring-ai-qwen/src/test/resources/multimodal.test.png new file mode 100644 index 00000000000..4f898454121 Binary files /dev/null and b/models/spring-ai-qwen/src/test/resources/multimodal.test.png differ diff --git a/pom.xml b/pom.xml index e4927004b8d..64618ba9754 100644 --- a/pom.xml +++ b/pom.xml @@ -176,6 +176,7 @@ models/spring-ai-watsonx-ai models/spring-ai-zhipuai models/spring-ai-moonshot + models/spring-ai-qwen spring-ai-spring-boot-starters/spring-ai-starter-model-anthropic spring-ai-spring-boot-starters/spring-ai-starter-model-azure-openai diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index e163de24f9f..e791eb30207 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -276,6 +276,12 @@ ${project.version} + + org.springframework.ai + spring-ai-qwen + ${project.version} + + diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index e723b679b02..3278e4feb92 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -99,7 +99,12 @@ public enum AiProvider { /** * AI system provided by ONNX. */ - ONNX("onnx"); + ONNX("onnx"), + + /** + * AI system provided by Alibaba + */ + ALIBABA("alibaba"); private final String value;