Skip to content

Commit 3fe9823

Browse files
authored
Merge pull request #122 from puddly/puddly/zdo-nwk-and-ieee-addr-req
Implement ZDO converters for `NWK_addr_req` and `IEEE_addr_req`
2 parents 48438ca + e6d3e0d commit 3fe9823

File tree

6 files changed

+180
-57
lines changed

6 files changed

+180
-57
lines changed

tests/application/test_joining.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ async def test_unknown_device_discovery(device, make_application, mocker):
554554
Status=t.ZDOStatus.SUCCESS,
555555
IEEE=existing_ieee,
556556
NWK=existing_nwk + 1,
557+
NumAssoc=0,
557558
Index=0,
558559
Devices=[],
559560
),
@@ -591,6 +592,7 @@ async def test_unknown_device_discovery(device, make_application, mocker):
591592
Status=t.ZDOStatus.SUCCESS,
592593
IEEE=new_ieee,
593594
NWK=new_nwk,
595+
NumAssoc=0,
594596
Index=0,
595597
Devices=[],
596598
),
@@ -603,3 +605,24 @@ async def test_unknown_device_discovery(device, make_application, mocker):
603605
assert new_dev.ieee == new_ieee
604606

605607
await app.pre_shutdown()
608+
609+
610+
@pytest.mark.parametrize("device", FORMED_DEVICES)
611+
async def test_unknown_device_discovery_failure(device, make_application, mocker):
612+
mocker.patch("zigpy_znp.zigbee.application.IEEE_ADDR_DISCOVERY_TIMEOUT", new=0.1)
613+
614+
app, znp_server = make_application(server_cls=device)
615+
await app.startup(auto_form=False)
616+
617+
znp_server.reply_once_to(
618+
request=c.ZDO.IEEEAddrReq.Req(partial=True),
619+
responses=[
620+
c.ZDO.IEEEAddrReq.Rsp(Status=t.Status.SUCCESS),
621+
],
622+
)
623+
624+
# Discovery will throw an exception when the device cannot be found
625+
with pytest.raises(KeyError):
626+
await app._get_or_discover_device(nwk=0x3456)
627+
628+
await app.pre_shutdown()

tests/application/test_zigpy_callbacks.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,20 @@
1010
from ..conftest import FORMED_DEVICES, CoroutineMock
1111

1212

13-
def awaitable_mock(return_value):
13+
def awaitable_mock(*, return_value=None, side_effect=None):
14+
assert (return_value or side_effect) and not (return_value and side_effect)
15+
1416
mock_called = asyncio.get_running_loop().create_future()
1517

16-
def side_effect(*args, **kwargs):
18+
def side_effect_(*args, **kwargs):
1719
mock_called.set_result((args, kwargs))
1820

19-
return return_value
21+
if return_value is not None:
22+
return return_value
23+
else:
24+
raise side_effect
2025

21-
return mock_called, CoroutineMock(side_effect=side_effect)
26+
return mock_called, CoroutineMock(side_effect=side_effect_)
2227

2328

