Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ opentelemetry = [
]
pydantic = ["pydantic>=2.0.0,<3"]
openai-agents = [
"openai-agents>=0.2.3,<=0.2.9", # 0.2.10 doesn't work: https://github.com/openai/openai-agents-python/issues/1639
"openai-agents>=0.2.11,<0.3",
"eval-type-backport>=0.2.2; python_version < '3.10'"
]

Expand Down Expand Up @@ -57,7 +57,7 @@ dev = [
"pytest-cov>=6.1.1",
"httpx>=0.28.1",
"pytest-pretty>=1.3.0",
"openai-agents[litellm]>=0.2.3,<=0.2.9", # 0.2.10 doesn't work: https://github.com/openai/openai-agents-python/issues/1639
"openai-agents[litellm]>=0.2.11,<0.3"
]

[tool.poe.tasks]
Expand Down
2 changes: 2 additions & 0 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class ActivityModelInput(TypedDict, total=False):
handoffs: list[HandoffInput]
tracing: Required[ModelTracingInput]
previous_response_id: Optional[str]
conversation_id: Optional[str]
prompt: Optional[Any]


Expand Down Expand Up @@ -226,6 +227,7 @@ def make_tool(tool: ToolInput) -> Tool:
handoffs=handoffs,
tracing=ModelTracing(input["tracing"]),
previous_response_id=input.get("previous_response_id"),
conversation_id=input.get("conversation_id"),
prompt=input.get("prompt"),
)
except APIStatusError as e:
Expand Down
3 changes: 3 additions & 0 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ async def get_response(
tracing: ModelTracing,
*,
previous_response_id: Optional[str],
conversation_id: Optional[str],
prompt: Optional[ResponsePromptParam],
) -> ModelResponse:
def make_tool_info(tool: Tool) -> ToolInput:
Expand Down Expand Up @@ -134,6 +135,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
handoffs=handoff_infos,
tracing=ModelTracingInput(tracing.value),
previous_response_id=previous_response_id,
conversation_id=conversation_id,
prompt=prompt,
)

Expand Down Expand Up @@ -178,6 +180,7 @@ def stream_response(
tracing: ModelTracing,
*,
previous_response_id: Optional[str],
conversation_id: Optional[str],
prompt: ResponsePromptParam | None,
) -> AsyncIterator[TResponseStreamEvent]:
raise NotImplementedError("Temporal model doesn't support streams yet")
Expand Down
8 changes: 2 additions & 6 deletions temporalio/contrib/openai_agents/_temporal_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ async def get_response(
output_schema: Union[AgentOutputSchemaBase, None],
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: Union[str, None],
prompt: Union[ResponsePromptParam, None] = None,
**kwargs,
) -> ModelResponse:
"""Get a response from the model."""
return self.fn()
Expand All @@ -142,9 +140,7 @@ def stream_response(
output_schema: Optional[AgentOutputSchemaBase],
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: Optional[str],
prompt: Optional[ResponsePromptParam],
**kwargs,
) -> AsyncIterator[TResponseStreamEvent]:
"""Get a streamed response from the model. Unimplemented."""
raise NotImplementedError()
Expand Down
13 changes: 5 additions & 8 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,8 +1164,9 @@ async def get_response(
output_schema: Union[AgentOutputSchemaBase, None],
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: Union[str, None],
prompt: Union[ResponsePromptParam, None] = None,
previous_response_id: Optional[str] = None,
conversation_id: Optional[str] = None,
prompt: Optional[ResponsePromptParam] = None,
) -> ModelResponse:
if (
system_instructions
Expand Down Expand Up @@ -1553,9 +1554,7 @@ async def get_response(
output_schema: Union[AgentOutputSchemaBase, None],
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: Union[str, None],
prompt: Union[ResponsePromptParam, None] = None,
**kwargs,
) -> ModelResponse:
activity.logger.info("Waiting")
await asyncio.sleep(1.0)
Expand All @@ -1571,9 +1570,7 @@ def stream_response(
output_schema: Optional[AgentOutputSchemaBase],
handoffs: list[Handoff],
tracing: ModelTracing,
*,
previous_response_id: Optional[str],
prompt: Optional[ResponsePromptParam],
**kwargs,
) -> AsyncIterator[TResponseStreamEvent]:
raise NotImplementedError()

Expand Down
16 changes: 8 additions & 8 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.