Skip to content

Implement Auto Flush #413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 231 additions & 19 deletions deepgram/clients/live/v1/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
from typing import Dict, Union, Optional, cast, Any
from datetime import datetime

import websockets
from websockets.client import WebSocketClientProtocol
Expand All @@ -28,6 +29,7 @@
from .options import LiveOptions

ONE_SECOND = 1
HALF_SECOND = 0.5
DEEPGRAM_INTERVAL = 5
PING_INTERVAL = 20

Expand All @@ -49,8 +51,12 @@ class AsyncLiveClient: # pylint: disable=too-many-instance-attributes

_socket: WebSocketClientProtocol
_event_handlers: Dict[LiveTranscriptionEvents, list]
_listen_thread: asyncio.Task
_keep_alive_thread: asyncio.Task

_last_datagram: Optional[datetime] = None

_listen_thread: Union[asyncio.Task, None]
_keep_alive_thread: Union[asyncio.Task, None]
_flush_thread: Union[asyncio.Task, None]

_kwargs: Optional[Dict] = None
_addons: Optional[Dict] = None
Expand All @@ -67,7 +73,16 @@ def __init__(self, config: DeepgramClientOptions):

self._config = config
self._endpoint = "v1/listen"

self._listen_thread = None
self._keep_alive_thread = None
self._flush_thread = None

# exit
self._exit_event = asyncio.Event()

