Skip to content

Port to Pydantic-AI #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ pytest-snapshot
locust
psycopg2
dotenv-azd
freezegun
9 changes: 8 additions & 1 deletion src/backend/fastapi_app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ class State(TypedDict):
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
context = await common_parameters()
azure_credential = await get_azure_credential()
azure_credential = None
if (
os.getenv("OPENAI_CHAT_HOST") == "azure"
or os.getenv("OPENAI_EMBED_HOST") == "azure"
or os.getenv("POSTGRES_HOST", "").endswith(".database.azure.com")
):
azure_credential = await get_azure_credential()
engine = await create_postgres_engine_from_env(azure_credential)
sessionmaker = await create_async_sessionmaker(engine)
chat_client = await create_openai_chat_client(azure_credential)
Expand All @@ -53,6 +59,7 @@ def create_app(testing: bool = False):
if not testing:
load_dotenv(override=True)
logging.basicConfig(level=logging.INFO)

# Turn off particularly noisy INFO level logs from Azure Core SDK:
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
logging.getLogger("azure.identity").setLevel(logging.WARNING)
Expand Down
75 changes: 54 additions & 21 deletions src/backend/fastapi_app/api_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from enum import Enum
from typing import Any, Optional
from typing import Any, Optional, Union

from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel
from pydantic import BaseModel, Field
from pydantic_ai.messages import ModelRequest, ModelResponse


class AIChatRoles(str, Enum):
Expand Down Expand Up @@ -41,14 +42,34 @@ class ChatRequest(BaseModel):
sessionState: Optional[Any] = None


class ItemPublic(BaseModel):
id: int
type: str
brand: str
name: str
description: str
price: float

def to_str_for_rag(self):
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"


class ItemWithDistance(ItemPublic):
distance: float

def __init__(self, **data):
super().__init__(**data)
self.distance = round(self.distance, 2)


class ThoughtStep(BaseModel):
title: str
description: Any
props: dict = {}


class RAGContext(BaseModel):
data_points: dict[int, dict[str, Any]]
data_points: dict[int, ItemPublic]
thoughts: list[ThoughtStep]
followup_questions: Optional[list[str]] = None

Expand All @@ -69,27 +90,39 @@ class RetrievalResponseDelta(BaseModel):
sessionState: Optional[Any] = None


class ItemPublic(BaseModel):
id: int
type: str
brand: str
name: str
description: str
price: float


class ItemWithDistance(ItemPublic):
distance: float

def __init__(self, **data):
super().__init__(**data)
self.distance = round(self.distance, 2)


class ChatParams(ChatRequestOverrides):
prompt_template: str
response_token_limit: int = 1024
enable_text_search: bool
enable_vector_search: bool
original_user_query: str
past_messages: list[ChatCompletionMessageParam]
past_messages: list[Union[ModelRequest, ModelResponse]]


class Filter(BaseModel):
column: str
comparison_operator: str
value: Any


class PriceFilter(Filter):
column: str = Field(default="price", description="The column to filter on (always 'price' for this filter)")
comparison_operator: str = Field(description="The operator for price comparison ('>', '<', '>=', '<=', '=')")
value: float = Field(description="The price value to compare against (e.g., 30.00)")


class BrandFilter(Filter):
column: str = Field(default="brand", description="The column to filter on (always 'brand' for this filter)")
comparison_operator: str = Field(description="The operator for brand comparison ('=' or '!=')")
value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')")


class SearchResults(BaseModel):
query: str
"""The original search query"""

items: list[ItemPublic]
"""List of items that match the search query and filters"""

filters: list[Filter]
"""List of filters applied to the search results"""
14 changes: 9 additions & 5 deletions src/backend/fastapi_app/openai_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@


