Skip to content

Port from pydantic-ai to openai-agents SDK #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/app-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ jobs:
key: mypy${{ matrix.os }}-${{ matrix.python_version }}-${{ hashFiles('requirements-dev.txt', 'src/backend/requirements.txt', 'src/backend/pyproject.toml') }}

- name: Run MyPy
run: python3 -m mypy .
run: python3 -m mypy . --python-version ${{ matrix.python_version }}

- name: Run Pytest
run: python3 -m pytest -s -vv --cov --cov-fail-under=85
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ lint.isort.known-first-party = ["fastapi_app"]

[tool.mypy]
check_untyped_defs = true
python_version = 3.9
exclude = [".venv/*"]

[tool.pytest.ini_options]
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@ pytest-snapshot
locust
psycopg2
dotenv-azd
freezegun
9 changes: 4 additions & 5 deletions src/backend/fastapi_app/api_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from enum import Enum
from typing import Any, Optional, Union
from typing import Any, Optional

from openai.types.chat import ChatCompletionMessageParam
from openai.types.responses import ResponseInputItemParam
from pydantic import BaseModel, Field
from pydantic_ai.messages import ModelRequest, ModelResponse


class AIChatRoles(str, Enum):
Expand Down Expand Up @@ -37,7 +36,7 @@ class ChatRequestContext(BaseModel):


class ChatRequest(BaseModel):
messages: list[ChatCompletionMessageParam]
messages: list[ResponseInputItemParam]
context: ChatRequestContext
sessionState: Optional[Any] = None

Expand Down Expand Up @@ -96,7 +95,7 @@ class ChatParams(ChatRequestOverrides):
enable_text_search: bool
enable_vector_search: bool
original_user_query: str
past_messages: list[Union[ModelRequest, ModelResponse]]
past_messages: list[ResponseInputItemParam]


class Filter(BaseModel):
Expand Down
84 changes: 22 additions & 62 deletions src/backend/fastapi_app/prompts/query_fewshots.json
Original file line number Diff line number Diff line change
@@ -1,76 +1,36 @@
[
{
"parts": [
{
"content": "good options for climbing gear that can be used outside?",
"timestamp": "2025-05-07T19:02:46.977501Z",
"part_kind": "user-prompt"
}
],
"instructions": null,
"kind": "request"
"role": "user",
"content": "good options for climbing gear that can be used outside?"
},
{
"parts": [
{
"tool_name": "search_database",
"args": "{\"search_query\":\"climbing gear outside\"}",
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
"part_kind": "tool-call"
}
],
"model_name": "gpt-4o-mini-2024-07-18",
"timestamp": "2025-05-07T19:02:47Z",
"kind": "response"
"id": "madeup",
"call_id": "call_abc123",
"name": "search_database",
"arguments": "{\"search_query\":\"climbing gear outside\"}",
"type": "function_call"
},
{
"parts": [
{
"tool_name": "search_database",
"content": "Search results for climbing gear that can be used outside: ...",
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
"timestamp": "2025-05-07T19:02:48.242408Z",
"part_kind": "tool-return"
}
],
"instructions": null,
"kind": "request"
"id": "madeupoutput",
"call_id": "call_abc123",
"output": "Search results for climbing gear that can be used outside: ...",
"type": "function_call_output"
},
{
"parts": [
{
"content": "are there any shoes less than $50?",
"timestamp": "2025-05-07T19:02:46.977501Z",
"part_kind": "user-prompt"
}
],
"instructions": null,
"kind": "request"
"role": "user",
"content": "are there any shoes less than $50?"
},
{
"parts": [
{
"tool_name": "search_database",
"args": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
"part_kind": "tool-call"
}
],
"model_name": "gpt-4o-mini-2024-07-18",
"timestamp": "2025-05-07T19:02:47Z",
"kind": "response"
"id": "madeup",
"call_id": "call_abc456",
"name": "search_database",
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
"type": "function_call"
},
{
"parts": [
{
"tool_name": "search_database",
"content": "Search results for shoes cheaper than 50: ...",
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
"timestamp": "2025-05-07T19:02:48.242408Z",
"part_kind": "tool-return"
}
],
"instructions": null,
"kind": "request"
"id": "madeupoutput",
"call_id": "call_abc456",
"output": "Search results for shoes cheaper than 50: ...",
"type": "function_call_output"
}
]
146 changes: 80 additions & 66 deletions src/backend/fastapi_app/rag_advanced.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import json
from collections.abc import AsyncGenerator
from typing import Optional, Union

