Skip to content

Commit e2b7297

Browse files
authored
Merge pull request #104 from Azure-Samples/test-eval3
Add comment in workflow
2 parents dc19e7c + adb2b2e commit e2b7297

10 files changed

+167
-290
lines changed

.github/workflows/evaluate.yaml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ jobs:
4646
run: |
4747
echo "Comment contains #evaluate hashtag"
4848
49+
- name: Comment on pull request
50+
uses: actions/github-script@v7
51+
with:
52+
script: |
53+
github.rest.issues.createComment({
54+
issue_number: context.issue.number,
55+
owner: context.repo.owner,
56+
repo: context.repo.repo,
57+
body: "Starting evaluation! Check the Actions tab for progress, or wait for a comment with the results."
58+
})
59+
4960
- uses: actions/checkout@v4
5061

5162
- name: Install pgvector
@@ -133,7 +144,7 @@ jobs:
133144
134145
- name: Evaluate local RAG flow
135146
run: |
136-
python evals/evaluate.py --targeturl=http://127.0.0.1:8000/chat --numquestions=2 --resultsdir=evals/results/pr${{ github.event.issue.number }}
147+
python evals/evaluate.py --targeturl=http://127.0.0.1:8000/chat --resultsdir=evals/results/pr${{ github.event.issue.number }}
137148
138149
- name: Upload server logs as build artifact
139150
uses: actions/upload-artifact@v4

docs/evaluation.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,19 @@ pip install -r requirements-dev.txt
3333

3434
## Generate ground truth data
3535

36+
Modify the prompt in `evals/generate.txt` to match your database table and RAG scenario.
37+
3638
Generate ground truth data by running the following command:
3739

3840
```bash
39-
python evals/generate.py
41+
python evals/generate_ground_truth_data.py
4042
```
4143

