Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ class LiteLLMConfig(TypedDict, total=False):
params: Model parameters (e.g., max_tokens).
For a complete list of supported parameters, see
https://docs.litellm.ai/docs/completion/input#input-params-1.
streaming: Optional flag to indicate whether provider streaming should be used.
If omitted, defaults to True (preserves existing behaviour).
"""

model_id: str
params: Optional[dict[str, Any]]
streaming: Optional[bool]

def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None:
"""Initialize provider instance.
Expand Down
151 changes: 115 additions & 36 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ class OpenAIConfig(TypedDict, total=False):
params: Model parameters (e.g., max_tokens).
For a complete list of supported parameters, see
https://platform.openai.com/docs/api-reference/chat/create.
streaming: Optional flag to indicate whether provider streaming should be used.
If omitted, defaults to True (preserves existing behaviour).
"""

model_id: str
params: Optional[dict[str, Any]]
streaming: Optional[bool]

def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
"""Initialize provider instance.
Expand Down Expand Up @@ -263,7 +266,8 @@ def format_request(
return {
"messages": self.format_request_messages(messages, system_prompt),
"model": self.config["model_id"],
"stream": True,
# Use configured streaming flag; default True to preserve previous behavior.
"stream": bool(self.get_config().get("streaming", True)),
"stream_options": {"include_usage": True},
"tools": [
{
Expand Down Expand Up @@ -352,6 +356,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
case _:
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")

def _convert_non_streaming_to_streaming(self, response: Any) -> list[StreamEvent]:
"""Convert a provider non-streaming response into streaming-style events.

This helper intentionally *does not* emit the initial message_start/content_start events,
because the caller (stream) already yields them to preserve parity with streaming flow.
"""
events: list[StreamEvent] = []

# Extract main text content from first choice if available
if getattr(response, "choices", None):
choice = response.choices[0]
content = None
if hasattr(choice, "message") and hasattr(choice.message, "content"):
content = choice.message.content

# handle str content
if isinstance(content, str):
events.append(self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": content}))
# handle list content (list of blocks/dicts)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict):
# reasoning content
if "reasoningContent" in block and isinstance(block["reasoningContent"], dict):
try:
text = block["reasoningContent"]["reasoningText"]["text"]
events.append(
self.format_chunk(
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": text}
)
)
except Exception:
# fall back to keeping the block as text if malformed
pass
# text block
elif "text" in block:
events.append(
self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": block["text"]}
)
)
# ignore other block types for now
elif isinstance(block, str):
events.append(
self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": block})
)

# content stop
events.append(self.format_chunk({"chunk_type": "content_stop"}))

# message stop — convert finish reason if available
stop_reason = None
if getattr(response, "choices", None):
stop_reason = getattr(response.choices[0], "finish_reason", None)
events.append(self.format_chunk({"chunk_type": "message_stop", "data": stop_reason or "stop"}))

# metadata (usage) if present
if getattr(response, "usage", None):
events.append(self.format_chunk({"chunk_type": "metadata", "data": response.usage}))

return events

@override
async def stream(
self,
Expand Down Expand Up @@ -409,50 +475,63 @@ async def stream(

tool_calls: dict[int, list[Any]] = {}

async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if choice.delta.content:
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
)

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}
)
streaming = bool(self.get_config().get("streaming", True))

if streaming:
# response is an async iterator when streaming=True
async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

if choice.delta.content:
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
)

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}
)

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)
for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
break
if choice.finish_reason:
break

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})

for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
for tool_deltas in tool_calls.values():
yield self.format_chunk(
{"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
)

for tool_delta in tool_deltas:
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
for tool_delta in tool_deltas:
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
)

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})

# Skip remaining events as we don't have use for anything except the final usage payload
async for event in response:
_ = event
# Skip remaining events as we don't have use for anything except the final usage payload
async for event in response:
_ = event

if event.usage:
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
if event.usage:
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
else:
# Non-streaming provider response — convert to streaming-style events (excluding the initial
# message_start/content_start because we already emitted them above).
for ev in self._convert_non_streaming_to_streaming(response):
yield ev

logger.debug("finished streaming response from model")

Expand Down
46 changes: 46 additions & 0 deletions tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,52 @@ async def test_stream(openai_client, model_id, model, agenerator, alist):
openai_client.chat.completions.create.assert_called_once_with(**expected_request)


@pytest.mark.asyncio
async def test_stream_respects_streaming_flag(openai_client, model_id, alist):
# Model configured to NOT stream
model = OpenAIModel(client_args={}, model_id=model_id, params={"max_tokens": 1}, streaming=False)

# Mock a non-streaming response object
mock_choice = unittest.mock.Mock()
mock_choice.finish_reason = "stop"
mock_choice.message = unittest.mock.Mock()
mock_choice.message.content = "non-stream result"
mock_response = unittest.mock.Mock()
mock_response.choices = [mock_choice]
mock_response.usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)

openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response)

# Consume the generator and verify the events
response_gen = model.stream([{"role": "user", "content": [{"text": "hi"}]}])
tru_events = await alist(response_gen)

expected_request = {
"max_tokens": 1,
"model": model_id,
"messages": [{"role": "user", "content": [{"text": "hi", "type": "text"}]}],
"stream": False,
"stream_options": {"include_usage": True},
"tools": [],
}
openai_client.chat.completions.create.assert_called_once_with(**expected_request)

exp_events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"text": "non-stream result"}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
{
"metadata": {
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
"metrics": {"latencyMs": 0},
}
},
]
assert tru_events == exp_events


@pytest.mark.asyncio
async def test_stream_empty(openai_client, model_id, model, agenerator, alist):
mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None)
Expand Down