diff --git a/stubs/psycopg2/psycopg2/_psycopg.pyi b/stubs/psycopg2/psycopg2/_psycopg.pyi index dd2150819b09..abe74a9e0180 100644 --- a/stubs/psycopg2/psycopg2/_psycopg.pyi +++ b/stubs/psycopg2/psycopg2/_psycopg.pyi @@ -98,7 +98,7 @@ class cursor: def execute(self, query: str | bytes | Composable, vars: _Vars = ...) -> None: ... def executemany(self, query: str | bytes | Composable, vars_list: Iterable[_Vars]) -> None: ... def fetchall(self) -> list[tuple[Any, ...]]: ... - def fetchmany(self, size=...) -> list[tuple[Any, ...]]: ... + def fetchmany(self, size: int | None = ...) -> list[tuple[Any, ...]]: ... def fetchone(self) -> tuple[Any, ...] | None: ... def mogrify(self, *args, **kwargs): ... def nextset(self): ... diff --git a/stubs/psycopg2/psycopg2/extras.pyi b/stubs/psycopg2/psycopg2/extras.pyi index edd63870b07f..8037e3c273ce 100644 --- a/stubs/psycopg2/psycopg2/extras.pyi +++ b/stubs/psycopg2/psycopg2/extras.pyi @@ -1,5 +1,6 @@ from collections import OrderedDict -from typing import Any +from collections.abc import Callable +from typing import Any, NamedTuple, TypeVar, overload from psycopg2._ipaddress import register_ipaddress as register_ipaddress from psycopg2._json import ( @@ -28,22 +29,37 @@ from psycopg2._range import ( from .extensions import connection as _connection, cursor as _cursor, quote_ident as quote_ident +_T_cur = TypeVar("_T_cur", bound=_cursor) + class DictCursorBase(_cursor): - row_factory: Any def __init__(self, *args, **kwargs) -> None: ... - def fetchone(self) -> tuple[Any, ...] | None: ... - def fetchmany(self, size: int | None = ...) -> list[tuple[Any, ...]]: ... - def fetchall(self) -> list[tuple[Any, ...]]: ... - def __iter__(self): ... class DictConnection(_connection): - def cursor(self, *args, **kwargs): ... + @overload + def cursor(self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ...) -> DictCursor: ... + @overload + def cursor( + self, + name: str | bytes | None = ..., + *, + cursor_factory: Callable[..., _T_cur], + withhold: bool = ..., + scrollable: bool | None = ..., + ) -> _T_cur: ... + @overload + def cursor( + self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ... + ) -> _T_cur: ... class DictCursor(DictCursorBase): def __init__(self, *args, **kwargs) -> None: ... index: Any def execute(self, query, vars: Any | None = ...): ... def callproc(self, procname, vars: Any | None = ...): ... + def fetchone(self) -> DictRow | None: ... # type: ignore[override] + def fetchmany(self, size: int | None = ...) -> list[DictRow]: ... # type: ignore[override] + def fetchall(self) -> list[DictRow]: ... # type: ignore[override] + def __next__(self) -> DictRow: ... # type: ignore[override] class DictRow(list[Any]): def __init__(self, cursor) -> None: ... @@ -58,20 +74,56 @@ class DictRow(list[Any]): def __reduce__(self): ... class RealDictConnection(_connection): - def cursor(self, *args, **kwargs): ... + @overload + def cursor( + self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ... + ) -> RealDictCursor: ... + @overload + def cursor( + self, + name: str | bytes | None = ..., + *, + cursor_factory: Callable[..., _T_cur], + withhold: bool = ..., + scrollable: bool | None = ..., + ) -> _T_cur: ... + @overload + def cursor( + self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ... + ) -> _T_cur: ... class RealDictCursor(DictCursorBase): def __init__(self, *args, **kwargs) -> None: ... column_mapping: Any def execute(self, query, vars: Any | None = ...): ... def callproc(self, procname, vars: Any | None = ...): ... + def fetchone(self) -> RealDictRow | None: ... # type: ignore[override] + def fetchmany(self, size: int | None = ...) -> list[RealDictRow]: ... # type: ignore[override] + def fetchall(self) -> list[RealDictRow]: ... # type: ignore[override] + def __next__(self) -> RealDictRow: ... # type: ignore[override] class RealDictRow(OrderedDict[Any, Any]): def __init__(self, *args, **kwargs) -> None: ... def __setitem__(self, key, value) -> None: ... class NamedTupleConnection(_connection): - def cursor(self, *args, **kwargs): ... + @overload + def cursor( + self, name: str | bytes | None = ..., *, withhold: bool = ..., scrollable: bool | None = ... + ) -> NamedTupleCursor: ... + @overload + def cursor( + self, + name: str | bytes | None = ..., + *, + cursor_factory: Callable[..., _T_cur], + withhold: bool = ..., + scrollable: bool | None = ..., + ) -> _T_cur: ... + @overload + def cursor( + self, name: str | bytes | None, cursor_factory: Callable[..., _T_cur], withhold: bool = ..., scrollable: bool | None = ... + ) -> _T_cur: ... class NamedTupleCursor(_cursor): Record: Any @@ -79,10 +131,10 @@ class NamedTupleCursor(_cursor): def execute(self, query, vars: Any | None = ...): ... def executemany(self, query, vars): ... def callproc(self, procname, vars: Any | None = ...): ... - def fetchone(self) -> tuple[Any, ...] | None: ... - def fetchmany(self, size: int | None = ...) -> list[tuple[Any, ...]]: ... - def fetchall(self) -> list[tuple[Any, ...]]: ... - def __iter__(self): ... + def fetchone(self) -> NamedTuple | None: ... # type: ignore[override] + def fetchmany(self, size: int | None = ...) -> list[NamedTuple]: ... # type: ignore[override] + def fetchall(self) -> list[NamedTuple]: ... # type: ignore[override] + def __next__(self) -> NamedTuple: ... # type: ignore[override] class LoggingConnection(_connection): log: Any