# auto flush
self._flush_event = asyncio.Event()
self._event_handlers = {
event: [] for event in LiveTranscriptionEvents.__members__.values()
}
Expand Down Expand Up @@ -112,7 +127,7 @@ async def start(

if isinstance(options, LiveOptions):
self._logger.info("LiveOptions switching class -> dict")
self._options = cast(Dict[str, str], options.to_dict())
self._options = options.to_dict()
elif options is not None:
self._options = options
else:
Expand Down Expand Up @@ -146,12 +161,19 @@ async def start(
self._listen_thread = asyncio.create_task(self._listening())

# keepalive thread
if self._config.options.get("keepalive") == "true":
if self._config.is_keep_alive_enabled():
self._logger.notice("keepalive is enabled")
self._keep_alive_thread = asyncio.create_task(self._keep_alive())
else:
self._logger.notice("keepalive is disabled")

# flush thread
if self._config.is_auto_flush_enabled():
self._logger.notice("autoflush is enabled")
self._flush_thread = asyncio.create_task(self._flush())
else:
self._logger.notice("autoflush is disabled")

# push open event
await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Open),
Expand Down Expand Up @@ -186,7 +208,7 @@ def on(self, event: LiveTranscriptionEvents, handler) -> None:
"""
Registers event handlers for specific events.
"""
self._logger.info("event fired: %s", event)
self._logger.info("event subscribed: %s", event)
if event in LiveTranscriptionEvents.__members__.values() and callable(handler):
self._event_handlers[event].append(handler)

Expand All @@ -195,13 +217,14 @@ async def _emit(self, event: LiveTranscriptionEvents, *args, **kwargs) -> None:
"""
Emits events to the registered event handlers.
"""
self._logger.debug("callback handlers for: %s", event)
for handler in self._event_handlers[event]:
if asyncio.iscoroutinefunction(handler):
await handler(self, *args, **kwargs)
else:
asyncio.create_task(handler(self, *args, **kwargs))

# pylint: disable=too-many-return-statements,too-many-statements,too-many-locals
# pylint: disable=too-many-return-statements,too-many-statements,too-many-locals,too-many-branches
async def _listening(self) -> None:
"""
Listens for messages from the WebSocket connection.
Expand Down Expand Up @@ -244,6 +267,13 @@ async def _listening(self) -> None:
message
)
self._logger.verbose("LiveResultResponse: %s", msg_result)

# auto flush
if self._config.is_inspecting_messages():
inspect_res = await self._inspect(msg_result)
if not inspect_res:
self._logger.error("inspect_res failed")

await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Transcript),
result=msg_result,
Expand Down Expand Up @@ -426,8 +456,7 @@ async def _keep_alive(self) -> None:

# deepgram keepalive
if counter % DEEPGRAM_INTERVAL == 0:
self._logger.verbose("Sending KeepAlive...")
await self.send(json.dumps({"type": "KeepAlive"}))
await self.keep_alive()

except websockets.exceptions.ConnectionClosedOK as e:
self._logger.notice(f"_keep_alive({e.code}) exiting gracefully")
Expand Down Expand Up @@ -514,6 +543,132 @@ async def _keep_alive(self) -> None:
raise
return

## pylint: disable=too-many-return-statements,too-many-statements
async def _flush(self) -> None:
self._logger.debug("AsyncLiveClient._flush ENTER")

delta_in_ms_str = self._config.options.get("auto_flush_reply_delta")
if delta_in_ms_str is None:
self._logger.error("auto_flush_reply_delta is None")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return
delta_in_ms = float(delta_in_ms_str)

while True:
try:
await asyncio.sleep(HALF_SECOND)

if self._exit_event.is_set():
self._logger.notice("_flush exiting gracefully")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return

if self._socket is None:
self._logger.notice("socket is None, exiting flush")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return

if self._last_datagram is None:
self._logger.debug("AutoFlush last_datagram is None")
continue

delta = datetime.now() - self._last_datagram
diff_in_ms = delta.total_seconds() * 1000
self._logger.debug("AutoFlush delta: %f", diff_in_ms)
if diff_in_ms < delta_in_ms:
self._logger.debug("AutoFlush delta is less than threshold")
continue

self._last_datagram = None
await self.finalize()

except websockets.exceptions.ConnectionClosedOK as e:
self._logger.notice(f"_flush({e.code}) exiting gracefully")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return

except websockets.exceptions.ConnectionClosed as e:
if e.code == 1000:
self._logger.notice(f"_flush({e.code}) exiting gracefully")
self._logger.debug("AsyncLiveClient._flush LEAVE")
return

self._logger.error(
"ConnectionClosed in AsyncLiveClient._flush with code %s: %s",
e.code,
e.reason,
)
cc_error: ErrorResponse = ErrorResponse(
"ConnectionClosed in AsyncLiveClient._flush",
f"{e}",
"ConnectionClosed",
)
await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Error),
error=cc_error,
**dict(cast(Dict[Any, Any], self._kwargs)),
)

# signal exit and close
await self._signal_exit()

self._logger.debug("AsyncLiveClient._flush LEAVE")

if self._config.options.get("termination_exception") == "true":
raise
return

except websockets.exceptions.WebSocketException as e:
self._logger.error(
"WebSocketException in AsyncLiveClient._flush: %s", e
)
ws_error: ErrorResponse = ErrorResponse(
"WebSocketException in AsyncLiveClient._flush",
f"{e}",
"Exception",
)
await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Error),
error=ws_error,
**dict(cast(Dict[Any, Any], self._kwargs)),
)

# signal exit and close
await self._signal_exit()

self._logger.debug("AsyncLiveClient._flush LEAVE")

if self._config.options.get("termination_exception") == "true":
raise
return

except Exception as e: # pylint: disable=broad-except
self._logger.error("Exception in AsyncLiveClient._flush: %s", e)
e_error: ErrorResponse = ErrorResponse(
"Exception in AsyncLiveClient._flush",
f"{e}",
"Exception",
)
self._logger.error("Exception in AsyncLiveClient._flush: %s", str(e))
await self._emit(
LiveTranscriptionEvents(LiveTranscriptionEvents.Error),
error=e_error,
**dict(cast(Dict[Any, Any], self._kwargs)),
)

# signal exit and close
await self._signal_exit()

self._logger.debug("AsyncLiveClient._flush LEAVE")

if self._config.options.get("termination_exception") == "true":
raise
return

# pylint: enable=too-many-return-statements

# pylint: disable=too-many-return-statements

async def send(self, data: Union[str, bytes]) -> bool:
"""
Sends data over the WebSocket connection.
Expand Down Expand Up @@ -570,6 +725,35 @@ async def send(self, data: Union[str, bytes]) -> bool:

# pylint: enable=too-many-return-statements

async def keep_alive(self) -> bool:
"""
Sends a KeepAlive message
"""
self._logger.spam("AsyncLiveClient.keep_alive ENTER")

if self._exit_event.is_set():
self._logger.notice("keep_alive exiting gracefully")
self._logger.debug("AsyncLiveClient.keep_alive LEAVE")
return False

if self._socket is None:
self._logger.notice("socket is not intialized")
self._logger.debug("AsyncLiveClient.keep_alive LEAVE")
return False

self._logger.notice("Sending KeepAlive...")
ret = await self.send(json.dumps({"type": "KeepAlive"}))

if not ret:
self._logger.error("keep_alive failed")
self._logger.spam("AsyncLiveClient.keep_alive LEAVE")
return False

self._logger.notice("keep_alive succeeded")
self._logger.spam("AsyncLiveClient.keep_alive LEAVE")

return True

async def finalize(self) -> bool:
"""
Finalizes the Transcript connection by flushing it
Expand All @@ -581,14 +765,18 @@ async def finalize(self) -> bool:
self._logger.debug("AsyncLiveClient.finalize LEAVE")
return False

if self._socket is not None:
self._logger.notice("sending Finalize...")
ret = await self.send(json.dumps({"type": "Finalize"}))
if self._socket is None:
self._logger.notice("socket is not intialized")
self._logger.debug("AsyncLiveClient.finalize LEAVE")
return False

if not ret:
self._logger.error("finalize failed")
self._logger.spam("AsyncLiveClient.finalize LEAVE")
return False
self._logger.notice("Sending Finalize...")
ret = await self.send(json.dumps({"type": "Finalize"}))

if not ret:
self._logger.error("finalize failed")
self._logger.spam("AsyncLiveClient.finalize LEAVE")
return False

self._logger.notice("finalize succeeded")
self._logger.spam("AsyncLiveClient.finalize LEAVE")
Expand All @@ -609,13 +797,20 @@ async def finish(self) -> bool:
try:
# Before cancelling, check if the tasks were created
tasks = []
if self._config.options.get("keepalive") == "true":
if self._keep_alive_thread is not None:
self._keep_alive_thread.cancel()
tasks.append(self._keep_alive_thread)
if self._keep_alive_thread is not None:
self._keep_alive_thread.cancel()
tasks.append(self._keep_alive_thread)
self._logger.notice("processing _keep_alive_thread cancel...")

if self._flush_thread is not None:
self._flush_thread.cancel()
tasks.append(self._flush_thread)
self._logger.notice("processing _flush_thread cancel...")

if self._listen_thread is not None:
self._listen_thread.cancel()
tasks.append(self._listen_thread)
self._logger.notice("processing _listen_thread cancel...")

# Use asyncio.gather to wait for tasks to be cancelled
await asyncio.gather(*filter(None, tasks), return_exceptions=True)
Expand Down Expand Up @@ -673,3 +868,20 @@ async def _signal_exit(self) -> None:
self._logger.error("socket.wait_closed failed: %s", e)

self._socket = None # type: ignore

async def _inspect(self, msg_result: LiveResultResponse) -> bool:
sentence = msg_result.channel.alternatives[0].transcript
if len(sentence) == 0:
return True

if msg_result.is_final:
self._logger.debug("AutoFlush is_final received")
self._last_datagram = None
else:
self._last_datagram = datetime.now()
self._logger.debug(
"AutoFlush interim received: %s",
str(self._last_datagram),
)

return True
Loading