Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.client.advisor;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.Content;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* This advisor implements common Retrieval Augmented Generation (RAG) flows using the
* building blocks defined in the {@link org.springframework.ai.rag} package and following
* the Modular RAG Architecture.
* <p>
* It's the successor of the {@link QuestionAnswerAdvisor}.
*
* @author Christian Tzolov
* @author Thomas Vitale
* @since 1.0.0
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
*/
public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

public static final String DOCUMENT_CONTEXT = "rag_document_context";

public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
{query}

Context information is below. Use this information to answer the user query.

---------------------
{context}
---------------------

Given the context and provided history information and not prior knowledge,
reply to the user query. If the answer is not in the context, inform
the user that you can't answer the query.
""");

private final DocumentRetriever documentRetriever;

private final PromptTemplate promptTemplate;

private final boolean protectFromBlocking;

private final int order;

public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable PromptTemplate promptTemplate,
@Nullable Boolean protectFromBlocking, @Nullable Integer order) {
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
this.documentRetriever = documentRetriever;
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false;
this.order = order != null ? order : 0;
}

public static Builder builder() {
return new Builder();
}

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
Assert.notNull(chain, "chain cannot be null");

AdvisedRequest processedAdvisedRequest = before(advisedRequest);
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
return after(advisedResponse);
}

@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
Assert.notNull(chain, "chain cannot be null");

// This can be executed by both blocking and non-blocking Threads
// E.g. a command line or Tomcat blocking Thread implementation
// or by a WebFlux dispatch in a non-blocking manner.
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
// @formatter:off
Mono.just(advisedRequest)
.publishOn(Schedulers.boundedElastic())
.map(this::before)
.flatMapMany(chain::nextAroundStream)
: chain.nextAroundStream(before(advisedRequest));
// @formatter:on

return advisedResponses.map(ar -> {
if (onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
});
}

private AdvisedRequest before(AdvisedRequest request) {
Map<String, Object> context = new HashMap<>(request.adviseContext());

// 0. Create a query from the user text and parameters.
Query query = new Query(new PromptTemplate(request.userText(), request.userParams()).render());

// 1. Retrieve similar documents for the original query.
List<Document> documents = this.documentRetriever.retrieve(query);
context.put(DOCUMENT_CONTEXT, documents);

// 2. Combine retrieved documents.
String documentContext = documents.stream()
.map(Content::getContent)
.collect(Collectors.joining(System.lineSeparator()));

// 3. Define augmentation prompt parameters.
Map<String, Object> promptParameters = Map.of("query", query.text(), "context", documentContext);

// 4. Augment user prompt with the context data.
UserMessage augmentedUserMessage = (UserMessage) this.promptTemplate.createMessage(promptParameters);

return AdvisedRequest.from(request)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have an issue, #1592 related to how we want to consistently implement builders. Does from go in the 'product' class or in the 'builder' class?

.withUserText(augmentedUserMessage.getContent())
.withAdviseContext(context)
.build();
}

private AdvisedResponse after(AdvisedResponse advisedResponse) {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
}

private Predicate<AdvisedResponse> onFinishReason() {
return advisedResponse -> advisedResponse.response()
.getResults()
.stream()
.anyMatch(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
}

@Override
public String getName() {
return this.getClass().getSimpleName();
}

@Override
public int getOrder() {
return this.order;
}

public static final class Builder {

private DocumentRetriever documentRetriever;

private PromptTemplate promptTemplate;

private Boolean protectFromBlocking;

private Integer order;

private Builder() {
}

public Builder documentRetriever(DocumentRetriever documentRetriever) {
this.documentRetriever = documentRetriever;
return this;
}

public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}

public Builder protectFromBlocking(Boolean protectFromBlocking) {
this.protectFromBlocking = protectFromBlocking;
return this;
}

public Builder order(Integer order) {
this.order = order;
return this;
}

public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(this.documentRetriever, this.promptTemplate,
this.protectFromBlocking, this.order);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@
import org.springframework.ai.model.Media;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
* The data of the chat client request that can be modified before the execution of the
* ChatClient's call method
*
* @author Christian Tzolov
* @since 1.0.0
* @param chatModel the chat model used
* @param userText the text provided by the user
* @param systemText the text provided by the system
Expand All @@ -57,13 +57,53 @@
* @param advisorParams the map of advisor parameters
* @param adviseContext the map of advise context
* @param toolContext the tool context
* @author Christian Tzolov
* @author Thomas Vitale
* @since 1.0.0
*/
public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions,
List<Media> media, List<String> functionNames, List<FunctionCallback> functionCallbacks, List<Message> messages,
Map<String, Object> userParams, Map<String, Object> systemParams, List<Advisor> advisors,
Map<String, Object> advisorParams, Map<String, Object> adviseContext, Map<String, Object> toolContext) {
public record AdvisedRequest(
// @formatter:off
ChatModel chatModel,
String userText,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be an odd case, but maybe should allow userText to be null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought a lot about a possible use case where I would want that and didn't really come up with anything. For example, OpenAI requires a message to be included in each request to a chat model. It fails if the user message is null. Before this change, we supported the user message to be null, and in that case set the "userText" to be an empty string as a workaround. That could be dangerous because as a developer I wouldn't catch scenarios where I'm not passing a user message by mistake, but the model would still be called and the outcome would be unpredictable.

@Nullable
String systemText,
@Nullable
ChatOptions chatOptions,
List<Media> media,
List<String> functionNames,
List<FunctionCallback> functionCallbacks,
List<Message> messages,
Map<String, Object> userParams,
Map<String, Object> systemParams,
List<Advisor> advisors,
Map<String, Object> advisorParams,
Map<String, Object> adviseContext,
Map<String, Object> toolContext
// @formatter:on
) {

public AdvisedRequest {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.hasText(userText, "userText cannot be null or empty");
Assert.notNull(media, "media cannot be null");
Assert.notNull(functionNames, "functionNames cannot be null");
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
Assert.notNull(messages, "messages cannot be null");
Assert.notNull(userParams, "userParams cannot be null");
Assert.notNull(systemParams, "systemParams cannot be null");
Assert.notNull(advisors, "advisors cannot be null");
Assert.notNull(advisorParams, "advisorParams cannot be null");
Assert.notNull(adviseContext, "adviseContext cannot be null");
Assert.notNull(toolContext, "toolContext cannot be null");
}

public static Builder builder() {
return new Builder();
}

public static Builder from(AdvisedRequest from) {
Assert.notNull(from, "AdvisedRequest cannot be null");

Builder builder = new Builder();
builder.chatModel = from.chatModel;
builder.userText = from.userText;
Expand All @@ -79,23 +119,18 @@ public static Builder from(AdvisedRequest from) {
builder.advisorParams = from.advisorParams;
builder.adviseContext = from.adviseContext;
builder.toolContext = from.toolContext;

return builder;
}

public static Builder builder() {
return new Builder();
}

public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Object>> contextTransform) {
Assert.notNull(contextTransform, "contextTransform cannot be null");
return from(this)
.withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext))))
.build();
}

public Prompt toPrompt() {

var messages = new ArrayList<Message>(this.messages());
var messages = new ArrayList<>(this.messages());

String processedSystemText = this.systemText();
if (StringUtils.hasText(processedSystemText)) {
Expand All @@ -111,7 +146,6 @@ public Prompt toPrompt() {
? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText();

if (StringUtils.hasText(processedUserText)) {

Map<String, Object> userParams = new HashMap<>(this.userParams());
if (StringUtils.hasText(formatParam)) {
userParams.put("spring_ai_soc_format", formatParam);
Expand All @@ -137,17 +171,15 @@ public Prompt toPrompt() {
return new Prompt(messages, this.chatOptions());
}

public static class Builder {

public Map<String, Object> toolContext = Map.of();
public static final class Builder {

private ChatModel chatModel;

private String userText = "";
private String userText;

private String systemText = "";
private String systemText;

private ChatOptions chatOptions = null;
private ChatOptions chatOptions;

private List<Media> media = List.of();

Expand All @@ -167,6 +199,11 @@ public static class Builder {

private Map<String, Object> adviseContext = Map.of();

public Map<String, Object> toolContext = Map.of();

private Builder() {
}

public Builder withChatModel(ChatModel chatModel) {
this.chatModel = chatModel;
return this;
Expand Down Expand Up @@ -202,11 +239,6 @@ public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
return this;
}

public Builder withToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
return this;
}

public Builder withMessages(List<Message> messages) {
this.messages = messages;
return this;
Expand Down Expand Up @@ -237,6 +269,11 @@ public Builder withAdviseContext(Map<String, Object> adviseContext) {
return this;
}

public Builder withToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
return this;
}

public AdvisedRequest build() {
return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media,
this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams,
Expand Down
Loading