diff --git a/src/fastapi_app/query_rewriter.py b/src/fastapi_app/query_rewriter.py index 9cf4fffe..6274ca50 100644 --- a/src/fastapi_app/query_rewriter.py +++ b/src/fastapi_app/query_rewriter.py @@ -56,7 +56,7 @@ def build_search_function() -> list[ChatCompletionToolParam]: ] -def extract_search_arguments(chat_completion: ChatCompletion): +def extract_search_arguments(original_user_query: str, chat_completion: ChatCompletion): response_message = chat_completion.choices[0].message search_query = None filters = [] @@ -67,7 +67,8 @@ def extract_search_arguments(chat_completion: ChatCompletion): function = tool.function if function.name == "search_database": arg = json.loads(function.arguments) - search_query = arg.get("search_query") + # Even though its required, search_query is not always specified + search_query = arg.get("search_query", original_user_query) if "price_filter" in arg and arg["price_filter"]: price_filter = arg["price_filter"] filters.append( diff --git a/src/fastapi_app/rag_advanced.py b/src/fastapi_app/rag_advanced.py index 00e96bfd..d603d997 100644 --- a/src/fastapi_app/rag_advanced.py +++ b/src/fastapi_app/rag_advanced.py @@ -65,7 +65,7 @@ async def run( tool_choice="auto", ) - query_text, filters = extract_search_arguments(chat_completion) + query_text, filters = extract_search_arguments(original_user_query, chat_completion) # Retrieve relevant items from the database with the GPT optimized query results = await self.searcher.search_and_embed(