diff --git a/zigpy_znp/async_utils.py b/zigpy_znp/async_utils.py new file mode 100644 index 00000000..0613c643 --- /dev/null +++ b/zigpy_znp/async_utils.py @@ -0,0 +1,197 @@ +import asyncio +import logging +import functools +import threading + +LOGGER = logging.getLogger(__name__) + + +_znp_loop = None # the loop in which the serial communication will be handled +_worker_loop = ( + None # the loop in which the frames are handled (MainThread in home assistant) +) + +# if there is a need to create a worker loop this will be the thread it is running in +_worker_loop_thread = None + + +def try_get_running_loop_as_worker_loop(): + """ + this function will set the worker loop to the currently running loop + (if there is one). + """ + global _worker_loop + if _worker_loop is None: + try: + _worker_loop = asyncio.get_running_loop() + except RuntimeError: + pass + + +# this will get the running loop in case of integration in home assistant +# if there is no running loop, a loop will be created later +try_get_running_loop_as_worker_loop() + + +def get_worker_loop(): + """ + Getter for the worker loop. + """ + global _worker_loop + if _worker_loop is None: + try: + _worker_loop = asyncio.get_running_loop() + LOGGER.info("used asyncio's running loop") + except RuntimeError: + create_new_worker_loop(True) + return _worker_loop + + +def get_znp_loop(): + """ + Getter for the ZNP serial loop. + """ + return _znp_loop + + +def start_worker_loop_in_thread(): + """ + Create a thread and run the worker loop. + """ + global _worker_loop_thread, _worker_loop + if _worker_loop_thread is None and _worker_loop is not None: + + def run_worker_loop(): + asyncio.set_event_loop(_worker_loop) + _worker_loop.run_forever() + + _worker_loop_thread = threading.Thread( + target=run_worker_loop, daemon=True, name="ZigpyWorkerThread" + ) + _worker_loop_thread.start() + + +def create_new_worker_loop(start_thread: bool = True): + """ + Creates a new worker loop, starts a new thread too, if start_thread is True. + """ + global _worker_loop + LOGGER.info("creating new event loop as worker loop") + _worker_loop = asyncio.new_event_loop() + if start_thread: + start_worker_loop_in_thread() + + +def init_znp_loop(): + """ + Create and run ZNP loop. + """ + global _znp_loop + if _znp_loop is None: + _znp_loop = asyncio.new_event_loop() + + def run_znp_loop(): + # asyncio.set_event_loop(_znp_loop) + _znp_loop.run_forever() + + znp_thread = threading.Thread( + target=run_znp_loop, daemon=True, name="ZigpySerialThread" + ) + znp_thread.start() + + +# will create and start a new ZNP loop on module initialization +if _znp_loop is None: + init_znp_loop() + + +def run_in_loop( + function, loop=None, loop_getter=None, wait_for_result: bool = True, *args, **kwargs +): + """ + Can be used as decorator or as normal function. + Will run the function in the specified loop. + @param function: + The co-routine that shall be run (function call only) + @param loop: + Loop in which the co-routine shall run (either loop or loop_getter must be set) + @param loop_getter: + Getter for the loop in which the co-routine shall run + (either loop or loop_getter must be set) + @param wait_for_result: + Will "fire and forget" if false. Otherwise, + the return value of the coro is returned. + @param args: args + @param kwargs: kwargs + @return: + None if wait_for_result is false. Otherwise, the return value of the co-routine. + """ + if loop is None and loop_getter is None: + raise RuntimeError("either loop or loop_getter must be passed to run_in_loop") + + if asyncio.iscoroutine(function): + # called as a function call + _loop = loop if loop is not None else loop_getter() + future = asyncio.run_coroutine_threadsafe(function, _loop) + return future.result() if wait_for_result else None + else: + # probably a decorator + + # wrap the function in a new function, + # that will run the co-routine in the loop provided + @functools.wraps(function) + def new_sync(*args, **kwargs): + loop if loop is not None else loop_getter() + return run_in_loop( + function(*args, **kwargs), + loop=loop, + loop_getter=loop_getter, + wait_for_result=wait_for_result, + ) + + if not asyncio.iscoroutinefunction(function): + return new_sync + else: + # wrap the function again in an async function, so that it can be awaited + async def new_async(*args, **kwargs): + return new_sync(*args, **kwargs) + + return new_async + + +def run_in_znp_loop(*args, **kwargs): + """ + Can be used as decorator or as normal function. + Will run the function in the znp loop. + @param function: + The co-routine that shall be run (function call only) + @param wait_for_result: + Will "fire and forget" if false. + Otherwise, the return value of the coro is returned. + @param args: args + @param kwargs: kwargs + @return: + None if wait_for_result is false. + Otherwise, the return value of the co-routine. + """ + kwargs["loop_getter"] = get_znp_loop + return run_in_loop(*args, **kwargs) + + +def run_in_worker_loop(*args, **kwargs): + """ + Can be used as decorator or as normal function. + Will run the function in the worker loop. + @param function: + The co-routine that shall be run (function call only) + @param wait_for_result: + Will "fire and forget" if false. + Otherwise, the return value of the coro is returned. + @param args: args + @param kwargs: kwargs + @return: + None if wait_for_result is false. + Otherwise, the return value of the co-routine. + """ + kwargs["loop_getter"] = get_worker_loop + return run_in_loop(*args, **kwargs) diff --git a/zigpy_znp/tools/energy_scan.py b/zigpy_znp/tools/energy_scan.py index 3eb0fb85..21bfedfe 100644 --- a/zigpy_znp/tools/energy_scan.py +++ b/zigpy_znp/tools/energy_scan.py @@ -1,5 +1,4 @@ import sys -import asyncio import logging import itertools from collections import deque, defaultdict @@ -8,6 +7,7 @@ from zigpy.exceptions import NetworkNotFormed import zigpy_znp.types as t +from zigpy_znp import async_utils from zigpy_znp.tools.common import setup_parser from zigpy_znp.zigbee.application import ControllerApplication @@ -93,4 +93,4 @@ async def main(argv): if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) # pragma: no cover + async_utils.run_in_worker_loop(main(sys.argv[1:])) # pragma: no cover diff --git a/zigpy_znp/tools/flash_read.py b/zigpy_znp/tools/flash_read.py index 7499a101..6ad7ab96 100644 --- a/zigpy_znp/tools/flash_read.py +++ b/zigpy_znp/tools/flash_read.py @@ -5,6 +5,7 @@ import async_timeout import zigpy_znp.commands as c +from zigpy_znp import async_utils from zigpy_znp.api import ZNP from zigpy_znp.config import CONFIG_SCHEMA from zigpy_znp.tools.common import ClosableFileType, setup_parser @@ -87,4 +88,4 @@ async def main(argv): if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) # pragma: no cover + async_utils.run_in_worker_loop(main(sys.argv[1:])) # pragma: no cover diff --git a/zigpy_znp/tools/flash_write.py b/zigpy_znp/tools/flash_write.py index 713f5dd7..a0f1c7be 100644 --- a/zigpy_znp/tools/flash_write.py +++ b/zigpy_znp/tools/flash_write.py @@ -8,6 +8,7 @@ import zigpy_znp.types as t import zigpy_znp.commands as c +from zigpy_znp import async_utils from zigpy_znp.api import ZNP from zigpy_znp.config import CONFIG_SCHEMA from zigpy_znp.tools.common import ClosableFileType, setup_parser @@ -174,4 +175,4 @@ async def main(argv): if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) # pragma: no cover + async_utils.run_in_worker_loop(main(sys.argv[1:])) # pragma: no cover diff --git a/zigpy_znp/tools/network_backup.py b/zigpy_znp/tools/network_backup.py index f5f9c521..491aa96e 100644 --- a/zigpy_znp/tools/network_backup.py +++ b/zigpy_znp/tools/network_backup.py @@ -2,7 +2,6 @@ import sys import json -import asyncio import logging import datetime @@ -10,6 +9,7 @@ import zigpy_znp import zigpy_znp.types as t +from zigpy_znp import async_utils from zigpy_znp.api import ZNP from zigpy_znp.tools.common import ClosableFileType, setup_parser, validate_backup_json from zigpy_znp.zigbee.application import ControllerApplication @@ -117,6 +117,8 @@ async def main(argv: list[str]) -> None: f.write(json.dumps(backup_obj, indent=4)) + LOGGER.info("done") + if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) # pragma: no cover + async_utils.run_in_worker_loop(main(sys.argv[1:])) # pragma: no cover diff --git a/zigpy_znp/tools/network_restore.py b/zigpy_znp/tools/network_restore.py index e3c9eafb..7a37eff0 100644 --- a/zigpy_znp/tools/network_restore.py +++ b/zigpy_znp/tools/network_restore.py @@ -2,13 +2,13 @@ import sys import json -import asyncio import zigpy.state import zigpy.zdo.types as zdo_t import zigpy_znp.const as const import zigpy_znp.types as t +from zigpy_znp import async_utils from zigpy_znp.api import ZNP from zigpy_znp.tools.common import ClosableFileType, setup_parser, validate_backup_json from zigpy_znp.zigbee.application import ControllerApplication @@ -130,4 +130,4 @@ async def main(argv: list[str]) -> None: if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) # pragma: no cover + async_utils.run_in_worker_loop(main(sys.argv[1:])) # pragma: no cover diff --git a/zigpy_znp/tools/network_scan.py b/zigpy_znp/tools/network_scan.py index d2f21754..0bcafae0 100644 --- a/zigpy_znp/tools/network_scan.py +++ b/zigpy_znp/tools/network_scan.py @@ -1,11 +1,11 @@ import sys import time -import asyncio import logging import itertools import zigpy_znp.types as t import zigpy_znp.commands as c +from zigpy_znp import async_utils from zigpy_znp.api import ZNP from zigpy_znp.config import CONFIG_SCHEMA from zigpy_znp.types.nvids import OsalNvIds @@ -155,4 +155,4 @@ async def main(argv): if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) # pragma: no cover + async_utils.run_in_worker_loop(main(sys.argv[1:])) # pragma: no cover diff --git a/zigpy_znp/tools/nvram_read.py b/zigpy_znp/tools/nvram_read.py index c3be94f4..e1e66748 100644 --- a/zigpy_znp/tools/nvram_read.py +++ b/zigpy_znp/tools/nvram_read.py @@ -1,9 +1,9 @@ import sys import json -import asyncio import logging import zigpy_znp.types as t +from zigpy_znp import async_utils from zigpy_znp.api import ZNP from zigpy_znp.config import CONFIG_SCHEMA from zigpy_znp.exceptions import SecurityError, CommandNotRecognized @@ -96,4 +96,4 @@ async def main(argv): if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) # pragma: no cover + async_utils.run_in_worker_loop(main(sys.argv[1:])) # pragma: no cover diff --git a/zigpy_znp/tools/nvram_reset.py b/zigpy_znp/tools/nvram_reset.py index 31caa78f..19d448df 100644 --- a/zigpy_znp/tools/nvram_reset.py +++ b/zigpy_znp/tools/nvram_reset.py @@ -1,7 +1,7 @@ import sys -import asyncio import logging +from zigpy_znp import async_utils from zigpy_znp.api import ZNP from zigpy_znp.config import CONFIG_SCHEMA from zigpy_znp.types.nvids import ( @@ -79,4 +79,4 @@ async def main(argv): if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) # pragma: no cover + async_utils.run_in_worker_loop(main(sys.argv[1:])) # pragma: no cover diff --git a/zigpy_znp/tools/nvram_write.py b/zigpy_znp/tools/nvram_write.py index f672f0a9..17152373 100644 --- a/zigpy_znp/tools/nvram_write.py +++ b/zigpy_znp/tools/nvram_write.py @@ -1,8 +1,8 @@ import sys import json -import asyncio import logging +from zigpy_znp import async_utils from zigpy_znp.api import ZNP from zigpy_znp.config import CONFIG_SCHEMA from zigpy_znp.types.nvids import ExNvIds, OsalNvIds @@ -68,4 +68,4 @@ async def main(argv): if __name__ == "__main__": - asyncio.run(main(sys.argv[1:])) # pragma: no cover + async_utils.run_in_worker_loop(main(sys.argv[1:])) # pragma: no cover diff --git a/zigpy_znp/uart.py b/zigpy_znp/uart.py index 5571e60d..fa58460f 100644 --- a/zigpy_znp/uart.py +++ b/zigpy_znp/uart.py @@ -1,26 +1,18 @@ import typing import asyncio import logging -import warnings +import threading -import serial +import serialpy as serial +import serialpy as serial_asyncio import zigpy_znp.config as conf import zigpy_znp.frames as frames import zigpy_znp.logger as log +import zigpy_znp.async_utils as async_utils from zigpy_znp.types import Bytes from zigpy_znp.exceptions import InvalidFrame -with warnings.catch_warnings(): - warnings.filterwarnings( - action="ignore", - module="serial_asyncio", - message='"@coroutine" decorator is deprecated', - category=DeprecationWarning, - ) - import serial_asyncio # noqa: E402 - - LOGGER = logging.getLogger(__name__) @@ -35,6 +27,7 @@ def __init__(self, api): self._transport = None self._connected_event = asyncio.Event() + @async_utils.run_in_znp_loop def close(self) -> None: """Closes the port.""" @@ -44,8 +37,43 @@ def close(self) -> None: if self._transport is not None: LOGGER.debug("Closing serial port") - self._transport.close() - self._transport = None + def close_transport(): + LOGGER.warning( + "Closing serial port in thread %s" % threading.current_thread().name + ) + self._transport.close() + self._transport = None + + try: + close_transport() + except RuntimeError: + LOGGER.warning("Trying to close transport in its own loop") + close_transport_in_loop = async_utils.run_in_loop( + close_transport, self._transport._loop + ) + try: + close_transport_in_loop() + except RuntimeError: + LOGGER.warning("Trying to close transport in znp loop") + close_transport_in_loop = async_utils.run_in_loop( + close_transport, async_utils.get_znp_loop() + ) + + try: + close_transport_in_loop() + except RuntimeError: + LOGGER.warning("Trying to close transport in worker loop") + close_transport_in_loop = async_utils.run_in_loop( + close_transport, async_utils.get_worker_loop() + ) + + try: + close_transport_in_loop() + except RuntimeError as e: + LOGGER.error( + "Failed to close serial connection in any loop" + ) + raise e def connection_lost(self, exc: typing.Optional[Exception]) -> None: """Connection lost.""" @@ -76,7 +104,13 @@ def data_received(self, data: bytes) -> None: LOGGER.log(log.TRACE, "Parsed frame: %s", frame) try: - self._api.frame_received(frame.payload) + + async def _frame_received(payload): + self._api.frame_received(payload) + + async_utils.run_in_worker_loop( + _frame_received(frame.payload), wait_for_result=False + ) except Exception as e: LOGGER.error( "Received an exception while passing frame to API: %s", @@ -170,14 +204,21 @@ def __repr__(self) -> str: ) +@async_utils.run_in_znp_loop async def connect(config: conf.ConfigType, api) -> ZnpMtProtocol: - loop = asyncio.get_running_loop() + LOGGER.info("uart connecting in thread %s" % threading.current_thread().name) + loop = async_utils.get_znp_loop() port = config[conf.CONF_DEVICE_PATH] baudrate = config[conf.CONF_DEVICE_BAUDRATE] flow_control = config[conf.CONF_DEVICE_FLOW_CONTROL] - LOGGER.debug("Connecting to %s at %s baud", port, baudrate) + LOGGER.debug( + "Connecting to %s at %s baud in thread %s", + port, + baudrate, + threading.current_thread().name, + ) _, protocol = await serial_asyncio.create_serial_connection( loop=loop, diff --git a/zigpy_znp/utils.py b/zigpy_znp/utils.py index 04d05ef1..5a5acf89 100644 --- a/zigpy_znp/utils.py +++ b/zigpy_znp/utils.py @@ -73,7 +73,6 @@ def resolve(self, response: t.CommandBase) -> bool: if not any(c.matches(response) for c in self.matching_commands): return False - return self._resolve(response) def _resolve(self, response: t.CommandBase) -> bool: