Skip to content

Commit aa4f40e

Browse files
authored
Merge pull request #2 from Azure-Samples/mergeupstream
merge from upstream main
2 parents 99a1589 + f55564d commit aa4f40e

25 files changed

+961
-429
lines changed

.github/workflows/app-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
architecture: x64
8686

8787
- name: Install uv
88-
uses: astral-sh/setup-uv@v5
88+
uses: astral-sh/setup-uv@v6
8989
with:
9090
enable-cache: true
9191
version: "0.4.20"

.github/workflows/evaluate.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
python-version: '3.12'
8383

8484
- name: Install uv
85-
uses: astral-sh/setup-uv@v5
85+
uses: astral-sh/setup-uv@v6
8686
with:
8787
enable-cache: true
8888
version: "0.4.20"

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ pytest-snapshot
1414
locust
1515
psycopg2
1616
dotenv-azd
17+
freezegun

src/backend/fastapi_app/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ class State(TypedDict):
3434
@asynccontextmanager
3535
async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
3636
context = await common_parameters()
37-
azure_credential = await get_azure_credential()
37+
azure_credential = None
38+
if (
39+
os.getenv("OPENAI_CHAT_HOST") == "azure"
40+
or os.getenv("OPENAI_EMBED_HOST") == "azure"
41+
or os.getenv("POSTGRES_HOST", "").endswith(".database.azure.com")
42+
):
43+
azure_credential = await get_azure_credential()
3844
engine = await create_postgres_engine_from_env(azure_credential)
3945
sessionmaker = await create_async_sessionmaker(engine)
4046
chat_client = await create_openai_chat_client(azure_credential)
@@ -53,6 +59,7 @@ def create_app(testing: bool = False):
5359
if not testing:
5460
load_dotenv(override=True)
5561
logging.basicConfig(level=logging.INFO)
62+
5663
# Turn off particularly noisy INFO level logs from Azure Core SDK:
5764
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
5865
logging.getLogger("azure.identity").setLevel(logging.WARNING)

src/backend/fastapi_app/api_models.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from enum import Enum
2-
from typing import Any, Optional
2+
from typing import Any, Optional, Union
33

44
from openai.types.chat import ChatCompletionMessageParam
5-
from pydantic import BaseModel
5+
from pydantic import BaseModel, Field
6+
from pydantic_ai.messages import ModelRequest, ModelResponse
67

78

89
class AIChatRoles(str, Enum):
@@ -40,6 +41,30 @@ class ChatRequest(BaseModel):
4041
context: ChatRequestContext
4142
sessionState: Optional[Any] = None
4243

44+
45+
class ItemPublic(BaseModel):
46+
id: int
47+
name: str
48+
location: str
49+
cuisine: str
50+
rating: int
51+
price_level: int
52+
review_count: int
53+
hours: int
54+
tags: str
55+
description: str
56+
menu_summary: str
57+
top_reviews: str
58+
vibe: str
59+
60+
61+
class ItemWithDistance(ItemPublic):
62+
distance: float
63+
64+
def __init__(self, **data):
65+
super().__init__(**data)
66+
self.distance = round(self.distance, 2)
67+
4368

4469
class ThoughtStep(BaseModel):
4570
title: str
@@ -48,7 +73,7 @@ class ThoughtStep(BaseModel):
4873

4974

5075
class RAGContext(BaseModel):
51-
data_points: dict[int, dict[str, Any]]
76+
data_points: dict[int, ItemPublic]
5277
thoughts: list[ThoughtStep]
5378
followup_questions: Optional[list[str]] = None
5479

@@ -69,34 +94,39 @@ class RetrievalResponseDelta(BaseModel):
6994
sessionState: Optional[Any] = None
7095

7196

