Skip to content

Add initial typings #1127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[flake8]
select = C90,E,F,W,Y0
ignore = E402,E731,W503,W504,E252
exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv,.tox
exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox
per-file-ignores = *.pyi: F401,F403,F405,F811,E127,E128,E203,E266,E301,E302,E305,E501,E701,E704,E741,B303,W503,W504
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ jobs:
github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
version_file: asyncpg/_version.py
version_line_pattern: |
__version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
__version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
- name: Stop if not approved
if: steps.checkver.outputs.approved != 'true'
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -34,3 +34,5 @@ docs/_build
/.eggs
/.vscode
/.mypy_cache
/.venv*
/.tox
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
recursive-include docs *.py *.rst Makefile *.css
recursive-include examples *.py
recursive-include tests *.py *.pem
recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.c *.h
recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h
include LICENSE README.rst Makefile performance.png .flake8
7 changes: 6 additions & 1 deletion asyncpg/__init__.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

from .connection import connect, Connection # NOQA
from .exceptions import * # NOQA
@@ -14,6 +15,10 @@

from ._version import __version__ # NOQA

from . import exceptions

__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection')

__all__: tuple[str, ...] = (
'connect', 'create_pool', 'Pool', 'Record', 'Connection'
)
__all__ += exceptions.__all__ # NOQA
13 changes: 10 additions & 3 deletions asyncpg/_asyncio_compat.py
Original file line number Diff line number Diff line change
@@ -4,18 +4,25 @@
#
# SPDX-License-Identifier: PSF-2.0

from __future__ import annotations

import asyncio
import functools
import sys
import typing

if typing.TYPE_CHECKING:
from . import compat

if sys.version_info < (3, 11):
from async_timeout import timeout as timeout_ctx
else:
from asyncio import timeout as timeout_ctx

_T = typing.TypeVar('_T')


