Skip to content

Commit f929ff8

Browse files
authored
Simplify AI Chat Response Streaming (#1167)
Reason --- - Simplify code and logic to stream chat response by solely relying on asyncio event loop. - Reduce overhead of managing threads to increase efficiency and throughput (where possible). Details --- - Use async/await with no threading when generating chat response via OpenAI, Gemini, Anthropic AI model APIs - Use threading for offline chat model as llama-cpp doesn't support async streaming yet
2 parents 973aded + a4b5842 commit f929ff8

File tree

12 files changed

+362
-366
lines changed

12 files changed

+362
-366
lines changed

src/khoj/database/adapters/__init__.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -763,9 +763,9 @@ async def ais_agent_accessible(agent: Agent, user: KhojUser) -> bool:
763763
return False
764764

765765
@staticmethod
766-
def get_conversation_agent_by_id(agent_id: int):
767-
agent = Agent.objects.filter(id=agent_id).first()
768-
if agent == AgentAdapters.get_default_agent():
766+
async def aget_conversation_agent_by_id(agent_id: int):
767+
agent = await Agent.objects.filter(id=agent_id).afirst()
768+
if agent == await AgentAdapters.aget_default_agent():
769769
# If the agent is set to the default agent, then return None and let the default application code be used
770770
return None
771771
return agent
@@ -1109,14 +1109,6 @@ def get_all_chat_models():
11091109
async def aget_all_chat_models():
11101110
return await sync_to_async(list)(ChatModel.objects.prefetch_related("ai_model_api").all())
11111111

1112-
@staticmethod
1113-
def get_vision_enabled_config():
1114-
chat_models = ConversationAdapters.get_all_chat_models()
1115-
for config in chat_models:
1116-
if config.vision_enabled:
1117-
return config
1118-
return None
1119-
11201112
@staticmethod
11211113
async def aget_vision_enabled_config():
11221114
chat_models = await ConversationAdapters.aget_all_chat_models()
@@ -1171,7 +1163,11 @@ def get_chat_model(user: KhojUser):
11711163
@staticmethod
11721164
async def aget_chat_model(user: KhojUser):
11731165
subscribed = await ais_user_subscribed(user)
1174-
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
1166+
config = (
1167+
await UserConversationConfig.objects.filter(user=user)
1168+
.prefetch_related("setting", "setting__ai_model_api")
1169+
.afirst()
1170+
)
11751171
if subscribed:
11761172
# Subscibed users can use any available chat model
11771173
if config:
@@ -1387,7 +1383,7 @@ def create_conversation_from_public_conversation(
13871383

13881384
@staticmethod
13891385
@require_valid_user
1390-
def save_conversation(
1386+
async def save_conversation(
13911387
user: KhojUser,
13921388
conversation_log: dict,
13931389
client_application: ClientApplication = None,
@@ -1396,19 +1392,21 @@ def save_conversation(
13961392
):
13971393
slug = user_message.strip()[:200] if user_message else None
13981394
if conversation_id:
1399-
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id).first()
1395+
conversation = await Conversation.objects.filter(
1396+
user=user, client=client_application, id=conversation_id
1397+
).afirst()
14001398
else:
14011399
conversation = (
1402-
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").first()
1400+
await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
14031401
)
14041402

14051403
if conversation:
14061404
conversation.conversation_log = conversation_log
14071405
conversation.slug = slug
14081406
conversation.updated_at = datetime.now(tz=timezone.utc)
1409-
conversation.save()
1407+
await conversation.asave()
14101408
else:
1411-
Conversation.objects.create(
1409+
await Conversation.objects.acreate(
14121410
user=user, conversation_log=conversation_log, client=client_application, slug=slug
14131411
)
14141412

@@ -1455,17 +1453,21 @@ async def aget_conversation_starters(user: KhojUser, max_results=3):
14551453
return random.sample(all_questions, max_results)
14561454

14571455
@staticmethod
1458-
def get_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
1456+
async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
14591457
agent: Agent = (
1460-
conversation.agent if is_subscribed and AgentAdapters.get_default_agent() != conversation.agent else None
1458+
conversation.agent
1459+
if is_subscribed and await AgentAdapters.aget_default_agent() != conversation.agent
1460+
else None
14611461
)
14621462
if agent and agent.chat_model:
1463-
chat_model = conversation.agent.chat_model
1463+
chat_model = await ChatModel.objects.select_related("ai_model_api").aget(
1464+
pk=conversation.agent.chat_model.pk
1465+
)
14641466
else:
1465-
chat_model = ConversationAdapters.get_chat_model(user)
1467+
chat_model = await ConversationAdapters.aget_chat_model(user)
14661468

14671469
if chat_model is None:
1468-
chat_model = ConversationAdapters.get_default_chat_model()
1470+
chat_model = await ConversationAdapters.aget_default_chat_model()
14691471

14701472
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
14711473
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:

src/khoj/processor/conversation/anthropic/anthropic_chat.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from datetime import datetime, timedelta
3-
from typing import Dict, List, Optional
3+
from typing import AsyncGenerator, Dict, List, Optional
44

55
import pyjson5
66
from langchain.schema import ChatMessage
@@ -137,7 +137,7 @@ def anthropic_send_message_to_model(
137137
)
138138

139139

140-
def converse_anthropic(
140+
async def converse_anthropic(
141141
references,
142142
user_query,
143143
online_results: Optional[Dict[str, Dict]] = None,
@@ -161,7 +161,7 @@ def converse_anthropic(
161161
generated_asset_results: Dict[str, Dict] = {},
162162
deepthought: Optional[bool] = False,
163163
tracer: dict = {},
164-
):
164+
) -> AsyncGenerator[str, None]:
165165
"""
166166
Converse with user using Anthropic's Claude
167167
"""
@@ -191,11 +191,17 @@ def converse_anthropic(
191191

192192
# Get Conversation Primer appropriate to Conversation Type
193193
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
194-
completion_func(chat_response=prompts.no_notes_found.format())
195-
return iter([prompts.no_notes_found.format()])
194+
response = prompts.no_notes_found.format()
195+
if completion_func:
196+
await completion_func(chat_response=response)
197+
yield response
198+
return
196199
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
197-
completion_func(chat_response=prompts.no_online_results_found.format())
198-
return iter([prompts.no_online_results_found.format()])
200+
response = prompts.no_online_results_found.format()
201+
if completion_func:
202+
await completion_func(chat_response=response)
203+
yield response
204+
return
199205

200206
context_message = ""
201207
if not is_none_or_empty(references):
@@ -228,17 +234,21 @@ def converse_anthropic(
228234
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
229235

230236
# Get Response from Claude
231-
return anthropic_chat_completion_with_backoff(
237+
full_response = ""
238+
async for chunk in anthropic_chat_completion_with_backoff(
232239
messages=messages,
233-
compiled_references=references,
234-
online_results=online_results,
235240
model_name=model,
236241
temperature=0.2,
237242
api_key=api_key,
238243
api_base_url=api_base_url,
239244
system_prompt=system_prompt,
240-
completion_func=completion_func,
241245
max_prompt_size=max_prompt_size,
242246
deepthought=deepthought,
243247
tracer=tracer,
244-
)
248+
):
249+
full_response += chunk
250+
yield chunk
251+
252+
# Call completion_func once finish streaming and we have the full response
253+
if completion_func:
254+
await completion_func(chat_response=full_response)

src/khoj/processor/conversation/anthropic/utils.py

+21-64
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from threading import Thread
2+
from time import perf_counter
33
from typing import Dict, List
44

55
import anthropic
@@ -13,13 +13,13 @@
1313
)
1414

1515
from khoj.processor.conversation.utils import (
16-
ThreadedGenerator,
1716
commit_conversation_trace,
1817
get_image_from_base64,
1918
get_image_from_url,
2019
)
2120
from khoj.utils.helpers import (
22-
get_ai_api_info,
21+
get_anthropic_async_client,
22+
get_anthropic_client,
2323
get_chat_usage_metrics,
2424
is_none_or_empty,
2525
is_promptrace_enabled,
@@ -28,24 +28,12 @@
2828
logger = logging.getLogger(__name__)
2929

3030
anthropic_clients: Dict[str, anthropic.Anthropic | anthropic.AnthropicVertex] = {}
31+
anthropic_async_clients: Dict[str, anthropic.AsyncAnthropic | anthropic.AsyncAnthropicVertex] = {}
3132

3233
DEFAULT_MAX_TOKENS_ANTHROPIC = 8000
3334
MAX_REASONING_TOKENS_ANTHROPIC = 12000
3435

3536

36-
def get_anthropic_client(api_key, api_base_url=None) -> anthropic.Anthropic | anthropic.AnthropicVertex:
37-
api_info = get_ai_api_info(api_key, api_base_url)
38-
if api_info.api_key:
39-
client = anthropic.Anthropic(api_key=api_info.api_key)
40-
else:
41-
client = anthropic.AnthropicVertex(
42-
region=api_info.region,
43-
project_id=api_info.project,
44-
credentials=api_info.credentials,
45-
)
46-
return client
47-
48-
4937
@retry(
5038
wait=wait_random_exponential(min=1, max=10),
5139
stop=stop_after_attempt(2),
@@ -126,60 +114,23 @@ def anthropic_completion_with_backoff(
126114
before_sleep=before_sleep_log(logger, logging.DEBUG),
127115
reraise=True,
128116
)
129-
def anthropic_chat_completion_with_backoff(
117+
async def anthropic_chat_completion_with_backoff(
130118
messages: list[ChatMessage],
131-
compiled_references,
132-
online_results,
133119
model_name,
134120
temperature,
135121
api_key,
136122
api_base_url,
137123
system_prompt: str,
138124
max_prompt_size=None,
139-
completion_func=None,
140-
deepthought=False,
141-
model_kwargs=None,
142-
tracer={},
143-
):
144-
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
145-
t = Thread(
146-
target=anthropic_llm_thread,
147-
args=(
148-
g,
149-
messages,
150-
system_prompt,
151-
model_name,
152-
temperature,
153-
api_key,
154-
api_base_url,
155-
max_prompt_size,
156-
deepthought,
157-
model_kwargs,
158-
tracer,
159-
),
160-
)
161-
t.start()
162-
return g
163-
164-
165-
def anthropic_llm_thread(
166-
g,
167-
messages: list[ChatMessage],
168-
system_prompt: str,
169-
model_name: str,
170-
temperature,
171-
api_key,
172-
api_base_url=None,
173-
max_prompt_size=None,
174125
deepthought=False,
175126
model_kwargs=None,
176127
tracer={},
177128
):
178129
try:
179-
client = anthropic_clients.get(api_key)
130+
client = anthropic_async_clients.get(api_key)
180131
if not client:
181-
client = get_anthropic_client(api_key, api_base_url)
182-
anthropic_clients[api_key] = client
132+
client = get_anthropic_async_client(api_key, api_base_url)
133+
anthropic_async_clients[api_key] = client
183134

184135
model_kwargs = model_kwargs or dict()
185136
max_tokens = DEFAULT_MAX_TOKENS_ANTHROPIC
@@ -193,7 +144,8 @@ def anthropic_llm_thread(
193144

194145
aggregated_response = ""
195146
final_message = None
196-
with client.messages.stream(
147+
start_time = perf_counter()
148+
async with client.messages.stream(
197149
messages=formatted_messages,
198150
model=model_name, # type: ignore
199151
temperature=temperature,
@@ -202,10 +154,17 @@ def anthropic_llm_thread(
202154
max_tokens=max_tokens,
203155
**model_kwargs,
204156
) as stream:
205-
for text in stream.text_stream:
157+
async for text in stream.text_stream:
158+
# Log the time taken to start response
159+
if aggregated_response == "":
160+
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
161+
# Handle streamed response chunk
206162
aggregated_response += text
207-
g.send(text)
208-
final_message = stream.get_final_message()
163+
yield text
164+
final_message = await stream.get_final_message()
165+
166+
# Log the time taken to stream the entire response
167+
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
209168

210169
# Calculate cost of chat
211170
input_tokens = final_message.usage.input_tokens
@@ -222,9 +181,7 @@ def anthropic_llm_thread(
222181
if is_promptrace_enabled():
223182
commit_conversation_trace(messages, aggregated_response, tracer)
224183
except Exception as e:
225-
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
226-
finally:
227-
g.close()
184+
logger.error(f"Error in anthropic_chat_completion_with_backoff stream: {e}", exc_info=True)
228185

229186

230187
def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt: str = None):

0 commit comments

Comments
 (0)