Skip to content

Additional return types for psycopg2 connections #8528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 17, 2022
2 changes: 1 addition & 1 deletion stubs/psycopg2/psycopg2/_psycopg.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class cursor:
def execute(self, query: str | bytes | Composable, vars: _Vars = ...) -> None: ...
def executemany(self, query: str | bytes | Composable, vars_list: Iterable[_Vars]) -> None: ...
def fetchall(self) -> list[tuple[Any, ...]]: ...
def fetchmany(self, size=...) -> list[tuple[Any, ...]]: ...
def fetchmany(self, size: int | None = ...) -> list[tuple[Any, ...]]: ...
def fetchone(self) -> tuple[Any, ...] | None: ...
def mogrify(self, *args, **kwargs): ...
def nextset(self): ...
Expand Down
78 changes: 65 additions & 13 deletions stubs/psycopg2/psycopg2/extras.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import OrderedDict
from typing import Any
from collections.abc import Callable
from typing import Any, NamedTuple, TypeVar, overload

from psycopg2._ipaddress import register_ipaddress as register_ipaddress
from psycopg2._json import (
Expand Down Expand Up @@ -28,22 +29,37 @@ from psycopg2._range import (

from .extensions import connection as _connection, cursor as _cursor, quote_ident as quote_ident

_T_cur = TypeVar("_T_cur", bound=_cursor)

class DictCursorBase(_cursor):
row_factory: Any
def __init__(self, *args, **kwargs) -> None: ...
def fetchone(self) -> tuple[Any, ...] | None: ...
def fetchmany(self, size: int | None = ...) -> list[tuple[Any, ...]]: ...
def fetchall(self) -> list[tuple[Any, ...]]: ...
def __iter__(self): ...

class DictConnection(_connection):
def cursor(self, *args, **kwargs): ...
@overload
def cursor(self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...) -> DictCursor: ...
@overload
def cursor(
self,
name: str | bytes | None = ...,
*,
cursor_factory: Callable[..., _T_cur],
withhold: bool = ...,
scrollable: bool | None = ...,
) -> _T_cur: ...
@overload
def cursor(
self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ...
) -> _T_cur: ...
Comment on lines -40 to +52
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These overloads are consistent with the base class's definitions of the cursor method except that the default cursor_factory case has been updated to return DictCursor. This comment also applies to RealDictConnection and NamedTupleConnection.


class DictCursor(DictCursorBase):
def __init__(self, *args, **kwargs) -> None: ...
index: Any
def execute(self, query, vars: Any | None = ...): ...
def callproc(self, procname, vars: Any | None = ...): ...
def fetchone(self) -> DictRow | None: ... # type: ignore[override]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative could be to make the base class generic. But that may cause problems until we have TypeVar defaults.

def fetchmany(self, size: int | None = ...) -> list[DictRow]: ... # type: ignore[override]
def fetchall(self) -> list[DictRow]: ... # type: ignore[override]
def __next__(self) -> DictRow: ... # type: ignore[override]
Comment on lines +59 to +62
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return types aren't compatible with the base cursor class's definition, but that is consistent with the reality of the implementation. This comment also applies to the other modified cursor type defs.


class DictRow(list[Any]):
def __init__(self, cursor) -> None: ...
Expand All @@ -58,31 +74,67 @@ class DictRow(list[Any]):
def __reduce__(self): ...

class RealDictConnection(_connection):
def cursor(self, *args, **kwargs): ...
@overload
def cursor(
self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...
) -> RealDictCursor: ...
@overload
def cursor(
self,
name: str | bytes | None = ...,
*,
cursor_factory: Callable[..., _T_cur],
withhold: bool = ...,
scrollable: bool | None = ...,
) -> _T_cur: ...
@overload
def cursor(
self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ...
) -> _T_cur: ...

class RealDictCursor(DictCursorBase):
def __init__(self, *args, **kwargs) -> None: ...
column_mapping: Any
def execute(self, query, vars: Any | None = ...): ...
def callproc(self, procname, vars: Any | None = ...): ...
def fetchone(self) -> RealDictRow | None: ... # type: ignore[override]
def fetchmany(self, size: int | None = ...) -> list[RealDictRow]: ... # type: ignore[override]
def fetchall(self) -> list[RealDictRow]: ... # type: ignore[override]
def __next__(self) -> RealDictRow: ... # type: ignore[override]

class RealDictRow(OrderedDict[Any, Any]):
def __init__(self, *args, **kwargs) -> None: ...
def __setitem__(self, key, value) -> None: ...

class NamedTupleConnection(_connection):
def cursor(self, *args, **kwargs): ...
@overload
def cursor(
self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...
) -> NamedTupleCursor: ...
@overload
def cursor(
self,
name: str | bytes | None = ...,
*,
cursor_factory: Callable[..., _T_cur],
withhold: bool = ...,
scrollable: bool | None = ...,
) -> _T_cur: ...
@overload
def cursor(
self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ...
) -> _T_cur: ...

class NamedTupleCursor(_cursor):
Record: Any
MAX_CACHE: int
def execute(self, query, vars: Any | None = ...): ...
def executemany(self, query, vars): ...
def callproc(self, procname, vars: Any | None = ...): ...
def fetchone(self) -> tuple[Any, ...] | None: ...
def fetchmany(self, size: int | None = ...) -> list[tuple[Any, ...]]: ...
def fetchall(self) -> list[tuple[Any, ...]]: ...
def __iter__(self): ...
def fetchone(self) -> NamedTuple | None: ... # type: ignore[override]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning typing.NamedTuple is a little weird, but it apparently works with both pyright and mypy, so I suppose it's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. But at least it offers a hint to the caller that they can use either list indexing or access the fields with dot-notation.

def fetchmany(self, size: int | None = ...) -> list[NamedTuple]: ... # type: ignore[override]
def fetchall(self) -> list[NamedTuple]: ... # type: ignore[override]
def __next__(self) -> NamedTuple: ... # type: ignore[override]

class LoggingConnection(_connection):
log: Any
Expand Down