Skip to content

Keep the device firmware version in sync #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions tests/test_device.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test ZHA device switch."""

import asyncio
from datetime import UTC, datetime
import logging
import time
from unittest import mock
Expand All @@ -13,6 +14,7 @@
from zigpy.quirks.v2 import DeviceAlertLevel, DeviceAlertMetadata, QuirkBuilder
import zigpy.types
from zigpy.zcl.clusters import general
from zigpy.zcl.clusters.general import Ota
from zigpy.zcl.foundation import Status, WriteAttributesResponse
import zigpy.zdo.types as zdo_t

Expand Down Expand Up @@ -42,7 +44,11 @@
from zha.application.platforms.sensor import LQISensor, RSSISensor
from zha.application.platforms.switch import Switch
from zha.exceptions import ZHAException
from zha.zigbee.device import ClusterBinding, get_device_automation_triggers
from zha.zigbee.device import (
ClusterBinding,
DeviceUpdatedEvent,
get_device_automation_triggers,
)
from zha.zigbee.group import Group


Expand Down Expand Up @@ -741,7 +747,7 @@ async def test_device_properties(

assert zha_device.power_configuration_ch is None
assert zha_device.basic_ch is not None
assert zha_device.sw_version is None
assert zha_device.firmware_version is None

assert len(zha_device.platform_entities) == 3

Expand Down Expand Up @@ -780,6 +786,43 @@ async def test_device_properties(
assert zha_device.is_coordinator is None


async def test_device_firmware_version_syncing(zha_gateway: Gateway) -> None:
"""Test device firmware version syncing."""
zigpy_dev = await zigpy_device_from_json(
zha_gateway.application_controller,
"tests/data/devices/philips-sml001.json",
)

zha_device = await join_zigpy_device(zha_gateway, zigpy_dev)

# Register a callback to listen for device updates
update_callback = mock.Mock()
zha_device.on_event(DeviceUpdatedEvent.event_type, update_callback)

# The firmware version is restored on device initialization
assert zha_device.firmware_version == "0x42006bb7"

# If we update the entity, the device updates as well
update_entity = get_entity(zha_device, platform=Platform.UPDATE)
update_entity._ota_cluster_handler.attribute_updated(
attrid=Ota.AttributeDefs.current_file_version.id,
value=zigpy.types.uint32_t(0xABCD1234),
timestamp=datetime.now(UTC),
)

assert zha_device.firmware_version == "0xabcd1234"

# Duplicate updates are ignored
update_entity._ota_cluster_handler.attribute_updated(
attrid=Ota.AttributeDefs.current_file_version.id,
value=zigpy.types.uint32_t(0xABCD1234),
timestamp=datetime.now(UTC),
)

assert zha_device.firmware_version == "0xabcd1234"
assert update_callback.mock_calls == [call(DeviceUpdatedEvent())]


async def test_quirks_v2_device_renaming(zha_gateway: Gateway) -> None:
"""Test quirks v2 device renaming."""
registry = DeviceRegistry()
Expand Down
1 change: 0 additions & 1 deletion tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ async def setup_test_data(
)

zha_device = await join_zigpy_device(zha_gateway, zigpy_device)
zha_device.async_update_sw_build_id(installed_fw_version)

return zha_device, ota_cluster, fw_image, installed_fw_version

Expand Down
1 change: 1 addition & 0 deletions zha/application/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def pretty_name(self) -> str:
ZHA_CLUSTER_HANDLER_CFG_DONE = "zha_channel_cfg_done"
ZHA_CLUSTER_HANDLER_READS_PER_REQ = 5
ZHA_EVENT = "zha_event"
ZHA_DEVICE_UPDATED_EVENT = "zha_device_updated_event"
ZHA_GW_MSG = "zha_gateway_message"
ZHA_GW_MSG_DEVICE_FULL_INIT = "device_fully_initialized"
ZHA_GW_MSG_DEVICE_INFO = "device_info"
Expand Down
1 change: 0 additions & 1 deletion zha/zigbee/cluster_handlers/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,6 @@ def cluster_command(
self.cluster.update_attribute(
Ota.AttributeDefs.current_file_version.id, current_file_version
)
self._endpoint.device.sw_version = current_file_version


@registries.CLUSTER_HANDLER_REGISTRY.register(Partition.cluster_id)
Expand Down
70 changes: 52 additions & 18 deletions zha/zigbee/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import asyncio
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
Expand Down Expand Up @@ -58,10 +58,16 @@
UNKNOWN_MODEL,
ZHA_CLUSTER_HANDLER_CFG_DONE,
ZHA_CLUSTER_HANDLER_MSG,
ZHA_DEVICE_UPDATED_EVENT,
ZHA_EVENT,
)
from zha.application.helpers import convert_to_zcl_values, convert_zcl_value
from zha.application.platforms import BaseEntityInfo, PlatformEntity
from zha.application.platforms import (
BaseEntityInfo,
EntityStateChangedEvent,
PlatformEntity,
)
from zha.const import STATE_CHANGED
from zha.event import EventBase
from zha.exceptions import ZHAException
from zha.mixins import LogMixin
Expand Down Expand Up @@ -113,6 +119,14 @@ class ZHAEvent:
event: Final[str] = ZHA_EVENT


@dataclass(kw_only=True, frozen=True)
class DeviceUpdatedEvent:
"""Event generated when the device information has changed."""

event_type: Final[str] = ZHA_DEVICE_UPDATED_EVENT
event: Final[str] = ZHA_DEVICE_UPDATED_EVENT


@dataclass(kw_only=True, frozen=True)
class ClusterHandlerConfigurationComplete:
"""Event generated when all cluster handlers are configured."""
Expand Down Expand Up @@ -218,7 +232,7 @@ def __init__(
self._power_config_ch: ClusterHandler | None = None
self._identify_ch: ClusterHandler | None = None
self._basic_ch: ClusterHandler | None = None
self._sw_build_id: int | None = None
self._firmware_version: str | None = None

device_options = _gateway.config.config.device_options
if self.is_mains_powered:
Expand All @@ -238,15 +252,20 @@ def __init__(
self._pending_entities: list[PlatformEntity] = []
self.semaphore: asyncio.Semaphore = asyncio.Semaphore(3)

self._on_remove_callbacks: list[Callable[[], None]] = []

self._zdo_handler: ZDOClusterHandler = ZDOClusterHandler(self)
self._zdo_handler.on_add()
self._on_remove_callbacks.append(self._zdo_handler.on_remove)

self.status: DeviceStatus = DeviceStatus.CREATED

self._endpoints: dict[int, Endpoint] = {}
for ep_id, endpoint in zigpy_device.endpoints.items():
if ep_id != 0:
self._endpoints[ep_id] = Endpoint.new(endpoint, self)
ep = Endpoint.new(endpoint, self)
self._endpoints[ep_id] = ep
self._on_remove_callbacks.append(ep.on_remove)

def __repr__(self) -> str:
"""Return a string representation of the device."""
Expand Down Expand Up @@ -523,14 +542,9 @@ def zigbee_signature(self) -> dict[str, Any]:
}

@property
def sw_version(self) -> int | None:
def firmware_version(self) -> str | None:
"""Return the software version for this device."""
return self._sw_build_id

@sw_version.setter
def sw_version(self, sw_build_id: int) -> None:
"""Set the software version for this device."""
self._sw_build_id = sw_build_id
return self._firmware_version

@property
def platform_entities(self) -> dict[tuple[Platform, str], PlatformEntity]:
Expand All @@ -553,9 +567,13 @@ def new(
"""Create new device."""
return cls(zigpy_dev, gateway)

def async_update_sw_build_id(self, sw_version: int) -> None:
"""Update device sw version."""
self._sw_build_id = sw_version
def async_update_firmware_version(self, firmware_version: str) -> None:
"""Update device firmware version."""
if firmware_version == self._firmware_version:
return

self._firmware_version = firmware_version
self.emit(DeviceUpdatedEvent.event_type, DeviceUpdatedEvent())

async def _check_available(self, *_: Any) -> None:
# don't flip the availability state of the coordinator
Expand Down Expand Up @@ -869,16 +887,32 @@ async def async_initialize(self, from_cache: bool = False) -> None:
# At this point we can compute a primary entity
self._compute_primary_entity()

# Sync the device's firmware version with the first platform entity
for (platform, _unique_id), entity in self.platform_entities.items():
if platform != Platform.UPDATE:
continue
Comment on lines +890 to +893
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a bit weird to go from entity level back up to device level for this info. We create a small (unnecessary?) dependence here.

Alternative might be for the device.firmware_version to return the cached current_file_version attribute (like the update entity currently does) and also have the device subscribe to current_file_version attribute changes (instead of the update entity doing that).
Then, you would have the update entity use device.firmware_version and have it subscribe to the DeviceUpdatedEvent events.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I was thinking of this from the opposite direction: the update entity can be overriden per-device and then provide a formatted firmware version (without the ZHA device needing to be patched). It would avoid having the device scan through clusters to determine the correct firmware version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the update entity can be overriden per-device and then provide a formatted firmware version (without the ZHA device needing to be patched)

Does the base ZHA device need to know the "friendly version" though? For HA, we generally use the update entity for that, right? Except for the little firmware version string in the device card/registry...

But yeah, I'm fine with this as well, so feel free to merge as-is if you want.


self._firmware_version = entity.installed_version

def entity_update_listener(event: EntityStateChangedEvent) -> None:
"""Listen to firmware update entity changes."""
entity = self.get_platform_entity(event.platform, event.unique_id)
self.async_update_firmware_version(entity.installed_version)

self._on_remove_callbacks.append(
entity.on_event(STATE_CHANGED, entity_update_listener)
)

break

self.debug("power source: %s", self.power_source)
self.status = DeviceStatus.INITIALIZED
self.debug("completed initialization")

async def on_remove(self) -> None:
"""Cancel tasks this device owns."""
self._zdo_handler.on_remove()

for endpoint in self._endpoints.values():
endpoint.on_remove()
for callback in self._on_remove_callbacks:
callback()

for platform_entity in self._platform_entities.values():
await platform_entity.on_remove()
Expand Down
Loading