Skip to content

Commit cc705a7

Browse files
committed
fix: Incorrect order when Advisors have the same order
- Fix incorrect order when Advisors have the same order - Added the `hasNextCallAdvisor` method to `CallAdvisorChain`. - Added the `hasNextStreamAdvisor` method to `StreamAdvisorChain`. - The `DefaultAroundAdvisorChain` implements both `hasNextCallAdvisor` and `hasNextStreamAdvisor` method. - Added a last StreamAdvisor check in `ChatModelStreamAdvisor`. - Added a last CallAdvisor check in `ChatModelCallAdvisor`. - Updated the corresponding test cases. Signed-off-by: YunKui Lu <[email protected]>
1 parent 694bb50 commit cc705a7

File tree

8 files changed

+184
-2
lines changed

8 files changed

+184
-2
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ private ChatModelCallAdvisor(ChatModel chatModel) {
4848
@Override
4949
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
5050
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
51+
Assert.isTrue(!callAdvisorChain.hasNextCallAdvisor(),
52+
"ChatModelCallAdvisor should be the last CallAdvisor in the chain");
5153

5254
ChatClientRequest formattedChatClientRequest = augmentWithFormatInstructions(chatClientRequest);
5355

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ private ChatModelStreamAdvisor(ChatModel chatModel) {
4848
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
4949
StreamAdvisorChain streamAdvisorChain) {
5050
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
51+
Assert.isTrue(!streamAdvisorChain.hasNextStreamAdvisor(),
52+
"ChatModelStreamAdvisor should be the last StreamAdvisor in the chain");
5153

5254
return this.chatModel.stream(chatClientRequest.prompt())
5355
.map(chatResponse -> ChatClientResponse.builder()

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,21 @@ public List<CallAdvisor> getCallAdvisors() {
146146
return this.originalCallAdvisors;
147147
}
148148

149+
@Override
150+
public boolean hasNextCallAdvisor() {
151+
return !this.callAdvisors.isEmpty();
152+
}
153+
149154
@Override
150155
public List<StreamAdvisor> getStreamAdvisors() {
151156
return this.originalStreamAdvisors;
152157
}
153158

159+
@Override
160+
public boolean hasNextStreamAdvisor() {
161+
return !this.streamAdvisors.isEmpty();
162+
}
163+
154164
@Override
155165
public ObservationRegistry getObservationRegistry() {
156166
return this.observationRegistry;
@@ -192,7 +202,7 @@ public Builder pushAll(List<? extends Advisor> advisors) {
192202
.toList();
193203

194204
if (!CollectionUtils.isEmpty(callAroundAdvisorList)) {
195-
callAroundAdvisorList.forEach(this.callAdvisors::push);
205+
this.callAdvisors.addAll(callAroundAdvisorList);
196206
}
197207

198208
List<StreamAdvisor> streamAroundAdvisorList = advisors.stream()
@@ -201,7 +211,7 @@ public Builder pushAll(List<? extends Advisor> advisors) {
201211
.toList();
202212

203213
if (!CollectionUtils.isEmpty(streamAroundAdvisorList)) {
204-
streamAroundAdvisorList.forEach(this.streamAdvisors::push);
214+
this.streamAdvisors.addAll(streamAroundAdvisorList);
205215
}
206216

207217
this.reOrder();

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
* @author Christian Tzolov
2929
* @author Dariusz Jedrzejczyk
3030
* @author Thomas Vitale
31+
* @author YunKui Lu
3132
* @since 1.0.0
3233
*/
3334
public interface CallAdvisorChain extends AdvisorChain {
@@ -44,4 +45,11 @@ public interface CallAdvisorChain extends AdvisorChain {
4445
*/
4546
List<CallAdvisor> getCallAdvisors();
4647

48+
/**
49+
* Returns true if there is a next {@link CallAdvisor} in the chain.
50+
*/
51+
default boolean hasNextCallAdvisor() {
52+
throw new UnsupportedOperationException("This CallAdvisorChain does not support hasNextCallAdvisor()");
53+
}
54+
4755
}

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
* @author Christian Tzolov
3131
* @author Dariusz Jedrzejczyk
3232
* @author Thomas Vitale
33+
* @author YunKui Lu
3334
* @since 1.0.0
3435
*/
3536
public interface StreamAdvisorChain extends AdvisorChain {
@@ -46,4 +47,11 @@ public interface StreamAdvisorChain extends AdvisorChain {
4647
*/
4748
List<StreamAdvisor> getStreamAdvisors();
4849

50+
/**
51+
* Returns true if there is a next {@link StreamAdvisor} in the chain.
52+
*/
53+
default boolean hasNextStreamAdvisor() {
54+
throw new UnsupportedOperationException("This StreamAdvisorChain does not support hasNextStreamAdvisor()");
55+
}
56+
4957
}

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisorTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818

1919
import org.junit.jupiter.api.Test;
2020

21+
import org.springframework.ai.chat.client.ChatClientRequest;
22+
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
23+
import org.springframework.ai.chat.model.ChatModel;
24+
2125
import static org.assertj.core.api.Assertions.assertThatThrownBy;
26+
import static org.mockito.Mockito.mock;
27+
import static org.mockito.Mockito.when;
2228

2329
/**
2430
* Unit tests for {@link ChatModelCallAdvisor}.
@@ -34,4 +40,19 @@ void whenChatModelIsNullThenThrow() {
3440
.hasMessage("chatModel cannot be null");
3541
}
3642

43+
@Test
44+
void whenNotLastInChainThrow() {
45+
ChatModel chatModel = mock(ChatModel.class);
46+
ChatClientRequest chatClientRequest = mock(ChatClientRequest.class);
47+
CallAdvisorChain callAdvisorChain = mock(CallAdvisorChain.class);
48+
49+
when(callAdvisorChain.hasNextCallAdvisor()).thenReturn(true);
50+
51+
ChatModelCallAdvisor chatModelCallAdvisor = ChatModelCallAdvisor.builder().chatModel(chatModel).build();
52+
53+
assertThatThrownBy(() -> chatModelCallAdvisor.adviseCall(chatClientRequest, callAdvisorChain))
54+
.isInstanceOf(IllegalArgumentException.class)
55+
.hasMessage("ChatModelCallAdvisor should be the last CallAdvisor in the chain");
56+
}
57+
3758
}

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisorTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818

1919
import org.junit.jupiter.api.Test;
2020

21+
import org.springframework.ai.chat.client.ChatClientRequest;
22+
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
23+
import org.springframework.ai.chat.model.ChatModel;
24+
2125
import static org.assertj.core.api.Assertions.assertThatThrownBy;
26+
import static org.mockito.Mockito.mock;
27+
import static org.mockito.Mockito.when;
2228

2329
/**
2430
* Unit tests for {@link ChatModelStreamAdvisor}.
@@ -34,4 +40,19 @@ void whenChatModelIsNullThenThrow() {
3440
.hasMessage("chatModel cannot be null");
3541
}
3642

43+
@Test
44+
void whenNotLastInChainThrow() {
45+
ChatModel chatModel = mock(ChatModel.class);
46+
ChatClientRequest chatClientRequest = mock(ChatClientRequest.class);
47+
StreamAdvisorChain streamAdvisorChain = mock(StreamAdvisorChain.class);
48+
49+
when(streamAdvisorChain.hasNextStreamAdvisor()).thenReturn(true);
50+
51+
ChatModelStreamAdvisor chatModelStreamAdvisor = ChatModelStreamAdvisor.builder().chatModel(chatModel).build();
52+
53+
assertThatThrownBy(() -> chatModelStreamAdvisor.adviseStream(chatClientRequest, streamAdvisorChain))
54+
.isInstanceOf(IllegalArgumentException.class)
55+
.hasMessage("ChatModelStreamAdvisor should be the last StreamAdvisor in the chain");
56+
}
57+
3758
}

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,34 @@ void getCallAdvisors() {
103103
assertThat(chain.getCallAdvisors()).containsExactlyInAnyOrder(advisors.toArray(new CallAdvisor[0]));
104104
}
105105

106+
@Test
107+
void hasNextCallAdvisor() {
108+
// The first advisor
109+
TestAdvisor advisor1 = new TestAdvisor("advisor1", 1) {
110+
@Override
111+
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest,
112+
CallAdvisorChain callAdvisorChain) {
113+
assertThat(callAdvisorChain.hasNextCallAdvisor()).isTrue();
114+
return callAdvisorChain.nextCall(chatClientRequest);
115+
}
116+
};
117+
118+
// The last advisor
119+
TestAdvisor advisor2 = new TestAdvisor("advisor2", 2) {
120+
@Override
121+
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest,
122+
CallAdvisorChain callAdvisorChain) {
123+
assertThat(callAdvisorChain.hasNextCallAdvisor()).isFalse();
124+
return null;
125+
}
126+
};
127+
128+
List<CallAdvisor> advisors = List.of(advisor1, advisor2);
129+
CallAdvisorChain chain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP).pushAll(advisors).build();
130+
131+
chain.nextCall(mock(ChatClientRequest.class));
132+
}
133+
106134
@Test
107135
void getStreamAdvisors() {
108136
StreamAdvisor mockAdvisor1 = mock(StreamAdvisor.class);
@@ -125,4 +153,86 @@ void getStreamAdvisors() {
125153
assertThat(chain.getStreamAdvisors()).containsExactlyInAnyOrder(advisors.toArray(new StreamAdvisor[0]));
126154
}
127155

156+
@Test
157+
void hasNextStreamAdvisor() {
158+
// The first advisor
159+
TestAdvisor advisor1 = new TestAdvisor("advisor1", 1) {
160+
@Override
161+
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
162+
StreamAdvisorChain streamAdvisorChain) {
163+
assertThat(streamAdvisorChain.hasNextStreamAdvisor()).isTrue();
164+
return streamAdvisorChain.nextStream(chatClientRequest);
165+
}
166+
};
167+
// The last advisor
168+
TestAdvisor advisor2 = new TestAdvisor("advisor2", 2) {
169+
@Override
170+
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
171+
StreamAdvisorChain streamAdvisorChain) {
172+
assertThat(streamAdvisorChain.hasNextStreamAdvisor()).isFalse();
173+
return Flux.empty();
174+
}
175+
};
176+
177+
List<StreamAdvisor> advisors = List.of(advisor1, advisor2);
178+
StreamAdvisorChain chain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
179+
.pushAll(advisors)
180+
.build();
181+
182+
chain.nextStream(mock(ChatClientRequest.class)).blockLast();
183+
}
184+
185+
@Test
186+
void testOrder() {
187+
TestAdvisor advisor1 = new TestAdvisor("advisor1", 1);
188+
TestAdvisor advisor21 = new TestAdvisor("advisor2_1", 2);
189+
TestAdvisor advisor22 = new TestAdvisor("advisor2_2", 2);
190+
TestAdvisor advisor3 = new TestAdvisor("advisor3", 3);
191+
192+
var advisors = List.of(advisor3, advisor1, advisor21, advisor22);
193+
194+
DefaultAroundAdvisorChain chain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
195+
.pushAll(advisors)
196+
.build();
197+
198+
assertThat(chain.getStreamAdvisors()).containsExactly(advisor1, advisor21, advisor22, advisor3);
199+
assertThat(chain.getCallAdvisors()).containsExactly(advisor1, advisor21, advisor22, advisor3);
200+
}
201+
202+
private static class TestAdvisor implements CallAdvisor, StreamAdvisor {
203+
204+
private final String name;
205+
206+
private final int order;
207+
208+
private TestAdvisor(String name, int order) {
209+
this.name = name;
210+
this.order = order;
211+
}
212+
213+
@Override
214+
public String getName() {
215+
return name;
216+
}
217+
218+
@Override
219+
public int getOrder() {
220+
return order;
221+
}
222+
223+
@Override
224+
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
225+
System.out.println(callAdvisorChain.hasNextCallAdvisor());
226+
return callAdvisorChain.nextCall(chatClientRequest);
227+
}
228+
229+
@Override
230+
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
231+
StreamAdvisorChain streamAdvisorChain) {
232+
System.out.println(streamAdvisorChain.hasNextStreamAdvisor());
233+
return streamAdvisorChain.nextStream(chatClientRequest);
234+
}
235+
236+
}
237+
128238
}

0 commit comments

Comments
 (0)