@@ -12,39 +12,39 @@ def build_search_function() -> list[ChatCompletionToolParam]:
12
12
"type" : "function" ,
13
13
"function" : {
14
14
"name" : "search_database" ,
15
- "description" : "Search PostgreSQL database for relevant products based on user query" ,
15
+ "description" : "Search PostgreSQL database for relevant restaurants based on user query" ,
16
16
"parameters" : {
17
17
"type" : "object" ,
18
18
"properties" : {
19
19
"search_query" : {
20
20
"type" : "string" ,
21
21
"description" : "Query string to use for full text search, e.g. 'red shoes'" ,
22
22
},
23
- "price_filter " : {
23
+ "price_level_filter " : {
24
24
"type" : "object" ,
25
- "description" : "Filter search results based on price of the product" ,
25
+ "description" : "Filter search results to a certain price level (from 1 $ to 4 $$$$, with 4 being most costly)" , # noqa: E501
26
26
"properties" : {
27
27
"comparison_operator" : {
28
28
"type" : "string" ,
29
- "description" : "Operator to compare the column value, either '>', '<', '>=', '<=', '='" , # noqa
29
+ "description" : "Operator to compare the column value, either '>', '<', '>=', '<=', '='" , # noqa: E501
30
30
},
31
31
"value" : {
32
32
"type" : "number" ,
33
- "description" : "Value to compare against, e.g. 30 " ,
33
+ "description" : "Value to compare against, either 1, 2, 3, 4 " ,
34
34
},
35
35
},
36
36
},
37
- "brand_filter " : {
37
+ "rating_filter " : {
38
38
"type" : "object" ,
39
- "description" : "Filter search results based on brand of the product" ,
39
+ "description" : "Filter search results based on ratings of restaurant (from 1 to 5 stars, with 5 the best)" , # noqa: E501
40
40
"properties" : {
41
41
"comparison_operator" : {
42
42
"type" : "string" ,
43
- "description" : "Operator to compare the column value, either '=' or '! ='" ,
43
+ "description" : "Operator to compare the column value, either '>', '<', '>=', '<=', ' ='" , # noqa: E501
44
44
},
45
45
"value" : {
46
46
"type" : "string" ,
47
- "description" : "Value to compare against, e.g. AirStrider " ,
47
+ "description" : "Value to compare against, either 0 1 2 3 4 5 " ,
48
48
},
49
49
},
50
50
},
@@ -69,22 +69,26 @@ def extract_search_arguments(original_user_query: str, chat_completion: ChatComp
69
69
arg = json .loads (function .arguments )
70
70
# Even though its required, search_query is not always specified
71
71
search_query = arg .get ("search_query" , original_user_query )
72
- if "price_filter" in arg and arg ["price_filter" ] and isinstance (arg ["price_filter" ], dict ):
73
- price_filter = arg ["price_filter" ]
72
+ if (
73
+ "price_level_filter" in arg
74
+ and arg ["price_level_filter" ]
75
+ and isinstance (arg ["price_level_filter" ], dict )
76
+ ):
77
+ price_level_filter = arg ["price_level_filter" ]
74
78
filters .append (
75
79
{
76
- "column" : "price " ,
77
- "comparison_operator" : price_filter ["comparison_operator" ],
78
- "value" : price_filter ["value" ],
80
+ "column" : "price_level " ,
81
+ "comparison_operator" : price_level_filter ["comparison_operator" ],
82
+ "value" : price_level_filter ["value" ],
79
83
}
80
84
)
81
- if "brand_filter " in arg and arg ["brand_filter " ] and isinstance (arg ["brand_filter " ], dict ):
82
- brand_filter = arg ["brand_filter " ]
85
+ if "rating_filter " in arg and arg ["rating_filter " ] and isinstance (arg ["rating_filter " ], dict ):
86
+ rating_filter = arg ["rating_filter " ]
83
87
filters .append (
84
88
{
85
- "column" : "brand " ,
86
- "comparison_operator" : brand_filter ["comparison_operator" ],
87
- "value" : brand_filter ["value" ],
89
+ "column" : "rating " ,
90
+ "comparison_operator" : rating_filter ["comparison_operator" ],
91
+ "value" : rating_filter ["value" ],
88
92
}
89
93
)
90
94
elif query_text := response_message .content :
0 commit comments