diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ae757c5c..08b04bd4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,9 +14,23 @@ repos: hooks: - id: flake8 entry: pflake8 - additional_dependencies: ['pyproject-flake8==0.0.1a2'] + additional_dependencies: + - pyproject-flake8==0.0.1a2 + - flake8-bugbear==22.1.11 + - flake8-comprehensions==3.8.0 + - flake8_2020==1.6.1 + - mccabe==0.6.1 + - pycodestyle==2.8.0 + - pyflakes==2.4.0 - repo: https://github.com/PyCQA/isort rev: 5.10.1 hooks: - id: isort + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.931 + hooks: + - id: mypy + additional_dependencies: + - zigpy==0.43.0 diff --git a/setup.cfg b/setup.cfg index 0054551c..869d1b46 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,3 +39,23 @@ testing = [coverage:run] source = zigpy_znp + +[flake8] +max-line-length = 88 + +[mypy] +ignore_missing_imports = True +install_types = True +non_interactive = True +check_untyped_defs = True +show_error_codes = True +show_error_context = True +disable_error_code = + attr-defined, + arg-type, + type-var, + var-annotated, + assignment, + call-overload, + name-defined, + union-attr diff --git a/tests/conftest.py b/tests/conftest.py index ab8080a4..c1d404d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ # Python 3.8 already has this from unittest.mock import AsyncMock as CoroutineMock # noqa: F401 except ImportError: - from asynctest import CoroutineMock # noqa: F401 + from asynctest import CoroutineMock # type:ignore[no-redef] # noqa: F401 import zigpy.endpoint import zigpy.zdo.types as zdo_t @@ -69,8 +69,9 @@ def write(self, data): assert self._is_connected self.protocol.data_received(data) - def close(self, *, error=ValueError("Connection was closed")): + def close(self, *, error=ValueError("Connection was closed")): # noqa: B008 LOGGER.debug("Closing %s", self) + if not self._is_connected: return diff --git a/tests/test_commands.py b/tests/test_commands.py index 62e65965..3cf592ce 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -137,7 +137,7 @@ def test_commands_schema(): commands_by_id[cmd.Req.header].append(cmd.Req) else: - assert False, "Command is empty" + assert False, "Command is empty" # noqa: B011 elif cmd.type == t.CommandType.SRSP: # The one command like this is RPCError assert cmd is c.RPCError.CommandNotRecognized @@ -153,7 +153,7 @@ def test_commands_schema(): commands_by_id[cmd.Rsp.header].append(cmd.Rsp) else: - assert False, "Command has unknown type" + assert False, "Command has unknown type" # noqa: B011 duplicate_commands = { cmd: commands for cmd, commands in commands_by_id.items() if len(commands) > 1 diff --git a/tests/test_types_cstruct.py b/tests/test_types_cstruct.py index df09216f..48e8fcae 100644 --- a/tests/test_types_cstruct.py +++ b/tests/test_types_cstruct.py @@ -1,7 +1,12 @@ +import typing + import pytest import zigpy_znp.types as t +if typing.TYPE_CHECKING: + import typing_extensions + def test_struct_fields(): class TestStruct(t.CStruct): @@ -266,7 +271,7 @@ class TestStruct(t.CStruct): def test_old_nib_deserialize(): - PaddingByte = t.uint8_t + PaddingByte: typing_extensions.TypeAlias = t.uint8_t class NwkState16(t.enum_uint16): NWK_INIT = 0 @@ -330,11 +335,11 @@ class OldNIB(t.CStruct): nwkConcentratorDiscoveryTime: t.uint8_t nwkConcentratorRadius: t.uint8_t nwkAllFresh: t.uint8_t - PaddingByte3: PaddingByte + PaddingByte3: PaddingByte # type:ignore[valid-type] nwkManagerAddr: t.NWK nwkTotalTransmissions: t.uint16_t nwkUpdateId: t.uint8_t - PaddingByte4: PaddingByte + PaddingByte4: PaddingByte # type:ignore[valid-type] nib = t.NIB( SequenceNum=54, diff --git a/tests/test_uart.py b/tests/test_uart.py index 3cfc6e2b..018ce1a9 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -218,7 +218,7 @@ def test_uart_frame_received_error(connected_uart, mocker): uart.data_received(test_frame_bytes * 3) # We should have received all three frames - znp.frame_received.call_count == 3 + assert znp.frame_received.call_count == 3 async def test_connection_lost(dummy_serial_conn, mocker, event_loop): diff --git a/zigpy_znp/api.py b/zigpy_znp/api.py index f7e155c2..30c510c2 100644 --- a/zigpy_znp/api.py +++ b/zigpy_znp/api.py @@ -2,6 +2,7 @@ import os import time +import typing import asyncio import logging import itertools @@ -30,6 +31,9 @@ from zigpy_znp.exceptions import CommandNotRecognized, InvalidCommandResponse from zigpy_znp.types.nvids import ExNvIds, OsalNvIds +if typing.TYPE_CHECKING: + import typing_extensions + LOGGER = logging.getLogger(__name__) @@ -50,8 +54,8 @@ def __init__(self, config: conf.ConfigType): self._listeners = defaultdict(list) self._sync_request_lock = asyncio.Lock() - self.capabilities = None - self.version = None + self.capabilities = None # type: int + self.version = None # type: float self.nvram = NVRAMHelper(self) self.network_info: zigpy.state.NetworkInformation = None @@ -542,7 +546,7 @@ async def ping_task(): try: async with async_timeout.timeout(CONNECT_PING_TIMEOUT): - result = await ping_task + result = await ping_task # type:ignore[misc] except asyncio.TimeoutError: ping_task.cancel() @@ -609,7 +613,7 @@ def close(self) -> None: self._app = None - for header, listeners in self._listeners.items(): + for _header, listeners in self._listeners.items(): for listener in listeners: listener.cancel() @@ -659,7 +663,7 @@ def remove_listener(self, listener: BaseResponseListener) -> None: counts[OneShotResponseListener], ) - def frame_received(self, frame: GeneralFrame) -> bool: + def frame_received(self, frame: GeneralFrame) -> bool | None: """ Called when a frame has been received. Returns whether or not the frame was handled by any listener. @@ -669,7 +673,7 @@ def frame_received(self, frame: GeneralFrame) -> bool: if frame.header not in c.COMMANDS_BY_ID: LOGGER.error("Received an unknown frame: %s", frame) - return + return None command_cls = c.COMMANDS_BY_ID[frame.header] @@ -680,7 +684,7 @@ def frame_received(self, frame: GeneralFrame) -> bool: # https://github.com/home-assistant/core/issues/50005 if command_cls == c.ZDO.ParentAnnceRsp.Callback: LOGGER.warning("Failed to parse broken %s as %s", frame, command_cls) - return + return None raise @@ -760,7 +764,21 @@ def callback_for_response( return self.callback_for_responses([response], callback) - def wait_for_responses(self, responses, *, context=False) -> asyncio.Future: + @typing.overload + def wait_for_responses( + self, responses, *, context: typing_extensions.Literal[False] = ... + ) -> asyncio.Future: + ... + + @typing.overload + def wait_for_responses( + self, responses, *, context: typing_extensions.Literal[True] + ) -> tuple[asyncio.Future, OneShotResponseListener]: + ... + + def wait_for_responses( + self, responses, *, context: bool = False + ) -> asyncio.Future | tuple[asyncio.Future, OneShotResponseListener]: """ Creates a one-shot listener that matches any *one* of the given responses. """ @@ -787,7 +805,9 @@ def wait_for_response(self, response: t.CommandBase) -> asyncio.Future: return self.wait_for_responses([response]) - async def request(self, request: t.CommandBase, **response_params) -> t.CommandBase: + async def request( + self, request: t.CommandBase, **response_params + ) -> t.CommandBase | None: """ Sends a SREQ/AREQ request and returns its SRSP (only for SREQ), failing if any of the SRSP's parameters don't match `response_params`. @@ -827,7 +847,7 @@ async def request(self, request: t.CommandBase, **response_params) -> t.CommandB if not request.Rsp: LOGGER.debug("Request has no response, not waiting for one.") self._uart.send(frame) - return + return None # We need to create the response listener before we send the request response_future = self.wait_for_responses( diff --git a/zigpy_znp/tools/common.py b/zigpy_znp/tools/common.py index 521d4ed6..dd13d1be 100644 --- a/zigpy_znp/tools/common.py +++ b/zigpy_znp/tools/common.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import typing import logging import argparse @@ -116,7 +117,7 @@ def validate_backup_json(backup: t.JSONType) -> None: class CustomArgumentParser(argparse.ArgumentParser): - def parse_args(self, args: list[str] = None, namespace=None): + def parse_args(self, args: typing.Sequence[str] | None = None, namespace=None): args = super().parse_args(args, namespace) # Since we're running as a CLI tool, install our own log level and color logger diff --git a/zigpy_znp/types/basic.py b/zigpy_znp/types/basic.py index 47fc2e70..c65bbbba 100644 --- a/zigpy_znp/types/basic.py +++ b/zigpy_znp/types/basic.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum +import typing from zigpy_znp.types.cstruct import CStruct @@ -36,11 +37,11 @@ def serialize_list(objects) -> Bytes: class FixedIntType(int): - _signed = None - _size = None + _signed: bool + _size: int def __new__(cls, *args, **kwargs): - if cls._signed is None or cls._size is None: + if getattr(cls, "_signed", None) is None or getattr(cls, "_size", None) is None: raise TypeError(f"{cls} is abstract and cannot be created") instance = super().__new__(cls, *args, **kwargs) @@ -58,7 +59,7 @@ def __init_subclass__(cls, signed=None, size=None, hex_repr=None) -> None: cls._size = size if hex_repr: - fmt = f"0x{{:0{cls._size * 2}X}}" + fmt = f"0x{{:0{cls._size * 2}X}}" # type:ignore[operator] cls.__str__ = cls.__repr__ = lambda self: fmt.format(self) elif hex_repr is not None and not hex_repr: cls.__str__ = super().__str__ @@ -83,7 +84,7 @@ def deserialize(cls, data: bytes) -> tuple[FixedIntType, bytes]: r = cls.from_bytes(data[: cls._size], "little", signed=cls._signed) data = data[cls._size :] - return r, data + return typing.cast(FixedIntType, r), data class uint_t(FixedIntType, signed=False): @@ -162,7 +163,7 @@ class ShortBytes(Bytes): _header = uint8_t def serialize(self) -> Bytes: - return self._header(len(self)).serialize() + self + return self._header(len(self)).serialize() + self # type:ignore[return-value] @classmethod def deserialize(cls, data: bytes) -> tuple[Bytes, bytes]: @@ -182,7 +183,7 @@ class BaseListType(list): @classmethod def _serialize_item(cls, item, *, align): if not isinstance(item, cls._item_type): - item = cls._item_type(item) + item = cls._item_type(item) # type:ignore[misc] if issubclass(cls._item_type, CStruct): return item.serialize(align=align) @@ -215,7 +216,7 @@ def serialize(self, *, align=False) -> bytes: def deserialize(cls, data: bytes, *, align=False) -> tuple[LVList, bytes]: length, data = cls._header.deserialize(data) r = cls() - for i in range(length): + for _i in range(length): item, data = cls._deserialize_item(data, align=align) r.append(item) return r, data @@ -242,7 +243,7 @@ def serialize(self, *, align=False) -> bytes: @classmethod def deserialize(cls, data: bytes, *, align=False) -> tuple[FixedList, bytes]: r = cls() - for i in range(cls._length): + for _i in range(cls._length): item, data = cls._deserialize_item(data, align=align) r.append(item) return r, data @@ -271,7 +272,7 @@ def enum_flag_factory(int_type: FixedIntType) -> enum.Flag: appropriate methods but with only one non-Enum parent class. """ - class _NewEnum(int_type, enum.Flag): + class _NewEnum(int_type, enum.Flag): # type:ignore[misc,valid-type] # Rebind classmethods to our own class _missing_ = classmethod(enum.IntFlag._missing_.__func__) _create_pseudo_member_ = classmethod( @@ -286,7 +287,7 @@ class _NewEnum(int_type, enum.Flag): __rxor__ = enum.IntFlag.__rxor__ __invert__ = enum.IntFlag.__invert__ - return _NewEnum + return _NewEnum # type:ignore[return-value] class enum_uint8(uint8_t, enum.Enum): @@ -321,33 +322,33 @@ class enum_uint64(uint64_t, enum.Enum): pass -class enum_flag_uint8(enum_flag_factory(uint8_t)): +class enum_flag_uint8(enum_flag_factory(uint8_t)): # type:ignore[misc] pass -class enum_flag_uint16(enum_flag_factory(uint16_t)): +class enum_flag_uint16(enum_flag_factory(uint16_t)): # type:ignore[misc] pass -class enum_flag_uint24(enum_flag_factory(uint24_t)): +class enum_flag_uint24(enum_flag_factory(uint24_t)): # type:ignore[misc] pass -class enum_flag_uint32(enum_flag_factory(uint32_t)): +class enum_flag_uint32(enum_flag_factory(uint32_t)): # type:ignore[misc] pass -class enum_flag_uint40(enum_flag_factory(uint40_t)): +class enum_flag_uint40(enum_flag_factory(uint40_t)): # type:ignore[misc] pass -class enum_flag_uint48(enum_flag_factory(uint48_t)): +class enum_flag_uint48(enum_flag_factory(uint48_t)): # type:ignore[misc] pass -class enum_flag_uint56(enum_flag_factory(uint56_t)): +class enum_flag_uint56(enum_flag_factory(uint56_t)): # type:ignore[misc] pass -class enum_flag_uint64(enum_flag_factory(uint64_t)): +class enum_flag_uint64(enum_flag_factory(uint64_t)): # type:ignore[misc] pass diff --git a/zigpy_znp/types/commands.py b/zigpy_znp/types/commands.py index 02ff7381..4a1fc7e1 100644 --- a/zigpy_znp/types/commands.py +++ b/zigpy_znp/types/commands.py @@ -214,7 +214,7 @@ class Req(CommandBase, header=header, schema=definition.req_schema): req_header = header rsp_header = CommandHeader(0x0040 + req_header) - class Req( + class Req( # type:ignore[no-redef] CommandBase, header=req_header, schema=definition.req_schema ): pass @@ -261,7 +261,9 @@ class Callback( ) # pragma: no cover # If there is no request, this is a just a response - class Rsp(CommandBase, header=header, schema=definition.rsp_schema): + class Rsp( # type:ignore[no-redef] + CommandBase, header=header, schema=definition.rsp_schema + ): pass Rsp.__qualname__ = qualname + ".Rsp" diff --git a/zigpy_znp/types/cstruct.py b/zigpy_znp/types/cstruct.py index 440309b7..d1d67679 100644 --- a/zigpy_znp/types/cstruct.py +++ b/zigpy_znp/types/cstruct.py @@ -25,7 +25,7 @@ def __post_init__(self) -> None: def get_size_and_alignment(self, align=False) -> tuple[int, int]: if issubclass(self.type, (zigpy_t.FixedIntType, t.FixedIntType)): - return self.type._size, self.type._size if align else 1 + return self.type._size, (self.type._size if align else 1) elif issubclass(self.type, zigpy_t.EUI64): return 8, 1 elif issubclass(self.type, zigpy_t.KeyData): @@ -131,7 +131,7 @@ def get_alignment(cls, *, align=False) -> int: def get_size(cls, *, align=False) -> int: total_size = 0 - for padding, size, field in cls.get_padded_fields(align=align): + for padding, size, _field in cls.get_padded_fields(align=align): total_size += padding + size final_padding = (-total_size) % cls.get_alignment(align=align) @@ -198,7 +198,7 @@ def replace(self, **kwargs) -> CStruct: return type(self)(**d) - def __eq__(self, other: CStruct) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(self, type(other)) and not isinstance(other, type(self)): return NotImplemented diff --git a/zigpy_znp/types/named.py b/zigpy_znp/types/named.py index 032729af..004fa0be 100644 --- a/zigpy_znp/types/named.py +++ b/zigpy_znp/types/named.py @@ -64,7 +64,7 @@ def _get_address_type(self): }[self.mode] @classmethod - def deserialize(cls, data: bytes) -> "AddrModeAddress": + def deserialize(cls, data: bytes) -> tuple[AddrModeAddress, bytes]: mode, data = AddrMode.deserialize(data) address, data = EUI64.deserialize(data) diff --git a/zigpy_znp/utils.py b/zigpy_znp/utils.py index 7e4cd469..04d05ef1 100644 --- a/zigpy_znp/utils.py +++ b/zigpy_znp/utils.py @@ -14,7 +14,7 @@ def deduplicate_commands( commands: typing.Iterable[t.CommandBase], -) -> tuple[t.CommandBase]: +) -> tuple[t.CommandBase, ...]: """ Deduplicates an iterable of commands by folding more-specific commands into less- specific commands. Used to avoid triggering callbacks multiple times per packet. @@ -135,7 +135,8 @@ class CallbackResponseListener(BaseResponseListener): def _resolve(self, response: t.CommandBase) -> bool: try: - result = self.callback(response) + # https://github.com/python/mypy/issues/5485 + result = self.callback(response) # type:ignore[misc] # Run coroutines in the background if asyncio.iscoroutine(result): diff --git a/zigpy_znp/zigbee/application.py b/zigpy_znp/zigbee/application.py index 370dd517..8697dabc 100644 --- a/zigpy_znp/zigbee/application.py +++ b/zigpy_znp/zigbee/application.py @@ -806,7 +806,7 @@ def _zstack_build_id(self) -> t.uint32_t: # Old versions of Z-Stack do not include `CodeRevision` in the version response if self._version_rsp.CodeRevision is None: - return 0x00000000 + return t.uint32_t(0x00000000) return self._version_rsp.CodeRevision diff --git a/zigpy_znp/znp/security.py b/zigpy_znp/znp/security.py index 0b01a53b..08b75cc1 100644 --- a/zigpy_znp/znp/security.py +++ b/zigpy_znp/znp/security.py @@ -24,7 +24,7 @@ def replace(self, **kwargs) -> StoredDevice: return dataclasses.replace(self, **kwargs) -def rotate(lst: typing.Sequence, n: int) -> typing.Sequence: +def rotate(lst: list, n: int) -> list: return lst[n:] + lst[:n] @@ -172,9 +172,9 @@ async def read_addr_manager_entries(znp: ZNP) -> typing.Sequence[t.AddrMgrEntry] return entries -async def read_hashed_link_keys( +async def read_hashed_link_keys( # type:ignore[misc] znp: ZNP, tclk_seed: t.KeyData -) -> typing.Iterable[zigpy.state.Key]: +) -> typing.AsyncGenerator[zigpy.state.Key, None]: if znp.version >= 3.30: entries = znp.nvram.read_table( item_id=ExNvIds.TCLK_TABLE, @@ -206,7 +206,7 @@ async def read_hashed_link_keys( async def read_unhashed_link_keys( znp: ZNP, addr_mgr_entries: typing.Sequence[t.AddrMgrEntry] -) -> typing.Iterable[zigpy.state.Key]: +) -> typing.AsyncGenerator[zigpy.state.Key, None]: if znp.version == 3.30: link_key_offset_base = 0x0000 table = znp.nvram.read_table(