Skip to content

Commit adb2b2e

Browse files
committed
Make mypy happy
1 parent 31e7cc1 commit adb2b2e

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

evals/generate_ground_truth.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from azure.identity import AzureDeveloperCliCredential, get_bearer_token_provider
88
from dotenv_azd import load_azd_env
99
from openai import AzureOpenAI, OpenAI
10+
from openai.types.chat import ChatCompletionToolParam
1011
from sqlalchemy import create_engine, select
1112
from sqlalchemy.orm import Session
1213

@@ -15,7 +16,7 @@
1516
logger = logging.getLogger("ragapp")
1617

1718

18-
def qa_pairs_tool(num_questions: int = 1) -> dict:
19+
def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:
1920
return {
2021
"type": "function",
2122
"function": {
@@ -45,7 +46,7 @@ def qa_pairs_tool(num_questions: int = 1) -> dict:
4546
}
4647

4748

48-
def source_retriever() -> Generator[dict, None, None]:
49+
def source_retriever() -> Generator[str, None, None]:
4950
# Connect to the database
5051
DBHOST = os.environ["POSTGRES_HOST"]
5152
DBUSER = os.environ["POSTGRES_USERNAME"]
@@ -76,8 +77,9 @@ def answer_formatter(answer, source) -> str:
7677
return f"{answer} [{source['id']}]"
7778

7879

79-
def get_openai_client() -> AzureOpenAI | OpenAI:
80+
def get_openai_client() -> tuple[AzureOpenAI | OpenAI, str]:
8081
"""Return an OpenAI client based on the environment variables"""
82+
openai_client: AzureOpenAI | OpenAI
8183
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
8284
if OPENAI_CHAT_HOST == "azure":
8385
if api_key := os.getenv("AZURE_OPENAI_KEY"):
@@ -101,8 +103,7 @@ def get_openai_client() -> AzureOpenAI | OpenAI:
101103
raise NotImplementedError("Ollama OpenAI Service is not supported. Switch to Azure or OpenAI.com")
102104
else:
103105
logger.info("Using OpenAI Service with API Key from OPENAICOM_KEY")
104-
openai_config = {"api_type": "openai", "api_key": os.environ["OPENAICOM_KEY"]}
105-
openai_client = OpenAI(**openai_config)
106+
openai_client = OpenAI(api_key=os.environ["OPENAICOM_KEY"])
106107
model = os.environ["OPENAICOM_CHAT_MODEL"]
107108
return openai_client, model
108109

@@ -127,6 +128,9 @@ def generate_ground_truth_data(num_questions_total: int, num_questions_per_sourc
127128
],
128129
tools=[qa_pairs_tool(num_questions=2)],
129130
)
131+
if not result.choices[0].message.tool_calls:
132+
logger.warning("No tool calls found in response, skipping")
133+
continue
130134
qa_pairs = json.loads(result.choices[0].message.tool_calls[0].function.arguments)["qa_list"]
131135
qa_pairs = [{"question": qa_pair["question"], "truth": qa_pair["answer"]} for qa_pair in qa_pairs]
132136
qa.extend(qa_pairs)

0 commit comments

Comments
 (0)