diff --git a/tests/conftest.py b/tests/conftest.py index bc9f0e81..c1d404d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,8 +69,9 @@ def write(self, data): assert self._is_connected self.protocol.data_received(data) - def close(self, *, error=ValueError("Connection was closed")): # noqa: + def close(self, *, error=ValueError("Connection was closed")): # noqa: B008 LOGGER.debug("Closing %s", self) + if not self._is_connected: return diff --git a/tests/test_commands.py b/tests/test_commands.py index 62e65965..3cf592ce 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -137,7 +137,7 @@ def test_commands_schema(): commands_by_id[cmd.Req.header].append(cmd.Req) else: - assert False, "Command is empty" + assert False, "Command is empty" # noqa: B011 elif cmd.type == t.CommandType.SRSP: # The one command like this is RPCError assert cmd is c.RPCError.CommandNotRecognized @@ -153,7 +153,7 @@ def test_commands_schema(): commands_by_id[cmd.Rsp.header].append(cmd.Rsp) else: - assert False, "Command has unknown type" + assert False, "Command has unknown type" # noqa: B011 duplicate_commands = { cmd: commands for cmd, commands in commands_by_id.items() if len(commands) > 1 diff --git a/tests/test_types_cstruct.py b/tests/test_types_cstruct.py index 25ad25e4..48e8fcae 100644 --- a/tests/test_types_cstruct.py +++ b/tests/test_types_cstruct.py @@ -1,7 +1,12 @@ +import typing + import pytest import zigpy_znp.types as t +if typing.TYPE_CHECKING: + import typing_extensions + def test_struct_fields(): class TestStruct(t.CStruct): @@ -266,7 +271,7 @@ class TestStruct(t.CStruct): def test_old_nib_deserialize(): - PaddingByte = t.uint8_t + PaddingByte: typing_extensions.TypeAlias = t.uint8_t class NwkState16(t.enum_uint16): NWK_INIT = 0 @@ -296,15 +301,15 @@ class OldNIB(t.CStruct): SecurityLevel: t.uint8_t SymLink: t.uint8_t CapabilityFlags: t.uint8_t - PaddingByte0: PaddingByte # type:ignore[valid-type] + PaddingByte0: PaddingByte TransactionPersistenceTime: t.uint16_t nwkProtocolVersion: t.uint8_t RouteDiscoveryTime: t.uint8_t RouteExpiryTime: t.uint8_t - PaddingByte1: PaddingByte # type:ignore[valid-type] + PaddingByte1: PaddingByte nwkDevAddress: t.NWK nwkLogicalChannel: t.uint8_t - PaddingByte2: PaddingByte # type:ignore[valid-type] + PaddingByte2: PaddingByte nwkCoordAddress: t.NWK nwkCoordExtAddress: t.EUI64 nwkPanId: t.uint16_t diff --git a/tests/test_uart.py b/tests/test_uart.py index 3cfc6e2b..018ce1a9 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -218,7 +218,7 @@ def test_uart_frame_received_error(connected_uart, mocker): uart.data_received(test_frame_bytes * 3) # We should have received all three frames - znp.frame_received.call_count == 3 + assert znp.frame_received.call_count == 3 async def test_connection_lost(dummy_serial_conn, mocker, event_loop): diff --git a/zigpy_znp/api.py b/zigpy_znp/api.py index 3e48c213..30c510c2 100644 --- a/zigpy_znp/api.py +++ b/zigpy_znp/api.py @@ -2,18 +2,17 @@ import os import time +import typing import asyncio import logging import itertools import contextlib import dataclasses -from typing import Union, overload from collections import Counter, defaultdict import zigpy.state import async_timeout import zigpy.zdo.types as zdo_t -from typing_extensions import Literal import zigpy_znp.const as const import zigpy_znp.types as t @@ -32,6 +31,9 @@ from zigpy_znp.exceptions import CommandNotRecognized, InvalidCommandResponse from zigpy_znp.types.nvids import ExNvIds, OsalNvIds +if typing.TYPE_CHECKING: + import typing_extensions + LOGGER = logging.getLogger(__name__) @@ -762,21 +764,21 @@ def callback_for_response( return self.callback_for_responses([response], callback) - @overload + @typing.overload def wait_for_responses( - self, responses, *, context: Literal[False] = ... + self, responses, *, context: typing_extensions.Literal[False] = ... ) -> asyncio.Future: ... - @overload + @typing.overload def wait_for_responses( - self, responses, *, context: Literal[True] + self, responses, *, context: typing_extensions.Literal[True] ) -> tuple[asyncio.Future, OneShotResponseListener]: ... def wait_for_responses( self, responses, *, context: bool = False - ) -> Union[asyncio.Future | tuple[asyncio.Future, OneShotResponseListener]]: + ) -> asyncio.Future | tuple[asyncio.Future, OneShotResponseListener]: """ Creates a one-shot listener that matches any *one* of the given responses. """ diff --git a/zigpy_znp/tools/common.py b/zigpy_znp/tools/common.py index fe70867d..dd13d1be 100644 --- a/zigpy_znp/tools/common.py +++ b/zigpy_znp/tools/common.py @@ -1,9 +1,9 @@ from __future__ import annotations import sys +import typing import logging import argparse -from typing import Optional, Sequence import jsonschema import coloredlogs @@ -117,9 +117,7 @@ def validate_backup_json(backup: t.JSONType) -> None: class CustomArgumentParser(argparse.ArgumentParser): - def parse_args( - self, args: Optional[Sequence[str]] = None, namespace=None - ): # type:ignore[override] + def parse_args(self, args: typing.Sequence[str] | None = None, namespace=None): args = super().parse_args(args, namespace) # Since we're running as a CLI tool, install our own log level and color logger diff --git a/zigpy_znp/types/basic.py b/zigpy_znp/types/basic.py index 8131c6f1..c65bbbba 100644 --- a/zigpy_znp/types/basic.py +++ b/zigpy_znp/types/basic.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum +import typing from zigpy_znp.types.cstruct import CStruct @@ -36,12 +37,11 @@ def serialize_list(objects) -> Bytes: class FixedIntType(int): - _signed = None - _size = None # type:int + _signed: bool + _size: int - @classmethod def __new__(cls, *args, **kwargs): - if cls._signed is None or cls._size is None: + if getattr(cls, "_signed", None) is None or getattr(cls, "_size", None) is None: raise TypeError(f"{cls} is abstract and cannot be created") instance = super().__new__(cls, *args, **kwargs) @@ -49,7 +49,6 @@ def __new__(cls, *args, **kwargs): return instance - @classmethod def __init_subclass__(cls, signed=None, size=None, hex_repr=None) -> None: super().__init_subclass__() @@ -79,15 +78,13 @@ def serialize(self) -> bytes: raise ValueError(str(e)) from e @classmethod - def deserialize( - cls, data: bytes - ) -> tuple[FixedIntType, bytes]: # type:ignore[return-value] - if len(data) < cls._size: # type:ignore[operator] + def deserialize(cls, data: bytes) -> tuple[FixedIntType, bytes]: + if len(data) < cls._size: raise ValueError(f"Data is too short to contain {cls._size} bytes") r = cls.from_bytes(data[: cls._size], "little", signed=cls._signed) data = data[cls._size :] - return r, data # type:ignore[return-value] + return typing.cast(FixedIntType, r), data class uint_t(FixedIntType, signed=False): diff --git a/zigpy_znp/types/cstruct.py b/zigpy_znp/types/cstruct.py index cd454382..d1d67679 100644 --- a/zigpy_znp/types/cstruct.py +++ b/zigpy_znp/types/cstruct.py @@ -25,11 +25,7 @@ def __post_init__(self) -> None: def get_size_and_alignment(self, align=False) -> tuple[int, int]: if issubclass(self.type, (zigpy_t.FixedIntType, t.FixedIntType)): - return ( # type: ignore[return-value] - self.type._size, - self.type._size if align else 1, - 1, - ) + return self.type._size, (self.type._size if align else 1) elif issubclass(self.type, zigpy_t.EUI64): return 8, 1 elif issubclass(self.type, zigpy_t.KeyData): diff --git a/zigpy_znp/utils.py b/zigpy_znp/utils.py index b1ccfca2..04d05ef1 100644 --- a/zigpy_znp/utils.py +++ b/zigpy_znp/utils.py @@ -6,7 +6,6 @@ import logging import functools import dataclasses -from typing import CoroutineFunction import zigpy_znp.types as t @@ -136,6 +135,7 @@ class CallbackResponseListener(BaseResponseListener): def _resolve(self, response: t.CommandBase) -> bool: try: + # https://github.com/python/mypy/issues/5485 result = self.callback(response) # type:ignore[misc] # Run coroutines in the background @@ -166,8 +166,8 @@ def matches(self, other) -> bool: def combine_concurrent_calls( - function: CoroutineFunction, -) -> CoroutineFunction: + function: typing.CoroutineFunction, +) -> typing.CoroutineFunction: """ Decorator that allows concurrent calls to expensive coroutines to share a result. """ diff --git a/zigpy_znp/znp/security.py b/zigpy_znp/znp/security.py index a101b576..08b75cc1 100644 --- a/zigpy_znp/znp/security.py +++ b/zigpy_znp/znp/security.py @@ -24,8 +24,8 @@ def replace(self, **kwargs) -> StoredDevice: return dataclasses.replace(self, **kwargs) -def rotate(lst: typing.Sequence, n: int) -> typing.Sequence: - return lst[n:] + lst[:n] # type:ignore[operator] +def rotate(lst: list, n: int) -> list: + return lst[n:] + lst[:n] def compute_key(ieee: t.EUI64, tclk_seed: t.KeyData, shift: int) -> t.KeyData: @@ -174,7 +174,7 @@ async def read_addr_manager_entries(znp: ZNP) -> typing.Sequence[t.AddrMgrEntry] async def read_hashed_link_keys( # type:ignore[misc] znp: ZNP, tclk_seed: t.KeyData -) -> typing.Iterable[zigpy.state.Key]: +) -> typing.AsyncGenerator[zigpy.state.Key, None]: if znp.version >= 3.30: entries = znp.nvram.read_table( item_id=ExNvIds.TCLK_TABLE, @@ -204,9 +204,9 @@ async def read_hashed_link_keys( # type:ignore[misc] ) -async def read_unhashed_link_keys( # type:ignore[misc] +async def read_unhashed_link_keys( znp: ZNP, addr_mgr_entries: typing.Sequence[t.AddrMgrEntry] -) -> typing.Iterable[zigpy.state.Key]: +) -> typing.AsyncGenerator[zigpy.state.Key, None]: if znp.version == 3.30: link_key_offset_base = 0x0000 table = znp.nvram.read_table(