Skip to content

Commit d42432b

Browse files
authoredMar 5, 2024
Add initial typings (#1127)
* Added typings to miscellaneous files * Added unit test to check codebase with mypy * Updated release workflow and build to account for annotations * Updated manifest to include stub files
1 parent 1d4e568 commit d42432b

File tree

16 files changed

+512
-60
lines changed

16 files changed

+512
-60
lines changed
 

‎.flake8

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
[flake8]
2+
select = C90,E,F,W,Y0
23
ignore = E402,E731,W503,W504,E252
3-
exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv,.tox
4+
exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox
5+
per-file-ignores = *.pyi: F401,F403,F405,F811,E127,E128,E203,E266,E301,E302,E305,E501,E701,E704,E741,B303,W503,W504

‎.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
2323
version_file: asyncpg/_version.py
2424
version_line_pattern: |
25-
__version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
25+
__version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
2626
2727
- name: Stop if not approved
2828
if: steps.checkver.outputs.approved != 'true'

‎.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,5 @@ docs/_build
3434
/.eggs
3535
/.vscode
3636
/.mypy_cache
37+
/.venv*
38+
/.tox

‎MANIFEST.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
recursive-include docs *.py *.rst Makefile *.css
22
recursive-include examples *.py
33
recursive-include tests *.py *.pem
4-
recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.c *.h
4+
recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h
55
include LICENSE README.rst Makefile performance.png .flake8

‎asyncpg/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This module is part of asyncpg and is released under
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

7+
from __future__ import annotations
78

89
from .connection import connect, Connection # NOQA
910
from .exceptions import * # NOQA
@@ -14,6 +15,10 @@
1415

1516
from ._version import __version__ # NOQA
1617

18+
from . import exceptions
1719

18-
__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection')
20+
21+
__all__: tuple[str, ...] = (
22+
'connect', 'create_pool', 'Pool', 'Record', 'Connection'
23+
)
1924
__all__ += exceptions.__all__ # NOQA

‎asyncpg/_asyncio_compat.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
#
55
# SPDX-License-Identifier: PSF-2.0
66

7+
from __future__ import annotations
78

89
import asyncio
910
import functools
1011
import sys
12+
import typing
13+
14+
if typing.TYPE_CHECKING:
15+
from . import compat
1116

1217
if sys.version_info < (3, 11):
1318
from async_timeout import timeout as timeout_ctx
1419
else:
1520
from asyncio import timeout as timeout_ctx
1621

22+
_T = typing.TypeVar('_T')
23+
1724

18-
async def wait_for(fut, timeout):
25+
async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
1926
"""Wait for the single Future or coroutine to complete, with timeout.
2027
2128
Coroutine will be wrapped in Task.
@@ -65,7 +72,7 @@ async def wait_for(fut, timeout):
6572
return await fut
6673

6774

68-
async def _cancel_and_wait(fut):
75+
async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
6976
"""Cancel the *fut* future or task and wait until it completes."""
7077

7178
loop = asyncio.get_running_loop()
@@ -82,6 +89,6 @@ async def _cancel_and_wait(fut):
8289
fut.remove_done_callback(cb)
8390

8491

85-
def _release_waiter(waiter, *args):
92+
def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None:
8693
if not waiter.done():
8794
waiter.set_result(None)

‎asyncpg/_version.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,8 @@
1010
# supported platforms, publish the packages on PyPI, merge the PR
1111
# to the target branch, create a Git tag pointing to the commit.
1212

13-
__version__ = '0.30.0.dev0'
13+
from __future__ import annotations
14+
15+
import typing
16+
17+
__version__: typing.Final = '0.30.0.dev0'

‎asyncpg/compat.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,25 @@
44
# This module is part of asyncpg and is released under
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

7+
from __future__ import annotations
78

89
import pathlib
910
import platform
1011
import typing
1112
import sys
1213

14+
if typing.TYPE_CHECKING:
15+
import asyncio
1316

14-
SYSTEM = platform.uname().system
17+
SYSTEM: typing.Final = platform.uname().system
1518

1619

17-
if SYSTEM == 'Windows':
20+
if sys.platform == 'win32':
1821
import ctypes.wintypes
1922

20-
CSIDL_APPDATA = 0x001a
23+
CSIDL_APPDATA: typing.Final = 0x001a
2124

22-
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
25+
def get_pg_home_directory() -> pathlib.Path | None:
2326
# We cannot simply use expanduser() as that returns the user's
2427
# home directory, whereas Postgres stores its config in
2528
# %AppData% on Windows.
@@ -31,14 +34,14 @@ def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
3134
return pathlib.Path(buf.value) / 'postgresql'
3235

3336
else:
34-
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
37+
def get_pg_home_directory() -> pathlib.Path | None:
3538
try:
3639
return pathlib.Path.home()
3740
except (RuntimeError, KeyError):
3841
return None
3942

4043

41-
async def wait_closed(stream):
44+
async def wait_closed(stream: asyncio.StreamWriter) -> None:
4245
# Not all asyncio versions have StreamWriter.wait_closed().
4346
if hasattr(stream, 'wait_closed'):
4447
try:
@@ -59,3 +62,12 @@ async def wait_closed(stream):
5962
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
6063
else:
6164
from asyncio import timeout as timeout # noqa: F401
65+
66+
if sys.version_info < (3, 9):
67+
from typing import ( # noqa: F401
68+
Awaitable as Awaitable,
69+
)
70+
else:
71+
from collections.abc import ( # noqa: F401
72+
Awaitable as Awaitable,
73+
)

‎asyncpg/introspection.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44
# This module is part of asyncpg and is released under
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

7+
from __future__ import annotations
78

8-
_TYPEINFO_13 = '''\
9+
import typing
10+
11+
if typing.TYPE_CHECKING:
12+
from . import protocol
13+
14+
_TYPEINFO_13: typing.Final = '''\
915
(
1016
SELECT
1117
t.oid AS oid,
@@ -124,7 +130,7 @@
124130
'''.format(typeinfo=_TYPEINFO_13)
125131

126132

127-
_TYPEINFO = '''\
133+
_TYPEINFO: typing.Final = '''\
128134
(
129135
SELECT
130136
t.oid AS oid,
@@ -248,7 +254,7 @@
248254
'''.format(typeinfo=_TYPEINFO)
249255

250256

251-
TYPE_BY_NAME = '''\
257+
TYPE_BY_NAME: typing.Final = '''\
252258
SELECT
253259
t.oid,
254260
t.typelem AS elemtype,
@@ -277,16 +283,16 @@
277283
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
278284

279285

280-
def is_scalar_type(typeinfo) -> bool:
286+
def is_scalar_type(typeinfo: protocol.Record) -> bool:
281287
return (
282288
typeinfo['kind'] in SCALAR_TYPE_KINDS and
283289
not typeinfo['elemtype']
284290
)
285291

286292

287-
def is_domain_type(typeinfo) -> bool:
288-
return typeinfo['kind'] == b'd'
293+
def is_domain_type(typeinfo: protocol.Record) -> bool:
294+
return typeinfo['kind'] == b'd' # type: ignore[no-any-return]
289295

290296

291-
def is_composite_type(typeinfo) -> bool:
292-
return typeinfo['kind'] == b'c'
297+
def is_composite_type(typeinfo: protocol.Record) -> bool:
298+
return typeinfo['kind'] == b'c' # type: ignore[no-any-return]

‎asyncpg/protocol/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@
66

77
# flake8: NOQA
88

9+
from __future__ import annotations
10+
911
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP

‎asyncpg/protocol/protocol.pyi

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
import asyncio
2+
import asyncio.protocols
3+
import hmac
4+
from codecs import CodecInfo
5+
from collections.abc import Callable, Iterable, Iterator, Sequence
6+
from hashlib import md5, sha256
7+
from typing import (
8+
Any,
9+
ClassVar,
10+
Final,
11+
Generic,
12+
Literal,
13+
NewType,
14+
TypeVar,
15+
final,
16+
overload,
17+
)
18+
from typing_extensions import TypeAlias
19+
20+
import asyncpg.pgproto.pgproto
21+
22+
from ..connect_utils import _ConnectionParameters
23+
from ..pgproto.pgproto import WriteBuffer
24+
from ..types import Attribute, Type
25+
26+
_T = TypeVar('_T')
27+
_Record = TypeVar('_Record', bound=Record)
28+
_OtherRecord = TypeVar('_OtherRecord', bound=Record)
29+
_PreparedStatementState = TypeVar(
30+
'_PreparedStatementState', bound=PreparedStatementState[Any]
31+
)
32+
33+
_NoTimeoutType = NewType('_NoTimeoutType', object)
34+
_TimeoutType: TypeAlias = float | None | _NoTimeoutType
35+
36+
BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]]
37+
BUILTIN_TYPE_OID_MAP: Final[dict[int, str]]
38+
NO_TIMEOUT: Final[_NoTimeoutType]
39+
40+
hashlib_md5 = md5
41+
42+
@final
43+
class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext):
44+
__pyx_vtable__: Any
45+
def __init__(self, conn_key: object) -> None: ...
46+
def add_python_codec(
47+
self,
48+
typeoid: int,
49+
typename: str,
50+
typeschema: str,
51+
typeinfos: Iterable[object],
52+
typekind: str,
53+
encoder: Callable[[Any], Any],
54+
decoder: Callable[[Any], Any],
55+
format: object,
56+
) -> Any: ...
57+
def clear_type_cache(self) -> None: ...
58+
def get_data_codec(
59+
self, oid: int, format: object = ..., ignore_custom_codec: bool = ...
60+
) -> Any: ...
61+
def get_text_codec(self) -> CodecInfo: ...
62+
def register_data_types(self, types: Iterable[object]) -> None: ...
63+
def remove_python_codec(
64+
self, typeoid: int, typename: str, typeschema: str
65+
) -> None: ...
66+
def set_builtin_type_codec(
67+
self,
68+
typeoid: int,
69+
typename: str,
70+
typeschema: str,
71+
typekind: str,
72+
alias_to: str,
73+
format: object = ...,
74+
) -> Any: ...
75+
def __getattr__(self, name: str) -> Any: ...
76+
def __reduce__(self) -> Any: ...
77+
78+
@final
79+
class PreparedStatementState(Generic[_Record]):
80+
closed: bool
81+
prepared: bool
82+
name: str
83+
query: str
84+
refs: int
85+
record_class: type[_Record]
86+
ignore_custom_codec: bool
87+
__pyx_vtable__: Any
88+
def __init__(
89+
self,
90+
name: str,
91+
query: str,
92+
protocol: BaseProtocol[Any],
93+
record_class: type[_Record],
94+
ignore_custom_codec: bool,
95+
) -> None: ...
96+
def _get_parameters(self) -> tuple[Type, ...]: ...
97+
def _get_attributes(self) -> tuple[Attribute, ...]: ...
98+
def _init_types(self) -> set[int]: ...
99+
def _init_codecs(self) -> None: ...
100+
def attach(self) -> None: ...
101+
def detach(self) -> None: ...
102+
def mark_closed(self) -> None: ...
103+
def mark_unprepared(self) -> None: ...
104+
def __reduce__(self) -> Any: ...
105+
106+
class CoreProtocol:
107+
backend_pid: Any
108+
backend_secret: Any
109+
__pyx_vtable__: Any
110+
def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ...
111+
def is_in_transaction(self) -> bool: ...
112+
def __reduce__(self) -> Any: ...
113+
114+
class BaseProtocol(CoreProtocol, Generic[_Record]):
115+
queries_count: Any
116+
is_ssl: bool
117+
__pyx_vtable__: Any
118+
def __init__(
119+
self,
120+
addr: object,
121+
connected_fut: object,
122+
con_params: _ConnectionParameters,
123+
record_class: type[_Record],
124+
loop: object,
125+
) -> None: ...
126+
def set_connection(self, connection: object) -> None: ...
127+
def get_server_pid(self, *args: object, **kwargs: object) -> int: ...
128+
def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ...
129+
def get_record_class(self) -> type[_Record]: ...
130+
def abort(self) -> None: ...
131+
async def bind(
132+
self,
133+
state: PreparedStatementState[_OtherRecord],
134+
args: Sequence[object],
135+
portal_name: str,
136+
timeout: _TimeoutType,
137+
) -> Any: ...
138+
@overload
139+
async def bind_execute(
140+
self,
141+
state: PreparedStatementState[_OtherRecord],
142+
args: Sequence[object],
143+
portal_name: str,
144+
limit: int,
145+
return_extra: Literal[False],
146+
timeout: _TimeoutType,
147+
) -> list[_OtherRecord]: ...
148+
@overload
149+
async def bind_execute(
150+
self,
151+
state: PreparedStatementState[_OtherRecord],
152+
args: Sequence[object],
153+
portal_name: str,
154+
limit: int,
155+
return_extra: Literal[True],
156+
timeout: _TimeoutType,
157+
) -> tuple[list[_OtherRecord], bytes, bool]: ...
158+
@overload
159+
async def bind_execute(
160+
self,
161+
state: PreparedStatementState[_OtherRecord],
162+
args: Sequence[object],
163+
portal_name: str,
164+
limit: int,
165+
return_extra: bool,
166+
timeout: _TimeoutType,
167+
) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ...
168+
async def bind_execute_many(
169+
self,
170+
state: PreparedStatementState[_OtherRecord],
171+
args: Iterable[Sequence[object]],
172+
portal_name: str,
173+
timeout: _TimeoutType,
174+
) -> None: ...
175+
async def close(self, timeout: _TimeoutType) -> None: ...
176+
def _get_timeout(self, timeout: _TimeoutType) -> float | None: ...
177+
def _is_cancelling(self) -> bool: ...
178+
async def _wait_for_cancellation(self) -> None: ...
179+
async def close_statement(
180+
self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType
181+
) -> Any: ...
182+
async def copy_in(self, *args: object, **kwargs: object) -> str: ...
183+
async def copy_out(self, *args: object, **kwargs: object) -> str: ...
184+
async def execute(self, *args: object, **kwargs: object) -> Any: ...
185+
def is_closed(self, *args: object, **kwargs: object) -> Any: ...
186+
def is_connected(self, *args: object, **kwargs: object) -> Any: ...
187+
def data_received(self, data: object) -> None: ...
188+
def connection_made(self, transport: object) -> None: ...
189+
def connection_lost(self, exc: Exception | None) -> None: ...
190+
def pause_writing(self, *args: object, **kwargs: object) -> Any: ...
191+
@overload
192+
async def prepare(
193+
self,
194+
stmt_name: str,
195+
query: str,
196+
timeout: float | None = ...,
197+
*,
198+
state: _PreparedStatementState,
199+
ignore_custom_codec: bool = ...,
200+
record_class: None,
201+
) -> _PreparedStatementState: ...
202+
@overload
203+
async def prepare(
204+
self,
205+
stmt_name: str,
206+
query: str,
207+
timeout: float | None = ...,
208+
*,
209+
state: None = ...,
210+
ignore_custom_codec: bool = ...,
211+
record_class: type[_OtherRecord],
212+
) -> PreparedStatementState[_OtherRecord]: ...
213+
async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ...
214+
async def query(self, *args: object, **kwargs: object) -> str: ...
215+
def resume_writing(self, *args: object, **kwargs: object) -> Any: ...
216+
def __reduce__(self) -> Any: ...
217+
218+
@final
219+
class Codec:
220+
__pyx_vtable__: Any
221+
def __reduce__(self) -> Any: ...
222+
223+
class DataCodecConfig:
224+
__pyx_vtable__: Any
225+
def __init__(self, cache_key: object) -> None: ...
226+
def add_python_codec(
227+
self,
228+
typeoid: int,
229+
typename: str,
230+
typeschema: str,
231+
typekind: str,
232+
typeinfos: Iterable[object],
233+
encoder: Callable[[ConnectionSettings, WriteBuffer, object], object],
234+
decoder: Callable[..., object],
235+
format: object,
236+
xformat: object,
237+
) -> Any: ...
238+
def add_types(self, types: Iterable[object]) -> Any: ...
239+
def clear_type_cache(self) -> None: ...
240+
def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ...
241+
def remove_python_codec(
242+
self, typeoid: int, typename: str, typeschema: str
243+
) -> Any: ...
244+
def set_builtin_type_codec(
245+
self,
246+
typeoid: int,
247+
typename: str,
248+
typeschema: str,
249+
typekind: str,
250+
alias_to: str,
251+
format: object = ...,
252+
) -> Any: ...
253+
def __reduce__(self) -> Any: ...
254+
255+
class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ...
256+
257+
class Record:
258+
@overload
259+
def get(self, key: str) -> Any | None: ...
260+
@overload
261+
def get(self, key: str, default: _T) -> Any | _T: ...
262+
def items(self) -> Iterator[tuple[str, Any]]: ...
263+
def keys(self) -> Iterator[str]: ...
264+
def values(self) -> Iterator[Any]: ...
265+
@overload
266+
def __getitem__(self, index: str) -> Any: ...
267+
@overload
268+
def __getitem__(self, index: int) -> Any: ...
269+
@overload
270+
def __getitem__(self, index: slice) -> tuple[Any, ...]: ...
271+
def __iter__(self) -> Iterator[Any]: ...
272+
def __contains__(self, x: object) -> bool: ...
273+
def __len__(self) -> int: ...
274+
275+
class Timer:
276+
def __init__(self, budget: float | None) -> None: ...
277+
def __enter__(self) -> None: ...
278+
def __exit__(self, et: object, e: object, tb: object) -> None: ...
279+
def get_remaining_budget(self) -> float: ...
280+
def has_budget_greater_than(self, amount: float) -> bool: ...
281+
282+
@final
283+
class SCRAMAuthentication:
284+
AUTHENTICATION_METHODS: ClassVar[list[str]]
285+
DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int]
286+
DIGEST = sha256
287+
REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]]
288+
REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]]
289+
SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]]
290+
authentication_method: bytes
291+
authorization_message: bytes | None
292+
client_channel_binding: bytes
293+
client_first_message_bare: bytes | None
294+
client_nonce: bytes | None
295+
client_proof: bytes | None
296+
password_salt: bytes | None
297+
password_iterations: int
298+
server_first_message: bytes | None
299+
server_key: hmac.HMAC | None
300+
server_nonce: bytes | None

‎asyncpg/serverversion.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
# This module is part of asyncpg and is released under
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

7+
from __future__ import annotations
78

89
import re
10+
import typing
911

1012
from .types import ServerVersion
1113

12-
version_regex = re.compile(
14+
version_regex: typing.Final = re.compile(
1315
r"(Postgre[^\s]*)?\s*"
1416
r"(?P<major>[0-9]+)\.?"
1517
r"((?P<minor>[0-9]+)\.?)?"
@@ -19,7 +21,15 @@
1921
)
2022

2123

22-
def split_server_version_string(version_string):
24+
class _VersionDict(typing.TypedDict):
25+
major: int
26+
minor: int | None
27+
micro: int | None
28+
releaselevel: str | None
29+
serial: int | None
30+
31+
32+
def split_server_version_string(version_string: str) -> ServerVersion:
2333
version_match = version_regex.search(version_string)
2434

2535
if version_match is None:
@@ -28,17 +38,17 @@ def split_server_version_string(version_string):
2838
f'version from "{version_string}"'
2939
)
3040

31-
version = version_match.groupdict()
41+
version: _VersionDict = version_match.groupdict() # type: ignore[assignment] # noqa: E501
3242
for ver_key, ver_value in version.items():
3343
# Cast all possible versions parts to int
3444
try:
35-
version[ver_key] = int(ver_value)
45+
version[ver_key] = int(ver_value) # type: ignore[literal-required, call-overload] # noqa: E501
3646
except (TypeError, ValueError):
3747
pass
3848

39-
if version.get("major") < 10:
49+
if version["major"] < 10:
4050
return ServerVersion(
41-
version.get("major"),
51+
version["major"],
4252
version.get("minor") or 0,
4353
version.get("micro") or 0,
4454
version.get("releaselevel") or "final",
@@ -52,7 +62,7 @@ def split_server_version_string(version_string):
5262
# want to keep that behaviour consistent, i.e not fail
5363
# a major version check due to a bugfix release.
5464
return ServerVersion(
55-
version.get("major"),
65+
version["major"],
5666
0,
5767
version.get("minor") or 0,
5868
version.get("releaselevel") or "final",

‎asyncpg/types.py

Lines changed: 74 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,32 @@
44
# This module is part of asyncpg and is released under
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

7+
from __future__ import annotations
78

8-
import collections
9+
import typing
910

1011
from asyncpg.pgproto.types import (
1112
BitString, Point, Path, Polygon,
1213
Box, Line, LineSegment, Circle,
1314
)
1415

16+
if typing.TYPE_CHECKING:
17+
from typing_extensions import Self
18+
1519

1620
__all__ = (
1721
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
1822
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion',
1923
)
2024

2125

22-
Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema'])
26+
class Type(typing.NamedTuple):
27+
oid: int
28+
name: str
29+
kind: str
30+
schema: str
31+
32+
2333
Type.__doc__ = 'Database data type.'
2434
Type.oid.__doc__ = 'OID of the type.'
2535
Type.name.__doc__ = 'Type name. For example "int2".'
@@ -28,25 +38,61 @@
2838
Type.schema.__doc__ = 'Name of the database schema that defines the type.'
2939

3040

31-
Attribute = collections.namedtuple('Attribute', ['name', 'type'])
41+
class Attribute(typing.NamedTuple):
42+
name: str
43+
type: Type
44+
45+
3246
Attribute.__doc__ = 'Database relation attribute.'
3347
Attribute.name.__doc__ = 'Attribute name.'
3448
Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.'
3549

3650

37-
ServerVersion = collections.namedtuple(
38-
'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial'])
51+
class ServerVersion(typing.NamedTuple):
52+
major: int
53+
minor: int
54+
micro: int
55+
releaselevel: str
56+
serial: int
57+
58+
3959
ServerVersion.__doc__ = 'PostgreSQL server version tuple.'
4060

4161

42-
class Range:
43-
"""Immutable representation of PostgreSQL `range` type."""
62+
class _RangeValue(typing.Protocol):
63+
def __eq__(self, __value: object) -> bool:
64+
...
65+
66+
def __lt__(self, __other: _RangeValue) -> bool:
67+
...
68+
69+
def __gt__(self, __other: _RangeValue) -> bool:
70+
...
71+
4472

45-
__slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty'
73+
_RV = typing.TypeVar('_RV', bound=_RangeValue)
74+
75+
76+
class Range(typing.Generic[_RV]):
77+
"""Immutable representation of PostgreSQL `range` type."""
4678

47-
def __init__(self, lower=None, upper=None, *,
48-
lower_inc=True, upper_inc=False,
49-
empty=False):
79+
__slots__ = ('_lower', '_upper', '_lower_inc', '_upper_inc', '_empty')
80+
81+
_lower: _RV | None
82+
_upper: _RV | None
83+
_lower_inc: bool
84+
_upper_inc: bool
85+
_empty: bool
86+
87+
def __init__(
88+
self,
89+
lower: _RV | None = None,
90+
upper: _RV | None = None,
91+
*,
92+
lower_inc: bool = True,
93+
upper_inc: bool = False,
94+
empty: bool = False
95+
) -> None:
5096
self._empty = empty
5197
if empty:
5298
self._lower = self._upper = None
@@ -58,34 +104,34 @@ def __init__(self, lower=None, upper=None, *,
58104
self._upper_inc = upper is not None and upper_inc
59105

60106
@property
61-
def lower(self):
107+
def lower(self) -> _RV | None:
62108
return self._lower
63109

64110
@property
65-
def lower_inc(self):
111+
def lower_inc(self) -> bool:
66112
return self._lower_inc
67113

68114
@property
69-
def lower_inf(self):
115+
def lower_inf(self) -> bool:
70116
return self._lower is None and not self._empty
71117

72118
@property
73-
def upper(self):
119+
def upper(self) -> _RV | None:
74120
return self._upper
75121

76122
@property
77-
def upper_inc(self):
123+
def upper_inc(self) -> bool:
78124
return self._upper_inc
79125

80126
@property
81-
def upper_inf(self):
127+
def upper_inf(self) -> bool:
82128
return self._upper is None and not self._empty
83129

84130
@property
85-
def isempty(self):
131+
def isempty(self) -> bool:
86132
return self._empty
87133

88-
def _issubset_lower(self, other):
134+
def _issubset_lower(self, other: Self) -> bool:
89135
if other._lower is None:
90136
return True
91137
if self._lower is None:
@@ -96,7 +142,7 @@ def _issubset_lower(self, other):
96142
and (other._lower_inc or not self._lower_inc)
97143
)
98144

99-
def _issubset_upper(self, other):
145+
def _issubset_upper(self, other: Self) -> bool:
100146
if other._upper is None:
101147
return True
102148
if self._upper is None:
@@ -107,21 +153,21 @@ def _issubset_upper(self, other):
107153
and (other._upper_inc or not self._upper_inc)
108154
)
109155

110-
def issubset(self, other):
156+
def issubset(self, other: Self) -> bool:
111157
if self._empty:
112158
return True
113159
if other._empty:
114160
return False
115161

116162
return self._issubset_lower(other) and self._issubset_upper(other)
117163

118-
def issuperset(self, other):
164+
def issuperset(self, other: Self) -> bool:
119165
return other.issubset(self)
120166

121-
def __bool__(self):
167+
def __bool__(self) -> bool:
122168
return not self._empty
123169

124-
def __eq__(self, other):
170+
def __eq__(self, other: object) -> bool:
125171
if not isinstance(other, Range):
126172
return NotImplemented
127173

@@ -132,14 +178,14 @@ def __eq__(self, other):
132178
self._upper_inc,
133179
self._empty
134180
) == (
135-
other._lower,
136-
other._upper,
181+
other._lower, # pyright: ignore [reportUnknownMemberType]
182+
other._upper, # pyright: ignore [reportUnknownMemberType]
137183
other._lower_inc,
138184
other._upper_inc,
139185
other._empty
140186
)
141187

142-
def __hash__(self):
188+
def __hash__(self) -> int:
143189
return hash((
144190
self._lower,
145191
self._upper,
@@ -148,7 +194,7 @@ def __hash__(self):
148194
self._empty
149195
))
150196

151-
def __repr__(self):
197+
def __repr__(self) -> str:
152198
if self._empty:
153199
desc = 'empty'
154200
else:

‎pyproject.toml

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ classifiers = [
2828
"Topic :: Database :: Front-Ends",
2929
]
3030
dependencies = [
31-
'async_timeout>=4.0.3; python_version < "3.12.0"'
31+
'async_timeout>=4.0.3; python_version < "3.12.0"',
3232
]
3333

3434
[project.urls]
@@ -40,9 +40,11 @@ gssapi = [
4040
]
4141
test = [
4242
'flake8~=6.1',
43+
'flake8-pyi~=24.1.0',
4344
'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"',
4445
'gssapi; platform_system == "Linux"',
4546
'k5test; platform_system == "Linux"',
47+
'mypy~=1.8.0',
4648
]
4749
docs = [
4850
'Sphinx~=5.3.0',
@@ -107,3 +109,26 @@ exclude_lines = [
107109
"if __name__ == .__main__.",
108110
]
109111
show_missing = true
112+
113+
[tool.mypy]
114+
incremental = true
115+
strict = true
116+
implicit_reexport = true
117+
118+
[[tool.mypy.overrides]]
119+
module = [
120+
"asyncpg._testbase",
121+
"asyncpg._testbase.*",
122+
"asyncpg.cluster",
123+
"asyncpg.connect_utils",
124+
"asyncpg.connection",
125+
"asyncpg.connresource",
126+
"asyncpg.cursor",
127+
"asyncpg.exceptions",
128+
"asyncpg.exceptions.*",
129+
"asyncpg.pool",
130+
"asyncpg.prepared_stmt",
131+
"asyncpg.transaction",
132+
"asyncpg.utils",
133+
]
134+
ignore_errors = true

‎setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
with open(str(_ROOT / 'asyncpg' / '_version.py')) as f:
4545
for line in f:
46-
if line.startswith('__version__ ='):
46+
if line.startswith('__version__: typing.Final ='):
4747
_, _, version = line.partition('=')
4848
VERSION = version.strip(" \n'\"")
4949
break

‎tests/test__sourcecode.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def find_root():
1414
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
1515

1616

17-
class TestFlake8(unittest.TestCase):
17+
class TestCodeQuality(unittest.TestCase):
1818

1919
def test_flake8(self):
2020
try:
@@ -38,3 +38,34 @@ def test_flake8(self):
3838
output = ex.output.decode()
3939
raise AssertionError(
4040
'flake8 validation failed:\n{}'.format(output)) from None
41+
42+
def test_mypy(self):
43+
try:
44+
import mypy # NoQA
45+
except ImportError:
46+
raise unittest.SkipTest('mypy module is missing')
47+
48+
root_path = find_root()
49+
config_path = os.path.join(root_path, 'pyproject.toml')
50+
if not os.path.exists(config_path):
51+
raise RuntimeError('could not locate mypy.ini file')
52+
53+
try:
54+
subprocess.run(
55+
[
56+
sys.executable,
57+
'-m',
58+
'mypy',
59+
'--config-file',
60+
config_path,
61+
'asyncpg'
62+
],
63+
check=True,
64+
stdout=subprocess.PIPE,
65+
stderr=subprocess.STDOUT,
66+
cwd=root_path
67+
)
68+
except subprocess.CalledProcessError as ex:
69+
output = ex.output.decode()
70+
raise AssertionError(
71+
'mypy validation failed:\n{}'.format(output)) from None

0 commit comments

Comments
 (0)
Please sign in to comment.