diff --git a/requirements-dev.txt b/requirements-dev.txt index 722434db..1d7ad271 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,10 +1,12 @@ -r src/backend/requirements.txt ruff +mypy pre-commit pip-tools pip-compile-cross-platform pytest pytest-cov pytest-asyncio +pytest-snapshot mypy locust diff --git a/src/backend/fastapi_app/api_models.py b/src/backend/fastapi_app/api_models.py index 2e214a5e..c98ca76d 100644 --- a/src/backend/fastapi_app/api_models.py +++ b/src/backend/fastapi_app/api_models.py @@ -1,17 +1,42 @@ +from enum import Enum from typing import Any from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel +class AIChatRoles(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + class Message(BaseModel): content: str - role: str = "user" + role: AIChatRoles = AIChatRoles.USER + + +class RetrievalMode(str, Enum): + TEXT = "text" + VECTORS = "vectors" + HYBRID = "hybrid" + + +class ChatRequestOverrides(BaseModel): + top: int = 3 + temperature: float = 0.3 + retrieval_mode: RetrievalMode = RetrievalMode.HYBRID + use_advanced_flow: bool = True + prompt_template: str | None = None + + +class ChatRequestContext(BaseModel): + overrides: ChatRequestOverrides class ChatRequest(BaseModel): messages: list[ChatCompletionMessageParam] - context: dict = {} + context: ChatRequestContext class ThoughtStep(BaseModel): @@ -32,6 +57,12 @@ class RetrievalResponse(BaseModel): session_state: Any | None = None +class RetrievalResponseDelta(BaseModel): + delta: Message | None = None + context: RAGContext | None = None + session_state: Any | None = None + + class ItemPublic(BaseModel): id: int type: str @@ -43,3 +74,12 @@ class ItemPublic(BaseModel): class ItemWithDistance(ItemPublic): distance: float + + +class ChatParams(ChatRequestOverrides): + prompt_template: str + response_token_limit: int = 1024 + enable_text_search: bool + enable_vector_search: bool + original_user_query: str + past_messages: list[ChatCompletionMessageParam] diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index 024a5fbd..ddbd65ce 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -1,19 +1,25 @@ -import pathlib from collections.abc import AsyncGenerator -from typing import ( - Any, -) +from typing import Any -from openai import AsyncAzureOpenAI, AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessageParam +from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep -from .postgres_searcher import PostgresSearcher -from .query_rewriter import build_search_function, extract_search_arguments +from fastapi_app.api_models import ( + AIChatRoles, + Message, + RAGContext, + RetrievalResponse, + RetrievalResponseDelta, + ThoughtStep, +) +from fastapi_app.postgres_models import Item +from fastapi_app.postgres_searcher import PostgresSearcher +from fastapi_app.query_rewriter import build_search_function, extract_search_arguments +from fastapi_app.rag_base import ChatParams, RAGChatBase -class AdvancedRAGChat: +class AdvancedRAGChat(RAGChatBase): def __init__( self, *, @@ -27,24 +33,11 @@ def __init__( self.chat_model = chat_model self.chat_deployment = chat_deployment self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True) - current_dir = pathlib.Path(__file__).parent - self.query_prompt_template = open(current_dir / "prompts/query.txt").read() - self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read() - - async def run( - self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {} - ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: - text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] - vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] - top = overrides.get("top", 3) - - original_user_query = messages[-1]["content"] - if not isinstance(original_user_query, str): - raise ValueError("The most recent message content must be a string.") - past_messages = messages[:-1] - - # Generate an optimized keyword search query based on the chat history and the last question - query_response_token_limit = 500 + + async def generate_search_query( + self, original_user_query: str, past_messages: list[ChatCompletionMessageParam], query_response_token_limit: int + ) -> tuple[list[ChatCompletionMessageParam], Any | str | None, list]: + """Generate an optimized keyword search query based on the chat history and the last question""" query_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, system_prompt=self.query_prompt_template, @@ -67,12 +60,23 @@ async def run( query_text, filters = extract_search_arguments(original_user_query, chat_completion) + return query_messages, query_text, filters + + async def prepare_context( + self, chat_params: ChatParams + ) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]: + query_messages, query_text, filters = await self.generate_search_query( + original_user_query=chat_params.original_user_query, + past_messages=chat_params.past_messages, + query_response_token_limit=500, + ) + # Retrieve relevant items from the database with the GPT optimized query results = await self.searcher.search_and_embed( query_text, - top=top, - enable_vector_search=vector_search, - enable_text_search=text_search, + top=chat_params.top, + enable_vector_search=chat_params.enable_vector_search, + enable_text_search=chat_params.enable_text_search, filters=filters, ) @@ -80,55 +84,104 @@ async def run( content = "\n".join(sources_content) # Generate a contextual and content specific answer using the search results and chat history - response_token_limit = 1024 contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, - system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, - new_user_content=original_user_query + "\n\nSources:\n" + content, - past_messages=past_messages, - max_tokens=self.chat_token_limit - response_token_limit, + system_prompt=chat_params.prompt_template, + new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, + past_messages=chat_params.past_messages, + max_tokens=self.chat_token_limit - chat_params.response_token_limit, fallback_to_default=True, ) + thoughts = [ + ThoughtStep( + title="Prompt to generate search arguments", + description=[str(message) for message in query_messages], + props=( + {"model": self.chat_model, "deployment": self.chat_deployment} + if self.chat_deployment + else {"model": self.chat_model} + ), + ), + ThoughtStep( + title="Search using generated search arguments", + description=query_text, + props={ + "top": chat_params.top, + "vector_search": chat_params.enable_vector_search, + "text_search": chat_params.enable_text_search, + "filters": filters, + }, + ), + ThoughtStep( + title="Search results", + description=[result.to_dict() for result in results], + ), + ] + return contextual_messages, results, thoughts + + async def answer( + self, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], + ) -> RetrievalResponse: chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, messages=contextual_messages, - temperature=overrides.get("temperature", 0.3), - max_tokens=response_token_limit, + temperature=chat_params.temperature, + max_tokens=chat_params.response_token_limit, n=1, stream=False, ) - first_choice_message = chat_completion_response.choices[0].message return RetrievalResponse( - message=Message(content=str(first_choice_message.content), role=first_choice_message.role), + message=Message( + content=str(chat_completion_response.choices[0].message.content), role=AIChatRoles.ASSISTANT + ), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, - thoughts=[ + thoughts=earlier_thoughts + + [ ThoughtStep( - title="Prompt to generate search arguments", - description=[str(message) for message in query_messages], + title="Prompt to generate answer", + description=[str(message) for message in contextual_messages], props=( {"model": self.chat_model, "deployment": self.chat_deployment} if self.chat_deployment else {"model": self.chat_model} ), ), - ThoughtStep( - title="Search using generated search arguments", - description=query_text, - props={ - "top": top, - "vector_search": vector_search, - "text_search": text_search, - "filters": filters, - }, - ), - ThoughtStep( - title="Search results", - description=[result.to_dict() for result in results], - ), + ], + ), + ) + + async def answer_stream( + self, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], + ) -> AsyncGenerator[RetrievalResponseDelta, None]: + chat_completion_async_stream: AsyncStream[ + ChatCompletionChunk + ] = await self.openai_chat_client.chat.completions.create( + # Azure OpenAI takes the deployment name as the model name + model=self.chat_deployment if self.chat_deployment else self.chat_model, + messages=contextual_messages, + temperature=chat_params.temperature, + max_tokens=chat_params.response_token_limit, + n=1, + stream=True, + ) + + yield RetrievalResponseDelta( + context=RAGContext( + data_points={item.id: item.to_dict() for item in results}, + thoughts=earlier_thoughts + + [ ThoughtStep( title="Prompt to generate answer", description=[str(message) for message in contextual_messages], @@ -141,3 +194,11 @@ async def run( ], ), ) + + async for response_chunk in chat_completion_async_stream: + # first response has empty choices and last response has empty content + if response_chunk.choices and response_chunk.choices[0].delta.content: + yield RetrievalResponseDelta( + delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT) + ) + return diff --git a/src/backend/fastapi_app/rag_base.py b/src/backend/fastapi_app/rag_base.py new file mode 100644 index 00000000..f7f7bff4 --- /dev/null +++ b/src/backend/fastapi_app/rag_base.py @@ -0,0 +1,73 @@ +import pathlib +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator + +from openai.types.chat import ChatCompletionMessageParam + +from fastapi_app.api_models import ( + ChatParams, + ChatRequestOverrides, + RetrievalResponse, + RetrievalResponseDelta, + ThoughtStep, +) +from fastapi_app.postgres_models import Item + + +class RAGChatBase(ABC): + current_dir = pathlib.Path(__file__).parent + query_prompt_template = open(current_dir / "prompts/query.txt").read() + answer_prompt_template = open(current_dir / "prompts/answer.txt").read() + + def get_params(self, messages: list[ChatCompletionMessageParam], overrides: ChatRequestOverrides) -> ChatParams: + response_token_limit = 1024 + prompt_template = overrides.prompt_template or self.answer_prompt_template + + enable_text_search = overrides.retrieval_mode in ["text", "hybrid", None] + enable_vector_search = overrides.retrieval_mode in ["vectors", "hybrid", None] + + original_user_query = messages[-1]["content"] + if not isinstance(original_user_query, str): + raise ValueError("The most recent message content must be a string.") + past_messages = messages[:-1] + + return ChatParams( + top=overrides.top, + temperature=overrides.temperature, + retrieval_mode=overrides.retrieval_mode, + use_advanced_flow=overrides.use_advanced_flow, + response_token_limit=response_token_limit, + prompt_template=prompt_template, + enable_text_search=enable_text_search, + enable_vector_search=enable_vector_search, + original_user_query=original_user_query, + past_messages=past_messages, + ) + + @abstractmethod + async def prepare_context( + self, chat_params: ChatParams + ) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]: + raise NotImplementedError + + @abstractmethod + async def answer( + self, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], + ) -> RetrievalResponse: + raise NotImplementedError + + @abstractmethod + async def answer_stream( + self, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], + ) -> AsyncGenerator[RetrievalResponseDelta, None]: + raise NotImplementedError + if False: + yield 0 diff --git a/src/backend/fastapi_app/rag_simple.py b/src/backend/fastapi_app/rag_simple.py index f8db974e..2e6d859e 100644 --- a/src/backend/fastapi_app/rag_simple.py +++ b/src/backend/fastapi_app/rag_simple.py @@ -1,16 +1,23 @@ -import pathlib from collections.abc import AsyncGenerator -from typing import Any -from openai import AsyncAzureOpenAI, AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessageParam +from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam from openai_messages_token_helper import build_messages, get_token_limit -from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep -from .postgres_searcher import PostgresSearcher +from fastapi_app.api_models import ( + AIChatRoles, + Message, + RAGContext, + RetrievalResponse, + RetrievalResponseDelta, + ThoughtStep, +) +from fastapi_app.postgres_models import Item +from fastapi_app.postgres_searcher import PostgresSearcher +from fastapi_app.rag_base import ChatParams, RAGChatBase -class SimpleRAGChat: +class SimpleRAGChat(RAGChatBase): def __init__( self, *, @@ -24,69 +31,112 @@ def __init__( self.chat_model = chat_model self.chat_deployment = chat_deployment self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True) - current_dir = pathlib.Path(__file__).parent - self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read() - async def run( - self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {} - ) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]: - text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None] - vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None] - top = overrides.get("top", 3) - - original_user_query = messages[-1]["content"] - if not isinstance(original_user_query, str): - raise ValueError("The most recent message content must be a string.") - past_messages = messages[:-1] + async def prepare_context( + self, chat_params: ChatParams + ) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]: + """Retrieve relevant items from the database and build a context for the chat model.""" # Retrieve relevant items from the database results = await self.searcher.search_and_embed( - original_user_query, top=top, enable_vector_search=vector_search, enable_text_search=text_search + chat_params.original_user_query, + top=chat_params.top, + enable_vector_search=chat_params.enable_vector_search, + enable_text_search=chat_params.enable_text_search, ) sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results] content = "\n".join(sources_content) # Generate a contextual and content specific answer using the search results and chat history - response_token_limit = 1024 contextual_messages: list[ChatCompletionMessageParam] = build_messages( model=self.chat_model, - system_prompt=overrides.get("prompt_template") or self.answer_prompt_template, - new_user_content=original_user_query + "\n\nSources:\n" + content, - past_messages=past_messages, - max_tokens=self.chat_token_limit - response_token_limit, + system_prompt=chat_params.prompt_template, + new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content, + past_messages=chat_params.past_messages, + max_tokens=self.chat_token_limit - chat_params.response_token_limit, fallback_to_default=True, ) + thoughts = [ + ThoughtStep( + title="Search query for database", + description=chat_params.original_user_query, + props={ + "top": chat_params.top, + "vector_search": chat_params.enable_vector_search, + "text_search": chat_params.enable_text_search, + }, + ), + ThoughtStep( + title="Search results", + description=[result.to_dict() for result in results], + ), + ] + return contextual_messages, results, thoughts + + async def answer( + self, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], + ) -> RetrievalResponse: chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create( # Azure OpenAI takes the deployment name as the model name model=self.chat_deployment if self.chat_deployment else self.chat_model, messages=contextual_messages, - temperature=overrides.get("temperature", 0.3), - max_tokens=response_token_limit, + temperature=chat_params.temperature, + max_tokens=chat_params.response_token_limit, n=1, stream=False, ) - first_choice_message = chat_completion_response.choices[0].message return RetrievalResponse( - message=Message(content=str(first_choice_message.content), role=first_choice_message.role), + message=Message( + content=str(chat_completion_response.choices[0].message.content), role=AIChatRoles.ASSISTANT + ), context=RAGContext( data_points={item.id: item.to_dict() for item in results}, - thoughts=[ - ThoughtStep( - title="Search query for database", - description=original_user_query if text_search else None, - props={ - "top": top, - "vector_search": vector_search, - "text_search": text_search, - }, - ), + thoughts=earlier_thoughts + + [ ThoughtStep( - title="Search results", - description=[result.to_dict() for result in results], + title="Prompt to generate answer", + description=[str(message) for message in contextual_messages], + props=( + {"model": self.chat_model, "deployment": self.chat_deployment} + if self.chat_deployment + else {"model": self.chat_model} + ), ), + ], + ), + ) + + async def answer_stream( + self, + chat_params: ChatParams, + contextual_messages: list[ChatCompletionMessageParam], + results: list[Item], + earlier_thoughts: list[ThoughtStep], + ) -> AsyncGenerator[RetrievalResponseDelta, None]: + chat_completion_async_stream: AsyncStream[ + ChatCompletionChunk + ] = await self.openai_chat_client.chat.completions.create( + # Azure OpenAI takes the deployment name as the model name + model=self.chat_deployment if self.chat_deployment else self.chat_model, + messages=contextual_messages, + temperature=chat_params.temperature, + max_tokens=chat_params.response_token_limit, + n=1, + stream=True, + ) + + yield RetrievalResponseDelta( + context=RAGContext( + data_points={item.id: item.to_dict() for item in results}, + thoughts=earlier_thoughts + + [ ThoughtStep( title="Prompt to generate answer", description=[str(message) for message in contextual_messages], @@ -99,3 +149,10 @@ async def run( ], ), ) + async for response_chunk in chat_completion_async_stream: + # first response has empty choices and last response has empty content + if response_chunk.choices and response_chunk.choices[0].delta.content: + yield RetrievalResponseDelta( + delta=Message(content=str(response_chunk.choices[0].delta.content), role=AIChatRoles.ASSISTANT) + ) + return diff --git a/src/backend/fastapi_app/routes/api_routes.py b/src/backend/fastapi_app/routes/api_routes.py index b0f02189..44e08320 100644 --- a/src/backend/fastapi_app/routes/api_routes.py +++ b/src/backend/fastapi_app/routes/api_routes.py @@ -1,8 +1,19 @@ +import json +import logging +from collections.abc import AsyncGenerator + import fastapi from fastapi import HTTPException +from fastapi.responses import StreamingResponse from sqlalchemy import select -from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, RetrievalResponse +from fastapi_app.api_models import ( + ChatRequest, + ItemPublic, + ItemWithDistance, + RetrievalResponse, + RetrievalResponseDelta, +) from fastapi_app.dependencies import ChatClient, CommonDeps, DBSession, EmbeddingsClient from fastapi_app.postgres_models import Item from fastapi_app.postgres_searcher import PostgresSearcher @@ -12,6 +23,18 @@ router = fastapi.APIRouter() +async def format_as_ndjson(r: AsyncGenerator[RetrievalResponseDelta, None]) -> AsyncGenerator[str, None]: + """ + Format the response as NDJSON + """ + try: + async for event in r: + yield event.model_dump_json() + "\n" + except Exception as error: + logging.exception("Exception while generating response stream: %s", error) + yield json.dumps({"error": str(error)}, ensure_ascii=False) + "\n" + + @router.get("/items/{id}", response_model=ItemPublic) async def item_handler(database_session: DBSession, id: int) -> ItemPublic: """A simple API to get an item by ID.""" @@ -70,8 +93,6 @@ async def chat_handler( openai_chat: ChatClient, chat_request: ChatRequest, ): - overrides = chat_request.context.get("overrides", {}) - searcher = PostgresSearcher( db_session=database_session, openai_embed_client=openai_embed.client, @@ -79,20 +100,70 @@ async def chat_handler( embed_model=context.openai_embed_model, embed_dimensions=context.openai_embed_dimensions, ) - if overrides.get("use_advanced_flow"): - run_ragchat = AdvancedRAGChat( + rag_flow: SimpleRAGChat | AdvancedRAGChat + if chat_request.context.overrides.use_advanced_flow: + rag_flow = AdvancedRAGChat( searcher=searcher, openai_chat_client=openai_chat.client, chat_model=context.openai_chat_model, chat_deployment=context.openai_chat_deployment, - ).run + ) else: - run_ragchat = SimpleRAGChat( + rag_flow = SimpleRAGChat( searcher=searcher, openai_chat_client=openai_chat.client, chat_model=context.openai_chat_model, chat_deployment=context.openai_chat_deployment, - ).run + ) + + chat_params = rag_flow.get_params(chat_request.messages, chat_request.context.overrides) - response = await run_ragchat(chat_request.messages, overrides=overrides) + contextual_messages, results, thoughts = await rag_flow.prepare_context(chat_params) + response = await rag_flow.answer( + chat_params=chat_params, contextual_messages=contextual_messages, results=results, earlier_thoughts=thoughts + ) return response + + +@router.post("/chat/stream") +async def chat_stream_handler( + context: CommonDeps, + database_session: DBSession, + openai_embed: EmbeddingsClient, + openai_chat: ChatClient, + chat_request: ChatRequest, +): + searcher = PostgresSearcher( + db_session=database_session, + openai_embed_client=openai_embed.client, + embed_deployment=context.openai_embed_deployment, + embed_model=context.openai_embed_model, + embed_dimensions=context.openai_embed_dimensions, + ) + + rag_flow: SimpleRAGChat | AdvancedRAGChat + if chat_request.context.overrides.use_advanced_flow: + rag_flow = AdvancedRAGChat( + searcher=searcher, + openai_chat_client=openai_chat.client, + chat_model=context.openai_chat_model, + chat_deployment=context.openai_chat_deployment, + ) + else: + rag_flow = SimpleRAGChat( + searcher=searcher, + openai_chat_client=openai_chat.client, + chat_model=context.openai_chat_model, + chat_deployment=context.openai_chat_deployment, + ) + + chat_params = rag_flow.get_params(chat_request.messages, chat_request.context.overrides) + + # Intentionally do this before we stream down a response, to avoid using database connections during stream + # See https://github.com/tiangolo/fastapi/discussions/11321 + contextual_messages, results, thoughts = await rag_flow.prepare_context(chat_params) + + result = rag_flow.answer_stream( + chat_params=chat_params, contextual_messages=contextual_messages, results=results, earlier_thoughts=thoughts + ) + return StreamingResponse(content=format_as_ndjson(result), media_type="application/x-ndjson") diff --git a/src/frontend/src/api/models.ts b/src/frontend/src/api/models.ts index deee7b68..4e9c3e26 100644 --- a/src/frontend/src/api/models.ts +++ b/src/frontend/src/api/models.ts @@ -1,4 +1,4 @@ -import { AIChatCompletion } from "@microsoft/ai-chat-protocol"; +import { AIChatCompletion, AIChatCompletionDelta, AIChatCompletionOperationOptions } from "@microsoft/ai-chat-protocol"; export const enum RetrievalMode { Hybrid = "hybrid", @@ -14,6 +14,14 @@ export type ChatAppRequestOverrides = { prompt_template?: string; }; +export type ChatAppRequestContext = { + overrides: ChatAppRequestOverrides; +}; + +export interface ChatAppRequestOptions extends AIChatCompletionOperationOptions { + context: ChatAppRequestContext +} + export type Thoughts = { title: string; description: any; // It can be any output from the api @@ -29,3 +37,7 @@ export type RAGContext = { export interface RAGChatCompletion extends AIChatCompletion { context: RAGContext; } + +export interface RAGChatCompletionDelta extends AIChatCompletionDelta { + context: RAGContext; +} diff --git a/src/frontend/src/pages/chat/Chat.tsx b/src/frontend/src/pages/chat/Chat.tsx index 6918cf76..da0b6934 100644 --- a/src/frontend/src/pages/chat/Chat.tsx +++ b/src/frontend/src/pages/chat/Chat.tsx @@ -1,11 +1,11 @@ import { useRef, useState, useEffect } from "react"; import { Panel, DefaultButton, TextField, SpinButton, Slider, Checkbox } from "@fluentui/react"; import { SparkleFilled } from "@fluentui/react-icons"; -import { AIChatMessage, AIChatProtocolClient } from "@microsoft/ai-chat-protocol"; import styles from "./Chat.module.css"; -import {RetrievalMode, RAGChatCompletion} from "../../api"; +import { RetrievalMode, RAGChatCompletion, RAGChatCompletionDelta, ChatAppRequestOptions } from "../../api"; +import { AIChatProtocolClient, AIChatMessage } from "@microsoft/ai-chat-protocol"; import { Answer, AnswerError, AnswerLoading } from "../../components/Answer"; import { QuestionInput } from "../../components/QuestionInput"; import { ExampleList } from "../../components/Example"; @@ -22,11 +22,13 @@ const Chat = () => { const [retrieveCount, setRetrieveCount] = useState(3); const [retrievalMode, setRetrievalMode] = useState(RetrievalMode.Hybrid); const [useAdvancedFlow, setUseAdvancedFlow] = useState(true); + const [shouldStream, setShouldStream] = useState(true); const lastQuestionRef = useRef(""); const chatMessageStreamEnd = useRef(null); const [isLoading, setIsLoading] = useState(false); + const [isStreaming, setIsStreaming] = useState(false); const [error, setError] = useState(); const [activeCitation, setActiveCitation] = useState(); @@ -34,7 +36,55 @@ const Chat = () => { const [selectedAnswer, setSelectedAnswer] = useState(0); const [answers, setAnswers] = useState<[user: string, response: RAGChatCompletion][]>([]); - + const [streamedAnswers, setStreamedAnswers] = useState<[user: string, response: RAGChatCompletion][]>([]); + + const handleAsyncRequest = async (question: string, answers: [string, RAGChatCompletion][], result: AsyncIterable) => { + let answer = ""; + let chatCompletion: RAGChatCompletion = { + context: { + data_points: {}, + followup_questions: null, + thoughts: [] + }, + message: { content: "", role: "assistant" } + }; + const updateState = (newContent: string) => { + return new Promise(resolve => { + setTimeout(() => { + answer += newContent; + // We need to create a new object to trigger a re-render + const latestCompletion: RAGChatCompletion = { + ...chatCompletion, + message: { content: answer, role: chatCompletion.message.role } + }; + setStreamedAnswers([...answers, [question, latestCompletion]]); + resolve(null); + }, 33); + }); + }; + try { + setIsStreaming(true); + for await (const response of result) { + if (response.context) { + chatCompletion.context = { + ...chatCompletion.context, + ...response.context + }; + } + if (response.delta && response.delta.role) { + chatCompletion.message.role = response.delta.role; + } + if (response.delta && response.delta.content) { + setIsLoading(false); + await updateState(response.delta.content); + } + } + } finally { + setIsStreaming(false); + } + chatCompletion.message.content = answer; + return chatCompletion; + }; const makeApiRequest = async (question: string) => { lastQuestionRef.current = question; @@ -49,7 +99,7 @@ const Chat = () => { { content: answer[1].message.content, role: "assistant" } ]); const allMessages: AIChatMessage[] = [...messages, { content: question, role: "user" }]; - const options = { + const options: ChatAppRequestOptions = { context: { overrides: { use_advanced_flow: useAdvancedFlow, @@ -61,8 +111,14 @@ const Chat = () => { } }; const chatClient: AIChatProtocolClient = new AIChatProtocolClient("/chat"); - const result = await chatClient.getCompletion(allMessages, options) as RAGChatCompletion; - setAnswers([...answers, [question, result]]); + if (shouldStream) { + const result = (await chatClient.getStreamedCompletion(allMessages, options)) as AsyncIterable; + const parsedResponse = await handleAsyncRequest(question, answers, result); + setAnswers([...answers, [question, parsedResponse]]); + } else { + const result = (await chatClient.getCompletion(allMessages, options)) as RAGChatCompletion; + setAnswers([...answers, [question, result]]); + } } catch (e) { setError(e); } finally { @@ -76,10 +132,13 @@ const Chat = () => { setActiveCitation(undefined); setActiveAnalysisPanelTab(undefined); setAnswers([]); + setStreamedAnswers([]); setIsLoading(false); + setIsStreaming(false); }; useEffect(() => chatMessageStreamEnd.current?.scrollIntoView({ behavior: "smooth" }), [isLoading]); + useEffect(() => chatMessageStreamEnd.current?.scrollIntoView({ behavior: "auto" }), [streamedAnswers]); const onPromptTemplateChange = (_ev?: React.FormEvent, newValue?: string) => { setPromptTemplate(newValue || ""); @@ -99,7 +158,11 @@ const Chat = () => { const onUseAdvancedFlowChange = (_ev?: React.FormEvent, checked?: boolean) => { setUseAdvancedFlow(!!checked); - } + }; + + const onShouldStreamChange = (_ev?: React.FormEvent, checked?: boolean) => { + setShouldStream(!!checked); + }; const onExampleClicked = (example: string) => { makeApiRequest(example); @@ -143,7 +206,26 @@ const Chat = () => { ) : (
- {answers.map((answer, index) => ( + {isStreaming && + streamedAnswers.map((streamedAnswer, index) => ( +
+ +
+ onShowCitation(c, index)} + onThoughtProcessClicked={() => onToggleTab(AnalysisPanelTabs.ThoughtProcessTab, index)} + onSupportingContentClicked={() => onToggleTab(AnalysisPanelTabs.SupportingContentTab, index)} + onFollowupQuestionClicked={q => makeApiRequest(q)} + /> +
+
+ ))} + {!isStreaming && + answers.map((answer, index) => (
@@ -210,7 +292,6 @@ const Chat = () => { onRenderFooterContent={() => setIsConfigPanelOpen(false)}>Close} isFooterAtBottom={true} > - { onChange={onRetrieveCountChange} /> - setRetrievalMode(retrievalMode)} - /> - + setRetrievalMode(retrievalMode)} />

Settings for final chat completion:

@@ -257,6 +335,12 @@ const Chat = () => { snapToStep /> +
diff --git a/tests/snapshots/test_api_routes/test_advanced_chat_flow/advanced_chat_flow_response.json b/tests/snapshots/test_api_routes/test_advanced_chat_flow/advanced_chat_flow_response.json new file mode 100644 index 00000000..2e9eb3ae --- /dev/null +++ b/tests/snapshots/test_api_routes/test_advanced_chat_flow/advanced_chat_flow_response.json @@ -0,0 +1,68 @@ +{ + "message": { + "content": "The capital of France is Paris. [Benefit_Options-2.pdf].", + "role": "assistant" + }, + "context": { + "data_points": { + "1": { + "id": 1, + "type": "Footwear", + "brand": "Daybird", + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.", + "price": 109.99 + } + }, + "thoughts": [ + { + "title": "Prompt to generate search arguments", + "description": [ + "{'role': 'system', 'content': 'Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.\\nYou have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.\\nGenerate a search query based on the conversation and the new question.\\nIf the question is not in English, translate the question to English before generating the search query.\\nIf you cannot generate a search query, return the original user question.\\nDO NOT return anything besides the query.'}", + "{'role': 'user', 'content': 'What is the capital of France?'}" + ], + "props": { + "model": "gpt-35-turbo", + "deployment": "gpt-35-turbo" + } + }, + { + "title": "Search using generated search arguments", + "description": "The capital of France is Paris. [Benefit_Options-2.pdf].", + "props": { + "top": 1, + "vector_search": true, + "text_search": true, + "filters": [] + } + }, + { + "title": "Search results", + "description": [ + { + "id": 1, + "type": "Footwear", + "brand": "Daybird", + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.", + "price": 109.99 + } + ], + "props": {} + }, + { + "title": "Prompt to generate answer", + "description": [ + "{'role': 'system', 'content': \"Assistant helps customers with questions about products.\\nRespond as if you are a salesperson helping a customer in a store. Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the products.\\nIf there isn't enough information below, say you don't know.\\nDo not generate answers that don't use the sources below.\\nEach product has an ID in brackets followed by colon and the product details.\\nAlways include the product ID for each product you use in the response.\\nUse square brackets to reference the source, for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", + "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird Type:Footwear\\n\\n\"}" + ], + "props": { + "model": "gpt-35-turbo", + "deployment": "gpt-35-turbo" + } + } + ], + "followup_questions": null + }, + "session_state": null +} \ No newline at end of file diff --git a/tests/snapshots/test_api_routes/test_advanced_chat_streaming_flow/advanced_chat_streaming_flow_response.jsonlines b/tests/snapshots/test_api_routes/test_advanced_chat_streaming_flow/advanced_chat_streaming_flow_response.jsonlines new file mode 100644 index 00000000..8b65342f --- /dev/null +++ b/tests/snapshots/test_api_routes/test_advanced_chat_streaming_flow/advanced_chat_streaming_flow_response.jsonlines @@ -0,0 +1,2 @@ +{"delta":null,"context":{"data_points":{"1":{"id":1,"type":"Footwear","brand":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.","price":109.99}},"thoughts":[{"title":"Prompt to generate search arguments","description":["{'role': 'system', 'content': 'Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.\\nYou have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.\\nGenerate a search query based on the conversation and the new question.\\nIf the question is not in English, translate the question to English before generating the search query.\\nIf you cannot generate a search query, return the original user question.\\nDO NOT return anything besides the query.'}","{'role': 'user', 'content': 'What is the capital of France?'}"],"props":{"model":"gpt-35-turbo","deployment":"gpt-35-turbo"}},{"title":"Search using generated search arguments","description":"The capital of France is Paris. [Benefit_Options-2.pdf].","props":{"top":1,"vector_search":true,"text_search":true,"filters":[]}},{"title":"Search results","description":[{"id":1,"type":"Footwear","brand":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.","price":109.99}],"props":{}},{"title":"Prompt to generate answer","description":["{'role': 'system', 'content': \"Assistant helps customers with questions about products.\\nRespond as if you are a salesperson helping a customer in a store. Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the products.\\nIf there isn't enough information below, say you don't know.\\nDo not generate answers that don't use the sources below.\\nEach product has an ID in brackets followed by colon and the product details.\\nAlways include the product ID for each product you use in the response.\\nUse square brackets to reference the source, for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}","{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird Type:Footwear\\n\\n\"}"],"props":{"model":"gpt-35-turbo","deployment":"gpt-35-turbo"}}],"followup_questions":null},"session_state":null} +{"delta":{"content":"The capital of France is Paris. [Benefit_Options-2.pdf].","role":"assistant"},"context":null,"session_state":null} diff --git a/tests/snapshots/test_api_routes/test_simple_chat_flow/simple_chat_flow_response.json b/tests/snapshots/test_api_routes/test_simple_chat_flow/simple_chat_flow_response.json new file mode 100644 index 00000000..d5ecba21 --- /dev/null +++ b/tests/snapshots/test_api_routes/test_simple_chat_flow/simple_chat_flow_response.json @@ -0,0 +1,56 @@ +{ + "message": { + "content": "The capital of France is Paris. [Benefit_Options-2.pdf].", + "role": "assistant" + }, + "context": { + "data_points": { + "1": { + "id": 1, + "type": "Footwear", + "brand": "Daybird", + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.", + "price": 109.99 + } + }, + "thoughts": [ + { + "title": "Search query for database", + "description": "What is the capital of France?", + "props": { + "top": 1, + "vector_search": true, + "text_search": true + } + }, + { + "title": "Search results", + "description": [ + { + "id": 1, + "type": "Footwear", + "brand": "Daybird", + "name": "Wanderer Black Hiking Boots", + "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.", + "price": 109.99 + } + ], + "props": {} + }, + { + "title": "Prompt to generate answer", + "description": [ + "{'role': 'system', 'content': \"Assistant helps customers with questions about products.\\nRespond as if you are a salesperson helping a customer in a store. Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the products.\\nIf there isn't enough information below, say you don't know.\\nDo not generate answers that don't use the sources below.\\nEach product has an ID in brackets followed by colon and the product details.\\nAlways include the product ID for each product you use in the response.\\nUse square brackets to reference the source, for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", + "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird Type:Footwear\\n\\n\"}" + ], + "props": { + "model": "gpt-35-turbo", + "deployment": "gpt-35-turbo" + } + } + ], + "followup_questions": null + }, + "session_state": null +} \ No newline at end of file diff --git a/tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.jsonlines b/tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.jsonlines new file mode 100644 index 00000000..6251bd52 --- /dev/null +++ b/tests/snapshots/test_api_routes/test_simple_chat_streaming_flow/simple_chat_streaming_flow_response.jsonlines @@ -0,0 +1,2 @@ +{"delta":null,"context":{"data_points":{"1":{"id":1,"type":"Footwear","brand":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.","price":109.99}},"thoughts":[{"title":"Search query for database","description":"What is the capital of France?","props":{"top":1,"vector_search":true,"text_search":true}},{"title":"Search results","description":[{"id":1,"type":"Footwear","brand":"Daybird","name":"Wanderer Black Hiking Boots","description":"Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long.","price":109.99}],"props":{}},{"title":"Prompt to generate answer","description":["{'role': 'system', 'content': \"Assistant helps customers with questions about products.\\nRespond as if you are a salesperson helping a customer in a store. Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the products.\\nIf there isn't enough information below, say you don't know.\\nDo not generate answers that don't use the sources below.\\nEach product has an ID in brackets followed by colon and the product details.\\nAlways include the product ID for each product you use in the response.\\nUse square brackets to reference the source, for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}","{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for all your outdoor adventures. These boots are made with a waterproof leather upper and a durable rubber sole for superior traction. With their cushioned insole and padded collar, these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird Type:Footwear\\n\\n\"}"],"props":{"model":"gpt-35-turbo","deployment":"gpt-35-turbo"}}],"followup_questions":null},"session_state":null} +{"delta":{"content":"The capital of France is Paris. [Benefit_Options-2.pdf].","role":"assistant"},"context":null,"session_state":null} diff --git a/tests/test_api_routes.py b/tests/test_api_routes.py index 98e3ceeb..cd221fbb 100644 --- a/tests/test_api_routes.py +++ b/tests/test_api_routes.py @@ -1,3 +1,5 @@ +import json + import pytest from tests.data import test_data @@ -105,7 +107,7 @@ async def test_search_handler_422(test_client): @pytest.mark.asyncio -async def test_simple_chat_flow(test_client): +async def test_simple_chat_flow(test_client, snapshot): """test the simple chat flow route with hybrid retrieval mode""" response = test_client.post( "/chat", @@ -120,115 +122,29 @@ async def test_simple_chat_flow(test_client): assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json" - assert response_data["message"]["content"] == "The capital of France is Paris. [Benefit_Options-2.pdf]." - assert response_data["message"]["role"] == "assistant" - assert response_data["context"]["data_points"] == { - "1": { - "id": 1, - "name": "Wanderer Black Hiking Boots", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " - "your outdoor adventures. These boots are made with a waterproof " - "leather upper and a durable rubber sole for superior traction. With " - "their cushioned insole and padded collar, these boots will keep you " - "comfortable all day long.", - "brand": "Daybird", - "price": 109.99, - "type": "Footwear", - } - } - assert response_data["context"]["thoughts"] == [ - { - "description": "What is the capital of France?", - "props": {"text_search": True, "top": 1, "vector_search": True}, - "title": "Search query for database", - }, - { - "description": [ - { - "brand": "Daybird", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your " - "outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long.", - "id": 1, - "name": "Wanderer Black Hiking Boots", - "price": 109.99, - "type": "Footwear", - }, - ], - "props": {}, - "title": "Search results", - }, - { - "description": [ - "{'role': 'system', 'content': \"Assistant helps customers with questions about " - "products.\\nRespond as if you are a salesperson helping a customer in a store. " - "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " - "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " - "generate answers that don't use the sources below.\\nEach product has an ID in brackets " - "followed by colon and the product details.\\nAlways include the product ID for each product " - "you use in the response.\\nUse square brackets to reference the source, " - "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", - "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " - "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " - "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " - 'Type:Footwear\\n\\n"}', - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate answer", - }, - ] - assert response_data["context"]["thoughts"] == [ - { - "description": "What is the capital of France?", - "props": {"text_search": True, "top": 1, "vector_search": True}, - "title": "Search query for database", - }, - { - "description": [ - { - "brand": "Daybird", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " - "your outdoor adventures. These boots are made with a waterproof leather upper and " - "a durable rubber sole for superior traction. With their cushioned insole and padded " - "collar, these boots will keep you comfortable all day long.", - "id": 1, - "name": "Wanderer Black Hiking Boots", - "price": 109.99, - "type": "Footwear", - } - ], - "props": {}, - "title": "Search results", - }, - { - "description": [ - "{'role': 'system', 'content': \"Assistant helps customers with questions about " - "products.\\nRespond as if you are a salesperson helping a customer in a store. " - "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " - "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " - "generate answers that don't use the sources below.\\nEach product has an ID in brackets " - "followed by colon and the product details.\\nAlways include the product ID for each product " - "you use in the response.\\nUse square brackets to reference the source, " - "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", - "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " - "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " - "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " - 'Type:Footwear\\n\\n"}', - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate answer", + snapshot.assert_match(json.dumps(response_data, indent=4), "simple_chat_flow_response.json") + + +@pytest.mark.asyncio +async def test_simple_chat_streaming_flow(test_client, snapshot): + """test the simple chat streaming flow route with hybrid retrieval mode""" + response = test_client.post( + "/chat/stream", + json={ + "context": { + "overrides": {"top": 1, "use_advanced_flow": False, "retrieval_mode": "hybrid", "temperature": 0.3} + }, + "messages": [{"content": "What is the capital of France?", "role": "user"}], }, - ] - assert response_data["session_state"] is None + ) + response_data = response.content + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/x-ndjson" + snapshot.assert_match(response_data, "simple_chat_streaming_flow_response.jsonlines") @pytest.mark.asyncio -async def test_advanced_chat_flow(test_client): +async def test_advanced_chat_flow(test_client, snapshot): """test the advanced chat flow route with hybrid retrieval mode""" response = test_client.post( "/chat", @@ -243,145 +159,25 @@ async def test_advanced_chat_flow(test_client): assert response.status_code == 200 assert response.headers["Content-Type"] == "application/json" - assert response_data["message"]["content"] == "The capital of France is Paris. [Benefit_Options-2.pdf]." - assert response_data["message"]["role"] == "assistant" - assert response_data["context"]["data_points"] == { - "1": { - "id": 1, - "name": "Wanderer Black Hiking Boots", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " - "your outdoor adventures. These boots are made with a waterproof " - "leather upper and a durable rubber sole for superior traction. With " - "their cushioned insole and padded collar, these boots will keep you " - "comfortable all day long.", - "brand": "Daybird", - "price": 109.99, - "type": "Footwear", - } - } - assert response_data["context"]["thoughts"] == [ - { - "description": [ - "{'role': 'system', 'content': 'Below is a history of the " - "conversation so far, and a new question asked by the user that " - "needs to be answered by searching database rows.\\nYou have " - "access to an Azure PostgreSQL database with an items table that " - "has columns for title, description, brand, price, and " - "type.\\nGenerate a search query based on the conversation and the " - "new question.\\nIf the question is not in English, translate the " - "question to English before generating the search query.\\nIf you " - "cannot generate a search query, return the original user " - "question.\\nDO NOT return anything besides the query.'}", - "{'role': 'user', 'content': 'What is the capital of France?'}", - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate search arguments", - }, - { - "description": "The capital of France is Paris. [Benefit_Options-2.pdf].", - "props": {"filters": [], "text_search": True, "top": 1, "vector_search": True}, - "title": "Search using generated search arguments", - }, - { - "description": [ - { - "brand": "Daybird", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all your " - "outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long.", - "id": 1, - "name": "Wanderer Black Hiking Boots", - "price": 109.99, - "type": "Footwear", - }, - ], - "props": {}, - "title": "Search results", - }, - { - "description": [ - "{'role': 'system', 'content': \"Assistant helps customers with questions about " - "products.\\nRespond as if you are a salesperson helping a customer in a store. " - "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " - "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " - "generate answers that don't use the sources below.\\nEach product has an ID in brackets " - "followed by colon and the product details.\\nAlways include the product ID for each product " - "you use in the response.\\nUse square brackets to reference the source, " - "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", - "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " - "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " - "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " - 'Type:Footwear\\n\\n"}', - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate answer", - }, - ] - assert response_data["context"]["thoughts"] == [ - { - "description": [ - "{'role': 'system', 'content': 'Below is a history of the " - "conversation so far, and a new question asked by the user that " - "needs to be answered by searching database rows.\\nYou have " - "access to an Azure PostgreSQL database with an items table that " - "has columns for title, description, brand, price, and " - "type.\\nGenerate a search query based on the conversation and the " - "new question.\\nIf the question is not in English, translate the " - "question to English before generating the search query.\\nIf you " - "cannot generate a search query, return the original user " - "question.\\nDO NOT return anything besides the query.'}", - "{'role': 'user', 'content': 'What is the capital of France?'}", - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate search arguments", - }, - { - "description": "The capital of France is Paris. [Benefit_Options-2.pdf].", - "props": {"filters": [], "text_search": True, "top": 1, "vector_search": True}, - "title": "Search using generated search arguments", - }, - { - "description": [ - { - "brand": "Daybird", - "description": "Daybird's Wanderer Hiking Boots in sleek black are perfect for all " - "your outdoor adventures. These boots are made with a waterproof leather upper and " - "a durable rubber sole for superior traction. With their cushioned insole and padded " - "collar, these boots will keep you comfortable all day long.", - "id": 1, - "name": "Wanderer Black Hiking Boots", - "price": 109.99, - "type": "Footwear", - } - ], - "props": {}, - "title": "Search results", - }, - { - "description": [ - "{'role': 'system', 'content': \"Assistant helps customers with questions about " - "products.\\nRespond as if you are a salesperson helping a customer in a store. " - "Do NOT respond with tables.\\nAnswer ONLY with the product details listed in the " - "products.\\nIf there isn't enough information below, say you don't know.\\nDo not " - "generate answers that don't use the sources below.\\nEach product has an ID in brackets " - "followed by colon and the product details.\\nAlways include the product ID for each product " - "you use in the response.\\nUse square brackets to reference the source, " - "for example [52].\\nDon't combine citations, list each product separately, for example [27][51].\"}", - "{'role': 'user', 'content': \"What is the capital of France?\\n\\nSources:\\n[1]:Name:Wanderer " - "Black Hiking Boots Description:Daybird's Wanderer Hiking Boots in sleek black are perfect for " - "all your outdoor adventures. These boots are made with a waterproof leather upper and a durable " - "rubber sole for superior traction. With their cushioned insole and padded collar, " - "these boots will keep you comfortable all day long. Price:109.99 Brand:Daybird " - 'Type:Footwear\\n\\n"}', - ], - "props": {"deployment": "gpt-35-turbo", "model": "gpt-35-turbo"}, - "title": "Prompt to generate answer", + snapshot.assert_match(json.dumps(response_data, indent=4), "advanced_chat_flow_response.json") + + +@pytest.mark.asyncio +async def test_advanced_chat_streaming_flow(test_client, snapshot): + """test the advanced chat streaming flow route with hybrid retrieval mode""" + response = test_client.post( + "/chat/stream", + json={ + "context": { + "overrides": {"top": 1, "use_advanced_flow": True, "retrieval_mode": "hybrid", "temperature": 0.3} + }, + "messages": [{"content": "What is the capital of France?", "role": "user"}], }, - ] - assert response_data["session_state"] is None + ) + response_data = response.content + assert response.status_code == 200 + assert response.headers["Content-Type"] == "application/x-ndjson" + snapshot.assert_match(response_data, "advanced_chat_streaming_flow_response.jsonlines") @pytest.mark.asyncio