Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f815280

Browse files
committedJun 17, 2025
Merge branch 'feature/include-pr-197'
2 parents e12bc2f + 239b00a commit f815280

File tree

4 files changed

+636
-3
lines changed

4 files changed

+636
-3
lines changed
 

‎README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ from strands import Agent
118118
from strands.models import BedrockModel
119119
from strands.models.ollama import OllamaModel
120120
from strands.models.llamaapi import LlamaAPIModel
121+
from strands.models.portkey import PortkeyModel
121122

122123
# Bedrock
123124
bedrock_model = BedrockModel(
@@ -142,6 +143,17 @@ llama_model = LlamaAPIModel(
142143
)
143144
agent = Agent(model=llama_model)
144145
response = agent("Tell me about Agentic AI")
146+
147+
# Portkey for all models
148+
portkey_model = PortkeyModel(
149+
api_key="<PORTKEY_API_KEY>",
150+
model_id="anthropic.claude-3-5-sonnet-20241022-v2:0",
151+
virtual_key="<BEDROCK_VIRTUAL_KEY>",
152+
provider="bedrock",
153+
base_url="http://portkey-service-gateway.service.prod.example.com/v1",
154+
)
155+
agent = Agent(model=portkey_model)
156+
response = agent("Tell me about Agentic AI")
145157
```
146158

147159
Built-in providers:

‎pyproject.toml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ packages = ["src/strands"]
4949
anthropic = [
5050
"anthropic>=0.21.0,<1.0.0",
5151
]
52+
# Optional dependencies for different AI providers
53+
5254
dev = [
5355
"commitizen>=4.4.0,<5.0.0",
5456
"hatch>=1.0.0,<2.0.0",
@@ -88,12 +90,17 @@ a2a = [
8890
"starlette>=0.46.2",
8991
]
9092

93+
portkey = [
94+
"portkey-ai>=1.0.0,<2.0.0",
95+
]
96+
9197
[tool.hatch.version]
9298
# Tells Hatch to use your version control system (git) to determine the version.
9399
source = "vcs"
94100

95101
[tool.hatch.envs.hatch-static-analysis]
96-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"]
102+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai","otel", "portkey"]
103+
97104
dependencies = [
98105
"mypy>=1.15.0,<2.0.0",
99106
"ruff>=0.11.6,<0.12.0",
@@ -116,7 +123,7 @@ lint-fix = [
116123
]
117124

118125
[tool.hatch.envs.hatch-test]
119-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel"]
126+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "portkey"]
120127
extra-dependencies = [
121128
"moto>=5.1.0,<6.0.0",
122129
"pytest>=8.0.0,<9.0.0",
@@ -132,7 +139,7 @@ extra-args = [
132139

133140
[tool.hatch.envs.dev]
134141
dev-mode = true
135-
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel"]
142+
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel,", "portkey"]
136143

137144
[tool.hatch.envs.a2a]
138145
dev-mode = true

‎src/strands/models/portkey.py

Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
"""Implementation of the Portkey model provider integration."""
2+
3+
import json
4+
import logging
5+
import uuid
6+
from typing import Any, Dict, Iterable, List, Optional, cast
7+
8+
from portkey_ai import Portkey
9+
from typing_extensions import TypedDict, override
10+
11+
from ..types.content import Messages
12+
from ..types.exceptions import ContextWindowOverflowException
13+
from ..types.models import Model
14+
from ..types.streaming import StreamEvent
15+
from ..types.tools import ToolSpec
16+
17+
# Configure logger for debug-level output
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class PortkeyModel(Model):
22+
"""Portkey model provider implementation."""
23+
24+
class PortkeyConfig(TypedDict, total=False):
25+
"""Configuration schema for the Portkey model."""
26+
27+
api_key: str
28+
virtual_key: str
29+
base_url: str
30+
model_id: str
31+
provider: str
32+
streaming: bool
33+
34+
def __init__(self, **model_config: PortkeyConfig):
35+
"""Initialize the Portkey model provider.
36+
37+
Sets up the model configuration and initializes the Portkey client.
38+
39+
Args:
40+
**model_config (PortkeyConfig): Configuration parameters for the model.
41+
"""
42+
self.config = PortkeyModel.PortkeyConfig()
43+
self.config["streaming"] = True
44+
self.update_config(**model_config)
45+
46+
# Extract provider(bedrock, openai, anthropic, etc) from model_config or infer from model_id.
47+
self.provider: str = str(model_config["provider"])
48+
49+
logger.debug("PortkeyModel initialized with config: %s", self.config)
50+
51+
self.client = Portkey(
52+
api_key=self.config["api_key"],
53+
virtual_key=self.config["virtual_key"],
54+
base_url=self.config["base_url"],
55+
model=self.config["model_id"],
56+
)
57+
self._current_tool_use_id: Optional[str] = None
58+
self._current_tool_name: Optional[str] = None
59+
self._current_tool_args = ""
60+
61+
@override
62+
def update_config(self, **model_config: PortkeyConfig) -> None:
63+
"""Update the model configuration.
64+
65+
Args:
66+
**model_config (PortkeyConfig): Configuration parameters to update.
67+
"""
68+
logger.debug("Updating config with: %s", model_config)
69+
self.config.update(cast(PortkeyModel.PortkeyConfig, model_config))
70+
71+
@override
72+
def get_config(self) -> PortkeyConfig:
73+
"""Retrieve the current model configuration.
74+
75+
Returns:
76+
PortkeyConfig: The current configuration dictionary.
77+
"""
78+
logger.debug("Retrieving current model config")
79+
return self.config
80+
81+
@override
82+
def format_request(
83+
self,
84+
messages: Messages,
85+
tool_specs: Optional[List[ToolSpec]] = None,
86+
system_prompt: Optional[str] = None,
87+
) -> Dict[str, Any]:
88+
"""Format the input messages and tool specifications into a request dictionary.
89+
90+
Prepares the messages, system prompt, and tool specifications into the format
91+
required by the Portkey client for streaming chat completions.
92+
93+
Args:
94+
messages (Messages): List of messages to format.
95+
tool_specs (Optional[List[ToolSpec]]): Optional list of tool specifications.
96+
system_prompt (Optional[str]): Optional system prompt string.
97+
98+
Returns:
99+
Dict[str, Any]: Formatted request dictionary.
100+
"""
101+
formatted_messages = []
102+
for msg in messages:
103+
role = msg.get("role")
104+
content = msg.get("content")
105+
if role in ("user", "assistant") and content:
106+
formatted_messages.extend(self._format_message_parts(role, content))
107+
108+
if system_prompt:
109+
formatted_messages.insert(0, {"role": "system", "content": system_prompt})
110+
111+
request = {
112+
"messages": formatted_messages,
113+
"model": self.config["model_id"],
114+
"stream": True,
115+
}
116+
117+
allow_tools = self._allow_tool_use()
118+
119+
if tool_specs and allow_tools:
120+
tool_calls = self._map_tools(tool_specs)
121+
else:
122+
tool_calls = None
123+
124+
if tool_calls:
125+
request["tools"] = tool_calls
126+
request["tool_choice"] = "auto"
127+
logger.debug("Formatted Portkey request: %s", json.dumps(request, default=str)[:300])
128+
return request
129+
130+
def _allow_tool_use(self) -> bool:
131+
"""Determine whether tool use is allowed based on provider and model.
132+
133+
Returns:
134+
bool: True if tool use is allowed for the current provider and model.
135+
"""
136+
provider = str(self.provider).lower()
137+
if provider == "openai":
138+
return True
139+
if provider == "bedrock":
140+
model_id = self.config.get("model_id", "").lower()
141+
return "anthropic" in model_id
142+
return False
143+
144+
@override
145+
def stream(self, request: Dict[str, Any]) -> Iterable[Any]:
146+
"""Stream responses from the Portkey client based on the request.
147+
148+
Args:
149+
request (Dict[str, Any]): The formatted request dictionary.
150+
151+
Returns:
152+
Iterable[Any]: An iterable stream of response events.
153+
154+
Raises:
155+
ContextWindowOverflowException: If the context window is exceeded.
156+
"""
157+
try:
158+
return iter(self.client.chat.completions.create(**request))
159+
except ContextWindowOverflowException:
160+
logger.error("Context window exceeded for request: %s", request)
161+
raise
162+
163+
@override
164+
def format_chunk(self, event: Any) -> StreamEvent:
165+
"""Format a single response event into a stream event for Strands Agents.
166+
167+
Converts the raw event from the Portkey client into the structured stream event
168+
format expected downstream.
169+
170+
Args:
171+
event (Any): The raw response event from the model.
172+
173+
Returns:
174+
StreamEvent: The formatted stream event dictionary.
175+
"""
176+
choice = event.get("choices", [{}])[0]
177+
delta = choice.get("delta", {})
178+
179+
tool_calls = delta.get("tool_calls")
180+
if tool_calls:
181+
tool_call = tool_calls[0]
182+
tool_name = tool_call.get("function", {}).get("name")
183+
call_type = tool_call.get("type")
184+
arguments_chunk = tool_call.get("function", {}).get("arguments", "")
185+
if tool_name and call_type and not self._current_tool_name:
186+
self._current_tool_name = tool_name
187+
self._current_tool_use_id = f"{tool_name}-{uuid.uuid4().hex[:6]}"
188+
self._current_tool_args = arguments_chunk
189+
return cast(
190+
StreamEvent,
191+
{
192+
"contentBlockStart": {
193+
"start": {
194+
"toolUse": {
195+
"name": self._current_tool_name,
196+
"toolUseId": self._current_tool_use_id,
197+
}
198+
}
199+
}
200+
},
201+
)
202+
203+
if arguments_chunk:
204+
return cast(StreamEvent, {"contentBlockDelta": {"delta": {"toolUse": {"input": arguments_chunk}}}})
205+
206+
if choice.get("finish_reason") == "tool_calls" or choice.get("finish_reason") == "tool_use":
207+
return cast(
208+
StreamEvent,
209+
{
210+
"contentBlockStop": {
211+
"name": self._current_tool_name,
212+
"toolUseId": self._current_tool_use_id,
213+
}
214+
},
215+
)
216+
217+
if delta.get("content"):
218+
return cast(StreamEvent, {"contentBlockDelta": {"delta": {"text": delta["content"]}}})
219+
elif event.get("usage"):
220+
usage_data = event["usage"]
221+
return cast(
222+
StreamEvent,
223+
{
224+
"metadata": {
225+
"metrics": {"latencyMs": 0},
226+
"usage": {
227+
"inputTokens": usage_data["prompt_tokens"],
228+
"outputTokens": usage_data["completion_tokens"],
229+
"totalTokens": usage_data["total_tokens"],
230+
},
231+
}
232+
},
233+
)
234+
return cast(StreamEvent, {})
235+
236+
@override
237+
def converse(
238+
self,
239+
messages: Messages,
240+
tool_specs: Optional[list[ToolSpec]] = None,
241+
system_prompt: Optional[str] = None,
242+
) -> Iterable[StreamEvent]:
243+
"""Converse with the model by streaming formatted message chunks.
244+
245+
Handles the full lifecycle of conversing with the model, including formatting
246+
the request, sending it, and yielding formatted response chunks.
247+
248+
Args:
249+
messages (Messages): List of message objects to be processed by the model.
250+
tool_specs (Optional[list[ToolSpec]]): List of tool specifications available to the model.
251+
system_prompt (Optional[str]): System prompt to provide context to the model.
252+
253+
Yields:
254+
Iterable[StreamEvent]: Formatted message chunks from the model.
255+
256+
Raises:
257+
ModelThrottledException: When the model service is throttling requests from the client.
258+
"""
259+
logger.debug("formatting request")
260+
request = self.format_request(messages, tool_specs, system_prompt)
261+
262+
logger.debug("invoking model %s", request)
263+
response = self.stream(request)
264+
logger.debug("streaming response from model %s", response)
265+
266+
yield cast(StreamEvent, {"messageStart": {"role": "assistant"}})
267+
268+
for event in response:
269+
yield self.format_chunk(event)
270+
271+
if self._should_terminate_with_tool_use(event):
272+
yield cast(StreamEvent, {"messageStop": {"stopReason": "tool_use"}})
273+
logger.debug("finished streaming response from model")
274+
275+
self._current_tool_use_id = None
276+
self._current_tool_name = None
277+
self._current_tool_args = ""
278+
279+
@staticmethod
280+
def _should_terminate_with_tool_use(event: dict) -> bool:
281+
"""Determine whether the stream should terminate due to a tool use.
282+
283+
This accounts for inconsistencies across providers: some may return a 'tool_calls'
284+
payload but label the finish_reason as 'stop' instead of 'tool_calls'.
285+
286+
Args:
287+
event (dict): The raw event from the model.
288+
289+
Returns:
290+
bool: True if the event indicates a tool use termination.
291+
"""
292+
choice = event.get("choices", [{}])[0]
293+
finish_reason = (choice.get("finish_reason") or "").lower()
294+
return finish_reason in ["tool_calls", "tool_use"]
295+
296+
def _format_tool_use_part(self, part: dict) -> dict:
297+
"""Format a tool use part of a message into the standard dictionary format.
298+
299+
Args:
300+
part (dict): The part of the message representing a tool use.
301+
302+
Returns:
303+
dict: Formatted dictionary representing the tool use.
304+
"""
305+
logger.debug("Formatting tool use part: %s", part)
306+
self._current_tool_use_id = part["toolUse"]["toolUseId"]
307+
return {
308+
"role": "assistant",
309+
"tool_calls": [
310+
{
311+
"id": self._current_tool_use_id,
312+
"type": "function",
313+
"function": {"name": part["toolUse"]["name"], "arguments": json.dumps(part["toolUse"]["input"])},
314+
}
315+
],
316+
"content": None,
317+
}
318+
319+
def _format_tool_result_part(self, part: dict) -> dict:
320+
"""Format a tool result part of a message into the standard dictionary format.
321+
322+
Args:
323+
part (dict): The part of the message representing a tool result.
324+
325+
Returns:
326+
dict: Formatted dictionary representing the tool result.
327+
"""
328+
logger.debug("Formatting tool result part: %s", part)
329+
result_text = " ".join([c["text"] for c in part["toolResult"]["content"] if "text" in c])
330+
return {"role": "tool", "tool_call_id": self._current_tool_use_id, "content": result_text}
331+
332+
def _format_message_parts(self, role: str, content: Any) -> List[Dict[str, Any]]:
333+
"""Format message parts into a list of standardized message dictionaries.
334+
335+
Handles plain text content as well as structured parts including tool uses and results.
336+
337+
Args:
338+
role (str): The role of the message sender (e.g., 'user', 'assistant').
339+
content (Any): The content of the message, can be string or list of parts.
340+
341+
Returns:
342+
List[Dict[str, Any]]: List of formatted message dictionaries.
343+
"""
344+
logger.debug("Formatting message parts for role '%s' with content: %s", role, content)
345+
parts = []
346+
if isinstance(content, str):
347+
parts.append({"role": role, "content": content})
348+
elif isinstance(content, list):
349+
for part in content:
350+
if "text" in part and isinstance(part["text"], str):
351+
parts.append({"role": role, "content": part["text"]})
352+
elif "toolUse" in part:
353+
parts.append(self._format_tool_use_part(part))
354+
elif "toolResult" in part and self._current_tool_use_id:
355+
parts.append(self._format_tool_result_part(part))
356+
return parts
357+
358+
@staticmethod
359+
def _map_tools(tool_specs: List[ToolSpec]) -> List[Dict[str, Any]]:
360+
"""Map tool specifications to the format expected by Portkey.
361+
362+
Args:
363+
tool_specs (List[ToolSpec]): List of tool specifications.
364+
365+
Returns:
366+
List[Dict[str, Any]]: Mapped list of tool dictionaries.
367+
"""
368+
logger.debug("Mapping tool specs: %s", tool_specs)
369+
return [
370+
{
371+
"type": "function",
372+
"function": {
373+
"name": spec["name"],
374+
"description": spec["description"],
375+
"parameters": {
376+
"type": "object",
377+
"properties": {
378+
k: {key: value for key, value in v.items() if key != "default" or value is not None}
379+
for k, v in spec["inputSchema"]["json"].get("properties", {}).items()
380+
},
381+
"required": spec["inputSchema"]["json"].get("required", []),
382+
},
383+
},
384+
}
385+
for spec in tool_specs
386+
]

‎tests/strands/models/test_portkey.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Python
2+
import unittest.mock
3+
4+
import pytest
5+
6+
from src.strands.models.portkey import PortkeyModel
7+
from src.strands.types.exceptions import ContextWindowOverflowException
8+
9+
10+
@pytest.fixture
11+
def model_config():
12+
return {
13+
"api_key": "test_api_key",
14+
"virtual_key": "test_virtual_key",
15+
"base_url": "https://test.url",
16+
"model_id": "test_model_id",
17+
"provider": "openai",
18+
}
19+
20+
21+
@pytest.fixture
22+
def portkey_model(model_config):
23+
return PortkeyModel(**model_config)
24+
25+
26+
def test__init__(portkey_model):
27+
assert portkey_model.config["api_key"] == "test_api_key"
28+
assert portkey_model.provider == "openai"
29+
30+
31+
def test_get_config(portkey_model):
32+
config = portkey_model.get_config()
33+
assert config["api_key"] == "test_api_key"
34+
35+
36+
def test_format_request_no_tools(portkey_model):
37+
messages = [{"role": "user", "content": "Hello"}]
38+
request = portkey_model.format_request(messages)
39+
assert "tools" not in request
40+
41+
42+
def test_format_request_with_tools(portkey_model):
43+
messages = [{"role": "user", "content": "Hello"}]
44+
tool_specs = [{"name": "test_tool", "description": "Test tool", "inputSchema": {"json": {"properties": {}}}}]
45+
request = portkey_model.format_request(messages, tool_specs)
46+
assert "tools" in request
47+
48+
49+
def test_format_request_system_prompt(portkey_model):
50+
messages = [{"role": "user", "content": "Hello"}]
51+
system_prompt = "Test system prompt"
52+
request = portkey_model.format_request(messages, system_prompt=system_prompt)
53+
assert request["messages"][0]["role"] == "system"
54+
55+
56+
def test_allow_tool_use_openai(portkey_model):
57+
assert portkey_model._allow_tool_use()
58+
59+
60+
def test_allow_tool_use_bedrock():
61+
model_config = {
62+
"api_key": "test_api_key",
63+
"virtual_key": "test_virtual_key",
64+
"base_url": "https://test.url",
65+
"model_id": "anthropic_model_id",
66+
"provider": "bedrock",
67+
}
68+
portkey_model = PortkeyModel(**model_config)
69+
assert portkey_model._allow_tool_use() is True
70+
71+
72+
def test_allow_tool_use_false():
73+
model_config = {
74+
"api_key": "test_api_key",
75+
"virtual_key": "test_virtual_key",
76+
"base_url": "https://test.url",
77+
"model_id": "test_model_id",
78+
"provider": "unknown",
79+
}
80+
portkey_model = PortkeyModel(**model_config)
81+
assert portkey_model._allow_tool_use() is False
82+
83+
84+
def test_stream(portkey_model):
85+
mock_event = {"choices": [{"delta": {"content": "test"}}]}
86+
with unittest.mock.patch.object(portkey_model.client.chat.completions, "create", return_value=iter([mock_event])):
87+
request = {"messages": [{"role": "user", "content": "Hello"}], "model": "test_model_id", "stream": True}
88+
response = list(portkey_model.stream(request))
89+
assert response[0]["choices"][0]["delta"]["content"] == "test"
90+
91+
92+
def test_stream_context_window_exception(portkey_model):
93+
with unittest.mock.patch.object(
94+
portkey_model.client.chat.completions,
95+
"create",
96+
side_effect=ContextWindowOverflowException("Context window exceeded"),
97+
):
98+
request = {"messages": [{"role": "user", "content": "Hello"}], "model": "test_model_id", "stream": True}
99+
with pytest.raises(ContextWindowOverflowException):
100+
list(portkey_model.stream(request))
101+
102+
103+
def test_format_chunk_tool_calls(portkey_model):
104+
event = {
105+
"choices": [
106+
{
107+
"delta": {
108+
"tool_calls": [
109+
{
110+
"function": {"name": "test_tool", "arguments": "test_args"},
111+
"type": "function",
112+
}
113+
]
114+
},
115+
"finish_reason": None,
116+
}
117+
]
118+
}
119+
chunk = portkey_model.format_chunk(event)
120+
assert "contentBlockStart" in chunk
121+
122+
123+
def test_format_chunk_arguments_chunk(portkey_model):
124+
event = {
125+
"choices": [
126+
{
127+
"delta": {
128+
"tool_calls": [
129+
{
130+
"function": {"arguments": "test_args"},
131+
}
132+
]
133+
},
134+
"finish_reason": None,
135+
}
136+
]
137+
}
138+
chunk = portkey_model.format_chunk(event)
139+
assert "contentBlockDelta" in chunk
140+
141+
142+
def test_format_chunk_finish_reason_tool_calls(portkey_model):
143+
event = {"choices": [{"finish_reason": "tool_calls"}]}
144+
chunk = portkey_model.format_chunk(event)
145+
assert "contentBlockStop" in chunk
146+
147+
148+
def test_format_chunk_usage(portkey_model):
149+
event = {
150+
"usage": {
151+
"prompt_tokens": 10,
152+
"completion_tokens": 5,
153+
"total_tokens": 15,
154+
},
155+
"choices": [{"delta": {"content": None}}], # Ensure 'content' key exists
156+
}
157+
chunk = portkey_model.format_chunk(event)
158+
assert chunk["metadata"]["usage"]["totalTokens"] == 15
159+
160+
161+
def test_format_message_parts_string(portkey_model):
162+
parts = portkey_model._format_message_parts("user", "test content")
163+
assert parts == [{"role": "user", "content": "test content"}]
164+
165+
166+
def test_format_message_parts_list_with_text(portkey_model):
167+
content = [{"text": "test text"}]
168+
parts = portkey_model._format_message_parts("assistant", content)
169+
assert parts == [{"role": "assistant", "content": "test text"}]
170+
171+
172+
def test_format_message_parts_tool_use(portkey_model):
173+
content = [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}]
174+
parts = portkey_model._format_message_parts("assistant", content)
175+
assert "tool_calls" in parts[0]
176+
177+
178+
def test_format_message_parts_tool_result(portkey_model):
179+
portkey_model._current_tool_use_id = "123"
180+
content = [{"toolResult": {"content": [{"text": "result text"}]}}]
181+
parts = portkey_model._format_message_parts("assistant", content)
182+
assert parts[0]["content"] == "result text"
183+
184+
185+
def test_map_tools(portkey_model):
186+
tool_specs = [
187+
{
188+
"name": "test_tool",
189+
"description": "Test tool",
190+
"inputSchema": {
191+
"json": {
192+
"properties": {"arg1": {"type": "string"}},
193+
"required": ["arg1"],
194+
}
195+
},
196+
}
197+
]
198+
tools = portkey_model._map_tools(tool_specs)
199+
assert tools[0]["function"]["name"] == "test_tool"
200+
assert tools[0]["function"]["parameters"]["required"] == ["arg1"]
201+
202+
203+
def test_format_tool_use_part(portkey_model):
204+
part = {"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {}}}
205+
formatted = portkey_model._format_tool_use_part(part)
206+
assert formatted["tool_calls"][0]["function"]["name"] == "test_tool"
207+
208+
209+
def test_format_tool_result_part(portkey_model):
210+
portkey_model._current_tool_use_id = "123"
211+
part = {"toolResult": {"content": [{"text": "result text"}]}}
212+
formatted = portkey_model._format_tool_result_part(part)
213+
assert formatted["content"] == "result text"
214+
215+
216+
def test_should_terminate_with_tool_use(portkey_model):
217+
event = {"choices": [{"finish_reason": "tool_calls"}]}
218+
assert portkey_model._should_terminate_with_tool_use(event) is True
219+
220+
221+
def test_converse(portkey_model):
222+
mock_event = {"choices": [{"delta": {"content": "test"}}]}
223+
with unittest.mock.patch.object(portkey_model.client.chat.completions, "create", return_value=iter([mock_event])):
224+
messages = [{"role": "user", "content": "Hello"}]
225+
tool_specs = [{"name": "test_tool", "description": "Test tool", "inputSchema": {"json": {"properties": {}}}}]
226+
system_prompt = "Test system prompt"
227+
response = list(portkey_model.converse(messages, tool_specs, system_prompt))
228+
assert response[0]["messageStart"]["role"] == "assistant"

0 commit comments

Comments
 (0)
Please sign in to comment.