2429
@pytest.mark.parametrize("device", FORMED_DEVICES)
@@ -45,7 +50,7 @@ async def test_on_zdo_relays_message_callback_unknown(
4550
app, znp_server = make_application(server_cls=device)
4651
await app.startup(auto_form=False)
4752

48-
discover_called, discover_mock = awaitable_mock(return_value=None)
53+
discover_called, discover_mock = awaitable_mock(side_effect=KeyError())
4954
mocker.patch.object(app, "_get_or_discover_device", new=discover_mock)
5055

5156
caplog.set_level(logging.WARNING)
@@ -183,7 +188,7 @@ async def test_on_af_message_callback(device, make_application, mocker):
183188
app.get_device.reset_mock()
184189

185190
# Message from an unknown device
186-
discover_called, discover_mock = awaitable_mock(return_value=None)
191+
discover_called, discover_mock = awaitable_mock(side_effect=KeyError())
187192
mocker.patch.object(app, "_get_or_discover_device", new=discover_mock)
188193

189194
znp_server.send(af_message)

zigpy_znp/commands/zdo.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ class ChildInfoList(t.LVList, item_type=t.EUI64, length_type=t.uint8_t):
116116
pass
117117

118118

119+
class NWKArray(t.CompleteList, item_type=t.NWK):
120+
pass
121+
122+
119123
class NullableNodeDescriptor(zigpy.zdo.types.NodeDescriptor):
120124
@classmethod
121125
def deserialize(cls, data: bytes) -> tuple[NullableNodeDescriptor, bytes]:
@@ -935,12 +939,13 @@ class ZDO(t.CommandsBase, subsystem=t.Subsystem.ZDO):
935939
),
936940
t.Param("IEEE", t.EUI64, "Extended address of the source device"),
937941
t.Param("NWK", t.NWK, "Short address of the source device"),
942+
t.Param("NumAssoc", t.uint8_t, "Number of associated devices"),
938943
t.Param(
939944
"Index",
940945
t.uint8_t,
941946
"Starting index into the list of associated devices",
942947
),
943-
t.Param("Devices", t.NWKList, "List of the associated devices"),
948+
t.Param("Devices", NWKArray, "List of the associated devices"),
944949
),
945950
)
946951

@@ -954,12 +959,13 @@ class ZDO(t.CommandsBase, subsystem=t.Subsystem.ZDO):
954959
),
955960
t.Param("IEEE", t.EUI64, "Extended address of the source device"),
956961
t.Param("NWK", t.NWK, "Short address of the source device"),
962+
t.Param("NumAssoc", t.uint8_t, "Number of associated devices"),
957963
t.Param(
958964
"Index",
959965
t.uint8_t,
960966
"Starting index into the list of associated devices",
961967
),
962-
t.Param("Devices", t.NWKList, "List of the associated devices"),
968+
t.Param("Devices", NWKArray, "List of the associated devices"),
963969
),
964970
)
965971

zigpy_znp/utils.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,14 @@ def matches(self, other) -> bool:
164164
return True
165165

166166

167-
def combine_concurrent_calls(function):
167+
def combine_concurrent_calls(
168+
function: typing.CoroutineFunction,
169+
) -> typing.CoroutineFunction:
168170
"""
169171
Decorator that allows concurrent calls to expensive coroutines to share a result.
170172
"""
171173

172-
futures = {}
174+
tasks = {}
173175
signature = inspect.signature(function)
174176

175177
@functools.wraps(function)
@@ -180,20 +182,15 @@ async def replacement(*args, **kwargs):
180182
# XXX: all args and kwargs are assumed to be hashable
181183
key = tuple(bound.arguments.items())
182184

183-
if key in futures:
184-
return await futures[key]
185+
if key in tasks:
186+
return await tasks[key]
185187

186-
future = futures[key] = asyncio.get_running_loop().create_future()
188+
tasks[key] = asyncio.create_task(function(*args, **kwargs))
187189

188190
try:
189-
result = await function(*args, **kwargs)
190-
except Exception as e:
191-
future.set_exception(e)
192-
raise
193-
else:
194-
future.set_result(result)
195-
return result
191+
return await tasks[key]
196192
finally:
197-
del futures[key]
193+
assert tasks[key].done()
194+
del tasks[key]
198195

199196
return replacement

zigpy_znp/zigbee/application.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
STARTUP_TIMEOUT = 5
4848
ZDO_REQUEST_TIMEOUT = 15
4949
DATA_CONFIRM_TIMEOUT = 8
50+
IEEE_ADDR_DISCOVERY_TIMEOUT = 5
5051
DEVICE_JOIN_MAX_DELAY = 5
5152
WATCHDOG_PERIOD = 30
5253
BROADCAST_SEND_WAIT_DURATION = 3
@@ -670,6 +671,12 @@ def _bind_callbacks(self) -> None:
670671
self.on_intentionally_unhandled_message,
671672
)
672673

