-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Modular RAG: Retrieval with Vector Stores #1604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
/* | ||
* Copyright 2023-2024 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.springframework.ai.chat.client.advisor; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.function.Predicate; | ||
import java.util.stream.Collectors; | ||
|
||
import reactor.core.publisher.Flux; | ||
import reactor.core.publisher.Mono; | ||
import reactor.core.scheduler.Schedulers; | ||
|
||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; | ||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; | ||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; | ||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; | ||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; | ||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; | ||
import org.springframework.ai.chat.messages.UserMessage; | ||
import org.springframework.ai.chat.model.ChatResponse; | ||
import org.springframework.ai.chat.prompt.PromptTemplate; | ||
import org.springframework.ai.document.Document; | ||
import org.springframework.ai.model.Content; | ||
import org.springframework.ai.rag.Query; | ||
import org.springframework.ai.rag.retrieval.source.DocumentRetriever; | ||
import org.springframework.lang.Nullable; | ||
import org.springframework.util.Assert; | ||
import org.springframework.util.StringUtils; | ||
|
||
/** | ||
* This advisor implements common Retrieval Augmented Generation (RAG) flows using the | ||
* building blocks defined in the {@link org.springframework.ai.rag} package and following | ||
* the Modular RAG Architecture. | ||
* <p> | ||
* It's the successor of the {@link QuestionAnswerAdvisor}. | ||
* | ||
* @author Christian Tzolov | ||
* @author Thomas Vitale | ||
* @since 1.0.0 | ||
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a> | ||
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a> | ||
*/ | ||
public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { | ||
|
||
public static final String DOCUMENT_CONTEXT = "rag_document_context"; | ||
|
||
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" | ||
{query} | ||
|
||
Context information is below. Use this information to answer the user query. | ||
|
||
--------------------- | ||
{context} | ||
--------------------- | ||
|
||
Given the context and provided history information and not prior knowledge, | ||
reply to the user query. If the answer is not in the context, inform | ||
the user that you can't answer the query. | ||
"""); | ||
|
||
private final DocumentRetriever documentRetriever; | ||
|
||
private final PromptTemplate promptTemplate; | ||
|
||
private final boolean protectFromBlocking; | ||
|
||
private final int order; | ||
|
||
public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable PromptTemplate promptTemplate, | ||
@Nullable Boolean protectFromBlocking, @Nullable Integer order) { | ||
Assert.notNull(documentRetriever, "documentRetriever cannot be null"); | ||
this.documentRetriever = documentRetriever; | ||
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; | ||
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false; | ||
this.order = order != null ? order : 0; | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
@Override | ||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { | ||
Assert.notNull(advisedRequest, "advisedRequest cannot be null"); | ||
Assert.notNull(chain, "chain cannot be null"); | ||
|
||
AdvisedRequest processedAdvisedRequest = before(advisedRequest); | ||
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest); | ||
return after(advisedResponse); | ||
} | ||
|
||
@Override | ||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { | ||
Assert.notNull(advisedRequest, "advisedRequest cannot be null"); | ||
Assert.notNull(chain, "chain cannot be null"); | ||
|
||
// This can be executed by both blocking and non-blocking Threads | ||
// E.g. a command line or Tomcat blocking Thread implementation | ||
// or by a WebFlux dispatch in a non-blocking manner. | ||
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ? | ||
// @formatter:off | ||
Mono.just(advisedRequest) | ||
.publishOn(Schedulers.boundedElastic()) | ||
.map(this::before) | ||
.flatMapMany(chain::nextAroundStream) | ||
: chain.nextAroundStream(before(advisedRequest)); | ||
// @formatter:on | ||
|
||
return advisedResponses.map(ar -> { | ||
if (onFinishReason().test(ar)) { | ||
ar = after(ar); | ||
} | ||
return ar; | ||
}); | ||
} | ||
|
||
private AdvisedRequest before(AdvisedRequest request) { | ||
Map<String, Object> context = new HashMap<>(request.adviseContext()); | ||
|
||
// 0. Create a query from the user text and parameters. | ||
Query query = new Query(new PromptTemplate(request.userText(), request.userParams()).render()); | ||
|
||
// 1. Retrieve similar documents for the original query. | ||
List<Document> documents = this.documentRetriever.retrieve(query); | ||
context.put(DOCUMENT_CONTEXT, documents); | ||
|
||
// 2. Combine retrieved documents. | ||
String documentContext = documents.stream() | ||
.map(Content::getContent) | ||
.collect(Collectors.joining(System.lineSeparator())); | ||
|
||
// 3. Define augmentation prompt parameters. | ||
Map<String, Object> promptParameters = Map.of("query", query.text(), "context", documentContext); | ||
|
||
// 4. Augment user prompt with the context data. | ||
UserMessage augmentedUserMessage = (UserMessage) this.promptTemplate.createMessage(promptParameters); | ||
|
||
return AdvisedRequest.from(request) | ||
.withUserText(augmentedUserMessage.getContent()) | ||
.withAdviseContext(context) | ||
.build(); | ||
} | ||
|
||
private AdvisedResponse after(AdvisedResponse advisedResponse) { | ||
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response()); | ||
chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT)); | ||
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext()); | ||
} | ||
|
||
private Predicate<AdvisedResponse> onFinishReason() { | ||
return advisedResponse -> advisedResponse.response() | ||
.getResults() | ||
.stream() | ||
.anyMatch(result -> result != null && result.getMetadata() != null | ||
&& StringUtils.hasText(result.getMetadata().getFinishReason())); | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return this.getClass().getSimpleName(); | ||
} | ||
|
||
@Override | ||
public int getOrder() { | ||
return this.order; | ||
} | ||
|
||
public static final class Builder { | ||
|
||
private DocumentRetriever documentRetriever; | ||
|
||
private PromptTemplate promptTemplate; | ||
|
||
private Boolean protectFromBlocking; | ||
|
||
private Integer order; | ||
|
||
private Builder() { | ||
} | ||
|
||
public Builder documentRetriever(DocumentRetriever documentRetriever) { | ||
this.documentRetriever = documentRetriever; | ||
return this; | ||
} | ||
|
||
public Builder promptTemplate(PromptTemplate promptTemplate) { | ||
this.promptTemplate = promptTemplate; | ||
return this; | ||
} | ||
|
||
public Builder protectFromBlocking(Boolean protectFromBlocking) { | ||
this.protectFromBlocking = protectFromBlocking; | ||
return this; | ||
} | ||
|
||
public Builder order(Integer order) { | ||
this.order = order; | ||
return this; | ||
} | ||
|
||
public RetrievalAugmentationAdvisor build() { | ||
return new RetrievalAugmentationAdvisor(this.documentRetriever, this.promptTemplate, | ||
this.protectFromBlocking, this.order); | ||
} | ||
|
||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,15 +34,15 @@ | |
import org.springframework.ai.model.Media; | ||
import org.springframework.ai.model.function.FunctionCallback; | ||
import org.springframework.ai.model.function.FunctionCallingOptions; | ||
import org.springframework.lang.Nullable; | ||
import org.springframework.util.Assert; | ||
import org.springframework.util.CollectionUtils; | ||
import org.springframework.util.StringUtils; | ||
|
||
/** | ||
* The data of the chat client request that can be modified before the execution of the | ||
* ChatClient's call method | ||
* | ||
* @author Christian Tzolov | ||
* @since 1.0.0 | ||
* @param chatModel the chat model used | ||
* @param userText the text provided by the user | ||
* @param systemText the text provided by the system | ||
|
@@ -57,13 +57,53 @@ | |
* @param advisorParams the map of advisor parameters | ||
* @param adviseContext the map of advise context | ||
* @param toolContext the tool context | ||
* @author Christian Tzolov | ||
* @author Thomas Vitale | ||
* @since 1.0.0 | ||
*/ | ||
public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions, | ||
List<Media> media, List<String> functionNames, List<FunctionCallback> functionCallbacks, List<Message> messages, | ||
Map<String, Object> userParams, Map<String, Object> systemParams, List<Advisor> advisors, | ||
Map<String, Object> advisorParams, Map<String, Object> adviseContext, Map<String, Object> toolContext) { | ||
public record AdvisedRequest( | ||
// @formatter:off | ||
ChatModel chatModel, | ||
String userText, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. may be an odd case, but maybe should allow userText to be null? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought a lot about a possible use case where I would want that and didn't really come up with anything. For example, OpenAI requires a message to be included in each request to a chat model. It fails if the user message is null. Before this change, we supported the user message to be null, and in that case set the "userText" to be an empty string as a workaround. That could be dangerous because as a developer I wouldn't catch scenarios where I'm not passing a user message by mistake, but the model would still be called and the outcome would be unpredictable. |
||
@Nullable | ||
String systemText, | ||
@Nullable | ||
ChatOptions chatOptions, | ||
List<Media> media, | ||
List<String> functionNames, | ||
List<FunctionCallback> functionCallbacks, | ||
List<Message> messages, | ||
Map<String, Object> userParams, | ||
Map<String, Object> systemParams, | ||
List<Advisor> advisors, | ||
Map<String, Object> advisorParams, | ||
Map<String, Object> adviseContext, | ||
Map<String, Object> toolContext | ||
// @formatter:on | ||
) { | ||
|
||
public AdvisedRequest { | ||
Assert.notNull(chatModel, "chatModel cannot be null"); | ||
Assert.hasText(userText, "userText cannot be null or empty"); | ||
Assert.notNull(media, "media cannot be null"); | ||
Assert.notNull(functionNames, "functionNames cannot be null"); | ||
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); | ||
Assert.notNull(messages, "messages cannot be null"); | ||
Assert.notNull(userParams, "userParams cannot be null"); | ||
Assert.notNull(systemParams, "systemParams cannot be null"); | ||
Assert.notNull(advisors, "advisors cannot be null"); | ||
Assert.notNull(advisorParams, "advisorParams cannot be null"); | ||
Assert.notNull(adviseContext, "adviseContext cannot be null"); | ||
Assert.notNull(toolContext, "toolContext cannot be null"); | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public static Builder from(AdvisedRequest from) { | ||
Assert.notNull(from, "AdvisedRequest cannot be null"); | ||
|
||
Builder builder = new Builder(); | ||
builder.chatModel = from.chatModel; | ||
builder.userText = from.userText; | ||
|
@@ -79,23 +119,18 @@ public static Builder from(AdvisedRequest from) { | |
builder.advisorParams = from.advisorParams; | ||
builder.adviseContext = from.adviseContext; | ||
builder.toolContext = from.toolContext; | ||
|
||
return builder; | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Object>> contextTransform) { | ||
Assert.notNull(contextTransform, "contextTransform cannot be null"); | ||
return from(this) | ||
.withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))) | ||
.build(); | ||
} | ||
|
||
public Prompt toPrompt() { | ||
|
||
var messages = new ArrayList<Message>(this.messages()); | ||
var messages = new ArrayList<>(this.messages()); | ||
|
||
String processedSystemText = this.systemText(); | ||
if (StringUtils.hasText(processedSystemText)) { | ||
|
@@ -111,7 +146,6 @@ public Prompt toPrompt() { | |
? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); | ||
|
||
if (StringUtils.hasText(processedUserText)) { | ||
|
||
Map<String, Object> userParams = new HashMap<>(this.userParams()); | ||
if (StringUtils.hasText(formatParam)) { | ||
userParams.put("spring_ai_soc_format", formatParam); | ||
|
@@ -137,17 +171,15 @@ public Prompt toPrompt() { | |
return new Prompt(messages, this.chatOptions()); | ||
} | ||
|
||
public static class Builder { | ||
|
||
public Map<String, Object> toolContext = Map.of(); | ||
public static final class Builder { | ||
|
||
private ChatModel chatModel; | ||
|
||
private String userText = ""; | ||
private String userText; | ||
|
||
private String systemText = ""; | ||
private String systemText; | ||
|
||
private ChatOptions chatOptions = null; | ||
private ChatOptions chatOptions; | ||
|
||
private List<Media> media = List.of(); | ||
|
||
|
@@ -167,6 +199,11 @@ public static class Builder { | |
|
||
private Map<String, Object> adviseContext = Map.of(); | ||
|
||
public Map<String, Object> toolContext = Map.of(); | ||
|
||
private Builder() { | ||
} | ||
|
||
public Builder withChatModel(ChatModel chatModel) { | ||
this.chatModel = chatModel; | ||
return this; | ||
|
@@ -202,11 +239,6 @@ public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) { | |
return this; | ||
} | ||
|
||
public Builder withToolContext(Map<String, Object> toolContext) { | ||
this.toolContext = toolContext; | ||
return this; | ||
} | ||
|
||
public Builder withMessages(List<Message> messages) { | ||
this.messages = messages; | ||
return this; | ||
|
@@ -237,6 +269,11 @@ public Builder withAdviseContext(Map<String, Object> adviseContext) { | |
return this; | ||
} | ||
|
||
public Builder withToolContext(Map<String, Object> toolContext) { | ||
this.toolContext = toolContext; | ||
return this; | ||
} | ||
|
||
public AdvisedRequest build() { | ||
return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media, | ||
this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have an issue, #1592 related to how we want to consistently implement builders. Does
from
go in the 'product' class or in the 'builder' class?