Skip to content

Commit 99a1589

Browse files
authored
Merge pull request #1 from rohit-ganguly/restaurantdata
Bring in sample data from Pittsburgh restaurants
2 parents 541adda + d68572c commit 99a1589

File tree

10 files changed

+90565
-181897
lines changed

10 files changed

+90565
-181897
lines changed

convert_csv_json.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import ast
2+
import csv
3+
import json
4+
5+
# Read CSV file - Using the correct dialect to handle quotes properly
6+
with open("data.csv", encoding="utf-8") as csv_file:
7+
# Use the csv.reader with proper quoting parameters
8+
csv_reader = csv.reader(csv_file, quoting=csv.QUOTE_ALL, doublequote=True, escapechar="\\")
9+
header = next(csv_reader) # Get the header row
10+
data = list(csv_reader) # Get all data rows
11+
12+
# Convert to JSON format
13+
json_data = []
14+
for row in data:
15+
item = {}
16+
for i in range(len(header)):
17+
if i < len(row): # Ensure we don't go out of bounds
18+
value = row[i].strip()
19+
# Check if the value looks like a JSON array
20+
if value.startswith("[") and value.endswith("]"):
21+
try:
22+
# Parse the JSON-like string into a Python object
23+
value = json.loads(value.replace("'", '"'))
24+
except (ValueError, SyntaxError):
25+
try:
26+
# Try with ast as a fallback
27+
value = ast.literal_eval(value)
28+
except (ValueError, SyntaxError):
29+
# If parsing fails, keep it as a string
30+
pass
31+
# Convert boolean strings
32+
elif value.lower() == "true":
33+
value = True
34+
elif value.lower() == "false":
35+
value = False
36+
# Try to convert numbers
37+
elif value.isdigit():
38+
value = int(value)
39+
elif value.replace(".", "", 1).isdigit() and value.count(".") <= 1:
40+
value = float(value)
41+
42+
item[header[i]] = value
43+
# remove is_open column
44+
del item["is_open"]
45+
json_data.append(item)
46+
47+
# Write to JSON file
48+
with open("data.json", "w", encoding="utf-8") as f:
49+
json.dump(json_data, f, indent=4, ensure_ascii=False)
50+
51+
print(f"Successfully converted CSV data to JSON format with {len(json_data)} records")

src/backend/fastapi_app/api_models.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,18 @@ class RetrievalResponseDelta(BaseModel):
7171

7272
class ItemPublic(BaseModel):
7373
id: int
74-
type: str
75-
brand: str
7674
name: str
75+
location: str
76+
cuisine: str
77+
rating: int
78+
price_level: int
79+
review_count: int
80+
hours: int
81+
tags: str
7782
description: str
78-
price: float
83+
menu_summary: str
84+
top_reviews: str
85+
vibe: str
7986

8087

8188
class ItemWithDistance(ItemPublic):

src/backend/fastapi_app/postgres_models.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
from pgvector.sqlalchemy import Vector
4-
from sqlalchemy import Index
4+
from sqlalchemy import VARCHAR, Index
5+
from sqlalchemy.dialects import postgresql
56
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
67

78

@@ -13,11 +14,19 @@ class Base(DeclarativeBase):
1314
class Item(Base):
1415
__tablename__ = "items"
1516
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
16-
type: Mapped[str] = mapped_column()
17-
brand: Mapped[str] = mapped_column()
1817
name: Mapped[str] = mapped_column()
18+
location: Mapped[str] = mapped_column()
19+
cuisine: Mapped[str] = mapped_column()
20+
rating: Mapped[int] = mapped_column()
21+
price_level: Mapped[int] = mapped_column()
22+
review_count: Mapped[int] = mapped_column()
23+
hours: Mapped[str] = mapped_column()
24+
tags: Mapped[list[str]] = mapped_column(postgresql.ARRAY(VARCHAR)) # Array of strings
1925
description: Mapped[str] = mapped_column()
20-
price: Mapped[float] = mapped_column()
26+
menu_summary: Mapped[str] = mapped_column()
27+
top_reviews: Mapped[str] = mapped_column()
28+
vibe: Mapped[str] = mapped_column()
29+
2130
# Embeddings for different models:
2231
embedding_3l: Mapped[Vector] = mapped_column(Vector(1024), nullable=True) # text-embedding-3-large
2332
embedding_nomic: Mapped[Vector] = mapped_column(Vector(768), nullable=True) # nomic-embed-text
@@ -33,10 +42,10 @@ def to_dict(self, include_embedding: bool = False):
3342
return model_dict
3443