674+
# These are responses to a broadcast but we ignore all but the first
675+
self._znp.callback_for_response(
676+
c.ZDO.IEEEAddrRsp.Callback(partial=True),
677+
self.on_intentionally_unhandled_message,
678+
)
679+
673680
def on_intentionally_unhandled_message(self, msg: t.CommandBase) -> None:
674681
"""
675682
Some commands are unhandled but frequently sent by devices on the network. To
@@ -693,9 +700,9 @@ async def on_zdo_relays_message(self, msg: c.ZDO.SrcRtgInd.Callback) -> None:
693700
ZDO source routing message callback
694701
"""
695702

696-
device = await self._get_or_discover_device(nwk=msg.DstAddr)
697-
698-
if device is None:
703+
try:
704+
device = await self._get_or_discover_device(nwk=msg.DstAddr)
705+
except KeyError:
699706
LOGGER.warning(
700707
"Received a ZDO message from an unknown device: %s", msg.DstAddr
701708
)
@@ -778,9 +785,9 @@ async def on_af_message(self, msg: c.AF.IncomingMsg.Callback) -> None:
778785
Handler for all non-ZDO messages.
779786
"""
780787

781-
device = await self._get_or_discover_device(nwk=msg.SrcAddr)
782-
783-
if device is None:
788+
try:
789+
device = await self._get_or_discover_device(nwk=msg.SrcAddr)
790+
except KeyError:
784791
LOGGER.warning(
785792
"Received an AF message from an unknown device: %s", msg.SrcAddr
786793
)
@@ -864,7 +871,7 @@ async def _watchdog_loop(self):
864871
return
865872

866873
@combine_concurrent_calls
867-
async def _get_or_discover_device(self, nwk: t.NWK) -> zigpy.device.Device | None:
874+
async def _get_or_discover_device(self, nwk: t.NWK) -> zigpy.device.Device:
868875
"""
869876
Finds a device by its NWK address. If a device does not exist in the zigpy
870877
database, attempt to look up its new NWK address. If it does not exist in the
@@ -880,23 +887,16 @@ async def _get_or_discover_device(self, nwk: t.NWK) -> zigpy.device.Device | Non
880887

881888
try:
882889
# XXX: Multiple responses may arrive but we only use the first one
883-
ieee_addr_rsp = await self._znp.request_callback_rsp(
884-
request=c.ZDO.IEEEAddrReq.Req(
885-
NWK=nwk,
886-
RequestType=c.zdo.AddrRequestType.SINGLE,
887-
StartIndex=0,
888-
),
889-
RspStatus=t.Status.SUCCESS,
890-
callback=c.ZDO.IEEEAddrRsp.Callback(
891-
partial=True,
892-
NWK=nwk,
893-
),
894-
timeout=5, # We don't want to wait forever
895-
)
890+
async with async_timeout.timeout(IEEE_ADDR_DISCOVERY_TIMEOUT):
891+
_, ieee, _, _, _, _ = await self.zigpy_device.zdo.IEEE_addr_req(
892+
*{
893+
"NWKAddrOfInterest": nwk,
894+
"RequestType": c.zdo.AddrRequestType.SINGLE,
895+
"StartIndex": 0,
896+
}.values()
897+
)
896898
except asyncio.TimeoutError:
897-
return None
898-
899-
ieee = ieee_addr_rsp.IEEE
899+
raise KeyError(f"Unknown device: 0x{nwk:04X}")
900900

901901
try:
902902
device = self.get_device(ieee=ieee)
@@ -1276,7 +1276,7 @@ async def _send_zdo_request(
12761276
# Call the converter with the ZDO request's kwargs
12771277
req_factory, rsp_factory, zdo_rsp_factory = ZDO_CONVERTERS[cluster]
12781278
request = req_factory(dst_addr, **zdo_kwargs)
1279-
callback = rsp_factory(dst_addr)
1279+
callback = rsp_factory(dst_addr, **zdo_kwargs)
12801280

12811281
LOGGER.debug(
12821282
"Intercepted AP ZDO request %s(%s) and replaced with %s",

0 commit comments

Comments
 (0)