72-
class ItemPublic(BaseModel):
73-
id: int
74-
name: str
75-
location: str
76-
cuisine: str
77-
rating: int
78-
price_level: int
79-
review_count: int
80-
hours: int
81-
tags: str
82-
description: str
83-
menu_summary: str
84-
top_reviews: str
85-
vibe: str
86-
87-
88-
class ItemWithDistance(ItemPublic):
89-
distance: float
90-
91-
def __init__(self, **data):
92-
super().__init__(**data)
93-
self.distance = round(self.distance, 2)
94-
95-
9697
class ChatParams(ChatRequestOverrides):
9798
prompt_template: str
9899
response_token_limit: int = 1024
99100
enable_text_search: bool
100101
enable_vector_search: bool
101102
original_user_query: str
102-
past_messages: list[ChatCompletionMessageParam]
103+
past_messages: list[Union[ModelRequest, ModelResponse]]
104+
105+
106+
class Filter(BaseModel):
107+
column: str
108+
comparison_operator: str
109+
value: Any
110+
111+
112+
class PriceLevelFilter(Filter):
113+
column: str = Field(default="price_level", description="The column to filter on (always 'price_level' for this filter)")
114+
comparison_operator: str = Field(description="The operator for price level comparison ('>', '<', '>=', '<=', '=')")
115+
value: float = Field(description="Value to compare against, either 1, 2, 3, 4")
116+
117+
118+
class RatingFilter(Filter):
119+
column: str = Field(default="rating", description="The column to filter on (always 'rating' for this filter)")
120+
comparison_operator: str = Field(description="The operator for rating comparison ('>', '<', '>=', '<=', '=')")
121+
value: str = Field(description="Value to compare against, either 0 1 2 3 4")
122+
123+
124+
class SearchResults(BaseModel):
125+
query: str
126+
"""The original search query"""
127+
128+
items: list[ItemPublic]
129+
"""List of items that match the search query and filters"""
130+
131+
filters: list[Filter]
132+
"""List of filters applied to the search results"""

src/backend/fastapi_app/openai_clients.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010

