Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/agents/voice/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ class StreamedAudioInput:
"""

def __init__(self):
self.queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = asyncio.Queue()
self.queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32] | None] = asyncio.Queue()

async def add_audio(self, audio: npt.NDArray[np.int16 | np.float32]):
async def add_audio(self, audio: npt.NDArray[np.int16 | np.float32] | None):
"""Adds more audio data to the stream.

Args:
audio: The audio data to add. Must be a numpy array of int16 or float32.
audio: The audio data to add. Must be a numpy array of int16 or float32 or None.
If None passed, it indicates the end of the stream.
"""
await self.queue.put(audio)
4 changes: 2 additions & 2 deletions src/agents/voice/models/openai_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
self._trace_include_sensitive_data = trace_include_sensitive_data
self._trace_include_sensitive_audio_data = trace_include_sensitive_audio_data

self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = input.queue
self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32] | None] = input.queue
self._output_queue: asyncio.Queue[str | ErrorSentinel | SessionCompleteSentinel] = (
asyncio.Queue()
)
Expand Down Expand Up @@ -245,7 +245,7 @@ async def _handle_events(self) -> None:
await self._output_queue.put(SessionCompleteSentinel())

async def _stream_audio(
self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]]
self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32] | None]
) -> None:
assert self._websocket is not None, "Websocket not initialized"
self._start_turn()
Expand Down
11 changes: 9 additions & 2 deletions tests/voice/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,14 @@ async def test_streamed_audio_input(self):
# Verify the queue contents
assert streamed_input.queue.qsize() == 2
# Test non-blocking get
assert np.array_equal(streamed_input.queue.get_nowait(), audio1)
retrieved_audio1 = streamed_input.queue.get_nowait()
# Satisfy type checker
assert retrieved_audio1 is not None
assert np.array_equal(retrieved_audio1, audio1)

# Test blocking get
assert np.array_equal(await streamed_input.queue.get(), audio2)
retrieved_audio2 = await streamed_input.queue.get()
# Satisfy type checker
assert retrieved_audio2 is not None
assert np.array_equal(retrieved_audio2, audio2)
assert streamed_input.queue.empty()