7
7
from azure .identity import AzureDeveloperCliCredential , get_bearer_token_provider
8
8
from dotenv_azd import load_azd_env
9
9
from openai import AzureOpenAI , OpenAI
10
+ from openai .types .chat import ChatCompletionToolParam
10
11
from sqlalchemy import create_engine , select
11
12
from sqlalchemy .orm import Session
12
13
15
16
logger = logging .getLogger ("ragapp" )
16
17
17
18
18
- def qa_pairs_tool (num_questions : int = 1 ) -> dict :
19
+ def qa_pairs_tool (num_questions : int = 1 ) -> ChatCompletionToolParam :
19
20
return {
20
21
"type" : "function" ,
21
22
"function" : {
@@ -45,7 +46,7 @@ def qa_pairs_tool(num_questions: int = 1) -> dict:
45
46
}
46
47
47
48
48
- def source_retriever () -> Generator [dict , None , None ]:
49
+ def source_retriever () -> Generator [str , None , None ]:
49
50
# Connect to the database
50
51
DBHOST = os .environ ["POSTGRES_HOST" ]
51
52
DBUSER = os .environ ["POSTGRES_USERNAME" ]
@@ -76,8 +77,9 @@ def answer_formatter(answer, source) -> str:
76
77
return f"{ answer } [{ source ['id' ]} ]"
77
78
78
79
79
- def get_openai_client () -> AzureOpenAI | OpenAI :
80
+ def get_openai_client () -> tuple [ AzureOpenAI | OpenAI , str ] :
80
81
"""Return an OpenAI client based on the environment variables"""
82
+ openai_client : AzureOpenAI | OpenAI
81
83
OPENAI_CHAT_HOST = os .getenv ("OPENAI_CHAT_HOST" )
82
84
if OPENAI_CHAT_HOST == "azure" :
83
85
if api_key := os .getenv ("AZURE_OPENAI_KEY" ):
@@ -101,8 +103,7 @@ def get_openai_client() -> AzureOpenAI | OpenAI:
101
103
raise NotImplementedError ("Ollama OpenAI Service is not supported. Switch to Azure or OpenAI.com" )
102
104
else :
103
105
logger .info ("Using OpenAI Service with API Key from OPENAICOM_KEY" )
104
- openai_config = {"api_type" : "openai" , "api_key" : os .environ ["OPENAICOM_KEY" ]}
105
- openai_client = OpenAI (** openai_config )
106
+ openai_client = OpenAI (api_key = os .environ ["OPENAICOM_KEY" ])
106
107
model = os .environ ["OPENAICOM_CHAT_MODEL" ]
107
108
return openai_client , model
108
109
@@ -127,6 +128,9 @@ def generate_ground_truth_data(num_questions_total: int, num_questions_per_sourc
127
128
],
128
129
tools = [qa_pairs_tool (num_questions = 2 )],
129
130
)
131
+ if not result .choices [0 ].message .tool_calls :
132
+ logger .warning ("No tool calls found in response, skipping" )
133
+ continue
130
134
qa_pairs = json .loads (result .choices [0 ].message .tool_calls [0 ].function .arguments )["qa_list" ]
131
135
qa_pairs = [{"question" : qa_pair ["question" ], "truth" : qa_pair ["answer" ]} for qa_pair in qa_pairs ]
132
136
qa .extend (qa_pairs )
0 commit comments