1111
async def create_openai_chat_client(
12-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
12+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
1313
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
1414
openai_chat_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
1515
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
1616
if OPENAI_CHAT_HOST == "azure":
17-
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-03-01-preview"
17+
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-10-21"
1818
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
1919
azure_deployment = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
2020
if api_key := os.getenv("AZURE_OPENAI_KEY"):
@@ -29,7 +29,7 @@ async def create_openai_chat_client(
2929
azure_deployment=azure_deployment,
3030
api_key=api_key,
3131
)
32-
else:
32+
elif azure_credential:
3333
logger.info(
3434
"Setting up Azure OpenAI client for chat completions using Azure Identity, endpoint %s, deployment %s",
3535
azure_endpoint,
@@ -44,6 +44,8 @@ async def create_openai_chat_client(
4444
azure_deployment=azure_deployment,
4545
azure_ad_token_provider=token_provider,
4646
)
47+
else:
48+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
4749
elif OPENAI_CHAT_HOST == "ollama":
4850
logger.info("Setting up OpenAI client for chat completions using Ollama")
4951
openai_chat_client = openai.AsyncOpenAI(
@@ -67,7 +69,7 @@ async def create_openai_chat_client(
6769

6870

6971
async def create_openai_embed_client(
70-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
72+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
7173
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
7274
openai_embed_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
7375
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
@@ -87,7 +89,7 @@ async def create_openai_embed_client(
8789
azure_deployment=azure_deployment,
8890
api_key=api_key,
8991
)
90-
else:
92+
elif azure_credential:
9193
logger.info(
9294
"Setting up Azure OpenAI client for embeddings using Azure Identity, endpoint %s, deployment %s",
9395
azure_endpoint,
@@ -102,6 +104,8 @@ async def create_openai_embed_client(
102104
azure_deployment=azure_deployment,
103105
azure_ad_token_provider=token_provider,
104106
)
107+
else:
108+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
105109
elif OPENAI_EMBED_HOST == "ollama":
106110
logger.info("Setting up OpenAI client for embeddings using Ollama")
107111
openai_embed_client = openai.AsyncOpenAI(

src/backend/fastapi_app/postgres_searcher.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import Float, Integer, column, select, text
66
from sqlalchemy.ext.asyncio import AsyncSession
77

8+
from fastapi_app.api_models import Filter
89
from fastapi_app.embeddings import compute_text_embedding
910
from fastapi_app.postgres_models import Item
1011

@@ -26,21 +27,24 @@ def __init__(
2627
self.embed_dimensions = embed_dimensions
2728
self.embedding_column = embedding_column
2829

29-
def build_filter_clause(self, filters) -> tuple[str, str]:
30+
def build_filter_clause(self, filters: Optional[list[Filter]]) -> tuple[str, str]:
3031
if filters is None:
3132
return "", ""
3233
filter_clauses = []
3334
for filter in filters:
34-
if isinstance(filter["value"], str):
35-
filter["value"] = f"'{filter['value']}'"
36-
filter_clauses.append(f"{filter['column']} {filter['comparison_operator']} {filter['value']}")
35+
filter_value = f"'{filter.value}'" if isinstance(filter.value, str) else filter.value
36+
filter_clauses.append(f"{filter.column} {filter.comparison_operator} {filter_value}")
3737
filter_clause = " AND ".join(filter_clauses)
3838
if len(filter_clause) > 0:
3939
return f"WHERE {filter_clause}", f"AND {filter_clause}"
4040
return "", ""
4141

4242
async def search(
43-
self, query_text: Optional[str], query_vector: list[float], top: int = 5, filters: Optional[list[dict]] = None
43+
self,
44+
query_text: Optional[str],
45+
query_vector: list[float],
46+
top: int = 5,
47+
filters: Optional[list[Filter]] = None,
4448
):
4549
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
4650
table_name = Item.__tablename__
@@ -106,7 +110,7 @@ async def search_and_embed(
106110
top: int = 5,
107111
enable_vector_search: bool = False,
108112
enable_text_search: bool = False,
109-
filters: Optional[list[dict]] = None,
113+
filters: Optional[list[Filter]] = None,
110114
) -> list[Item]:
111115
"""
112116
Search rows by query text. Optionally converts the query text to a vector if enable_vector_search is True.
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.
2-
You have access to an Azure PostgreSQL database with a restaurants table that has name, description, menu summary, vibe, ratings, etc.
3-
Generate a search query based on the conversation and the new question.
4-
If the question is not in English, translate the question to English before generating the search query.
5-
If you cannot generate a search query, return the original user question.
6-
DO NOT return anything besides the query.
1+
Your job is to find search results based off the user's question and past messages.
2+
You have access to only these tools:
3+
1. **search_database**: This tool allows you to search a table for restaurants based on a query.
4+
You can pass in a search query and optional filters.
5+
Once you get the search results, you're done.
Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,76 @@
11
[
2-
{"role": "user", "content": "good options for ethiopian restaurants?"},
3-
{"role": "assistant", "tool_calls": [
4-
{
5-
"id": "call_abc123",
6-
"type": "function",
7-
"function": {
8-
"arguments": "{\"search_query\":\"ethiopian\"}",
9-
"name": "search_database"
10-
}
11-
}
12-
]},
13-
{
14-
"role": "tool",
15-
"tool_call_id": "call_abc123",
16-
"content": "Search results for ethiopian: ..."
17-
},
18-
{"role": "user", "content": "are there any inexpensive chinese restaurants?"},
19-
{"role": "assistant", "tool_calls": [
20-
{
21-
"id": "call_abc456",
22-
"type": "function",
23-
"function": {
24-
"arguments": "{\"search_query\":\"chinese\",\"price_level_filter\":{\"comparison_operator\":\"<\",\"value\":3}}",
25-
"name": "search_database"
26-
}
27-
}
28-
]},
29-
{
30-
"role": "tool",
31-
"tool_call_id": "call_abc456",
32-
"content": "Search results for chinese: ..."
33-
}
2+
{
3+
"parts": [
4+
{
5+
"content": "good options for ethiopian restaurants?",
6+
"timestamp": "2025-05-07T19:02:46.977501Z",
7+
"part_kind": "user-prompt"
8+
}
9+
],
10+
"instructions": null,
11+
"kind": "request"
12+
},
13+
{
14+
"parts": [
15+
{
16+
"tool_name": "search_database",
17+
"args": "{\"search_query\":\"ethiopian\"}",
18+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
19+
"part_kind": "tool-call"
20+
}
21+
],
22+
"model_name": "gpt-4o-mini-2024-07-18",
23+
"timestamp": "2025-05-07T19:02:47Z",
24+
"kind": "response"
25+
},
26+
{
27+
"parts": [
28+
{
29+
"tool_name": "search_database",
30+
"content": "Search results for ethiopian: ...",
31+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
32+
"timestamp": "2025-05-07T19:02:48.242408Z",
33+
"part_kind": "tool-return"
34+
}
35+
],
36+
"instructions": null,
37+
"kind": "request"
38+
},
39+
{
40+
"parts": [
41+
{
42+
"content": "are there any inexpensive chinese restaurants?",
43+
"timestamp": "2025-05-07T19:02:46.977501Z",
44+
"part_kind": "user-prompt"
45+
}
46+
],
47+
"instructions": null,
48+
"kind": "request"
49+
},
50+
{
51+
"parts": [
52+
{
53+
"tool_name": "search_database",
54+
"args": "{\"search_query\":\"chinese\",\"price_level_filter\":{\"comparison_operator\":\"<\",\"value\":3}}",
55+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
56+
"part_kind": "tool-call"
57+
}
58+
],
59+
"model_name": "gpt-4o-mini-2024-07-18",
60+
"timestamp": "2025-05-07T19:02:47Z",
61+
"kind": "response"
62+
},
63+
{
64+
"parts": [
65+
{
66+
"tool_name": "search_database",
67+
"content": "Search results for chinese: ...",
68+
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
69+
"timestamp": "2025-05-07T19:02:48.242408Z",
70+
"part_kind": "tool-return"
71+
}
72+
],
73+
"instructions": null,
74+
"kind": "request"
75+
}
3476
]

0 commit comments

Comments
 (0)