From 4fbf75f30ae8792a8616d9901e1386f73b461a6e Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 30 May 2021 12:23:16 -0400 Subject: [PATCH 1/5] Treat messages from unknown devices as implicit joins --- tests/application/test_zigpy_callbacks.py | 59 ++++++-- zigpy_znp/commands/zdo.py | 9 +- zigpy_znp/zigbee/application.py | 156 ++++++++++++++++------ 3 files changed, 167 insertions(+), 57 deletions(-) diff --git a/tests/application/test_zigpy_callbacks.py b/tests/application/test_zigpy_callbacks.py index f7c6f5ae..2f8fb760 100644 --- a/tests/application/test_zigpy_callbacks.py +++ b/tests/application/test_zigpy_callbacks.py @@ -1,3 +1,4 @@ +import asyncio import logging import pytest @@ -6,20 +7,34 @@ import zigpy_znp.types as t import zigpy_znp.commands as c -from ..conftest import FORMED_DEVICES +from ..conftest import FORMED_DEVICES, CoroutineMock pytestmark = [pytest.mark.asyncio] +def awaitable_mock(return_value): + mock_called = asyncio.get_running_loop().create_future() + + def side_effect(*args, **kwargs): + mock_called.set_result((args, kwargs)) + + return return_value + + return mock_called, CoroutineMock(side_effect=side_effect) + + @pytest.mark.parametrize("device", FORMED_DEVICES) async def test_on_zdo_relays_message_callback(device, make_application, mocker): app, znp_server = make_application(server_cls=device) await app.startup(auto_form=False) device = mocker.Mock() - mocker.patch.object(app, "get_device", return_value=device) + discover_called, discover_mock = awaitable_mock(return_value=device) + mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) znp_server.send(c.ZDO.SrcRtgInd.Callback(DstAddr=0x1234, Relays=[0x5678, 0xABCD])) + + await discover_called assert device.relays == [0x5678, 0xABCD] await app.shutdown() @@ -32,8 +47,13 @@ async def test_on_zdo_relays_message_callback_unknown( app, znp_server = make_application(server_cls=device) await app.startup(auto_form=False) + discover_called, discover_mock = awaitable_mock(return_value=None) + mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) + caplog.set_level(logging.WARNING) znp_server.send(c.ZDO.SrcRtgInd.Callback(DstAddr=0x1234, Relays=[0x5678, 0xABCD])) + + await discover_called assert "unknown device" in caplog.text await app.shutdown() @@ -98,12 +118,10 @@ async def test_on_af_message_callback(device, make_application, mocker): await app.startup(auto_form=False) device = mocker.Mock() - mocker.patch.object( - app, - "get_device", - side_effect=[device, device, device, KeyError("No such device")], - ) + discover_called, discover_mock = awaitable_mock(return_value=device) + mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) mocker.patch.object(app, "handle_message") + mocker.patch.object(app, "get_device") af_message = c.AF.IncomingMsg.Callback( GroupId=1, @@ -123,43 +141,56 @@ async def test_on_af_message_callback(device, make_application, mocker): # Normal message znp_server.send(af_message) - app.get_device.assert_called_once_with(nwk=0xABCD) + + await discover_called device.radio_details.assert_called_once_with(lqi=19, rssi=None) app.handle_message.assert_called_once_with( sender=device, profile=260, cluster=2, src_ep=4, dst_ep=1, message=b"test" ) - # ZLL message device.reset_mock() app.handle_message.reset_mock() app.get_device.reset_mock() + # ZLL message + discover_called, discover_mock = awaitable_mock(return_value=device) + mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) + znp_server.send(af_message.replace(DstEndpoint=2)) - app.get_device.assert_called_once_with(nwk=0xABCD) + + await discover_called device.radio_details.assert_called_once_with(lqi=19, rssi=None) app.handle_message.assert_called_once_with( sender=device, profile=49246, cluster=2, src_ep=4, dst_ep=2, message=b"test" ) - # Message on an unknown endpoint (is this possible?) device.reset_mock() app.handle_message.reset_mock() app.get_device.reset_mock() + # Message on an unknown endpoint (is this possible?) + discover_called, discover_mock = awaitable_mock(return_value=device) + mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) + znp_server.send(af_message.replace(DstEndpoint=3)) - app.get_device.assert_called_once_with(nwk=0xABCD) + + await discover_called device.radio_details.assert_called_once_with(lqi=19, rssi=None) app.handle_message.assert_called_once_with( sender=device, profile=260, cluster=2, src_ep=4, dst_ep=3, message=b"test" ) - # Message from an unknown device device.reset_mock() app.handle_message.reset_mock() app.get_device.reset_mock() + # Message from an unknown device + discover_called, discover_mock = awaitable_mock(return_value=None) + mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) + znp_server.send(af_message) - app.get_device.assert_called_once_with(nwk=0xABCD) + + await discover_called assert device.radio_details.call_count == 0 assert app.handle_message.call_count == 0 diff --git a/zigpy_znp/commands/zdo.py b/zigpy_znp/commands/zdo.py index 9b629df6..6c44d05b 100644 --- a/zigpy_znp/commands/zdo.py +++ b/zigpy_znp/commands/zdo.py @@ -141,6 +141,11 @@ def serialize(self) -> bytes: return super().serialize() +class AddrRequestType(t.enum_uint8): + SINGLE = 0x00 + EXTENDED = 0x01 + + class ZDO(t.CommandsBase, subsystem=t.Subsystem.ZDO): # send a "Network Address Request". This message sends a broadcast message looking # for a 16 bit address with a known 64 bit IEEE address. You must subscribe to @@ -156,7 +161,7 @@ class ZDO(t.CommandsBase, subsystem=t.Subsystem.ZDO): ), t.Param( "RequestType", - t.uint8_t, + AddrRequestType, "0x00 -- single device request, 0x01 -- Extended", ), t.Param( @@ -174,7 +179,7 @@ class ZDO(t.CommandsBase, subsystem=t.Subsystem.ZDO): t.Param("NWK", t.NWK, "Short address of the device"), t.Param( "RequestType", - t.uint8_t, + AddrRequestType, "0x00 -- single device request, 0x01 -- Extended", ), t.Param( diff --git a/zigpy_znp/zigbee/application.py b/zigpy_znp/zigbee/application.py index 59834acf..4d7efd1b 100644 --- a/zigpy_znp/zigbee/application.py +++ b/zigpy_znp/zigbee/application.py @@ -4,6 +4,7 @@ import time import asyncio import logging +import functools import itertools import contextlib @@ -80,6 +81,36 @@ LOGGER = logging.getLogger(__name__) +def combine_concurrent_calls(function): + """ + Decorator that allows concurrent calls to expensive coroutines to share a result. + """ + + futures = {} + + @functools.wraps(function) + async def replacement(*args, **kwargs): + key = (tuple(args), tuple([(k, v) for k, v in kwargs.items()])) + + if key in futures: + return await futures[key] + + future = futures[key] = asyncio.get_running_loop().create_future() + + try: + result = await function(*args, **kwargs) + except Exception as e: + future.set_exception(e) + raise + else: + future.set_result(result) + return result + finally: + del futures[key] + + return replacement + + class ZNPCoordinator(zigpy.device.Device): """ Coordinator zigpy device that keeps track of our endpoints and clusters. @@ -124,7 +155,7 @@ def __init__(self, config: conf.ConfigType): self._version_rsp = None self._concurrent_requests_semaphore = None self._currently_waiting_requests = 0 - self._route_discovery_futures = {} + self._join_announce_tasks = {} ################################################################## @@ -175,11 +206,6 @@ def close(self): self._reconnect_task.cancel() self._watchdog_task.cancel() - for f in self._route_discovery_futures.values(): - f.cancel() - - self._route_discovery_futures.clear() - # This will close the UART, which will then close the transport if self._znp is not None: self._znp.close() @@ -713,6 +739,8 @@ async def permit(self, time_s=60, node=None): # Always permit joins on the coordinator first. # This unfortunately makes it impossible to permit joins via just one router # alone but firmware changes can make this possible on newer hardware. + LOGGER.info("Permitting joins for %d seconds", time_s) + response = await self._znp.request_callback_rsp( request=c.ZDO.MgmtPermitJoinReq.Req( AddrMode=t.AddrMode.NWK, @@ -864,18 +892,16 @@ def on_zdo_permit_join_message(self, msg: c.ZDO.PermitJoinInd.Callback) -> None: else: LOGGER.info("Coordinator is permitting joins for %d seconds", msg.Duration) - def on_zdo_relays_message(self, msg: c.ZDO.SrcRtgInd.Callback) -> None: + async def on_zdo_relays_message(self, msg: c.ZDO.SrcRtgInd.Callback) -> None: """ ZDO source routing message callback """ - LOGGER.info("ZDO device relays: %s", msg) + device = await self._get_or_discover_device(nwk=msg.DstAddr) - try: - device = self.get_device(nwk=msg.DstAddr) - except KeyError: + if device is None: LOGGER.warning( - "Received a ZDO message from an unknown device: 0x%04x", msg.DstAddr + "Received a ZDO message from an unknown device: %s", msg.DstAddr ) return @@ -894,11 +920,7 @@ def on_zdo_device_announce(self, msg: c.ZDO.EndDeviceAnnceInd.Callback) -> None: self._join_announce_tasks.pop(msg.IEEE).cancel() # Sometimes devices change their NWK when announcing so re-join it. - self.handle_join( - nwk=msg.NWK, - ieee=msg.IEEE, - parent_nwk=None, - ) + self.handle_join(nwk=msg.NWK, ieee=msg.IEEE, parent_nwk=None) device = self.get_device(ieee=msg.IEEE) @@ -942,16 +964,16 @@ def on_zdo_device_leave(self, msg: c.ZDO.LeaveInd.Callback) -> None: LOGGER.info("ZDO device left: %s", msg) self.handle_leave(nwk=msg.NWK, ieee=msg.IEEE) - def on_af_message(self, msg: c.AF.IncomingMsg.Callback) -> None: + async def on_af_message(self, msg: c.AF.IncomingMsg.Callback) -> None: """ Handler for all non-ZDO messages. """ - try: - device = self.get_device(nwk=msg.SrcAddr) - except KeyError: + device = await self._get_or_discover_device(nwk=msg.SrcAddr) + + if device is None: LOGGER.warning( - "Received an AF message from an unknown device: 0x%04x", msg.SrcAddr + "Received an AF message from an unknown device: %s", msg.SrcAddr ) return @@ -1020,7 +1042,68 @@ async def _watchdog_loop(self): return - async def _set_led_mode(self, *, led, mode) -> None: + @combine_concurrent_calls + async def _get_or_discover_device(self, nwk: t.NWK) -> zigpy.device.Device | None: + """ + Finds a device by its NWK address. If a device does not exist in the zigpy + database, attempt to look up its new NWK address. If joins are currently allowed + then the device will be treated as a new join if it does not exist in the zigpy + database. + """ + + try: + return self.get_device(nwk=nwk) + except KeyError: + pass + + LOGGER.debug("Device with NWK 0x%04X not in database", nwk) + + try: + # XXX: Multiple responses may arrive but we only use the first one + ieee_addr_rsp = await self._znp.request_callback_rsp( + request=c.ZDO.IEEEAddrReq.Req( + NWK=nwk, + RequestType=c.zdo.AddrRequestType.SINGLE, + StartIndex=0, + ), + RspStatus=t.Status.SUCCESS, + callback=c.ZDO.IEEEAddrRsp.Callback( + partial=True, + NWK=nwk, + ), + timeout=5, # We don't want to wait forever + ) + except asyncio.TimeoutError: + return + + ieee = ieee_addr_rsp.IEEE + + try: + device = self.get_device(ieee=ieee) + except KeyError: + if self._permit_joins_task.done(): + LOGGER.warning("Ignoring device because joins are not permitted") + return + + LOGGER.debug("Joins are permitted, treating unknown device as a new join") + self.handle_join(nwk=nwk, ieee=ieee, parent_nwk=None) + + return self.get_device(ieee=ieee) + + LOGGER.warning( + "Device %s changed its NWK from %s to %s", + device.ieee, + device.nwk, + nwk, + ) + + # Notify zigpy of the change + device.nwk = nwk + self.listener_event("raw_device_initialized", device) + + return device + + async def _set_led_mode(self, *, led: t.uint8_t, mode: c.util.LEDMode) -> None: """ Attempts to set the provided LED's mode. A Z-Stack bug causes the underlying command to never receive a response if the board has no LEDs, requiring this @@ -1433,6 +1516,7 @@ async def _send_request_raw( return response + @combine_concurrent_calls async def _discover_route(self, nwk: t.NWK) -> None: """ Instructs the coordinator to re-discover routes to the provided NWK. @@ -1444,25 +1528,15 @@ async def _discover_route(self, nwk: t.NWK) -> None: if self._znp.version < 3.30: return - if nwk in self._route_discovery_futures: - return await self._route_discovery_futures[nwk] - - future = asyncio.get_running_loop().create_future() - self._route_discovery_futures[nwk] = future - - try: - await self._znp.request( - c.ZDO.ExtRouteDisc.Req( - Dst=nwk, - Options=c.zdo.RouteDiscoveryOptions.UNICAST, - Radius=30, - ), - ) + await self._znp.request( + c.ZDO.ExtRouteDisc.Req( + Dst=nwk, + Options=c.zdo.RouteDiscoveryOptions.UNICAST, + Radius=30, + ), + ) - await asyncio.sleep(0.1 * 13) - finally: - future.set_result(True) - del self._route_discovery_futures[nwk] + await asyncio.sleep(0.1 * 13) async def _send_request( self, From 9c6b5f7c88516a42f95f80e06055c86a0b052c68 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 2 Jun 2021 13:50:39 -0400 Subject: [PATCH 2/5] Retry requests using a device's IEEE address --- zigpy_znp/zigbee/application.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/zigpy_znp/zigbee/application.py b/zigpy_znp/zigbee/application.py index 4d7efd1b..584fdc49 100644 --- a/zigpy_znp/zigbee/application.py +++ b/zigpy_znp/zigbee/application.py @@ -1575,8 +1575,8 @@ async def _send_request( tried_assoc_remove = False tried_route_discovery = False - tried_disable_route_discovery_suppression = False tried_last_good_route = False + tried_ieee_address = False # Don't release the concurrency-limiting semaphore until we are done trying. # There is no point in allowing requests to take turns getting buffer errors. @@ -1698,11 +1698,18 @@ async def _send_request( # letting the retry mechanism deal with it simpler. await self._discover_route(dst_addr.address) tried_route_discovery = True - elif not tried_disable_route_discovery_suppression: - # Disable route discovery suppression. This appears to - # generate a bit more network traffic. - options &= ~c.af.TransmitOptions.SUPPRESS_ROUTE_DISC_NETWORK - tried_disable_route_discovery_suppression = True + elif ( + not tried_ieee_address + and device is not None + and dst_addr.mode == t.AddrMode.NWK + ): + # Try using the device's IEEE address instead of its NWK. + # If it works, the NWK will be updated when relays arrive. + tried_ieee_address = True + dst_addr = t.AddrModeAddress( + mode=t.AddrMode.IEEE, + address=device.ieee, + ) LOGGER.debug( "Request failed (%s), retry attempt %s of %s", From 1987db6fb1fbd958ad1d06f4fcc69f36437b4b5b Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 2 Jun 2021 13:52:08 -0400 Subject: [PATCH 3/5] Remove broken remnant of code to only handle joins when they're enabled --- zigpy_znp/zigbee/application.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/zigpy_znp/zigbee/application.py b/zigpy_znp/zigbee/application.py index 584fdc49..e9b7ca38 100644 --- a/zigpy_znp/zigbee/application.py +++ b/zigpy_znp/zigbee/application.py @@ -1081,11 +1081,7 @@ async def _get_or_discover_device(self, nwk: t.NWK) -> zigpy.device.Device | Non try: device = self.get_device(ieee=ieee) except KeyError: - if self._permit_joins_task.done(): - LOGGER.warning("Ignoring device because joins are not permitted") - return - - LOGGER.debug("Joins are permitted, treating unknown device as a new join") + LOGGER.debug("Treating unknown device as a new join") self.handle_join(nwk=nwk, ieee=ieee, parent_nwk=None) return self.get_device(ieee=ieee) From 62ccd3efe9494b39aaaf114e993f3ff0c4130ad8 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 4 Jul 2021 17:10:59 -0400 Subject: [PATCH 4/5] Consolidate utility classes and functions into `zigpy_znp.utils` --- tests/api/test_response.py | 16 +-- zigpy_znp/api.py | 148 +------------------------- zigpy_znp/utils.py | 183 ++++++++++++++++++++++++++++++++ zigpy_znp/zigbee/application.py | 34 +----- 4 files changed, 198 insertions(+), 183 deletions(-) create mode 100644 zigpy_znp/utils.py diff --git a/tests/api/test_response.py b/tests/api/test_response.py index ee086abd..bfcf9ae9 100644 --- a/tests/api/test_response.py +++ b/tests/api/test_response.py @@ -5,7 +5,7 @@ import zigpy_znp.types as t import zigpy_znp.commands as c -from zigpy_znp.api import _deduplicate_commands +from zigpy_znp.utils import deduplicate_commands pytestmark = [pytest.mark.asyncio] @@ -215,15 +215,15 @@ async def test_command_deduplication_simple(): c1 = c.SYS.Ping.Rsp(partial=True) c2 = c.UTIL.TimeAlive.Rsp(Seconds=12) - assert _deduplicate_commands([]) == () - assert _deduplicate_commands([c1]) == (c1,) - assert _deduplicate_commands([c1, c1]) == (c1,) - assert _deduplicate_commands([c1, c2]) == (c1, c2) - assert _deduplicate_commands([c2, c1, c2]) == (c2, c1) + assert deduplicate_commands([]) == () + assert deduplicate_commands([c1]) == (c1,) + assert deduplicate_commands([c1, c1]) == (c1,) + assert deduplicate_commands([c1, c2]) == (c1, c2) + assert deduplicate_commands([c2, c1, c2]) == (c2, c1) async def test_command_deduplication_complex(): - result = _deduplicate_commands( + result = deduplicate_commands( [ c.SYS.Ping.Rsp(Capabilities=t.MTCapabilities.SYS), # Duplicating matching commands shouldn't do anything @@ -302,7 +302,7 @@ async def async_callback(response): c.UTIL.TimeAlive.Rsp(Seconds=10), ] - assert set(_deduplicate_commands(responses)) == { + assert set(deduplicate_commands(responses)) == { c.SYS.Ping.Rsp(partial=True), c.UTIL.TimeAlive.Rsp(Seconds=12), c.UTIL.TimeAlive.Rsp(Seconds=10), diff --git a/zigpy_znp/api.py b/zigpy_znp/api.py index 71d3dd38..1b2dffd3 100644 --- a/zigpy_znp/api.py +++ b/zigpy_znp/api.py @@ -1,12 +1,10 @@ from __future__ import annotations import time -import typing import asyncio import logging import itertools import contextlib -import dataclasses from collections import Counter, defaultdict import async_timeout @@ -17,6 +15,11 @@ import zigpy_znp.commands as c from zigpy_znp import uart from zigpy_znp.nvram import NVRAMHelper +from zigpy_znp.utils import ( + BaseResponseListener, + OneShotResponseListener, + CallbackResponseListener, +) from zigpy_znp.frames import GeneralFrame from zigpy_znp.znp.utils import NetworkInfo, load_network_info, detect_zstack_version from zigpy_znp.exceptions import CommandNotRecognized, InvalidCommandResponse @@ -26,147 +29,6 @@ STARTUP_DELAY = 1 # seconds -def _deduplicate_commands( - commands: typing.Iterable[t.CommandBase], -) -> tuple[t.CommandBase]: - """ - Deduplicates an iterable of commands by folding more-specific commands into less- - specific commands. Used to avoid triggering callbacks multiple times per packet. - """ - - # We essentially need to find the "maximal" commands, if you treat the relationship - # between two commands as a partial order. - maximal_commands = [] - - # Command matching as a relation forms a partially ordered set. - for command in commands: - for index, other_command in enumerate(maximal_commands): - if other_command.matches(command): - # If the other command matches us, we are redundant - break - elif command.matches(other_command): - # If we match another command, we replace it - maximal_commands[index] = command - break - else: - # Otherwise, we keep looking - continue # pragma: no cover - else: - # If we matched nothing and nothing matched us, we extend the list - maximal_commands.append(command) - - # The start of each chain is the maximal element - return tuple(maximal_commands) - - -@dataclasses.dataclass(frozen=True) -class BaseResponseListener: - matching_commands: tuple[t.CommandBase] - - def __post_init__(self): - commands = _deduplicate_commands(self.matching_commands) - - if not commands: - raise ValueError("Cannot create a listener without any matching commands") - - # We're frozen so __setattr__ is disallowed - object.__setattr__(self, "matching_commands", commands) - - def matching_headers(self) -> set[t.CommandHeader]: - """ - Returns the set of Z-Stack MT command headers for all the matching commands. - """ - - return {response.header for response in self.matching_commands} - - def resolve(self, response: t.CommandBase) -> bool: - """ - Attempts to resolve the listener with a given response. Can be called with any - command as an argument, including ones we don't match. - """ - - if not any(c.matches(response) for c in self.matching_commands): - return False - - return self._resolve(response) - - def _resolve(self, response: t.CommandBase) -> bool: - """ - Implemented by subclasses to handle matched commands. - - Return value indicates whether or not the listener has actually resolved, - which can sometimes be unavoidable. - """ - - raise NotImplementedError() # pragma: no cover - - def cancel(self): - """ - Implement by subclasses to cancel the listener. - - Return value indicates whether or not the listener is cancelable. - """ - - raise NotImplementedError() # pragma: no cover - - -@dataclasses.dataclass(frozen=True) -class OneShotResponseListener(BaseResponseListener): - """ - A response listener that resolves a single future exactly once. - """ - - future: asyncio.Future = dataclasses.field( - default_factory=lambda: asyncio.get_running_loop().create_future() - ) - - def _resolve(self, response: t.CommandBase) -> bool: - if self.future.done(): - # This happens if the UART receives multiple packets during the same - # event loop step and all of them match this listener. Our Future's - # add_done_callback will not fire synchronously and thus the listener - # is never properly removed. This isn't going to break anything. - LOGGER.debug("Future already has a result set: %s", self.future) - return False - - self.future.set_result(response) - return True - - def cancel(self): - if not self.future.done(): - self.future.cancel() - - return True - - -@dataclasses.dataclass(frozen=True) -class CallbackResponseListener(BaseResponseListener): - """ - A response listener with a sync or async callback that is never resolved. - """ - - callback: typing.Callable[[t.CommandBase], typing.Any] - - def _resolve(self, response: t.CommandBase) -> bool: - try: - result = self.callback(response) - - # Run coroutines in the background - if asyncio.iscoroutine(result): - asyncio.create_task(result) - except Exception: - LOGGER.warning( - "Caught an exception while executing callback", exc_info=True - ) - - # Callbacks are always resolved - return True - - def cancel(self): - # You can't cancel a callback - return False - - class ZNP: def __init__(self, config: conf.ConfigType): self._uart = None diff --git a/zigpy_znp/utils.py b/zigpy_znp/utils.py new file mode 100644 index 00000000..7b53c868 --- /dev/null +++ b/zigpy_znp/utils.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import typing +import asyncio +import logging +import functools +import dataclasses + +import zigpy_znp.types as t + +LOGGER = logging.getLogger(__name__) + + +def deduplicate_commands( + commands: typing.Iterable[t.CommandBase], +) -> tuple[t.CommandBase]: + """ + Deduplicates an iterable of commands by folding more-specific commands into less- + specific commands. Used to avoid triggering callbacks multiple times per packet. + """ + + # We essentially need to find the "maximal" commands, if you treat the relationship + # between two commands as a partial order. + maximal_commands = [] + + # Command matching as a relation forms a partially ordered set. + for command in commands: + for index, other_command in enumerate(maximal_commands): + if other_command.matches(command): + # If the other command matches us, we are redundant + break + elif command.matches(other_command): + # If we match another command, we replace it + maximal_commands[index] = command + break + else: + # Otherwise, we keep looking + continue # pragma: no cover + else: + # If we matched nothing and nothing matched us, we extend the list + maximal_commands.append(command) + + # The start of each chain is the maximal element + return tuple(maximal_commands) + + +@dataclasses.dataclass(frozen=True) +class BaseResponseListener: + matching_commands: tuple[t.CommandBase] + + def __post_init__(self): + commands = deduplicate_commands(self.matching_commands) + + if not commands: + raise ValueError("Cannot create a listener without any matching commands") + + # We're frozen so __setattr__ is disallowed + object.__setattr__(self, "matching_commands", commands) + + def matching_headers(self) -> set[t.CommandHeader]: + """ + Returns the set of Z-Stack MT command headers for all the matching commands. + """ + + return {response.header for response in self.matching_commands} + + def resolve(self, response: t.CommandBase) -> bool: + """ + Attempts to resolve the listener with a given response. Can be called with any + command as an argument, including ones we don't match. + """ + + if not any(c.matches(response) for c in self.matching_commands): + return False + + return self._resolve(response) + + def _resolve(self, response: t.CommandBase) -> bool: + """ + Implemented by subclasses to handle matched commands. + + Return value indicates whether or not the listener has actually resolved, + which can sometimes be unavoidable. + """ + + raise NotImplementedError() # pragma: no cover + + def cancel(self): + """ + Implement by subclasses to cancel the listener. + + Return value indicates whether or not the listener is cancelable. + """ + + raise NotImplementedError() # pragma: no cover + + +@dataclasses.dataclass(frozen=True) +class OneShotResponseListener(BaseResponseListener): + """ + A response listener that resolves a single future exactly once. + """ + + future: asyncio.Future = dataclasses.field( + default_factory=lambda: asyncio.get_running_loop().create_future() + ) + + def _resolve(self, response: t.CommandBase) -> bool: + if self.future.done(): + # This happens if the UART receives multiple packets during the same + # event loop step and all of them match this listener. Our Future's + # add_done_callback will not fire synchronously and thus the listener + # is never properly removed. This isn't going to break anything. + LOGGER.debug("Future already has a result set: %s", self.future) + return False + + self.future.set_result(response) + return True + + def cancel(self): + if not self.future.done(): + self.future.cancel() + + return True + + +@dataclasses.dataclass(frozen=True) +class CallbackResponseListener(BaseResponseListener): + """ + A response listener with a sync or async callback that is never resolved. + """ + + callback: typing.Callable[[t.CommandBase], typing.Any] + + def _resolve(self, response: t.CommandBase) -> bool: + try: + result = self.callback(response) + + # Run coroutines in the background + if asyncio.iscoroutine(result): + asyncio.create_task(result) + except Exception: + LOGGER.warning( + "Caught an exception while executing callback", exc_info=True + ) + + # Callbacks are always resolved + return True + + def cancel(self): + # You can't cancel a callback + return False + + +def combine_concurrent_calls(function): + """ + Decorator that allows concurrent calls to expensive coroutines to share a result. + """ + + futures = {} + + @functools.wraps(function) + async def replacement(*args, **kwargs): + # XXX: all args and kwargs are assumed to be hashable + key = (tuple(args), tuple([(k, v) for k, v in kwargs.items()])) + + if key in futures: + return await futures[key] + + future = futures[key] = asyncio.get_running_loop().create_future() + + try: + result = await function(*args, **kwargs) + except Exception as e: + future.set_exception(e) + raise + else: + future.set_result(result) + return result + finally: + del futures[key] + + return replacement diff --git a/zigpy_znp/zigbee/application.py b/zigpy_znp/zigbee/application.py index e9b7ca38..81078b82 100644 --- a/zigpy_znp/zigbee/application.py +++ b/zigpy_znp/zigbee/application.py @@ -4,7 +4,6 @@ import time import asyncio import logging -import functools import itertools import contextlib @@ -27,6 +26,7 @@ import zigpy_znp.config as conf import zigpy_znp.commands as c from zigpy_znp.api import ZNP +from zigpy_znp.utils import combine_concurrent_calls from zigpy_znp.exceptions import CommandNotRecognized, InvalidCommandResponse from zigpy_znp.types.nvids import OsalNvIds from zigpy_znp.zigbee.zdo_converters import ZDO_CONVERTERS @@ -45,7 +45,7 @@ MULTICAST_SEND_WAIT_DURATION = 3 REQUEST_MAX_RETRIES = 5 -REQUEST_ERROR_RETRY_DELAY = 0.5 # seconds +REQUEST_ERROR_RETRY_DELAY = 0.5 # Errors that go away on their own after waiting for a bit REQUEST_TRANSIENT_ERRORS = { @@ -81,36 +81,6 @@ LOGGER = logging.getLogger(__name__) -def combine_concurrent_calls(function): - """ - Decorator that allows concurrent calls to expensive coroutines to share a result. - """ - - futures = {} - - @functools.wraps(function) - async def replacement(*args, **kwargs): - key = (tuple(args), tuple([(k, v) for k, v in kwargs.items()])) - - if key in futures: - return await futures[key] - - future = futures[key] = asyncio.get_running_loop().create_future() - - try: - result = await function(*args, **kwargs) - except Exception as e: - future.set_exception(e) - raise - else: - future.set_result(result) - return result - finally: - del futures[key] - - return replacement - - class ZNPCoordinator(zigpy.device.Device): """ Coordinator zigpy device that keeps track of our endpoints and clusters. From ca4572ec97010a1c2c441c869bda22404a9b8887 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Mon, 5 Jul 2021 19:06:42 -0400 Subject: [PATCH 5/5] Increase test coverage for unknown device discovery --- tests/application/test_joining.py | 79 ++++++++++++++++++++++++++++++ tests/application/test_requests.py | 69 ++++++++++++++++++++++++-- tests/application/test_startup.py | 2 + zigpy_znp/zigbee/application.py | 7 ++- 4 files changed, 148 insertions(+), 9 deletions(-) diff --git a/tests/application/test_joining.py b/tests/application/test_joining.py index 32487a11..fa084f34 100644 --- a/tests/application/test_joining.py +++ b/tests/application/test_joining.py @@ -508,3 +508,82 @@ def bind_req_callback(request): await cluster.bind() await app.shutdown() + + +@pytest.mark.parametrize("device", FORMED_DEVICES) +async def test_unknown_device_discovery(device, make_application, mocker): + app, znp_server = make_application(server_cls=device) + await app.startup(auto_form=False) + + mocker.spy(app, "handle_join") + + # Existing devices do not need to be discovered + existing_nwk = 0x1234 + existing_ieee = t.EUI64(range(8)) + device = app.add_initialized_device(ieee=existing_ieee, nwk=existing_nwk) + + assert (await app._get_or_discover_device(nwk=existing_nwk)) is device + assert app.handle_join.call_count == 0 + + # If the device changes its NWK but doesn't tell zigpy, it will be re-discovered + did_ieee_addr_req1 = znp_server.reply_once_to( + request=c.ZDO.IEEEAddrReq.Req( + NWK=existing_nwk + 1, + RequestType=c.zdo.AddrRequestType.SINGLE, + StartIndex=0, + ), + responses=[ + c.ZDO.IEEEAddrReq.Rsp(Status=t.Status.SUCCESS), + c.ZDO.IEEEAddrRsp.Callback( + Status=t.ZDOStatus.SUCCESS, + IEEE=existing_ieee, + NWK=existing_nwk + 1, + Index=0, + Devices=[], + ), + ], + ) + + # The same device is discovered and its NWK was updated. Handles concurrency. + devices = await asyncio.gather( + app._get_or_discover_device(nwk=existing_nwk + 1), + app._get_or_discover_device(nwk=existing_nwk + 1), + app._get_or_discover_device(nwk=existing_nwk + 1), + app._get_or_discover_device(nwk=existing_nwk + 1), + app._get_or_discover_device(nwk=existing_nwk + 1), + ) + + assert devices == [device] * 5 + + # Only a single request is sent, since the coroutines are grouped + await did_ieee_addr_req1 + assert device.nwk == existing_nwk + 1 + assert app.handle_join.call_count == 0 + + # If a completely unknown device joins the network, it will be treated as a new join + new_nwk = 0x5678 + new_ieee = t.EUI64(range(1, 9)) + did_ieee_addr_req2 = znp_server.reply_once_to( + request=c.ZDO.IEEEAddrReq.Req( + NWK=new_nwk, + RequestType=c.zdo.AddrRequestType.SINGLE, + StartIndex=0, + ), + responses=[ + c.ZDO.IEEEAddrReq.Rsp(Status=t.Status.SUCCESS), + c.ZDO.IEEEAddrRsp.Callback( + Status=t.ZDOStatus.SUCCESS, + IEEE=new_ieee, + NWK=new_nwk, + Index=0, + Devices=[], + ), + ], + ) + new_dev = await app._get_or_discover_device(nwk=new_nwk) + await did_ieee_addr_req2 + assert app.handle_join.call_count == 1 + assert new_dev.nwk == new_nwk + assert new_dev.ieee == new_ieee + + await app.pre_shutdown() diff --git a/tests/application/test_requests.py b/tests/application/test_requests.py index bc1e2c78..ba53bcb8 100644 --- a/tests/application/test_requests.py +++ b/tests/application/test_requests.py @@ -487,7 +487,6 @@ async def test_nonstandard_profile(device, make_application): await app.startup(auto_form=False) device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xFA9E) - device.node_desc, _ = device.node_desc.deserialize(bytes(14)) ep = device.add_endpoint(2) ep.status = zigpy.endpoint.Status.ZDO_INIT @@ -605,7 +604,6 @@ async def test_request_recovery_route_rediscovery_zdo(device, make_application, app._znp._config[conf.CONF_ZNP_CONFIG][conf.CONF_ARSP_TIMEOUT] = 1 device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xABCD) - device.node_desc, _ = device.node_desc.deserialize(bytes(14)) # Fail the first time route_discovered = False @@ -669,7 +667,6 @@ async def test_request_recovery_route_rediscovery_af(device, make_application, m app._znp._config[conf.CONF_ZNP_CONFIG][conf.CONF_ARSP_TIMEOUT] = 1 device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xABCD) - device.node_desc, _ = device.node_desc.deserialize(bytes(14)) # Fail the first time route_discovered = False @@ -728,6 +725,70 @@ def set_route_discovered(req): await app.shutdown() +@pytest.mark.parametrize("device", [FormedLaunchpadCC26X2R1]) +async def test_request_recovery_use_ieee_addr(device, make_application, mocker): + app, znp_server = make_application(server_cls=device) + + await app.startup(auto_form=False) + + # The data confirm timeout must be shorter than the ARSP timeout + mocker.patch("zigpy_znp.zigbee.application.DATA_CONFIRM_TIMEOUT", new=0.1) + app._znp._config[conf.CONF_ZNP_CONFIG][conf.CONF_ARSP_TIMEOUT] = 1 + + device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xABCD) + + was_ieee_addr_used = False + + def data_confirm_replier(req): + nonlocal was_ieee_addr_used + + if req.DstAddrModeAddress.mode == t.AddrMode.IEEE: + status = t.Status.SUCCESS + was_ieee_addr_used = True + else: + status = t.Status.MAC_NO_ACK + + return c.AF.DataConfirm.Callback(Status=status, Endpoint=1, TSN=1) + + znp_server.reply_once_to( + c.ZDO.ExtRouteDisc.Req( + Dst=device.nwk, Options=c.zdo.RouteDiscoveryOptions.UNICAST, partial=True + ), + responses=[c.ZDO.ExtRouteDisc.Rsp(Status=t.Status.SUCCESS)], + ) + + znp_server.reply_to( + c.AF.DataRequestExt.Req(partial=True), + responses=[ + c.AF.DataRequestExt.Rsp(Status=t.Status.SUCCESS), + data_confirm_replier, + ], + ) + + # Ignore the source routing request as well + znp_server.reply_to( + c.AF.DataRequestSrcRtg.Req(partial=True), + responses=[ + c.AF.DataRequestSrcRtg.Rsp(Status=t.Status.SUCCESS), + c.AF.DataConfirm.Callback(Status=t.Status.MAC_NO_ACK, Endpoint=1, TSN=1), + ], + ) + + await app.request( + device=device, + profile=260, + cluster=1, + src_ep=1, + dst_ep=1, + sequence=1, + data=b"\x00", + ) + + assert was_ieee_addr_used + + await app.shutdown() + + @pytest.mark.parametrize("device_cls", FORMED_DEVICES) @pytest.mark.parametrize("fw_assoc_remove", [True, False]) @pytest.mark.parametrize("final_status", [t.Status.SUCCESS, t.Status.APS_NO_ACK]) @@ -744,7 +805,6 @@ async def test_request_recovery_assoc_remove( app._znp._config[conf.CONF_ZNP_CONFIG][conf.CONF_ARSP_TIMEOUT] = 1 device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xABCD) - device.node_desc, _ = device.node_desc.deserialize(bytes(14)) assoc_device, _ = c.util.Device.deserialize(b"\xFF" * 100) assoc_device.shortAddr = device.nwk @@ -880,7 +940,6 @@ async def test_request_recovery_manual_source_route( device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xABCD) device.relays = relays - device.node_desc, _ = device.node_desc.deserialize(bytes(14)) def data_confirm_replier(req): if isinstance(req, c.AF.DataRequestExt.Req) or not succeed: diff --git a/tests/application/test_startup.py b/tests/application/test_startup.py index a173a63b..ef1acc0e 100644 --- a/tests/application/test_startup.py +++ b/tests/application/test_startup.py @@ -71,6 +71,7 @@ async def test_info( assert app.channel is None assert app.channels is None assert app.network_key is None + assert app.network_key_seq is None await app.startup(auto_form=False) @@ -84,6 +85,7 @@ async def test_info( assert app.channel == channel assert app.channels == channels assert app.network_key == network_key + assert app.network_key_seq == 0 assert app.zigpy_device.manufacturer == "Texas Instruments" assert app.zigpy_device.model == model diff --git a/zigpy_znp/zigbee/application.py b/zigpy_znp/zigbee/application.py index 81078b82..e8a8cf72 100644 --- a/zigpy_znp/zigbee/application.py +++ b/zigpy_znp/zigbee/application.py @@ -1016,9 +1016,8 @@ async def _watchdog_loop(self): async def _get_or_discover_device(self, nwk: t.NWK) -> zigpy.device.Device | None: """ Finds a device by its NWK address. If a device does not exist in the zigpy - database, attempt to look up its new NWK address. If joins are currently allowed - then the device will be treated as a new join if it does not exist in the zigpy - database. + database, attempt to look up its new NWK address. If it does not exist in the + zigpy database, treat the device as a new join. """ try: @@ -1044,7 +1043,7 @@ async def _get_or_discover_device(self, nwk: t.NWK) -> zigpy.device.Device | Non timeout=5, # We don't want to wait forever ) except asyncio.TimeoutError: - return + return None ieee = ieee_addr_rsp.IEEE