4244
Review the generated data after running that script, removing any question/answer pairs that don't seem like realistic user input.
4345
4446
## Evaluate the RAG answer quality
4547
46-
Review the configuration in `evals/eval_config.json` to ensure that everything is correctly setup. You may want to adjust the metrics used. [TODO: link to evaluator docs]
48+
Review the configuration in `evals/eval_config.json` to ensure that everything is correctly setup. You may want to adjust the metrics used. See [the ai-rag-chat-evaluator README](https://github.com/Azure-Samples/ai-rag-chat-evaluator) for more information on the available metrics.
4749
4850
By default, the evaluation script will evaluate every question in the ground truth data.
4951
Run the evaluation script by running the following command:
@@ -68,8 +70,6 @@ Compare answers across runs by running the following command:
6870
python -m evaltools diff evals/results/baseline/
6971
```
7072
71-
## Run the evaluation in GitHub actions
72-
73+
## Run the evaluation on a PR
7374
74-
# TODO: Add GPT-4 deployment with high capacity for evaluation
75-
# TODO: Add CI workflow that can be triggered to run the evaluate on the local app
75+
To run the evaluation on the changes in a PR, you can add a `/evaluate` comment to the PR. This will trigger the evaluation workflow to run the evaluation on the PR changes, and will post the results to the PR.

evals/generate_ground_truth.py

Lines changed: 101 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import json
12
import logging
23
import os
34
from collections.abc import Generator
45
from pathlib import Path
56

6-
from azure.identity import AzureDeveloperCliCredential
7-
from dotenv import load_dotenv
8-
from evaltools.gen.generate import generate_test_qa_data
7+
from azure.identity import AzureDeveloperCliCredential, get_bearer_token_provider
8+
from dotenv_azd import load_azd_env
9+
from openai import AzureOpenAI, OpenAI
10+
from openai.types.chat import ChatCompletionToolParam
911
from sqlalchemy import create_engine, select
1012
from sqlalchemy.orm import Session
1113

@@ -14,7 +16,37 @@
1416
logger = logging.getLogger("ragapp")
1517

1618

17-
def source_retriever() -> Generator[dict, None, None]:
19+
def qa_pairs_tool(num_questions: int = 1) -> ChatCompletionToolParam:
20+
return {
21+
"type": "function",
22+
"function": {
23+
"name": "qa_pairs",
24+
"description": "Send in question and answer pairs for a customer-facing chat app",
25+
"parameters": {
26+
"type": "object",
27+
"properties": {
28+
"qa_list": {
29+
"type": "array",
30+
"description": f"List of {num_questions} question and answer pairs",
31+
"items": {
32+
"type": "object",
33+
"properties": {
34+
"question": {"type": "string", "description": "The question text"},
35+
"answer": {"type": "string", "description": "The answer text"},
36+
},
37+
"required": ["question", "answer"],
38+
},
39+
"minItems": num_questions,
40+
"maxItems": num_questions,
41+
}
42+
},
43+
"required": ["qa_list"],
44+
},
45+
},
46+
}
47+
48+
49+
def source_retriever() -> Generator[str, None, None]:
1850
# Connect to the database
1951
DBHOST = os.environ["POSTGRES_HOST"]
2052
DBUSER = os.environ["POSTGRES_USERNAME"]
@@ -27,16 +59,14 @@ def source_retriever() -> Generator[dict, None, None]:
2759
item_types = session.scalars(select(Item.type).distinct())
2860
for item_type in item_types:
2961
records = list(session.scalars(select(Item).filter(Item.type == item_type).order_by(Item.id)))
30-
# logger.info(f"Processing database records for type: {item_type}")
31-
# yield {
32-
# "citations": " ".join([f"[{record.id}] - {record.name}" for record in records]),
33-
# "content": "\n\n".join([record.to_str_for_rag() for record in records]),
34-
# }
62+
logger.info(f"Processing database records for type: {item_type}")
63+
yield "\n\n".join([f"## Product ID: [{record.id}]\n" + record.to_str_for_rag() for record in records])
3564
# Fetch each item individually
36-
records = list(session.scalars(select(Item).order_by(Item.id)))
37-
for record in records:
38-
logger.info(f"Processing database record: {record.name}")
39-
yield {"id": record.id, "content": record.to_str_for_rag()}
65+
# records = list(session.scalars(select(Item).order_by(Item.id)))
66+
# for record in records:
67+
# logger.info(f"Processing database record: {record.name}")
68+
# yield f"## Product ID: [{record.id}]\n" + record.to_str_for_rag()
69+
# await self.openai_chat_client.chat.completions.create(
4070

4171

4272
def source_to_text(source) -> str:
@@ -47,49 +77,76 @@ def answer_formatter(answer, source) -> str:
4777
return f"{answer} [{source['id']}]"
4878

4979

50-
def get_openai_config_dict() -> dict:
51-
"""Return a dictionary with OpenAI configuration based on environment variables."""
80+
def get_openai_client() -> tuple[AzureOpenAI | OpenAI, str]:
81+
"""Return an OpenAI client based on the environment variables"""
82+
openai_client: AzureOpenAI | OpenAI
5283
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
5384
if OPENAI_CHAT_HOST == "azure":
5485
if api_key := os.getenv("AZURE_OPENAI_KEY"):
5586
logger.info("Using Azure OpenAI Service with API Key from AZURE_OPENAI_KEY")
56-
api_key = os.environ["AZURE_OPENAI_KEY"]
87+
openai_client = AzureOpenAI(
88+
api_version=os.environ["AZURE_OPENAI_VERSION"],
89+
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
90+
api_key=api_key,
91+
)
5792
else:
5893
logger.info("Using Azure OpenAI Service with Azure Developer CLI Credential")
59-
azure_credential = AzureDeveloperCliCredential(process_timeout=60)
60-
api_key = azure_credential.get_token("https://cognitiveservices.azure.com/.default").token
61-
openai_config = {
62-
"api_type": "azure",
63-
"api_base": os.environ["AZURE_OPENAI_ENDPOINT"],
64-
"api_key": api_key,
65-
"api_version": os.environ["AZURE_OPENAI_VERSION"],
66-
"deployment": os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
67-
"model": os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"],
68-
}
94+
azure_credential = AzureDeveloperCliCredential(process_timeout=60, tenant_id=os.environ["AZURE_TENANT_ID"])
95+
token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
96+
openai_client = AzureOpenAI(
97+
api_version=os.environ["AZURE_OPENAI_VERSION"],
98+
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
99+
azure_ad_token_provider=token_provider,
100+
)
101+
model = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
69102
elif OPENAI_CHAT_HOST == "ollama":
70103
raise NotImplementedError("Ollama OpenAI Service is not supported. Switch to Azure or OpenAI.com")
71104
else:
72105
logger.info("Using OpenAI Service with API Key from OPENAICOM_KEY")
73-
openai_config = {
74-
"api_type": "openai",
75-
"api_key": os.environ["OPENAICOM_KEY"],
76-
"model": os.environ["OPENAICOM_CHAT_MODEL"],
77-
"deployment": "none-needed-for-openaicom",
78-
}
79-
return openai_config
106+
openai_client = OpenAI(api_key=os.environ["OPENAICOM_KEY"])
107+
model = os.environ["OPENAICOM_CHAT_MODEL"]
108+
return openai_client, model
109+
110+
111+
def generate_ground_truth_data(num_questions_total: int, num_questions_per_source: int = 5):
112+
logger.info("Generating %d questions total", num_questions_total)
113+
openai_client, model = get_openai_client()
114+
current_dir = Path(__file__).parent
115+
generate_prompt = open(current_dir / "generate_prompt.txt").read()
116+
output_file = Path(__file__).parent / "ground_truth.jsonl"
117+
118+
qa: list[dict] = []
119+
for source in source_retriever():
120+
if len(qa) > num_questions_total:
121+
logger.info("Generated enough questions already, stopping")
122+
break
123+
result = openai_client.chat.completions.create(
124+
model=model,
125+
messages=[
126+
{"role": "system", "content": generate_prompt},
127+
{"role": "user", "content": json.dumps(source)},
128+
],
129+
tools=[qa_pairs_tool(num_questions=2)],
130+
)
131+
if not result.choices[0].message.tool_calls:
132+
logger.warning("No tool calls found in response, skipping")
133+
continue
134+
qa_pairs = json.loads(result.choices[0].message.tool_calls[0].function.arguments)["qa_list"]
135+
qa_pairs = [{"question": qa_pair["question"], "truth": qa_pair["answer"]} for qa_pair in qa_pairs]
136+
qa.extend(qa_pairs)
137+
138+
logger.info("Writing %d questions to %s", num_questions_total, output_file)
139+
directory = Path(output_file).parent
140+
if not directory.exists():
141+
directory.mkdir(parents=True)
142+
with open(output_file, "w", encoding="utf-8") as f:
143+
for item in qa[0:num_questions_total]:
144+
f.write(json.dumps(item) + "\n")
80145

81146

82147
if __name__ == "__main__":
83148
logging.basicConfig(level=logging.WARNING)
84149
logger.setLevel(logging.INFO)
85-
load_dotenv(".env", override=True)
86-
87-
generate_test_qa_data(
88-
openai_config=get_openai_config_dict(),
89-
num_questions_total=202,
90-
num_questions_per_source=2,
91-
output_file=Path(__file__).parent / "ground_truth.jsonl",
92-
source_retriever=source_retriever,
93-
source_to_text=source_to_text,
94-
answer_formatter=answer_formatter,
95-
)
150+
load_azd_env()
151+
152+
generate_ground_truth_data(num_questions_total=10)

evals/generate_prompt.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Your job is to generate example questions that a customer might ask about the products.
2+
You should come up with a question and an answer based on the provided data.
3+
The answer should include the product ID in square brackets.
4+
For example,
5+
'What climbing gear do you have?'
6+
with answer:
7+
'We have a variety of climbing gear, including ropes, harnesses, and carabiners. [1][2]'.
8+
Remember that customers probably don't know the names of specific brands,
9+
so your questions should be more general questions from someone who is shopping for these types of products.

0 commit comments

Comments
 (0)