from agents import (
Agent,
ItemHelpers,
ModelSettings,
OpenAIChatCompletionsModel,
Runner,
ToolCallOutputItem,
function_tool,
set_tracing_disabled,
)
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic_ai import Agent, RunContext
from pydantic_ai.messages import ModelMessagesTypeAdapter
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.settings import ModelSettings
from openai.types.responses import EasyInputMessageParam, ResponseInputItemParam, ResponseTextDeltaEvent

from fastapi_app.api_models import (
AIChatRoles,
Expand All @@ -24,7 +30,9 @@
ThoughtStep,
)
from fastapi_app.postgres_searcher import PostgresSearcher
from fastapi_app.rag_base import ChatParams, RAGChatBase
from fastapi_app.rag_base import RAGChatBase

set_tracing_disabled(disabled=True)


class AdvancedRAGChat(RAGChatBase):
Expand All @@ -34,7 +42,7 @@ class AdvancedRAGChat(RAGChatBase):
def __init__(
self,
*,
messages: list[ChatCompletionMessageParam],
messages: list[ResponseInputItemParam],
overrides: ChatRequestOverrides,
searcher: PostgresSearcher,
openai_chat_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
Expand All @@ -46,34 +54,29 @@ def __init__(
self.model_for_thoughts = (
{"model": chat_model, "deployment": chat_deployment} if chat_deployment else {"model": chat_model}
)
pydantic_chat_model = OpenAIModel(
chat_model if chat_deployment is None else chat_deployment,
provider=OpenAIProvider(openai_client=openai_chat_client),
openai_agents_model = OpenAIChatCompletionsModel(
model=chat_model if chat_deployment is None else chat_deployment, openai_client=openai_chat_client
)
self.search_agent = Agent[ChatParams, SearchResults](
pydantic_chat_model,
model_settings=ModelSettings(
temperature=0.0,
max_tokens=500,
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
),
system_prompt=self.query_prompt_template,
tools=[self.search_database],
output_type=SearchResults,
self.search_agent = Agent(
name="Searcher",
instructions=self.query_prompt_template,
tools=[function_tool(self.search_database)],
tool_use_behavior="stop_on_first_tool",
model=openai_agents_model,
)
self.answer_agent = Agent(
pydantic_chat_model,
system_prompt=self.answer_prompt_template,
name="Answerer",
instructions=self.answer_prompt_template,
model=openai_agents_model,
model_settings=ModelSettings(
temperature=self.chat_params.temperature,
max_tokens=self.chat_params.response_token_limit,
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
extra_body={"seed": self.chat_params.seed} if self.chat_params.seed is not None else {},
),
)

async def search_database(
self,
ctx: RunContext[ChatParams],
search_query: str,
price_filter: Optional[PriceFilter] = None,
brand_filter: Optional[BrandFilter] = None,
Expand All @@ -97,66 +100,73 @@ async def search_database(
filters.append(brand_filter)
results = await self.searcher.search_and_embed(
search_query,
top=ctx.deps.top,
enable_vector_search=ctx.deps.enable_vector_search,
enable_text_search=ctx.deps.enable_text_search,
top=self.chat_params.top,
enable_vector_search=self.chat_params.enable_vector_search,
enable_text_search=self.chat_params.enable_text_search,
filters=filters,
)
return SearchResults(
query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters
)

async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots)
few_shots: list[ResponseInputItemParam] = json.loads(self.query_fewshots)
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
results = await self.search_agent.run(
user_query,
message_history=few_shots + self.chat_params.past_messages,
deps=self.chat_params,
)
items = results.output.items
new_user_message = EasyInputMessageParam(role="user", content=user_query)
all_messages = few_shots + self.chat_params.past_messages + [new_user_message]

run_results = await Runner.run(self.search_agent, input=all_messages)
most_recent_response = run_results.new_items[-1]
if isinstance(most_recent_response, ToolCallOutputItem):
search_results = most_recent_response.output
else:
raise ValueError("Error retrieving search results, model did not call tool properly")

thoughts = [
ThoughtStep(
title="Prompt to generate search arguments",
description=results.all_messages(),
description=[{"content": self.query_prompt_template}]
+ ItemHelpers.input_to_new_input_list(run_results.input),
props=self.model_for_thoughts,
),
ThoughtStep(
title="Search using generated search arguments",
description=results.output.query,
description=search_results.query,
props={
"top": self.chat_params.top,
"vector_search": self.chat_params.enable_vector_search,
"text_search": self.chat_params.enable_text_search,
"filters": results.output.filters,
"filters": search_results.filters,
},
),
ThoughtStep(
title="Search results",
description=items,
description=search_results.items,
),
]
return items, thoughts
return search_results.items, thoughts

async def answer(
self,
items: list[ItemPublic],
earlier_thoughts: list[ThoughtStep],
) -> RetrievalResponse:
response = await self.answer_agent.run(
user_prompt=self.prepare_rag_request(self.chat_params.original_user_query, items),
message_history=self.chat_params.past_messages,
run_results = await Runner.run(
self.answer_agent,
input=self.chat_params.past_messages
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}],
)

return RetrievalResponse(
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
message=Message(content=str(run_results.final_output), role=AIChatRoles.ASSISTANT),
context=RAGContext(
data_points={item.id: item for item in items},
thoughts=earlier_thoughts
+ [
ThoughtStep(
title="Prompt to generate answer",
description=response.all_messages(),
description=[{"content": self.answer_prompt_template}]
+ ItemHelpers.input_to_new_input_list(run_results.input),
props=self.model_for_thoughts,
),
],
Expand All @@ -168,24 +178,28 @@ async def answer_stream(
items: list[ItemPublic],
earlier_thoughts: list[ThoughtStep],
) -> AsyncGenerator[RetrievalResponseDelta, None]:
async with self.answer_agent.run_stream(
self.prepare_rag_request(self.chat_params.original_user_query, items),
message_history=self.chat_params.past_messages,
) as agent_stream_runner:
yield RetrievalResponseDelta(
context=RAGContext(
data_points={item.id: item for item in items},
thoughts=earlier_thoughts
+ [
ThoughtStep(
title="Prompt to generate answer",
description=agent_stream_runner.all_messages(),
props=self.model_for_thoughts,
),
],
),
)

async for message in agent_stream_runner.stream_text(delta=True, debounce_by=None):
yield RetrievalResponseDelta(delta=Message(content=str(message), role=AIChatRoles.ASSISTANT))
return
run_results = Runner.run_streamed(
self.answer_agent,
input=self.chat_params.past_messages
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}], # noqa
)

yield RetrievalResponseDelta(
context=RAGContext(
data_points={item.id: item for item in items},
thoughts=earlier_thoughts
+ [
ThoughtStep(
title="Prompt to generate answer",
description=[{"content": self.answer_prompt_template}]
+ ItemHelpers.input_to_new_input_list(run_results.input),
props=self.model_for_thoughts,
),
],
),
)

async for event in run_results.stream_events():
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
yield RetrievalResponseDelta(delta=Message(content=str(event.data.delta), role=AIChatRoles.ASSISTANT))
return
Loading
Loading