async def create_openai_chat_client(
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
openai_chat_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
if OPENAI_CHAT_HOST == "azure":
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-03-01-preview"
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-10-21"
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
azure_deployment = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
if api_key := os.getenv("AZURE_OPENAI_KEY"):
Expand All @@ -29,7 +29,7 @@ async def create_openai_chat_client(
azure_deployment=azure_deployment,
api_key=api_key,
)
else:
elif azure_credential:
logger.info(
"Setting up Azure OpenAI client for chat completions using Azure Identity, endpoint %s, deployment %s",
azure_endpoint,
Expand All @@ -44,6 +44,8 @@ async def create_openai_chat_client(
azure_deployment=azure_deployment,
azure_ad_token_provider=token_provider,
)
else:
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
elif OPENAI_CHAT_HOST == "ollama":
logger.info("Setting up OpenAI client for chat completions using Ollama")
openai_chat_client = openai.AsyncOpenAI(
Expand All @@ -67,7 +69,7 @@ async def create_openai_chat_client(


async def create_openai_embed_client(
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
openai_embed_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
Expand All @@ -87,7 +89,7 @@ async def create_openai_embed_client(
azure_deployment=azure_deployment,
api_key=api_key,
)
else:
elif azure_credential:
logger.info(
"Setting up Azure OpenAI client for embeddings using Azure Identity, endpoint %s, deployment %s",
azure_endpoint,
Expand All @@ -102,6 +104,8 @@ async def create_openai_embed_client(
azure_deployment=azure_deployment,
azure_ad_token_provider=token_provider,
)
else:
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
elif OPENAI_EMBED_HOST == "ollama":
logger.info("Setting up OpenAI client for embeddings using Ollama")
openai_embed_client = openai.AsyncOpenAI(
Expand Down
16 changes: 10 additions & 6 deletions src/backend/fastapi_app/postgres_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy import Float, Integer, column, select, text
from sqlalchemy.ext.asyncio import AsyncSession

from fastapi_app.api_models import Filter
from fastapi_app.embeddings import compute_text_embedding
from fastapi_app.postgres_models import Item

Expand All @@ -26,21 +27,24 @@ def __init__(
self.embed_dimensions = embed_dimensions
self.embedding_column = embedding_column

def build_filter_clause(self, filters) -> tuple[str, str]:
def build_filter_clause(self, filters: Optional[list[Filter]]) -> tuple[str, str]:
if filters is None:
return "", ""
filter_clauses = []
for filter in filters:
if isinstance(filter["value"], str):
filter["value"] = f"'{filter['value']}'"
filter_clauses.append(f"{filter['column']} {filter['comparison_operator']} {filter['value']}")
filter_value = f"'{filter.value}'" if isinstance(filter.value, str) else filter.value
filter_clauses.append(f"{filter.column} {filter.comparison_operator} {filter_value}")
filter_clause = " AND ".join(filter_clauses)
if len(filter_clause) > 0:
return f"WHERE {filter_clause}", f"AND {filter_clause}"
return "", ""

async def search(
self, query_text: Optional[str], query_vector: list[float], top: int = 5, filters: Optional[list[dict]] = None
self,
query_text: Optional[str],
query_vector: list[float],
top: int = 5,
filters: Optional[list[Filter]] = None,
):
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
table_name = Item.__tablename__
Expand Down Expand Up @@ -106,7 +110,7 @@ async def search_and_embed(
top: int = 5,
enable_vector_search: bool = False,
enable_text_search: bool = False,
filters: Optional[list[dict]] = None,
filters: Optional[list[Filter]] = None,
) -> list[Item]:
"""
Search rows by query text. Optionally converts the query text to a vector if enable_vector_search is True.
Expand Down
11 changes: 5 additions & 6 deletions src/backend/fastapi_app/prompts/query.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.
You have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.
Generate a search query based on the conversation and the new question.
If the question is not in English, translate the question to English before generating the search query.
If you cannot generate a search query, return the original user question.
DO NOT return anything besides the query.
Your job is to find search results based off the user's question and past messages.
You have access to only these tools:
1. **search_database**: This tool allows you to search a table for items based on a query.
You can pass in a search query and optional filters.
Once you get the search results, you're done.
106 changes: 74 additions & 32 deletions src/backend/fastapi_app/prompts/query_fewshots.json
Original file line number Diff line number Diff line change
@@ -1,34 +1,76 @@
[
{"role": "user", "content": "good options for climbing gear that can be used outside?"},
{"role": "assistant", "tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"arguments": "{\"search_query\":\"climbing gear outside\"}",
"name": "search_database"
}
}
]},
{
"role": "tool",
"tool_call_id": "call_abc123",
"content": "Search results for climbing gear that can be used outside: ..."
},
{"role": "user", "content": "are there any shoes less than $50?"},
{"role": "assistant", "tool_calls": [
{
"id": "call_abc456",
"type": "function",
"function": {
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
"name": "search_database"
}
}
]},
{
"role": "tool",
"tool_call_id": "call_abc456",
"content": "Search results for shoes cheaper than 50: ..."
}
{
"parts": [
{
"content": "good options for climbing gear that can be used outside?",
"timestamp": "2025-05-07T19:02:46.977501Z",
"part_kind": "user-prompt"
}
],
"instructions": null,
"kind": "request"
},
{
"parts": [
{
"tool_name": "search_database",
"args": "{\"search_query\":\"climbing gear outside\"}",
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
"part_kind": "tool-call"
}
],
"model_name": "gpt-4o-mini-2024-07-18",
"timestamp": "2025-05-07T19:02:47Z",
"kind": "response"
},
{
"parts": [
{
"tool_name": "search_database",
"content": "Search results for climbing gear that can be used outside: ...",
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
"timestamp": "2025-05-07T19:02:48.242408Z",
"part_kind": "tool-return"
}
],
"instructions": null,
"kind": "request"
},
{
"parts": [
{
"content": "are there any shoes less than $50?",
"timestamp": "2025-05-07T19:02:46.977501Z",
"part_kind": "user-prompt"
}
],
"instructions": null,
"kind": "request"
},
{
"parts": [
{
"tool_name": "search_database",
"args": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
"part_kind": "tool-call"
}
],
"model_name": "gpt-4o-mini-2024-07-18",
"timestamp": "2025-05-07T19:02:47Z",
"kind": "response"
},
{
"parts": [
{
"tool_name": "search_database",
"content": "Search results for shoes cheaper than 50: ...",
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
"timestamp": "2025-05-07T19:02:48.242408Z",
"part_kind": "tool-return"
}
],
"instructions": null,
"kind": "request"
}
]
Loading
Loading