diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f698f5f1..7fe7abb0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: [3.8, 3.9, "3.10"] steps: - name: Check out code from GitHub uses: actions/checkout@v2 @@ -224,7 +224,7 @@ jobs: needs: prepare-base strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: [3.8, 3.9, "3.10"] name: >- Run tests Python ${{ matrix.python-version }} steps: diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index 3fcaed40..f607e707 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -10,10 +10,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@master - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install wheel run: >- pip install wheel diff --git a/README.md b/README.md index 3f8b8597..984f6477 100644 --- a/README.md +++ b/README.md @@ -51,9 +51,6 @@ Below are the defaults with the top-level Home Assistant `zha:` key. zha: zigpy_config: znp_config: - # "auto" picks the largest value that keeps the device's transmit buffer from getting full - max_concurrent_requests: auto - # Only if your stick has a built-in power amplifier (i.e. CC1352P and CC2592) # If set, must be between: # * CC1352/2652: -22 and 19 diff --git a/setup.cfg b/setup.cfg index b21a1a23..1a3ecd52 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,9 +13,7 @@ license = GPL-3.0 packages = find: python_requires = >=3.7 install_requires = - pyserial-asyncio; platform_system!="Windows" - pyserial-asyncio!=0.5; platform_system=="Windows" # 0.5 broke writes - zigpy>=0.50.0 + zigpy>=0.51.0 async_timeout voluptuous coloredlogs diff --git a/tests/application/test_joining.py b/tests/application/test_joining.py index 6b9ff9cc..1cd30203 100644 --- a/tests/application/test_joining.py +++ b/tests/application/test_joining.py @@ -333,107 +333,3 @@ async def test_on_zdo_device_join_and_announce_slow(device, make_application, mo assert app.handle_join.call_count == 2 await app.shutdown() - - -@pytest.mark.parametrize("device", FORMED_DEVICES) -async def test_unknown_device_discovery(device, make_application, mocker): - app, znp_server = await make_application(server_cls=device) - await app.startup(auto_form=False) - - mocker.spy(app, "handle_join") - - # Existing devices do not need to be discovered - existing_nwk = 0x1234 - existing_ieee = t.EUI64(range(8)) - device = app.add_initialized_device(ieee=existing_ieee, nwk=existing_nwk) - - assert (await app._get_or_discover_device(nwk=existing_nwk)) is device - assert app.handle_join.call_count == 0 - - # If the device changes its NWK but doesn't tell zigpy, it will be re-discovered - did_ieee_addr_req1 = znp_server.reply_once_to( - request=c.ZDO.IEEEAddrReq.Req( - NWK=existing_nwk + 1, - RequestType=c.zdo.AddrRequestType.SINGLE, - StartIndex=0, - ), - responses=[ - c.ZDO.IEEEAddrReq.Rsp(Status=t.Status.SUCCESS), - c.ZDO.IEEEAddrRsp.Callback( - Status=t.ZDOStatus.SUCCESS, - IEEE=existing_ieee, - NWK=existing_nwk + 1, - NumAssoc=0, - Index=0, - Devices=[], - ), - ], - ) - - # The same device is discovered and its NWK was updated. Handles concurrency. - devices = await asyncio.gather( - app._get_or_discover_device(nwk=existing_nwk + 1), - app._get_or_discover_device(nwk=existing_nwk + 1), - app._get_or_discover_device(nwk=existing_nwk + 1), - app._get_or_discover_device(nwk=existing_nwk + 1), - app._get_or_discover_device(nwk=existing_nwk + 1), - ) - - assert devices == [device] * 5 - - # Only a single request is sent, since the coroutines are grouped - await did_ieee_addr_req1 - assert device.nwk == existing_nwk + 1 - assert app.handle_join.call_count == 1 - - # If a completely unknown device joins the network, it will be treated as a new join - new_nwk = 0x5678 - new_ieee = t.EUI64(range(1, 9)) - - did_ieee_addr_req2 = znp_server.reply_once_to( - request=c.ZDO.IEEEAddrReq.Req( - NWK=new_nwk, - RequestType=c.zdo.AddrRequestType.SINGLE, - StartIndex=0, - ), - responses=[ - c.ZDO.IEEEAddrReq.Rsp(Status=t.Status.SUCCESS), - c.ZDO.IEEEAddrRsp.Callback( - Status=t.ZDOStatus.SUCCESS, - IEEE=new_ieee, - NWK=new_nwk, - NumAssoc=0, - Index=0, - Devices=[], - ), - ], - ) - - new_dev = await app._get_or_discover_device(nwk=new_nwk) - await did_ieee_addr_req2 - assert app.handle_join.call_count == 2 - assert new_dev.nwk == new_nwk - assert new_dev.ieee == new_ieee - - await app.shutdown() - - -@pytest.mark.parametrize("device", FORMED_DEVICES) -async def test_unknown_device_discovery_failure(device, make_application, mocker): - mocker.patch("zigpy_znp.zigbee.application.IEEE_ADDR_DISCOVERY_TIMEOUT", new=0.1) - - app, znp_server = await make_application(server_cls=device) - await app.startup(auto_form=False) - - znp_server.reply_once_to( - request=c.ZDO.IEEEAddrReq.Req(partial=True), - responses=[ - c.ZDO.IEEEAddrReq.Rsp(Status=t.Status.SUCCESS), - ], - ) - - # Discovery will throw an exception when the device cannot be found - with pytest.raises(KeyError): - await app._get_or_discover_device(nwk=0x3456) - - await app.shutdown() diff --git a/tests/application/test_requests.py b/tests/application/test_requests.py index 04739026..7ee25a71 100644 --- a/tests/application/test_requests.py +++ b/tests/application/test_requests.py @@ -1,7 +1,7 @@ import asyncio -import logging import pytest +import zigpy.types as zigpy_t import zigpy.endpoint import zigpy.profiles import zigpy.zdo.types as zdo_t @@ -142,13 +142,13 @@ async def test_zigpy_request_failure(device, make_application, mocker): ], ) - mocker.spy(app, "_send_request") + mocker.spy(app, "send_packet") # Fail to turn on the light with pytest.raises(InvalidCommandResponse): await device.endpoints[1].on_off.on() - assert app._send_request.call_count == 1 + assert app.send_packet.call_count == 1 await app.shutdown() @@ -167,7 +167,7 @@ async def test_request_addr_mode(device, addr, make_application, mocker): device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xAABB) - mocker.patch.object(app, "_send_request", new=CoroutineMock()) + mocker.patch.object(app, "send_packet", new=CoroutineMock()) await app.request( device, @@ -180,8 +180,8 @@ async def test_request_addr_mode(device, addr, make_application, mocker): data=b"6", ) - assert app._send_request.call_count == 1 - assert app._send_request.mock_calls[0][2]["dst_addr"] == addr + assert app.send_packet.call_count == 1 + assert app.send_packet.mock_calls[0].args[0].dst == addr.as_zigpy_type() await app.shutdown() @@ -190,16 +190,17 @@ async def test_request_addr_mode(device, addr, make_application, mocker): async def test_mrequest(device, make_application, mocker): app, znp_server = await make_application(server_cls=device) - mocker.patch.object(app, "_send_request", new=CoroutineMock()) + mocker.patch.object(app, "send_packet", new=CoroutineMock()) group = app.groups.add_group(0x1234, "test group") await group.endpoint.on_off.on() - assert app._send_request.call_count == 1 - assert app._send_request.mock_calls[0][2]["dst_addr"] == t.AddrModeAddress( - mode=t.AddrMode.Group, address=0x1234 + assert app.send_packet.call_count == 1 + assert ( + app.send_packet.mock_calls[0].args[0].dst + == t.AddrModeAddress(mode=t.AddrMode.Group, address=0x1234).as_zigpy_type() ) - assert app._send_request.mock_calls[0][2]["data"] == b"\x01\x01\x01" + assert app.send_packet.mock_calls[0].args[0].data.serialize() == b"\x01\x01\x01" await app.shutdown() @@ -268,6 +269,7 @@ async def test_broadcast(device, make_application, mocker): radius=3, sequence=1, data=b"???", + broadcast_address=zigpy_t.BroadcastAddress.RX_ON_WHEN_IDLE, ) await app.shutdown() @@ -277,7 +279,7 @@ async def test_broadcast(device, make_application, mocker): async def test_request_concurrency(device, make_application, mocker): app, znp_server = await make_application( server_cls=device, - client_config={"znp_config": {conf.CONF_MAX_CONCURRENT_REQUESTS: 2}}, + client_config={conf.CONF_MAX_CONCURRENT_REQUESTS: 2}, ) await app.startup() @@ -319,7 +321,7 @@ async def callback(req): ) # We create a whole bunch at once - responses = await asyncio.gather( + await asyncio.gather( *[ app.request( device, @@ -334,69 +336,12 @@ async def callback(req): ] ) - assert all(status == t.Status.SUCCESS for status, msg in responses) assert in_flight_requests == 0 assert did_lock await app.shutdown() -""" -@pytest.mark.parametrize("device", [FormedLaunchpadCC26X2R1]) -async def test_request_concurrency_overflow(device, make_application, mocker): - mocker.patch("zigpy_znp.zigbee.application.MAX_WAITING_REQUESTS", new=1) - - app, znp_server = await make_application( - server_cls=device, client_config={ - 'znp_config': {conf.CONF_MAX_CONCURRENT_REQUESTS: 1} - } - ) - - await app.startup() - - device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xAABB) - - def make_response(req): - async def callback(req): - await asyncio.sleep(0.01 * req.TSN) - - znp_server.send(c.AF.DataRequestExt.Rsp(Status=t.Status.SUCCESS)) - znp_server.send( - c.AF.DataConfirm.Callback( - Status=t.Status.SUCCESS, Endpoint=1, TSN=req.TSN - ) - ) - - asyncio.create_task(callback(req)) - - znp_server.reply_to( - request=c.AF.DataRequestExt.Req(partial=True), responses=[make_response] - ) - - # We can only handle 1 in-flight request and 1 enqueued request. Last one will fail. - responses = await asyncio.gather( - *[ - app.request( - device, - profile=260, - cluster=1, - src_ep=1, - dst_ep=1, - sequence=seq, - data=b"\x00", - ) - for seq in range(3) - ], return_exceptions=True) - - (rsp1, stat1), (rsp2, stat2), error3 = responses - - assert rsp1 == rsp2 == t.Status.SUCCESS - assert isinstance(error3, ValueError) - - await app.shutdown() -""" - - @pytest.mark.parametrize("device", FORMED_DEVICES) async def test_nonstandard_profile(device, make_application): app, znp_server = await make_application(server_cls=device) @@ -974,34 +919,93 @@ async def test_route_discovery_concurrency(device, make_application): await app.shutdown() -@pytest.mark.parametrize("device", [FormedLaunchpadCC26X2R1]) -async def test_zdo_from_unknown(device, make_application, caplog, mocker): - mocker.patch("zigpy_znp.zigbee.application.IEEE_ADDR_DISCOVERY_TIMEOUT", new=0.1) - +@pytest.mark.parametrize("device", FORMED_DEVICES) +async def test_send_security_and_packet_source_route(device, make_application, mocker): app, znp_server = await make_application(server_cls=device) + await app.startup(auto_form=False) - znp_server.reply_once_to( - request=c.ZDO.IEEEAddrReq.Req(partial=True), - responses=[c.ZDO.IEEEAddrReq.Rsp(Status=t.Status.SUCCESS)], + packet = zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, address=app.state.node_info.nwk + ), + src_ep=0x9A, + dst=zigpy_t.AddrModeAddress(addr_mode=zigpy_t.AddrMode.NWK, address=0xEEFF), + dst_ep=0xBC, + tsn=0xDE, + profile_id=0x1234, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"test data"), + extended_timeout=False, + tx_options=( + zigpy_t.TransmitOptions.ACK | zigpy_t.TransmitOptions.APS_Encryption + ), + source_route=[0xAABB, 0xCCDD], + ) + + data_req = znp_server.reply_once_to( + request=c.AF.DataRequestSrcRtg.Req( + DstAddr=packet.dst.address, + DstEndpoint=packet.dst_ep, + # SrcEndpoint=packet.src_ep, + ClusterId=packet.cluster_id, + TSN=packet.tsn, + Data=packet.data.serialize(), + SourceRoute=packet.source_route, + partial=True, + ), + responses=[ + c.AF.DataRequestSrcRtg.Rsp(Status=t.Status.SUCCESS), + c.AF.DataConfirm.Callback( + Status=t.Status.SUCCESS, + Endpoint=packet.dst_ep, + TSN=packet.tsn, + ), + ], ) + await app.send_packet(packet) + req = await data_req + assert c.af.TransmitOptions.ENABLE_SECURITY in req.Options + + await app.shutdown() + + +@pytest.mark.parametrize("device", FORMED_DEVICES) +async def test_send_packet_failure(device, make_application, mocker): + app, znp_server = await make_application(server_cls=device) await app.startup(auto_form=False) - caplog.set_level(logging.WARNING) - - znp_server.send( - c.ZDO.MsgCbIncoming.Callback( - Src=0x1234, - IsBroadcast=t.Bool.false, - ClusterId=zdo_t.ZDOCmd.Mgmt_Leave_rsp, - SecurityUse=0, - TSN=123, - MacDst=0x0000, - Data=t.Bytes([123, 0x00]), - ) + packet = zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress(addr_mode=zigpy_t.AddrMode.NWK, address=0x0000), + src_ep=0x9A, + dst=zigpy_t.AddrModeAddress(addr_mode=zigpy_t.AddrMode.NWK, address=0xEEFF), + dst_ep=0xBC, + tsn=0xDE, + profile_id=0x1234, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"test data"), + ) + + znp_server.reply_to( + request=c.ZDO.ExtRouteDisc.Req(Dst=packet.dst.address, partial=True), + responses=[c.ZDO.ExtRouteDisc.Rsp(Status=t.Status.SUCCESS)], ) - await asyncio.sleep(0.5) - assert "unknown device" in caplog.text + znp_server.reply_to( + request=c.AF.DataRequestExt.Req(partial=True), + responses=[ + c.AF.DataRequestExt.Rsp(Status=t.Status.SUCCESS), + c.AF.DataConfirm.Callback( + Status=t.Status.MAC_NO_ACK, + Endpoint=packet.dst_ep, + TSN=packet.tsn, + ), + ], + ) + + with pytest.raises(zigpy.exceptions.DeliveryError) as excinfo: + await app.send_packet(packet) + + assert excinfo.value.status == t.Status.MAC_NO_ACK await app.shutdown() diff --git a/tests/application/test_startup.py b/tests/application/test_startup.py index 2c6147dc..02edcd15 100644 --- a/tests/application/test_startup.py +++ b/tests/application/test_startup.py @@ -1,4 +1,5 @@ import pytest +import voluptuous as vol from zigpy.exceptions import NetworkNotFormed import zigpy_znp.types as t @@ -9,6 +10,7 @@ from zigpy_znp.types.nvids import ExNvIds, OsalNvIds from ..conftest import ( + ALL_DEVICES, EMPTY_DEVICES, FORMED_DEVICES, CoroutineMock, @@ -266,3 +268,28 @@ async def test_zstack_build_id_empty(device, make_application, mocker): assert app._zstack_build_id == 0x00000000 await app.shutdown() + + +@pytest.mark.parametrize("device", [FormedLaunchpadCC26X2R1]) +async def test_deprecated_concurrency_config(device, make_application): + with pytest.raises(vol.MultipleInvalid) as exc: + app, znp_server = await make_application( + server_cls=device, + client_config={ + conf.CONF_ZNP_CONFIG: { + conf.CONF_MAX_CONCURRENT_REQUESTS: 16, + } + }, + ) + + assert "max_concurrent_requests" in str(exc.value) + + +@pytest.mark.parametrize("device", ALL_DEVICES) +async def test_reset_network_info(device, make_application): + app, znp_server = await make_application(server_cls=device) + await app.connect() + await app.reset_network_info() + + with pytest.raises(NetworkNotFormed): + await app.start_network() diff --git a/tests/application/test_zigpy_callbacks.py b/tests/application/test_zigpy_callbacks.py index d90219fb..265f44c6 100644 --- a/tests/application/test_zigpy_callbacks.py +++ b/tests/application/test_zigpy_callbacks.py @@ -1,29 +1,14 @@ import asyncio -import logging +from unittest.mock import MagicMock import pytest +import zigpy.types as zigpy_t import zigpy.zdo.types as zdo_t import zigpy_znp.types as t import zigpy_znp.commands as c -from ..conftest import FORMED_DEVICES, CoroutineMock, serialize_zdo_command - - -def awaitable_mock(*, return_value=None, side_effect=None): - assert (return_value or side_effect) and not (return_value and side_effect) - - mock_called = asyncio.get_running_loop().create_future() - - def side_effect_(*args, **kwargs): - mock_called.set_result((args, kwargs)) - - if return_value is not None: - return return_value - else: - raise side_effect - - return mock_called, CoroutineMock(side_effect=side_effect_) +from ..conftest import FORMED_DEVICES, serialize_zdo_command @pytest.mark.parametrize("device", FORMED_DEVICES) @@ -31,33 +16,12 @@ async def test_on_zdo_relays_message_callback(device, make_application, mocker): app, znp_server = await make_application(server_cls=device) await app.startup(auto_form=False) - device = mocker.Mock() - discover_called, discover_mock = awaitable_mock(return_value=device) - mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) - + app.handle_relays = MagicMock() znp_server.send(c.ZDO.SrcRtgInd.Callback(DstAddr=0x1234, Relays=[0x5678, 0xABCD])) - await discover_called - assert device.relays == [0x5678, 0xABCD] - - await app.shutdown() - - -@pytest.mark.parametrize("device", FORMED_DEVICES) -async def test_on_zdo_relays_message_callback_unknown( - device, make_application, mocker, caplog -): - app, znp_server = await make_application(server_cls=device) - await app.startup(auto_form=False) - - discover_called, discover_mock = awaitable_mock(side_effect=KeyError()) - mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) - - caplog.set_level(logging.WARNING) - znp_server.send(c.ZDO.SrcRtgInd.Callback(DstAddr=0x1234, Relays=[0x5678, 0xABCD])) + await asyncio.sleep(0.1) - await discover_called - assert "unknown device" in caplog.text + app.handle_relays.assert_called_once_with(nwk=0x1234, relays=[0x5678, 0xABCD]) await app.shutdown() @@ -106,10 +70,8 @@ async def test_on_zdo_device_announce_nwk_change(device, make_application, mocke app.handle_join.assert_called_once_with( nwk=new_nwk, ieee=device.ieee, parent_nwk=None ) - assert app.handle_message.call_count == 1 - assert app.handle_message.mock_calls[0][2]["cluster"] == zdo_t.ZDOCmd.Device_annce - # The device's NWK updated + # The device's NWK has been updated assert device.nwk == new_nwk await app.shutdown() @@ -140,15 +102,13 @@ async def test_on_af_message_callback(device, make_application, mocker): app, znp_server = await make_application(server_cls=device) await app.startup(auto_form=False) - device = mocker.Mock() - discover_called, discover_mock = awaitable_mock(return_value=device) - mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) mocker.patch.object(app, "handle_message") + device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xAABB) af_message = c.AF.IncomingMsg.Callback( - GroupId=1, + GroupId=0x0000, ClusterId=2, - SrcAddr=0xABCD, + SrcAddr=device.nwk, SrcEndpoint=4, DstEndpoint=1, # ZHA endpoint WasBroadcast=False, @@ -163,54 +123,173 @@ async def test_on_af_message_callback(device, make_application, mocker): # Normal message znp_server.send(af_message) + await asyncio.sleep(0.1) - await discover_called - device.radio_details.assert_called_once_with(lqi=19, rssi=None) app.handle_message.assert_called_once_with( - sender=device, profile=260, cluster=2, src_ep=4, dst_ep=1, message=b"test" + sender=device, + profile=260, + cluster=2, + src_ep=4, + dst_ep=1, + message=b"test", + dst_addressing=zigpy_t.AddrMode.NWK, ) - device.reset_mock() app.handle_message.reset_mock() # ZLL message - discover_called, discover_mock = awaitable_mock(return_value=device) - mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) - znp_server.send(af_message.replace(DstEndpoint=2)) + await asyncio.sleep(0.1) - await discover_called - device.radio_details.assert_called_once_with(lqi=19, rssi=None) app.handle_message.assert_called_once_with( - sender=device, profile=49246, cluster=2, src_ep=4, dst_ep=2, message=b"test" + sender=device, + profile=49246, + cluster=2, + src_ep=4, + dst_ep=2, + message=b"test", + dst_addressing=zigpy_t.AddrMode.NWK, ) - device.reset_mock() app.handle_message.reset_mock() # Message on an unknown endpoint (is this possible?) - discover_called, discover_mock = awaitable_mock(return_value=device) - mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) - znp_server.send(af_message.replace(DstEndpoint=3)) + await asyncio.sleep(0.1) - await discover_called - device.radio_details.assert_called_once_with(lqi=19, rssi=None) app.handle_message.assert_called_once_with( - sender=device, profile=260, cluster=2, src_ep=4, dst_ep=3, message=b"test" + sender=device, + profile=260, + cluster=2, + src_ep=4, + dst_ep=3, + message=b"test", + dst_addressing=zigpy_t.AddrMode.NWK, ) - device.reset_mock() app.handle_message.reset_mock() - # Message from an unknown device - discover_called, discover_mock = awaitable_mock(side_effect=KeyError()) - mocker.patch.object(app, "_get_or_discover_device", new=discover_mock) - znp_server.send(af_message) +@pytest.mark.parametrize("device", FORMED_DEVICES) +async def test_receive_zdo_broadcast(device, make_application, mocker): + app, znp_server = await make_application(server_cls=device) + await app.startup(auto_form=False) + + mocker.patch.object(app, "packet_received") - await discover_called - assert device.radio_details.call_count == 0 - assert app.handle_message.call_count == 0 + zdo_callback = c.ZDO.MsgCbIncoming.Callback( + Src=0x35D9, + IsBroadcast=t.Bool.true, + ClusterId=19, + SecurityUse=0, + TSN=129, + MacDst=0xFFFF, + Data=b"bogus", + ) + znp_server.send(zdo_callback) + await asyncio.sleep(0.1) + + assert app.packet_received.call_count == 1 + packet = app.packet_received.mock_calls[0].args[0] + assert packet.src == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, address=0x35D9 + ) + assert packet.src_ep == 0x00 + assert packet.dst == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Broadcast, + address=zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) + assert packet.dst_ep == 0x00 + assert packet.cluster_id == zdo_callback.ClusterId + assert packet.tsn == zdo_callback.TSN + assert packet.data.serialize() == bytes([zdo_callback.TSN]) + zdo_callback.Data + + await app.shutdown() + + +@pytest.mark.parametrize("device", FORMED_DEVICES) +async def test_receive_af_broadcast(device, make_application, mocker): + app, znp_server = await make_application(server_cls=device) + await app.startup(auto_form=False) + + mocker.patch.object(app, "packet_received") + + af_callback = c.AF.IncomingMsg.Callback( + GroupId=0x0000, + ClusterId=4096, + SrcAddr=0x1234, + SrcEndpoint=254, + DstEndpoint=2, + WasBroadcast=t.Bool.true, + LQI=90, + SecurityUse=t.Bool.false, + TimeStamp=4442962, + TSN=0, + Data=b"\x11\xA6\x00\x74\xB5\x7C\x00\x02\x5F", + MacSrcAddr=0x0000, + MsgResultRadius=0, + ) + znp_server.send(af_callback) + await asyncio.sleep(0.1) + + assert app.packet_received.call_count == 1 + packet = app.packet_received.mock_calls[0].args[0] + assert packet.src == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=0x1234, + ) + assert packet.src_ep == af_callback.SrcEndpoint + assert packet.dst == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Broadcast, + address=zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) + assert packet.dst_ep == af_callback.DstEndpoint + assert packet.cluster_id == af_callback.ClusterId + assert packet.tsn == af_callback.TSN + assert packet.lqi == af_callback.LQI + assert packet.data.serialize() == af_callback.Data + + await app.shutdown() + + +@pytest.mark.parametrize("device", FORMED_DEVICES) +async def test_receive_af_group(device, make_application, mocker): + app, znp_server = await make_application(server_cls=device) + await app.startup(auto_form=False) + + mocker.patch.object(app, "packet_received") + + af_callback = c.AF.IncomingMsg.Callback( + GroupId=0x1234, + ClusterId=4096, + SrcAddr=0x1234, + SrcEndpoint=254, + DstEndpoint=0, + WasBroadcast=t.Bool.false, + LQI=90, + SecurityUse=t.Bool.false, + TimeStamp=4442962, + TSN=0, + Data=b"\x11\xA6\x00\x74\xB5\x7C\x00\x02\x5F", + MacSrcAddr=0x0000, + MsgResultRadius=0, + ) + znp_server.send(af_callback) + await asyncio.sleep(0.1) + + assert app.packet_received.call_count == 1 + packet = app.packet_received.mock_calls[0].args[0] + assert packet.src == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=0x1234, + ) + assert packet.src_ep == af_callback.SrcEndpoint + assert packet.dst == zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Group, address=0x1234 + ) + assert packet.cluster_id == af_callback.ClusterId + assert packet.tsn == af_callback.TSN + assert packet.lqi == af_callback.LQI + assert packet.data.serialize() == af_callback.Data await app.shutdown() diff --git a/zigpy_znp/api.py b/zigpy_znp/api.py index bf0d0df5..6f95a310 100644 --- a/zigpy_znp/api.py +++ b/zigpy_znp/api.py @@ -13,6 +13,7 @@ import zigpy.state import async_timeout import zigpy.zdo.types as zdo_t +import zigpy.exceptions from zigpy.exceptions import NetworkNotFormed import zigpy_znp @@ -92,7 +93,7 @@ async def detect_zstack_version(self) -> float: except CommandNotRecognized: return 3.0 - async def load_network_info(self, *, load_devices=False): + async def _load_network_info(self, *, load_devices=False): """ Loads low-level network information from NVRAM. Loading key data greatly increases the runtime so it not enabled by default. @@ -100,27 +101,20 @@ async def load_network_info(self, *, load_devices=False): from zigpy_znp.znp import security - is_on_network = None - nib = None - - try: - nib = await self.nvram.osal_read(OsalNvIds.NIB, item_type=t.NIB) - except KeyError: - is_on_network = False - else: - is_on_network = nib.nwkLogicalChannel != 0 and nib.nwkKeyLoaded + nib = await self.nvram.osal_read(OsalNvIds.NIB, item_type=t.NIB) - if is_on_network and self.version >= 3.0: - # This NVRAM item is the very first thing initialized in `zgInit` - is_on_network = ( - await self.nvram.osal_read( - OsalNvIds.BDBNODEISONANETWORK, item_type=t.uint8_t - ) - == 1 - ) + if nib.nwkLogicalChannel == 0 or not nib.nwkKeyLoaded: + raise NetworkNotFormed() - if not is_on_network: - raise NetworkNotFormed("Device is not a part of a network") + # This NVRAM item is the very first thing initialized in `zgInit` + if ( + self.version >= 3.0 + and await self.nvram.osal_read( + OsalNvIds.BDBNODEISONANETWORK, item_type=t.uint8_t + ) + != 1 + ): + raise NetworkNotFormed() ieee = await self.nvram.osal_read(OsalNvIds.EXTADDR, item_type=t.EUI64) logical_type = await self.nvram.osal_read( @@ -224,6 +218,17 @@ async def load_network_info(self, *, load_devices=False): self.network_info = network_info self.node_info = node_info + async def load_network_info(self, *, load_devices=False): + """ + Loads low-level network information from NVRAM. + Loading key data greatly increases the runtime so it not enabled by default. + """ + + try: + await self._load_network_info(load_devices=load_devices) + except KeyError as e: + raise NetworkNotFormed() from e + async def start_network(self): # Both startup sequences end with the same callback started_as_coordinator = self.wait_for_response( @@ -264,7 +269,7 @@ async def start_network(self): c.app_config.BDBCommissioningStatus.FormationFailure, c.app_config.BDBCommissioningStatus.Success, ): - raise RuntimeError( + raise zigpy.exceptions.FormationFailure( f"Network formation failed: {commissioning_rsp}" ) else: @@ -283,10 +288,11 @@ async def start_network(self): async with async_timeout.timeout(STARTUP_TIMEOUT): await started_as_coordinator except asyncio.TimeoutError as e: - raise RuntimeError( - "Network formation refused, RF environment is likely too noisy." - " Temporarily unscrew the antenna or shield the coordinator" - " with metal until a network is formed." + raise zigpy.exceptions.FormationFailure( + "Network formation refused: there is too much RF interference." + " Make sure your coordinator is on a USB 2.0 extension cable and" + " away from any sources of interference, like USB 3.0 ports, SSDs," + " 2.4GHz routers, motherboards, etc." ) from e LOGGER.debug("Waiting for NIB to stabilize") @@ -308,18 +314,11 @@ async def start_network(self): await asyncio.sleep(1) - async def write_network_info( - self, - *, - network_info: zigpy.state.NetworkInfo, - node_info: zigpy.state.NodeInfo, - ) -> None: + async def reset_network_info(self): """ - Writes network and node state to NVRAM. + Resets node network information and leaves the current network. """ - from zigpy_znp.znp import security - # Delete any existing NV items that store formation state await self.nvram.osal_delete(OsalNvIds.HAS_CONFIGURED_ZSTACK1) await self.nvram.osal_delete(OsalNvIds.HAS_CONFIGURED_ZSTACK3) @@ -334,6 +333,19 @@ async def write_network_info( await self.reset() + async def write_network_info( + self, + *, + network_info: zigpy.state.NetworkInfo, + node_info: zigpy.state.NodeInfo, + ) -> None: + """ + Writes network and node state to NVRAM. + """ + from zigpy_znp.znp import security + + await self.reset_network_info() + # Form a network with completely random settings to get NVRAM to a known state for item, value in { OsalNvIds.PANID: t.uint16_t(0xFFFF), @@ -703,7 +715,7 @@ async def connect(self, *, test_port=True) -> None: self.close() raise - LOGGER.debug("Connected to %s at %s baud", self._uart.name, self._uart.baudrate) + LOGGER.debug("Connected to %s", self._uart.url) def connection_made(self) -> None: """ diff --git a/zigpy_znp/config.py b/zigpy_znp/config.py index 0ba472fe..e0ecb9ce 100644 --- a/zigpy_znp/config.py +++ b/zigpy_znp/config.py @@ -16,6 +16,7 @@ CONF_NWK_TC_ADDRESS, CONF_NWK_TC_LINK_KEY, CONF_NWK_EXTENDED_PAN_ID, + CONF_MAX_CONCURRENT_REQUESTS, cv_boolean, ) @@ -75,6 +76,18 @@ def bool_to_upper_str(value: typing.Any) -> str: return str(value).upper() +def cv_deprecated(message: str) -> typing.Callable[[typing.Any], None]: + """ + Raises a deprecation exception when a value is passed in. + """ + + def validator(value: typing.Any) -> None: + if value is not None: + raise vol.Invalid(message) + + return validator + + CONF_ZNP_CONFIG = "znp_config" CONF_TX_POWER = "tx_power" CONF_LED_MODE = "led_mode" @@ -82,7 +95,6 @@ def bool_to_upper_str(value: typing.Any) -> str: CONF_SREQ_TIMEOUT = "sync_request_timeout" CONF_ARSP_TIMEOUT = "async_response_timeout" CONF_AUTO_RECONNECT_RETRY_DELAY = "auto_reconnect_retry_delay" -CONF_MAX_CONCURRENT_REQUESTS = "max_concurrent_requests" CONF_CONNECT_RTS_STATES = "connect_rts_pin_states" CONF_CONNECT_DTR_STATES = "connect_dtr_pin_states" @@ -104,8 +116,11 @@ def bool_to_upper_str(value: typing.Any) -> str: vol.Optional(CONF_LED_MODE, default=LEDMode.OFF): vol.Any( None, EnumValue(LEDMode, transformer=bool_to_upper_str) ), - vol.Optional(CONF_MAX_CONCURRENT_REQUESTS, default="auto"): vol.Any( - "auto", VolPositiveNumber + vol.Optional(CONF_MAX_CONCURRENT_REQUESTS, default=None): ( + cv_deprecated( + "`zigpy_config: znp_config: max_concurrent_requests` has" + " been renamed to `zigpy_config: max_concurrent_requests`." + ) ), vol.Optional( CONF_CONNECT_RTS_STATES, default=[False, True, False] diff --git a/zigpy_znp/types/__init__.py b/zigpy_znp/types/__init__.py index 8262656c..77573f39 100644 --- a/zigpy_znp/types/__init__.py +++ b/zigpy_znp/types/__init__.py @@ -1,3 +1,5 @@ +from .zigpy_types import * # noqa: F401, F403 # isort:skip + from .basic import * # noqa: F401, F403 from .named import * # noqa: F401, F403 from .cstruct import * # noqa: F401, F403 diff --git a/zigpy_znp/types/named.py b/zigpy_znp/types/named.py index d622b234..1c05e514 100644 --- a/zigpy_znp/types/named.py +++ b/zigpy_znp/types/named.py @@ -5,21 +5,10 @@ import logging import dataclasses -from zigpy.types import ( # noqa: F401 - NWK, - EUI64, - Bool, - PanId, - Struct, - KeyData, - Channels, - ClusterId, - ExtendedPanId, - CharacterString, -) +import zigpy.types from zigpy.zdo.types import Status as ZDOStatus # noqa: F401 -from . import basic +from . import basic, zigpy_types LOGGER = logging.getLogger(__name__) @@ -55,21 +44,36 @@ def __new__(cls, mode=None, address=None): return instance + @classmethod + def from_zigpy_type( + cls, zigpy_addr: zigpy.types.AddrModeAddress + ) -> AddrModeAddress: + return cls( + mode=AddrMode[zigpy_addr.addr_mode.name], + address=zigpy_addr.address, + ) + + def as_zigpy_type(self) -> zigpy.types.AddrModeAddress: + return zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode[self.mode.name], + address=self.address, + ) + def _get_address_type(self): return { - AddrMode.NWK: NWK, - AddrMode.Group: NWK, - AddrMode.Broadcast: NWK, - AddrMode.IEEE: EUI64, + AddrMode.NWK: zigpy_types.NWK, + AddrMode.Group: zigpy_types.NWK, + AddrMode.Broadcast: zigpy_types.NWK, + AddrMode.IEEE: zigpy_types.EUI64, }[self.mode] @classmethod def deserialize(cls, data: bytes) -> tuple[AddrModeAddress, bytes]: mode, data = AddrMode.deserialize(data) - address, data = EUI64.deserialize(data) + address, data = zigpy_types.EUI64.deserialize(data) if mode != AddrMode.IEEE: - address, _ = NWK.deserialize(address.serialize()) + address, _ = zigpy_types.NWK.deserialize(address.serialize()) return cls(mode=mode, address=address), data @@ -386,11 +390,13 @@ class DeviceTypeCapabilities(basic.enum_flag_uint8): EndDevice = 1 << 2 -class ClusterIdList(basic.LVList, item_type=ClusterId, length_type=basic.uint8_t): +class ClusterIdList( + basic.LVList, item_type=zigpy_types.ClusterId, length_type=basic.uint8_t +): pass -class NWKList(basic.LVList, item_type=NWK, length_type=basic.uint8_t): +class NWKList(basic.LVList, item_type=zigpy_types.NWK, length_type=basic.uint8_t): pass diff --git a/zigpy_znp/types/structs.py b/zigpy_znp/types/structs.py index 4992ce39..1b73ac9a 100644 --- a/zigpy_znp/types/structs.py +++ b/zigpy_znp/types/structs.py @@ -1,9 +1,9 @@ -from . import basic, named, cstruct +from . import basic, named, cstruct, zigpy_types class NwkKeyDesc(cstruct.CStruct): KeySeqNum: basic.uint8_t - Key: named.KeyData + Key: zigpy_types.KeyData class NwkState(basic.enum_uint8): @@ -46,17 +46,17 @@ class NIB(cstruct.CStruct): RouteDiscoveryTime: basic.uint8_t RouteExpiryTime: basic.uint8_t - nwkDevAddress: named.NWK + nwkDevAddress: zigpy_types.NWK nwkLogicalChannel: basic.uint8_t - nwkCoordAddress: named.NWK - nwkCoordExtAddress: named.EUI64 + nwkCoordAddress: zigpy_types.NWK + nwkCoordExtAddress: zigpy_types.EUI64 nwkPanId: basic.uint16_t # XXX: this is really a uint16_t but we pad with zeroes so it works out in the end nwkState: NwkState - channelList: named.Channels + channelList: zigpy_types.Channels beaconOrder: basic.uint8_t superFrameOrder: basic.uint8_t @@ -68,9 +68,9 @@ class NIB(cstruct.CStruct): nodeDepth: basic.uint8_t - extendedPANID: named.EUI64 + extendedPANID: zigpy_types.EUI64 - nwkKeyLoaded: named.Bool + nwkKeyLoaded: zigpy_types.Bool spare1: NwkKeyDesc spare2: NwkKeyDesc @@ -80,13 +80,13 @@ class NIB(cstruct.CStruct): nwkLinkStatusPeriod: basic.uint8_t nwkRouterAgeLimit: basic.uint8_t - nwkUseMultiCast: named.Bool - nwkIsConcentrator: named.Bool + nwkUseMultiCast: zigpy_types.Bool + nwkIsConcentrator: zigpy_types.Bool nwkConcentratorDiscoveryTime: basic.uint8_t nwkConcentratorRadius: basic.uint8_t nwkAllFresh: basic.uint8_t - nwkManagerAddr: named.NWK + nwkManagerAddr: zigpy_types.NWK nwkTotalTransmissions: basic.uint16_t nwkUpdateId: basic.uint8_t @@ -94,8 +94,8 @@ class NIB(cstruct.CStruct): class Beacon(cstruct.CStruct): """Beacon message.""" - Src: named.NWK - PanId: named.PanId + Src: zigpy_types.NWK + PanId: zigpy_types.PanId Channel: basic.uint8_t PermitJoining: basic.uint8_t RouterCapacity: basic.uint8_t @@ -105,12 +105,12 @@ class Beacon(cstruct.CStruct): LQI: basic.uint8_t Depth: basic.uint8_t UpdateId: basic.uint8_t - ExtendedPanId: named.ExtendedPanId + ExtendedPanId: zigpy_types.ExtendedPanId class TCLinkKey(cstruct.CStruct): - ExtAddr: named.EUI64 - Key: named.KeyData + ExtAddr: zigpy_types.EUI64 + Key: zigpy_types.KeyData TxFrameCounter: basic.uint32_t RxFrameCounter: basic.uint32_t @@ -163,7 +163,7 @@ class TCLKDevEntry(cstruct.CStruct): txFrmCntr: basic.uint32_t rxFrmCntr: basic.uint32_t - extAddr: named.EUI64 + extAddr: zigpy_types.EUI64 keyAttributes: KeyAttributes keyType: KeyType @@ -174,7 +174,7 @@ class TCLKDevEntry(cstruct.CStruct): class NwkSecMaterialDesc(cstruct.CStruct): FrameCounter: basic.uint32_t - ExtendedPanID: named.EUI64 + ExtendedPanID: zigpy_types.EUI64 class AddrMgrUserType(basic.enum_flag_uint8): @@ -187,8 +187,8 @@ class AddrMgrUserType(basic.enum_flag_uint8): class AddrMgrEntry(cstruct.CStruct): type: AddrMgrUserType - nwkAddr: named.NWK - extAddr: named.EUI64 + nwkAddr: zigpy_types.NWK + extAddr: zigpy_types.EUI64 class AddressManagerTable(basic.CompleteList, item_type=AddrMgrEntry): @@ -202,7 +202,7 @@ class AuthenticationOption(basic.enum_uint8): class APSKeyDataTableEntry(cstruct.CStruct): - Key: named.KeyData + Key: zigpy_types.KeyData TxFrameCounter: basic.uint32_t RxFrameCounter: basic.uint32_t @@ -251,7 +251,7 @@ class BaseAssociatedDevice(cstruct.CStruct): linkInfo: LinkInfo endDev: AgingEndDevice timeoutCounter: basic.uint32_t - keepaliveRcv: named.Bool + keepaliveRcv: zigpy_types.Bool class AssociatedDeviceZStack1(BaseAssociatedDevice): diff --git a/zigpy_znp/types/zigpy_types.py b/zigpy_znp/types/zigpy_types.py new file mode 100644 index 00000000..91cfb100 --- /dev/null +++ b/zigpy_znp/types/zigpy_types.py @@ -0,0 +1,13 @@ +from zigpy.types import ( # noqa: F401 + NWK, + EUI64, + Bool, + PanId, + Struct, + KeyData, + Channels, + ClusterId, + ExtendedPanId, + CharacterString, + SerializableBytes, +) diff --git a/zigpy_znp/uart.py b/zigpy_znp/uart.py index 5571e60d..ec81a070 100644 --- a/zigpy_znp/uart.py +++ b/zigpy_znp/uart.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import typing import asyncio import logging -import warnings -import serial +import zigpy.serial import zigpy_znp.config as conf import zigpy_znp.frames as frames @@ -11,16 +12,6 @@ from zigpy_znp.types import Bytes from zigpy_znp.exceptions import InvalidFrame -with warnings.catch_warnings(): - warnings.filterwarnings( - action="ignore", - module="serial_asyncio", - message='"@coroutine" decorator is deprecated', - category=DeprecationWarning, - ) - import serial_asyncio # noqa: E402 - - LOGGER = logging.getLogger(__name__) @@ -29,12 +20,14 @@ class BufferTooShort(Exception): class ZnpMtProtocol(asyncio.Protocol): - def __init__(self, api): + def __init__(self, api, *, url: str | None = None) -> None: self._buffer = bytearray() self._api = api self._transport = None self._connected_event = asyncio.Event() + self.url = url + def close(self) -> None: """Closes the port.""" @@ -47,7 +40,7 @@ def close(self) -> None: self._transport.close() self._transport = None - def connection_lost(self, exc: typing.Optional[Exception]) -> None: + def connection_lost(self, exc: Exception | None) -> None: """Connection lost.""" if exc is not None: @@ -56,10 +49,10 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None: if self._api is not None: self._api.connection_lost(exc) - def connection_made(self, transport: serial_asyncio.SerialTransport) -> None: + def connection_made(self, transport: asyncio.BaseTransport) -> None: """Opened serial port.""" self._transport = transport - LOGGER.debug("Opened %s serial port", transport.serial.name) + LOGGER.debug("Opened %s serial port", self.url) self._connected_event.set() @@ -98,19 +91,15 @@ def write(self, data: bytes) -> None: self._transport.write(data) def set_dtr_rts(self, *, dtr: bool, rts: bool) -> None: + # TCP transport does not have DTR or RTS pins + if not hasattr(self._transport, "serial"): + return + LOGGER.debug("Setting serial pin states: DTR=%s, RTS=%s", dtr, rts) self._transport.serial.dtr = dtr self._transport.serial.rts = rts - @property - def name(self) -> str: - return self._transport.serial.name - - @property - def baudrate(self) -> int: - return self._transport.serial.baudrate - def _extract_frames(self) -> typing.Iterator[frames.TransportFrame]: """Extracts frames from the buffer until it is exhausted.""" while True: @@ -163,8 +152,7 @@ def _extract_frame(self) -> frames.TransportFrame: def __repr__(self) -> str: return ( f"<" - f"{type(self).__name__} connected to {self.name!r}" - f" at {self.baudrate} baud" + f"{type(self).__name__} connected to {self.url!r}" f" (api: {self._api})" f">" ) @@ -179,13 +167,11 @@ async def connect(config: conf.ConfigType, api) -> ZnpMtProtocol: LOGGER.debug("Connecting to %s at %s baud", port, baudrate) - _, protocol = await serial_asyncio.create_serial_connection( + _, protocol = await zigpy.serial.create_serial_connection( loop=loop, - protocol_factory=lambda: ZnpMtProtocol(api), + protocol_factory=lambda: ZnpMtProtocol(api, url=port), url=port, baudrate=baudrate, - parity=serial.PARITY_NONE, - stopbits=serial.STOPBITS_ONE, xonxoff=(flow_control == "software"), rtscts=(flow_control == "hardware"), ) diff --git a/zigpy_znp/zigbee/application.py b/zigpy_znp/zigbee/application.py index aa4a6af3..e6ff2f4d 100644 --- a/zigpy_znp/zigbee/application.py +++ b/zigpy_znp/zigbee/application.py @@ -1,11 +1,9 @@ from __future__ import annotations import os -import time import asyncio import logging import itertools -import contextlib import zigpy.zcl import zigpy.zdo @@ -18,7 +16,6 @@ import zigpy.profiles import zigpy.zdo.types as zdo_t import zigpy.application -from zigpy.types import deserialize as list_deserialize from zigpy.exceptions import DeliveryError import zigpy_znp.const as const @@ -39,7 +36,7 @@ PROBE_TIMEOUT = 5 STARTUP_TIMEOUT = 5 DATA_CONFIRM_TIMEOUT = 8 -IEEE_ADDR_DISCOVERY_TIMEOUT = 5 +EXTENDED_DATA_CONFIRM_TIMEOUT = 30 DEVICE_JOIN_MAX_DELAY = 5 WATCHDOG_PERIOD = 30 @@ -87,8 +84,6 @@ def __init__(self, config: conf.ConfigType): self._watchdog_task.cancel() self._version_rsp = None - self._concurrent_requests_semaphore = None - self._currently_waiting_requests = 0 self._join_announce_tasks: dict[t.EUI64, asyncio.TimerHandle] = {} @@ -166,6 +161,13 @@ async def load_network_info(self, *, load_devices=False) -> None: self.state.node_info = self._znp.node_info self.state.network_info = self._znp.network_info + async def reset_network_info(self) -> None: + """ + Resets node network information and leaves the current network. + """ + + await self._znp.reset_network_info() + async def write_network_info( self, *, @@ -218,13 +220,21 @@ async def start_network(self, *, read_only=False): ) await self._device.schedule_initialize() + # Deprecate ZNP-specific config + if self.znp_config[conf.CONF_MAX_CONCURRENT_REQUESTS] is not None: + raise RuntimeError( + "`zigpy_config:znp_config:max_concurrent_requests` is deprecated," + " move this key up to `zigpy_config:max_concurrent_requests` instead." + ) + # Now that we know what device we are, set the max concurrent requests - if self.znp_config[conf.CONF_MAX_CONCURRENT_REQUESTS] == "auto": + if self._config[conf.CONF_MAX_CONCURRENT_REQUESTS] is None: max_concurrent_requests = 16 if self._znp.nvram.align_structs else 2 else: - max_concurrent_requests = self.znp_config[conf.CONF_MAX_CONCURRENT_REQUESTS] + max_concurrent_requests = self._config[conf.CONF_MAX_CONCURRENT_REQUESTS] - self._concurrent_requests_semaphore = asyncio.Semaphore(max_concurrent_requests) + # Update the max value of the concurrent request semaphore at runtime + self._concurrent_requests_semaphore.max_value = max_concurrent_requests if self.state.network_info.network_key.key == const.Z2M_NETWORK_KEY: LOGGER.warning( @@ -271,93 +281,6 @@ def get_dst_address(self, cluster: zigpy.zcl.Cluster) -> zdo_t.MultiAddress: return dst_addr - @zigpy.util.retryable_request - async def request( - self, - device, - profile, - cluster, - src_ep, - dst_ep, - sequence, - data, - expect_reply=True, - use_ieee=False, - ) -> tuple[t.Status, str]: - tx_options = c.af.TransmitOptions.SUPPRESS_ROUTE_DISC_NETWORK - - if expect_reply: - tx_options |= c.af.TransmitOptions.ACK_REQUEST - - if use_ieee: - destination = t.AddrModeAddress(mode=t.AddrMode.IEEE, address=device.ieee) - else: - destination = t.AddrModeAddress(mode=t.AddrMode.NWK, address=device.nwk) - - return await self._send_request( - dst_addr=destination, - dst_ep=dst_ep, - src_ep=src_ep, - profile=profile, - cluster=cluster, - sequence=sequence, - options=tx_options, - radius=30, - data=data, - ) - - async def broadcast( - self, - profile, - cluster, - src_ep, - dst_ep, - grpid, - radius, - sequence, - data, - broadcast_address=zigpy.types.BroadcastAddress.RX_ON_WHEN_IDLE, - ) -> tuple[t.Status, str]: - assert grpid == 0 - - return await self._send_request( - dst_addr=t.AddrModeAddress( - mode=t.AddrMode.Broadcast, address=broadcast_address - ), - dst_ep=dst_ep, - src_ep=src_ep, - profile=profile, - cluster=cluster, - sequence=sequence, - options=c.af.TransmitOptions.NONE, - radius=radius, - data=data, - ) - - async def mrequest( - self, - group_id, - profile, - cluster, - src_ep, - sequence, - data, - *, - hops=0, - non_member_radius=3, - ) -> tuple[t.Status, str]: - return await self._send_request( - dst_addr=t.AddrModeAddress(mode=t.AddrMode.Group, address=group_id), - dst_ep=src_ep, - src_ep=src_ep, - profile=profile, - cluster=cluster, - sequence=sequence, - options=c.af.TransmitOptions.NONE, - radius=hops, - data=data, - ) - async def permit(self, time_s: int = 60, node: t.EUI64 = None): """ Permit joining the network via a specific node or via all router nodes. @@ -526,33 +449,54 @@ async def on_zdo_message(self, msg: c.ZDO.MsgCbIncoming.Callback) -> None: LOGGER.debug("Ignoring loopback ZDO request") return - message = t.uint8_t(msg.TSN).serialize() + msg.Data - hdr, data = zdo_t.ZDOHeader.deserialize(msg.ClusterId, message) - names, types = zdo_t.CLUSTERS[msg.ClusterId] - args, data = list_deserialize(data, types) - kwargs = dict(zip(names, args)) - - if msg.ClusterId == zdo_t.ZDOCmd.Device_annce: - self.on_zdo_device_announce(*args) - device = self.get_device(ieee=kwargs["IEEEAddr"]) + if msg.IsBroadcast: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) else: - try: - device = await self._get_or_discover_device(nwk=msg.Src) - except KeyError: - LOGGER.warning( - "Received a ZDO message from an unknown device: %s", msg.Src - ) - return + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=self.state.node_info.nwk, + ) - self.handle_message( - sender=device, - profile=ZDO_PROFILE, - cluster=msg.ClusterId, + packet = zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=msg.Src, + ), src_ep=ZDO_ENDPOINT, + dst=dst, dst_ep=ZDO_ENDPOINT, - message=message, + tsn=msg.TSN, + profile_id=ZDO_PROFILE, + cluster_id=msg.ClusterId, + data=t.SerializableBytes(t.uint8_t(msg.TSN).serialize() + msg.Data), + tx_options=( + zigpy.types.TransmitOptions.APS_Encryption + if msg.SecurityUse + else zigpy.types.TransmitOptions.NONE + ), ) + # Peek into the ZDO packet so that we can cancel our existing TC join timer when + # a device actually sends an announcemement + try: + zdo_hdr, zdo_args = self._device.zdo.deserialize( + cluster_id=packet.cluster_id, data=packet.data.serialize() + ) + except Exception: + LOGGER.warning("Failed to deserialize ZDO packet", exc_info=True) + else: + if zdo_hdr.command_id == zdo_t.ZDOCmd.Device_annce: + _, ieee, _ = zdo_args + + # Cancel any existing TC join timers so we don't double announce + if ieee in self._join_announce_tasks: + self._join_announce_tasks.pop(ieee).cancel() + + self.packet_received(packet) + def on_zdo_permit_join_message(self, msg: c.ZDO.PermitJoinInd.Callback) -> None: """ Coordinator join status change message. Only sent with Z-Stack 1.2 and 3.0. @@ -568,35 +512,7 @@ async def on_zdo_relays_message(self, msg: c.ZDO.SrcRtgInd.Callback) -> None: ZDO source routing message callback """ - try: - device = await self._get_or_discover_device(nwk=msg.DstAddr) - except KeyError: - LOGGER.warning( - "Received a ZDO message from an unknown device: %s", msg.DstAddr - ) - return - - # `relays` is a property with a setter that emits an event - device.relays = msg.Relays - - def on_zdo_device_announce(self, nwk: t.NWK, ieee: t.EUI64, capabilities) -> None: - """ - ZDO end device announcement callback - """ - - LOGGER.info( - "ZDO device announce: nwk=%s, ieee=%s, capabilities=%s", - nwk, - ieee, - capabilities, - ) - - # Cancel an existing join timer so we don't double announce - if ieee in self._join_announce_tasks: - self._join_announce_tasks.pop(ieee).cancel() - - # Sometimes devices change their NWK when announcing so re-join it. - self.handle_join(nwk=nwk, ieee=ieee, parent_nwk=None) + self.handle_relays(nwk=msg.DstAddr, relays=msg.Relays) def on_zdo_tc_device_join(self, msg: c.ZDO.TCDevInd.Callback) -> None: """ @@ -646,30 +562,50 @@ async def on_af_message(self, msg: c.AF.IncomingMsg.Callback) -> None: Handler for all non-ZDO messages. """ - try: - device = await self._get_or_discover_device(nwk=msg.SrcAddr) - except KeyError: - LOGGER.warning( - "Received an AF message from an unknown device: %s", msg.SrcAddr - ) - return - - device.radio_details(lqi=msg.LQI, rssi=None) - # XXX: Is it possible to receive messages on non-assigned endpoints? - if msg.DstEndpoint in self._device.endpoints: + if msg.DstEndpoint != 0 and msg.DstEndpoint in self._device.endpoints: profile = self._device.endpoints[msg.DstEndpoint].profile_id else: LOGGER.warning("Received a message on an unregistered endpoint: %s", msg) profile = zigpy.profiles.zha.PROFILE_ID - self.handle_message( - sender=device, - profile=profile, - cluster=msg.ClusterId, - src_ep=msg.SrcEndpoint, - dst_ep=msg.DstEndpoint, - message=msg.Data, + if msg.WasBroadcast: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) + elif msg.GroupId != 0x0000: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, + address=msg.GroupId, + ) + else: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=self.state.node_info.nwk, + ) + + self.packet_received( + zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, address=msg.SrcAddr + ), + src_ep=msg.SrcEndpoint, + dst=dst, + dst_ep=msg.DstEndpoint, + tsn=msg.TSN, + profile_id=profile, + cluster_id=msg.ClusterId, + data=t.SerializableBytes(bytes(msg.Data)), + tx_options=( + zigpy.types.TransmitOptions.APS_Encryption + if msg.SecurityUse + else zigpy.types.TransmitOptions.NONE + ), + radius=msg.MsgResultRadius, + lqi=msg.LQI, + rssi=None, + ) ) #################### @@ -723,64 +659,6 @@ async def _watchdog_loop(self): return - @combine_concurrent_calls - async def _get_or_discover_device(self, nwk: t.NWK) -> zigpy.device.Device: - """ - Finds a device by its NWK address. If a device does not exist in the zigpy - database, attempt to look up its new NWK address. If it does not exist in the - zigpy database, treat the device as a new join. - """ - - try: - return self.get_device(nwk=nwk) - except KeyError: - pass - - LOGGER.debug("Device with NWK 0x%04X not in database", nwk) - - try: - async with async_timeout.timeout(IEEE_ADDR_DISCOVERY_TIMEOUT): - ieee_addr_rsp = await self._znp.request_callback_rsp( - request=c.ZDO.IEEEAddrReq.Req( - NWK=nwk, - RequestType=c.zdo.AddrRequestType.SINGLE, - StartIndex=0, - ), - RspStatus=t.Status.SUCCESS, - callback=c.ZDO.IEEEAddrRsp.Callback(partial=True, NWK=nwk), - ) - except asyncio.TimeoutError: - raise KeyError(f"Unknown device: 0x{nwk:04X}") - else: - ieee = ieee_addr_rsp.IEEE - - try: - device = self.get_device(ieee=ieee) - except KeyError: - LOGGER.debug("Treating unknown device as a new join") - self.handle_join(nwk=nwk, ieee=ieee, parent_nwk=None) - - return self.get_device(ieee=ieee) - - # The `Device` object could have been updated while this coroutine is running - if device.nwk == nwk: - return device - - LOGGER.warning( - "Device %s changed its NWK from %s to %s", - device.ieee, - device.nwk, - nwk, - ) - - # Notify zigpy of the change - self.handle_join(nwk=nwk, ieee=ieee, parent_nwk=None) - - # `handle_join` will update the NWK - assert device.nwk == nwk - - return device - async def _set_led_mode(self, *, led: t.uint8_t, mode: c.util.LEDMode) -> None: """ Attempts to set the provided LED's mode. A Z-Stack bug causes the underlying @@ -836,42 +714,6 @@ async def _write_stack_settings(self) -> bool: return any_changed - @contextlib.asynccontextmanager - async def _limit_concurrency(self): - """ - Async context manager that prevents devices from being overwhelmed by requests. - Mainly a thin wrapper around `asyncio.Semaphore` that logs when it has to wait. - """ - - # Allow sending some requests before the application has fully started - if self._concurrent_requests_semaphore is None: - yield - return - - start_time = time.time() - was_locked = self._concurrent_requests_semaphore.locked() - - if was_locked: - self._currently_waiting_requests += 1 - LOGGER.debug( - "Max concurrency reached, delaying requests (%s enqueued)", - self._currently_waiting_requests, - ) - - try: - async with self._concurrent_requests_semaphore: - if was_locked: - LOGGER.debug( - "Previously delayed request is now running, " - "delayed by %0.2f seconds", - time.time() - start_time, - ) - - yield - finally: - if was_locked: - self._currently_waiting_requests -= 1 - async def _reconnect(self) -> None: """ Endlessly tries to reconnect to the currently configured radio. @@ -954,6 +796,7 @@ async def _send_request_raw( data, *, relays=None, + extended_timeout=False, ): """ Used by `request`/`mrequest`/`broadcast` to send a request. @@ -967,7 +810,7 @@ async def _send_request_raw( if relays is None: request = c.AF.DataRequestExt.Req( DstAddrModeAddress=dst_addr, - DstEndpoint=dst_ep, + DstEndpoint=dst_ep or 0, DstPanId=0x0000, SrcEndpoint=src_ep, ClusterId=cluster, @@ -979,7 +822,7 @@ async def _send_request_raw( else: request = c.AF.DataRequestSrcRtg.Req( DstAddr=dst_addr.address, - DstEndpoint=dst_ep, + DstEndpoint=dst_ep or 0, SrcEndpoint=src_ep, ClusterId=cluster, TSN=sequence, @@ -1030,13 +873,24 @@ async def _send_request_raw( dst_addr.mode == t.AddrMode.NWK and dst_addr.address == self._device.nwk ): - self.handle_message( - sender=self._device, - profile=profile, - cluster=cluster, - src_ep=src_ep, - dst_ep=dst_ep, - message=data, + self.packet_received( + zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=self._device.nwk, + ), + src_ep=src_ep, + dst=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=self._device.nwk, + ), + dst_ep=dst_ep, + tsn=sequence, + profile_id=profile, + cluster_id=cluster, + data=t.SerializableBytes(data), + radius=radius, + ) ) if dst_ep == ZDO_ENDPOINT or dst_addr.mode == t.AddrMode.Broadcast: @@ -1045,7 +899,11 @@ async def _send_request_raw( request=request, RspStatus=t.Status.SUCCESS ) else: - async with async_timeout.timeout(DATA_CONFIRM_TIMEOUT): + async with async_timeout.timeout( + EXTENDED_DATA_CONFIRM_TIMEOUT + if extended_timeout + else DATA_CONFIRM_TIMEOUT + ): # Shield from cancellation to prevent requests that time out in higher # layers from missing expected responses response = await asyncio.shield( @@ -1095,41 +953,40 @@ async def _discover_route(self, nwk: t.NWK) -> None: await asyncio.sleep(0.1 * 13) - async def _send_request( - self, - dst_addr, - dst_ep, - src_ep, - profile, - cluster, - sequence, - options, - radius, - data, - ) -> tuple[t.Status, str]: + async def send_packet(self, packet: zigpy.types.ZigbeePacket) -> None: """ Fault-tolerant wrapper around `_send_request_raw` that transparently attempts to repair routes and contact the device through other methods when Z-Stack errors are encountered. """ + LOGGER.debug("Sending packet %r", packet) + + options = c.af.TransmitOptions.SUPPRESS_ROUTE_DISC_NETWORK + + if zigpy.types.TransmitOptions.ACK in packet.tx_options: + options |= c.af.TransmitOptions.ACK_REQUEST + + if zigpy.types.TransmitOptions.APS_Encryption in packet.tx_options: + options |= c.af.TransmitOptions.ENABLE_SECURITY + try: - if dst_addr.mode == t.AddrMode.NWK: - device = self.get_device(nwk=dst_addr.address) - elif dst_addr.mode == t.AddrMode.IEEE: - device = self.get_device(ieee=dst_addr.address) - else: - device = None - except KeyError: + device = self.get_device_with_address(packet.dst) + except (KeyError, ValueError): # Sometimes a request is sent to a device not in the database. This should # work, the device object is only for recovery. device = None + dst_addr = t.AddrModeAddress.from_zigpy_type(packet.dst) + status = None response = None association = None force_relays = None + if packet.source_route is not None: + force_relays = packet.source_route + tried_assoc_remove = False tried_route_discovery = False tried_last_good_route = False @@ -1145,7 +1002,7 @@ async def _send_request( # indicating that a route is missing so we need to explicitly # check for one. if ( - dst_ep == ZDO_ENDPOINT + packet.dst_ep == ZDO_ENDPOINT and dst_addr.mode == t.AddrMode.NWK and dst_addr.address != self.state.node_info.nwk ): @@ -1165,16 +1022,18 @@ async def _send_request( response = await self._send_request_raw( dst_addr=dst_addr, - dst_ep=dst_ep, - src_ep=src_ep, - profile=profile, - cluster=cluster, - sequence=sequence, + dst_ep=packet.dst_ep, + src_ep=packet.src_ep, + profile=packet.profile_id, + cluster=packet.cluster_id, + sequence=packet.tsn, options=options, - radius=radius, - data=data, + radius=packet.radius or 0, + data=packet.data.serialize(), relays=force_relays, + extended_timeout=packet.extended_timeout, ) + status = response.Status break except InvalidCommandResponse as e: status = e.response.Status @@ -1292,7 +1151,8 @@ async def _send_request( else: raise DeliveryError( f"Request failed after {REQUEST_MAX_RETRIES} attempts:" - f" {status!r}" + f" {status!r}", + status=status, ) finally: # We *must* re-add the device association if we previously removed it but @@ -1306,8 +1166,3 @@ async def _send_request( NodeRelation=association.Device.nodeRelation, ) ) - - if response.Status != t.Status.SUCCESS: - return response.Status, "Failed to send request" - - return response.Status, "Sent request successfully"