diff --git a/tests/common.py b/tests/common.py index 11a944af9..b8d48c36b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -19,10 +19,15 @@ _LOGGER = logging.getLogger(__name__) -def patch_cluster(cluster: zigpy.zcl.Cluster) -> None: +def patch_cluster( + cluster: zigpy.zcl.Cluster, unsupported_attr: set[str] | None = None +) -> None: """Patch a cluster for testing.""" cluster.PLUGGED_ATTR_READS = {} + if unsupported_attr is None: + unsupported_attr = set() + async def _read_attribute_raw(attributes: Any, *args: Any, **kwargs: Any) -> Any: result = [] for attr_id in attributes: @@ -44,6 +49,19 @@ async def _read_attribute_raw(attributes: Any, *args: Any, **kwargs: Any) -> Any result.append(zcl_f.ReadAttributeRecord(attr_id, zcl_f.Status.FAILURE)) return (result,) + async def _discover_attributes(*args: Any, **kwargs: Any) -> Any: + schema = zcl_f.GENERAL_COMMANDS[ + zcl_f.GeneralCommand.Discover_Attributes_rsp + ].schema + records = [ + zcl_f.DiscoverAttributesResponseRecord.from_dict( + {"attrid": attr.id, "datatype": 0} + ) + for attr in cluster.attributes.values() + if attr.name not in unsupported_attr + ] + return schema(discovery_complete=t.Bool.true, attribute_info=records) + cluster.bind = AsyncMock(return_value=[0]) cluster.configure_reporting = AsyncMock( return_value=[ @@ -61,6 +79,7 @@ async def _read_attribute_raw(attributes: Any, *args: Any, **kwargs: Any) -> Any cluster._write_attributes = AsyncMock( return_value=[zcl_f.WriteAttributesResponse.deserialize(b"\x00")[0]] ) + cluster.discover_attributes = AsyncMock(side_effect=_discover_attributes) if cluster.cluster_id == 4: cluster.add = AsyncMock(return_value=[0]) if cluster.cluster_id == 0x1000: diff --git a/tests/conftest.py b/tests/conftest.py index 4e4c2c22f..9ffc6d18f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -436,6 +436,7 @@ def _mock_dev( patch_cluster: bool = True, quirk: Optional[Callable] = None, attributes: dict[int, dict[str, dict[str, Any]]] = None, + unsupported_attr: dict[int, set[str]] | None = None, ) -> zigpy.device.Device: """Make a fake device using the specified cluster classes.""" device = zigpy.device.Device( @@ -464,12 +465,16 @@ def _mock_dev( device = get_device(device) if patch_cluster: + if unsupported_attr is None: + unsupported_attr = {} for endpoint in (ep for epid, ep in device.endpoints.items() if epid): endpoint.request = AsyncMock(return_value=[0]) for cluster in itertools.chain( endpoint.in_clusters.values(), endpoint.out_clusters.values() ): - common.patch_cluster(cluster) + common.patch_cluster( + cluster, unsupported_attr.get(endpoint.endpoint_id, set()) + ) if attributes is not None: for ep_id, clusters in attributes.items(): diff --git a/tests/test_discover.py b/tests/test_discover.py index 7e4a87fb5..16e5d4a82 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -6,7 +6,7 @@ import re from typing import Any, Final from unittest import mock -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock import pytest from zhaquirks.ikea import PowerConfig1CRCluster, ScenesCluster @@ -116,10 +116,6 @@ async def _mock( return _mock -@patch( - "zigpy.zcl.clusters.general.Identify.request", - new=AsyncMock(return_value=[mock.sentinel.data, zcl_f.Status.SUCCESS]), -) @pytest.mark.parametrize("device", DEVICES) async def test_devices( device, @@ -140,7 +136,9 @@ async def test_devices( cluster_identify = _get_identify_cluster(zigpy_device) if cluster_identify: - cluster_identify.request.reset_mock() + cluster_identify.request = AsyncMock( + return_value=[mock.sentinel.data, zcl_f.Status.SUCCESS] + ) zha_dev: Device = await device_joined(zigpy_device) await zha_gateway.async_block_till_done() @@ -151,14 +149,24 @@ async def test_devices( False, cluster_identify.commands_by_name["trigger_effect"].id, cluster_identify.commands_by_name["trigger_effect"].schema, + manufacturer=None, + expect_reply=True, + tsn=None, effect_id=zigpy.zcl.clusters.general.Identify.EffectIdentifier.Okay, effect_variant=( zigpy.zcl.clusters.general.Identify.EffectVariant.Default ), - expect_reply=True, + ), + mock.call( + True, + zcl_f.GeneralCommand.Discover_Attributes, + zcl_f.GENERAL_COMMANDS[zcl_f.GeneralCommand.Discover_Attributes].schema, manufacturer=None, + expect_reply=True, tsn=None, - ) + start_attribute_id=0, + max_attribute_ids=255, + ), ] event_cluster_handlers = { diff --git a/tests/test_sensor.py b/tests/test_sensor.py index 0b42a5a55..e0a10d227 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -798,14 +798,12 @@ async def test_unsupported_attributes_sensor( SIG_EP_TYPE: zigpy.profiles.zha.DeviceType.ON_OFF_SWITCH, SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID, } - } + }, + unsupported_attr={1: unsupported_attributes}, ) - cluster = zigpy_device.endpoints[1].in_clusters[cluster_id] if cluster_id == smartenergy.Metering.cluster_id: # this one is mains powered zigpy_device.node_desc.mac_capability_flags |= 0b_0000_0100 - for attr in unsupported_attributes: - cluster.add_unsupported_attribute(attr) zha_device = await device_joined(zigpy_device) diff --git a/zha/zigbee/cluster_handlers/__init__.py b/zha/zigbee/cluster_handlers/__init__.py index 3ff76ee1a..f7365a702 100644 --- a/zha/zigbee/cluster_handlers/__init__.py +++ b/zha/zigbee/cluster_handlers/__init__.py @@ -11,11 +11,15 @@ from typing import TYPE_CHECKING, Any, Final, ParamSpec, TypedDict import zigpy.exceptions +import zigpy.types import zigpy.util import zigpy.zcl from zigpy.zcl.foundation import ( + GENERAL_COMMANDS, CommandSchema, ConfigureReportingResponseRecord, + DiscoverAttributesResponseRecord, + GeneralCommand, Status, ZCLAttributeDef, ) @@ -441,6 +445,7 @@ async def async_configure(self) -> None: if ch_specific_cfg: self.debug("Performing cluster handler specific configuration") await ch_specific_cfg() + self.debug("finished cluster handler configuration") else: self.debug("skipping cluster handler configuration") @@ -458,6 +463,10 @@ async def async_initialize(self, from_cache: bool) -> None: uncached = [a for a, cached in self.ZCL_INIT_ATTRS.items() if not cached] uncached.extend([cfg["attr"] for cfg in self.REPORT_CONFIG]) + if not from_cache: + self.debug("discovering unsupported attributes") + await self.discover_unsupported_attributes() + if cached: self.debug("initializing cached cluster handler attributes: %s", cached) await self._get_attributes( @@ -624,6 +633,47 @@ async def write_attributes_safe( f"Failed to write attribute {name}={value}: {record.status}", ) + async def _discover_attributes_all( + self, + ) -> list[DiscoverAttributesResponseRecord] | None: + discovery_complete = zigpy.types.Bool.false + start_attribute_id = 0 + attribute_info = [] + cluster = self.cluster + while discovery_complete != zigpy.types.Bool.true: + rsp = await cluster.discover_attributes( + start_attribute_id=start_attribute_id, max_attribute_ids=0xFF + ) + if not isinstance( + rsp, GENERAL_COMMANDS[GeneralCommand.Discover_Attributes_rsp].schema + ): + self.debug( + "Ignoring attribute discovery due to unexpected default response: %r", + rsp, + ) + return None + + attribute_info.extend(rsp.attribute_info) + discovery_complete = rsp.discovery_complete + start_attribute_id = ( + max((info.attrid for info in rsp.attribute_info), default=0) + 1 + ) + return attribute_info + + async def discover_unsupported_attributes(self): + """Discover the list of unsupported attributes from the device.""" + attribute_info = await self._discover_attributes_all() + if attribute_info is None: + return + attr_ids = {info.attrid for info in attribute_info} + + cluster = self.cluster + for attr_id in cluster.attributes: + if attr_id in attr_ids: + cluster.remove_unsupported_attribute(attr_id) + else: + cluster.add_unsupported_attribute(attr_id) + def log(self, level, msg, *args, **kwargs) -> None: """Log a message.""" msg = f"[%s:%s]: {msg}"