diff --git a/tests/api/test_request.py b/tests/api/test_request.py index 004d5698..cbf0a90d 100644 --- a/tests/api/test_request.py +++ b/tests/api/test_request.py @@ -112,43 +112,45 @@ async def test_callback_rsp_cleanup_timeout_internal(background, connected_znp): assert not znp._listeners -async def test_callback_rsp_cleanup_background_error(connected_znp): +async def test_callback_rsp_background_timeout(connected_znp, mocker): znp, znp_server = connected_znp znp._config[conf.CONF_ZNP_CONFIG][conf.CONF_SREQ_TIMEOUT] = 0.1 - znp._config[conf.CONF_ZNP_CONFIG][conf.CONF_ARSP_TIMEOUT] = 0.1 + znp._config[conf.CONF_ZNP_CONFIG][conf.CONF_ARSP_TIMEOUT] = 1.0 - assert not znp._listeners + mocker.spy(znp, "_unhandled_command") - # This request will timeout because we didn't send anything back - with pytest.raises(asyncio.TimeoutError): - await znp.request_callback_rsp( - request=c.UTIL.TimeAlive.Req(), - callback=c.SYS.ResetInd.Callback(partial=True), - background=True, + async def replier(req): + # SREQ reply works + await asyncio.sleep(0.05) + yield c.UTIL.TimeAlive.Rsp(Seconds=123) + + # And the callback will arrive before the AREQ timeout + await asyncio.sleep(0.9) + yield c.SYS.ResetInd.Callback( + Reason=t.ResetReason.PowerUp, + TransportRev=0x00, + ProductId=0x12, + MajorRel=0x01, + MinorRel=0x02, + MaintRel=0x03, ) - # We should be cleaned up - assert not znp._listeners - + reply = znp_server.reply_once_to(c.UTIL.TimeAlive.Req(), responses=replier) -async def test_callback_rsp_cleanup_background_timeout(connected_znp): - znp, znp_server = connected_znp - znp._config[conf.CONF_ZNP_CONFIG][conf.CONF_SREQ_TIMEOUT] = 0.1 - znp._config[conf.CONF_ZNP_CONFIG][conf.CONF_ARSP_TIMEOUT] = 0.1 - - assert not znp._listeners + await znp.request_callback_rsp( + request=c.UTIL.TimeAlive.Req(), + callback=c.SYS.ResetInd.Callback(partial=True), + background=True, + ) - # This request will timeout because we didn't send anything back - with pytest.raises(asyncio.TimeoutError): - await znp.request_callback_rsp( - request=c.UTIL.TimeAlive.Req(), - callback=c.SYS.ResetInd.Callback(partial=True), - background=True, - ) + await reply # We should be cleaned up assert not znp._listeners + # Command was properly handled + assert len(znp._unhandled_command.mock_calls) == 0 + async def test_callback_rsp_cleanup_concurrent(connected_znp, event_loop, mocker): znp, znp_server = connected_znp diff --git a/tests/conftest.py b/tests/conftest.py index ad225f14..20709b7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import json import asyncio +import inspect import logging import pathlib import contextlib @@ -263,37 +264,44 @@ class BaseServerZNP(ZNP): align_structs = False version = None - def _flatten_responses(self, request, responses): + async def _flatten_responses(self, request, responses): if responses is None: return elif isinstance(responses, t.CommandBase): yield responses + elif inspect.iscoroutinefunction(responses): + async for rsp in responses(request): + yield rsp + elif inspect.isasyncgen(responses): + async for rsp in responses: + yield rsp elif callable(responses): - yield from self._flatten_responses(request, responses(request)) + async for rsp in self._flatten_responses(request, responses(request)): + yield rsp else: for response in responses: - yield from self._flatten_responses(request, response) + async for rsp in self._flatten_responses(request, response): + yield rsp + + async def _send_responses(self, request, responses): + async for response in self._flatten_responses(request, responses): + await asyncio.sleep(0.001) + LOGGER.debug("Replying to %s with %s", request, response) + self.send(response) def reply_once_to(self, request, responses, *, override=False): if override: self._listeners[request.header].clear() - future = self.wait_for_response(request) - called_future = asyncio.get_running_loop().create_future() + request_future = self.wait_for_response(request) async def replier(): - request = await future - - for response in self._flatten_responses(request, responses): - await asyncio.sleep(0.001) - LOGGER.debug("Replying to %s with %s", request, response) - self.send(response) + request = await request_future + await self._send_responses(request, responses) - called_future.set_result(request) + return request - asyncio.create_task(replier()) - - return called_future + return asyncio.create_task(replier()) def reply_to(self, request, responses, *, override=False): if override: @@ -301,11 +309,7 @@ def reply_to(self, request, responses, *, override=False): async def callback(request): callback.call_count += 1 - - for response in self._flatten_responses(request, responses): - await asyncio.sleep(0.001) - LOGGER.debug("Replying to %s with %s", request, response) - self.send(response) + await self._send_responses(request, responses) callback.call_count = 0 diff --git a/zigpy_znp/api.py b/zigpy_znp/api.py index eb48f45e..6135b5c1 100644 --- a/zigpy_znp/api.py +++ b/zigpy_znp/api.py @@ -815,6 +815,7 @@ async def request_callback_rsp( callback_rsp, listener = self.wait_for_responses([callback], context=True) + # Typical request/response/callbacks are not backgrounded if not background: try: async with async_timeout.timeout(timeout): @@ -824,26 +825,28 @@ async def request_callback_rsp( finally: self.remove_listener(listener) - start_time = time.time() + # Backgrounded callback handlers need to respect the provided timeout + start_time = time.monotonic() - # If the SREQ/SRSP pair fails, we must cancel the AREQ listener try: async with async_timeout.timeout(timeout): request_rsp = await self.request(request, **response_params) except Exception: + # If the SREQ/SRSP pair fails, we must cancel the AREQ listener self.remove_listener(listener) raise - async def callback_handler(timeout): + # If it succeeds, create a background task to receive the AREQ but take into + # account the time it took to start the SREQ to ensure we do not grossly exceed + # the timeout + async def callback_catcher(timeout): try: async with async_timeout.timeout(timeout): await callback_rsp finally: self.remove_listener(listener) - # If it succeeds, create a background task to receive the AREQ but take into - # account the time it took to start the SREQ to ensure we do not grossly exceed - # the timeout - asyncio.create_task(callback_handler(time.time() - start_time)) + callback_timeout = max(0, timeout - (time.monotonic() - start_time)) + asyncio.create_task(callback_catcher(callback_timeout)) return request_rsp