async def wait_for(fut, timeout):
async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
"""Wait for the single Future or coroutine to complete, with timeout.
Coroutine will be wrapped in Task.
@@ -65,7 +72,7 @@ async def wait_for(fut, timeout):
return await fut


async def _cancel_and_wait(fut):
async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
"""Cancel the *fut* future or task and wait until it completes."""

loop = asyncio.get_running_loop()
@@ -82,6 +89,6 @@ async def _cancel_and_wait(fut):
fut.remove_done_callback(cb)


def _release_waiter(waiter, *args):
def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None:
if not waiter.done():
waiter.set_result(None)
6 changes: 5 additions & 1 deletion asyncpg/_version.py
Original file line number Diff line number Diff line change
@@ -10,4 +10,8 @@
# supported platforms, publish the packages on PyPI, merge the PR
# to the target branch, create a Git tag pointing to the commit.

__version__ = '0.30.0.dev0'
from __future__ import annotations

import typing

__version__: typing.Final = '0.30.0.dev0'
24 changes: 18 additions & 6 deletions asyncpg/compat.py
Original file line number Diff line number Diff line change
@@ -4,22 +4,25 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

import pathlib
import platform
import typing
import sys

if typing.TYPE_CHECKING:
import asyncio

SYSTEM = platform.uname().system
SYSTEM: typing.Final = platform.uname().system


if SYSTEM == 'Windows':
if sys.platform == 'win32':
import ctypes.wintypes

CSIDL_APPDATA = 0x001a
CSIDL_APPDATA: typing.Final = 0x001a

def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
def get_pg_home_directory() -> pathlib.Path | None:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
@@ -31,14 +34,14 @@ def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
return pathlib.Path(buf.value) / 'postgresql'

else:
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
def get_pg_home_directory() -> pathlib.Path | None:
try:
return pathlib.Path.home()
except (RuntimeError, KeyError):
return None


async def wait_closed(stream):
async def wait_closed(stream: asyncio.StreamWriter) -> None:
# Not all asyncio versions have StreamWriter.wait_closed().
if hasattr(stream, 'wait_closed'):
try:
@@ -59,3 +62,12 @@ async def wait_closed(stream):
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
else:
from asyncio import timeout as timeout # noqa: F401

if sys.version_info < (3, 9):
from typing import ( # noqa: F401
Awaitable as Awaitable,
)
else:
from collections.abc import ( # noqa: F401
Awaitable as Awaitable,
)
22 changes: 14 additions & 8 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,14 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

_TYPEINFO_13 = '''\
import typing

if typing.TYPE_CHECKING:
from . import protocol

_TYPEINFO_13: typing.Final = '''\
(
SELECT
t.oid AS oid,
@@ -124,7 +130,7 @@
'''.format(typeinfo=_TYPEINFO_13)


_TYPEINFO = '''\
_TYPEINFO: typing.Final = '''\
(
SELECT
t.oid AS oid,
@@ -248,7 +254,7 @@
'''.format(typeinfo=_TYPEINFO)


TYPE_BY_NAME = '''\
TYPE_BY_NAME: typing.Final = '''\
SELECT
t.oid,
t.typelem AS elemtype,
@@ -277,16 +283,16 @@
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')


def is_scalar_type(typeinfo) -> bool:
def is_scalar_type(typeinfo: protocol.Record) -> bool:
return (
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)


def is_domain_type(typeinfo) -> bool:
return typeinfo['kind'] == b'd'
def is_domain_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'd' # type: ignore[no-any-return]


def is_composite_type(typeinfo) -> bool:
return typeinfo['kind'] == b'c'
def is_composite_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'c' # type: ignore[no-any-return]
2 changes: 2 additions & 0 deletions asyncpg/protocol/__init__.py
Original file line number Diff line number Diff line change
@@ -6,4 +6,6 @@

# flake8: NOQA

from __future__ import annotations

from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP
300 changes: 300 additions & 0 deletions asyncpg/protocol/protocol.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
import asyncio
import asyncio.protocols
import hmac
from codecs import CodecInfo
from collections.abc import Callable, Iterable, Iterator, Sequence
from hashlib import md5, sha256
from typing import (
Any,
ClassVar,
Final,
Generic,
Literal,
NewType,
TypeVar,
final,
overload,
)
from typing_extensions import TypeAlias

import asyncpg.pgproto.pgproto

from ..connect_utils import _ConnectionParameters
from ..pgproto.pgproto import WriteBuffer
from ..types import Attribute, Type

_T = TypeVar('_T')
_Record = TypeVar('_Record', bound=Record)
_OtherRecord = TypeVar('_OtherRecord', bound=Record)
_PreparedStatementState = TypeVar(
'_PreparedStatementState', bound=PreparedStatementState[Any]
)

_NoTimeoutType = NewType('_NoTimeoutType', object)
_TimeoutType: TypeAlias = float | None | _NoTimeoutType

BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]]
BUILTIN_TYPE_OID_MAP: Final[dict[int, str]]
NO_TIMEOUT: Final[_NoTimeoutType]

hashlib_md5 = md5

@final
class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext):
__pyx_vtable__: Any
def __init__(self, conn_key: object) -> None: ...
def add_python_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typeinfos: Iterable[object],
typekind: str,
encoder: Callable[[Any], Any],
decoder: Callable[[Any], Any],
format: object,
) -> Any: ...
def clear_type_cache(self) -> None: ...
def get_data_codec(
self, oid: int, format: object = ..., ignore_custom_codec: bool = ...
) -> Any: ...
def get_text_codec(self) -> CodecInfo: ...
def register_data_types(self, types: Iterable[object]) -> None: ...
def remove_python_codec(
self, typeoid: int, typename: str, typeschema: str
) -> None: ...
def set_builtin_type_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typekind: str,
alias_to: str,
format: object = ...,
) -> Any: ...
def __getattr__(self, name: str) -> Any: ...
def __reduce__(self) -> Any: ...

@final
class PreparedStatementState(Generic[_Record]):
closed: bool
prepared: bool
name: str
query: str
refs: int
record_class: type[_Record]
ignore_custom_codec: bool
__pyx_vtable__: Any
def __init__(
self,
name: str,
query: str,
protocol: BaseProtocol[Any],
record_class: type[_Record],
ignore_custom_codec: bool,
) -> None: ...
def _get_parameters(self) -> tuple[Type, ...]: ...
def _get_attributes(self) -> tuple[Attribute, ...]: ...
def _init_types(self) -> set[int]: ...
def _init_codecs(self) -> None: ...
def attach(self) -> None: ...
def detach(self) -> None: ...
def mark_closed(self) -> None: ...
def mark_unprepared(self) -> None: ...
def __reduce__(self) -> Any: ...

class CoreProtocol:
backend_pid: Any
backend_secret: Any
__pyx_vtable__: Any
def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ...
def is_in_transaction(self) -> bool: ...
def __reduce__(self) -> Any: ...

class BaseProtocol(CoreProtocol, Generic[_Record]):
queries_count: Any
is_ssl: bool
__pyx_vtable__: Any
def __init__(
self,
addr: object,
connected_fut: object,
con_params: _ConnectionParameters,
record_class: type[_Record],
loop: object,
) -> None: ...
def set_connection(self, connection: object) -> None: ...
def get_server_pid(self, *args: object, **kwargs: object) -> int: ...
def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ...
def get_record_class(self) -> type[_Record]: ...
def abort(self) -> None: ...
async def bind(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
timeout: _TimeoutType,
) -> Any: ...
@overload
async def bind_execute(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
limit: int,
return_extra: Literal[False],
timeout: _TimeoutType,
) -> list[_OtherRecord]: ...
@overload
async def bind_execute(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
limit: int,
return_extra: Literal[True],
timeout: _TimeoutType,
) -> tuple[list[_OtherRecord], bytes, bool]: ...
@overload
async def bind_execute(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
limit: int,
return_extra: bool,
timeout: _TimeoutType,
) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ...
async def bind_execute_many(
self,
state: PreparedStatementState[_OtherRecord],
args: Iterable[Sequence[object]],
portal_name: str,
timeout: _TimeoutType,
) -> None: ...
async def close(self, timeout: _TimeoutType) -> None: ...
def _get_timeout(self, timeout: _TimeoutType) -> float | None: ...
def _is_cancelling(self) -> bool: ...
async def _wait_for_cancellation(self) -> None: ...
async def close_statement(
self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType
) -> Any: ...
async def copy_in(self, *args: object, **kwargs: object) -> str: ...
async def copy_out(self, *args: object, **kwargs: object) -> str: ...
async def execute(self, *args: object, **kwargs: object) -> Any: ...
def is_closed(self, *args: object, **kwargs: object) -> Any: ...
def is_connected(self, *args: object, **kwargs: object) -> Any: ...
def data_received(self, data: object) -> None: ...
def connection_made(self, transport: object) -> None: ...
def connection_lost(self, exc: Exception | None) -> None: ...
def pause_writing(self, *args: object, **kwargs: object) -> Any: ...
@overload
async def prepare(
self,
stmt_name: str,
query: str,
timeout: float | None = ...,
*,
state: _PreparedStatementState,
ignore_custom_codec: bool = ...,
record_class: None,
) -> _PreparedStatementState: ...
@overload
async def prepare(
self,
stmt_name: str,
query: str,
timeout: float | None = ...,
*,
state: None = ...,
ignore_custom_codec: bool = ...,
record_class: type[_OtherRecord],
) -> PreparedStatementState[_OtherRecord]: ...
async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ...
async def query(self, *args: object, **kwargs: object) -> str: ...
def resume_writing(self, *args: object, **kwargs: object) -> Any: ...
def __reduce__(self) -> Any: ...

@final
class Codec:
__pyx_vtable__: Any
def __reduce__(self) -> Any: ...

class DataCodecConfig:
__pyx_vtable__: Any
def __init__(self, cache_key: object) -> None: ...
def add_python_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typekind: str,
typeinfos: Iterable[object],
encoder: Callable[[ConnectionSettings, WriteBuffer, object], object],
decoder: Callable[..., object],
format: object,
xformat: object,
) -> Any: ...
def add_types(self, types: Iterable[object]) -> Any: ...
def clear_type_cache(self) -> None: ...
def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ...
def remove_python_codec(
self, typeoid: int, typename: str, typeschema: str
) -> Any: ...
def set_builtin_type_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typekind: str,
alias_to: str,
format: object = ...,
) -> Any: ...
def __reduce__(self) -> Any: ...

class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ...

class Record:
@overload
def get(self, key: str) -> Any | None: ...
@overload
def get(self, key: str, default: _T) -> Any | _T: ...
def items(self) -> Iterator[tuple[str, Any]]: ...
def keys(self) -> Iterator[str]: ...
def values(self) -> Iterator[Any]: ...
@overload
def __getitem__(self, index: str) -> Any: ...
@overload
def __getitem__(self, index: int) -> Any: ...
@overload
def __getitem__(self, index: slice) -> tuple[Any, ...]: ...
def __iter__(self) -> Iterator[Any]: ...
def __contains__(self, x: object) -> bool: ...
def __len__(self) -> int: ...

class Timer:
def __init__(self, budget: float | None) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, et: object, e: object, tb: object) -> None: ...
def get_remaining_budget(self) -> float: ...
def has_budget_greater_than(self, amount: float) -> bool: ...

@final
class SCRAMAuthentication:
AUTHENTICATION_METHODS: ClassVar[list[str]]
DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int]
DIGEST = sha256
REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]]
REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]]
SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]]
authentication_method: bytes
authorization_message: bytes | None
client_channel_binding: bytes
client_first_message_bare: bytes | None
client_nonce: bytes | None
client_proof: bytes | None
password_salt: bytes | None
password_iterations: int
server_first_message: bytes | None
server_key: hmac.HMAC | None
server_nonce: bytes | None
24 changes: 17 additions & 7 deletions asyncpg/serverversion.py
Original file line number Diff line number Diff line change
@@ -4,12 +4,14 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

import re
import typing

from .types import ServerVersion

version_regex = re.compile(
version_regex: typing.Final = re.compile(
r"(Postgre[^\s]*)?\s*"
r"(?P<major>[0-9]+)\.?"
r"((?P<minor>[0-9]+)\.?)?"
@@ -19,7 +21,15 @@
)


def split_server_version_string(version_string):
class _VersionDict(typing.TypedDict):
major: int
minor: int | None
micro: int | None
releaselevel: str | None
serial: int | None


def split_server_version_string(version_string: str) -> ServerVersion:
version_match = version_regex.search(version_string)

if version_match is None:
@@ -28,17 +38,17 @@ def split_server_version_string(version_string):
f'version from "{version_string}"'
)

version = version_match.groupdict()
version: _VersionDict = version_match.groupdict() # type: ignore[assignment] # noqa: E501
for ver_key, ver_value in version.items():
# Cast all possible versions parts to int
try:
version[ver_key] = int(ver_value)
version[ver_key] = int(ver_value) # type: ignore[literal-required, call-overload] # noqa: E501
except (TypeError, ValueError):
pass

if version.get("major") < 10:
if version["major"] < 10:
return ServerVersion(
version.get("major"),
version["major"],
version.get("minor") or 0,
version.get("micro") or 0,
version.get("releaselevel") or "final",
@@ -52,7 +62,7 @@ def split_server_version_string(version_string):
# want to keep that behaviour consistent, i.e not fail
# a major version check due to a bugfix release.
return ServerVersion(
version.get("major"),
version["major"],
0,
version.get("minor") or 0,
version.get("releaselevel") or "final",
102 changes: 74 additions & 28 deletions asyncpg/types.py
Original file line number Diff line number Diff line change
@@ -4,22 +4,32 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

from __future__ import annotations

import collections
import typing

from asyncpg.pgproto.types import (
BitString, Point, Path, Polygon,
Box, Line, LineSegment, Circle,
)

if typing.TYPE_CHECKING:
from typing_extensions import Self


__all__ = (
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion',
)


Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema'])
class Type(typing.NamedTuple):
oid: int
name: str
kind: str
schema: str


Type.__doc__ = 'Database data type.'
Type.oid.__doc__ = 'OID of the type.'
Type.name.__doc__ = 'Type name. For example "int2".'
@@ -28,25 +38,61 @@
Type.schema.__doc__ = 'Name of the database schema that defines the type.'


Attribute = collections.namedtuple('Attribute', ['name', 'type'])
class Attribute(typing.NamedTuple):
name: str
type: Type


Attribute.__doc__ = 'Database relation attribute.'
Attribute.name.__doc__ = 'Attribute name.'
Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.'


ServerVersion = collections.namedtuple(
'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial'])
class ServerVersion(typing.NamedTuple):
major: int
minor: int
micro: int
releaselevel: str
serial: int


ServerVersion.__doc__ = 'PostgreSQL server version tuple.'


class Range:
"""Immutable representation of PostgreSQL `range` type."""
class _RangeValue(typing.Protocol):
def __eq__(self, __value: object) -> bool:
...

def __lt__(self, __other: _RangeValue) -> bool:
...

def __gt__(self, __other: _RangeValue) -> bool:
...


__slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty'
_RV = typing.TypeVar('_RV', bound=_RangeValue)


class Range(typing.Generic[_RV]):
"""Immutable representation of PostgreSQL `range` type."""

def __init__(self, lower=None, upper=None, *,
lower_inc=True, upper_inc=False,
empty=False):
__slots__ = ('_lower', '_upper', '_lower_inc', '_upper_inc', '_empty')

_lower: _RV | None
_upper: _RV | None
_lower_inc: bool
_upper_inc: bool
_empty: bool

def __init__(
self,
lower: _RV | None = None,
upper: _RV | None = None,
*,
lower_inc: bool = True,
upper_inc: bool = False,
empty: bool = False
) -> None:
self._empty = empty
if empty:
self._lower = self._upper = None
@@ -58,34 +104,34 @@ def __init__(self, lower=None, upper=None, *,
self._upper_inc = upper is not None and upper_inc

@property
def lower(self):
def lower(self) -> _RV | None:
return self._lower

@property
def lower_inc(self):
def lower_inc(self) -> bool:
return self._lower_inc

@property
def lower_inf(self):
def lower_inf(self) -> bool:
return self._lower is None and not self._empty

@property
def upper(self):
def upper(self) -> _RV | None:
return self._upper

@property
def upper_inc(self):
def upper_inc(self) -> bool:
return self._upper_inc

@property
def upper_inf(self):
def upper_inf(self) -> bool:
return self._upper is None and not self._empty

@property
def isempty(self):
def isempty(self) -> bool:
return self._empty

def _issubset_lower(self, other):
def _issubset_lower(self, other: Self) -> bool:
if other._lower is None:
return True
if self._lower is None:
@@ -96,7 +142,7 @@ def _issubset_lower(self, other):
and (other._lower_inc or not self._lower_inc)
)

def _issubset_upper(self, other):
def _issubset_upper(self, other: Self) -> bool:
if other._upper is None:
return True
if self._upper is None:
@@ -107,21 +153,21 @@ def _issubset_upper(self, other):
and (other._upper_inc or not self._upper_inc)
)

def issubset(self, other):
def issubset(self, other: Self) -> bool:
if self._empty:
return True
if other._empty:
return False

return self._issubset_lower(other) and self._issubset_upper(other)

def issuperset(self, other):
def issuperset(self, other: Self) -> bool:
return other.issubset(self)

def __bool__(self):
def __bool__(self) -> bool:
return not self._empty

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Range):
return NotImplemented

@@ -132,14 +178,14 @@ def __eq__(self, other):
self._upper_inc,
self._empty
) == (
other._lower,
other._upper,
other._lower, # pyright: ignore [reportUnknownMemberType]
other._upper, # pyright: ignore [reportUnknownMemberType]
other._lower_inc,
other._upper_inc,
other._empty
)

def __hash__(self):
def __hash__(self) -> int:
return hash((
self._lower,
self._upper,
@@ -148,7 +194,7 @@ def __hash__(self):
self._empty
))

def __repr__(self):
def __repr__(self) -> str:
if self._empty:
desc = 'empty'
else:
27 changes: 26 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ classifiers = [
"Topic :: Database :: Front-Ends",
]
dependencies = [
'async_timeout>=4.0.3; python_version < "3.12.0"'
'async_timeout>=4.0.3; python_version < "3.12.0"',
]

[project.urls]
@@ -40,9 +40,11 @@ gssapi = [
]
test = [
'flake8~=6.1',
'flake8-pyi~=24.1.0',
'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"',
'gssapi; platform_system == "Linux"',
'k5test; platform_system == "Linux"',
'mypy~=1.8.0',
]
docs = [
'Sphinx~=5.3.0',
@@ -107,3 +109,26 @@ exclude_lines = [
"if __name__ == .__main__.",
]
show_missing = true

[tool.mypy]
incremental = true
strict = true
implicit_reexport = true

[[tool.mypy.overrides]]
module = [
"asyncpg._testbase",
"asyncpg._testbase.*",
"asyncpg.cluster",
"asyncpg.connect_utils",
"asyncpg.connection",
"asyncpg.connresource",
"asyncpg.cursor",
"asyncpg.exceptions",
"asyncpg.exceptions.*",
"asyncpg.pool",
"asyncpg.prepared_stmt",
"asyncpg.transaction",
"asyncpg.utils",
]
ignore_errors = true
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@

with open(str(_ROOT / 'asyncpg' / '_version.py')) as f:
for line in f:
if line.startswith('__version__ ='):
if line.startswith('__version__: typing.Final ='):
_, _, version = line.partition('=')
VERSION = version.strip(" \n'\"")
break
33 changes: 32 additions & 1 deletion tests/test__sourcecode.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ def find_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


class TestFlake8(unittest.TestCase):
class TestCodeQuality(unittest.TestCase):

def test_flake8(self):
try:
@@ -38,3 +38,34 @@ def test_flake8(self):
output = ex.output.decode()
raise AssertionError(
'flake8 validation failed:\n{}'.format(output)) from None

def test_mypy(self):
try:
import mypy # NoQA
except ImportError:
raise unittest.SkipTest('mypy module is missing')

root_path = find_root()
config_path = os.path.join(root_path, 'pyproject.toml')
if not os.path.exists(config_path):
raise RuntimeError('could not locate mypy.ini file')

try:
subprocess.run(
[
sys.executable,
'-m',
'mypy',
'--config-file',
config_path,
'asyncpg'
],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=root_path
)
except subprocess.CalledProcessError as ex:
output = ex.output.decode()
raise AssertionError(
'mypy validation failed:\n{}'.format(output)) from None