3544
def to_str_for_rag(self):
36-
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
45+
return f"Name:{self.name} Description:{self.description} Location:{self.location} Cuisine:{self.cuisine} Rating:{self.rating} Price Level:{self.price_level} Review Count:{self.review_count} Hours:{self.hours} Tags:{self.tags} Menu Summary:{self.menu_summary} Top Reviews:{self.top_reviews} Vibe:{self.vibe}" # noqa: E501
3746

3847
def to_str_for_embedding(self):
39-
return f"Name: {self.name} Description: {self.description} Type: {self.type}"
48+
return f"Name: {self.name} Description: {self.description} Cuisine: {self.cuisine} Tags: {self.tags} Menu Summary: {self.menu_summary} Top Reviews: {self.top_reviews} Vibe: {self.vibe}" # noqa: E501
4049

4150

4251
"""
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
Assistant helps customers with questions about products.
2-
Respond as if you are a salesperson helping a customer in a store. Do NOT respond with tables.
3-
Answer ONLY with the product details listed in the products.
1+
Assistant helps Pycon attendees with questions about restaurants.
2+
Respond as if you are a conference volunteer. Do NOT respond with tables.
3+
Answer ONLY with the restaurant details listed in the sources.
44
If there isn't enough information below, say you don't know.
55
Do not generate answers that don't use the sources below.
6-
Each product has an ID in brackets followed by colon and the product details.
7-
Always include the product ID for each product you use in the response.
8-
Use square brackets to reference the source, for example [52].
9-
Don't combine citations, list each product separately, for example [27][51].
6+
Each restaurant has an ID in brackets followed by colon and the restaurant details.
7+
Always include the restaurant ID for each restaurant you reference in the response.
8+
Use square brackets to reference the restaurant, for example [52].
9+
Don't combine references, cite each restaurant separately, for example [27][51].

src/backend/fastapi_app/prompts/query.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.
2-
You have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.
2+
You have access to an Azure PostgreSQL database with a restaurants table that has name, description, menu summary, vibe, ratings, etc.
33
Generate a search query based on the conversation and the new question.
44
If the question is not in English, translate the question to English before generating the search query.
55
If you cannot generate a search query, return the original user question.
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,34 @@
11
[
2-
{"role": "user", "content": "good options for climbing gear that can be used outside?"},
2+
{"role": "user", "content": "good options for ethiopian restaurants?"},
33
{"role": "assistant", "tool_calls": [
44
{
55
"id": "call_abc123",
66
"type": "function",
77
"function": {
8-
"arguments": "{\"search_query\":\"climbing gear outside\"}",
8+
"arguments": "{\"search_query\":\"ethiopian\"}",
99
"name": "search_database"
1010
}
1111
}
1212
]},
1313
{
1414
"role": "tool",
1515
"tool_call_id": "call_abc123",
16-
"content": "Search results for climbing gear that can be used outside: ..."
16+
"content": "Search results for ethiopian: ..."
1717
},
18-
{"role": "user", "content": "are there any shoes less than $50?"},
18+
{"role": "user", "content": "are there any inexpensive chinese restaurants?"},
1919
{"role": "assistant", "tool_calls": [
2020
{
2121
"id": "call_abc456",
2222
"type": "function",
2323
"function": {
24-
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
24+
"arguments": "{\"search_query\":\"chinese\",\"price_level_filter\":{\"comparison_operator\":\"<\",\"value\":3}}",
2525
"name": "search_database"
2626
}
2727
}
2828
]},
2929
{
3030
"role": "tool",
3131
"tool_call_id": "call_abc456",
32-
"content": "Search results for shoes cheaper than 50: ..."
32+
"content": "Search results for chinese: ..."
3333
}
3434
]

src/backend/fastapi_app/query_rewriter.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,39 @@ def build_search_function() -> list[ChatCompletionToolParam]:
1212
"type": "function",
1313
"function": {
1414
"name": "search_database",
15-
"description": "Search PostgreSQL database for relevant products based on user query",
15+
"description": "Search PostgreSQL database for relevant restaurants based on user query",
1616
"parameters": {
1717
"type": "object",
1818
"properties": {
1919
"search_query": {
2020
"type": "string",
2121
"description": "Query string to use for full text search, e.g. 'red shoes'",
2222
},
23-
"price_filter": {
23+
"price_level_filter": {
2424
"type": "object",
25-
"description": "Filter search results based on price of the product",
25+
"description": "Filter search results to a certain price level (from 1 $ to 4 $$$$, with 4 being most costly)", # noqa: E501
2626
"properties": {
2727
"comparison_operator": {
2828
"type": "string",
29-
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa
29+
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa: E501
3030
},
3131
"value": {
3232
"type": "number",
33-
"description": "Value to compare against, e.g. 30",
33+
"description": "Value to compare against, either 1, 2, 3, 4",
3434
},
3535
},
3636
},
37-
"brand_filter": {
37+
"rating_filter": {
3838
"type": "object",
39-
"description": "Filter search results based on brand of the product",
39+
"description": "Filter search results based on ratings of restaurant (from 1 to 5 stars, with 5 the best)", # noqa: E501
4040
"properties": {
4141
"comparison_operator": {
4242
"type": "string",
43-
"description": "Operator to compare the column value, either '=' or '!='",
43+
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa: E501
4444
},
4545
"value": {
4646
"type": "string",
47-
"description": "Value to compare against, e.g. AirStrider",
47+
"description": "Value to compare against, either 0 1 2 3 4 5",
4848
},
4949
},
5050
},
@@ -69,22 +69,26 @@ def extract_search_arguments(original_user_query: str, chat_completion: ChatComp
6969
arg = json.loads(function.arguments)
7070
# Even though its required, search_query is not always specified
7171
search_query = arg.get("search_query", original_user_query)
72-
if "price_filter" in arg and arg["price_filter"] and isinstance(arg["price_filter"], dict):
73-
price_filter = arg["price_filter"]
72+
if (
73+
"price_level_filter" in arg
74+
and arg["price_level_filter"]
75+
and isinstance(arg["price_level_filter"], dict)
76+
):
77+
price_level_filter = arg["price_level_filter"]
7478
filters.append(
7579
{
76-
"column": "price",
77-
"comparison_operator": price_filter["comparison_operator"],
78-
"value": price_filter["value"],
80+
"column": "price_level",
81+
"comparison_operator": price_level_filter["comparison_operator"],
82+
"value": price_level_filter["value"],
7983
}
8084
)
81-
if "brand_filter" in arg and arg["brand_filter"] and isinstance(arg["brand_filter"], dict):
82-
brand_filter = arg["brand_filter"]
85+
if "rating_filter" in arg and arg["rating_filter"] and isinstance(arg["rating_filter"], dict):
86+
rating_filter = arg["rating_filter"]
8387
filters.append(
8488
{
85-
"column": "brand",
86-
"comparison_operator": brand_filter["comparison_operator"],
87-
"value": brand_filter["value"],
89+
"column": "rating",
90+
"comparison_operator": rating_filter["comparison_operator"],
91+
"value": rating_filter["value"],
8892
}
8993
)
9094
elif query_text := response_message.content:

0 commit comments

Comments
 (0)