From 511e3ea2c93222fb964ddd14373f390c9b43a66f Mon Sep 17 00:00:00 2001 From: Diego Argueta <620513-dargueta@users.noreply.github.com> Date: Tue, 17 Nov 2020 23:30:05 -0800 Subject: [PATCH 01/22] Add type annotations for `peewee` --- third_party/2and3/peewee.pyi | 2471 ++++++++++++++++++++++++++++++++++ 1 file changed, 2471 insertions(+) create mode 100644 third_party/2and3/peewee.pyi diff --git a/third_party/2and3/peewee.pyi b/third_party/2and3/peewee.pyi new file mode 100644 index 000000000000..7f3bc7969d7d --- /dev/null +++ b/third_party/2and3/peewee.pyi @@ -0,0 +1,2471 @@ +from bisect import bisect_left +from bisect import bisect_right +from contextlib import contextmanager +from functools import wraps +from typing import Any +from typing import AnyStr +from typing import Callable +from typing import ClassVar +from typing import Container +from typing import ContextManager +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import MutableSet +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Sequence +from typing import Text +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +import calendar +import collections +import datetime +import decimal +import operator +import re +import threading +import time +import uuid + +from typing_extensions import Literal +from typing_extensions import Protocol + +T = TypeVar("T") +__TModel = TypeVar("__TModel", bound="Model") +__TConvFunc = Callable[[Any], Any] +__TFunc = TypeVar("__TFunc", bound=Callable) +__TClass = TypeVar("__TClass", bound=type) +__TModelOrTable = Union[Type["Model"], "ModelAlias", "Table"] +__TSubquery = Union[Tuple["Query", Type["Model"]], Type["Model"], "ModelAlias"] +__TContextClass = TypeVar("__TContextClass", bound="Context") +__TField = TypeVar("__TField", bound="Field") +__TFieldOrModel = Union[__TModelOrTable, "Field"] +__TNode = TypeVar("__TNode", bound="Node") + +__version__: str +__all__: List[str] + +class __ICursor(Protocol): + description: Tuple[str, Any, Any, Any, Any, Any, Any] + rowcount: int + def fetchone(self) -> Optional[tuple]: ... + def fetchmany(self, size: int = ...) -> Iterable[tuple]: ... + def fetchall(self) -> Iterable[tuple]: ... + +class __IConnection(Protocol): + def cursor(self) -> __ICursor: ... + def execute(self, sql: str, *args: object) -> __ICursor: ... + def commit(self) -> Any: ... + def rollback(self) -> Any: ... + +def _sqlite_date_part(lookup_type: str, datetime_string: str) -> Optional[str]: ... +def _sqlite_date_trunc(lookup_type: str, datetime_string: str) -> Optional[str]: ... + +class attrdict(dict): + def __getattr__(self, attr: str) -> Any: ... + def __setattr__(self, attr: str, value: Any) -> None: ... + def __iadd__(self, rhs: Mapping[str, object]) -> attrdict: ... + def __add__(self, rhs: Mapping[str, object]) -> Mapping[str, object]: ... + +SENTINEL: object + +OP: attrdict + +DJANGO_MAP: attrdict + +FIELD: attrdict + +JOIN: attrdict + +ROW: attrdict + +SCOPE_NORMAL: int +SCOPE_SOURCE: int +SCOPE_VALUES: int +SCOPE_CTE: int +SCOPE_COLUMN: int + +CSQ_PARENTHESES_NEVER: int +CSQ_PARENTHESES_ALWAYS: int +CSQ_PARENTHESES_UNNESTED: int + +SNAKE_CASE_STEP1: re.Pattern +SNAKE_CASE_STEP2: re.Pattern + +MODEL_BASE: str + +# TODO (dargueta) +class _callable_context_manager(object): + def __call__(self, fn): + @wraps(fn) + def inner(*args, **kwargs): + with self: + return fn(*args, **kwargs) + return inner + +class Proxy(object): + obj: Any + def initialize(self, obj: Any) -> None: ... + def attach_callback(self, callback: __TConvFunc) -> __TConvFunc: ... + def passthrough(method: __TFunc) -> __TFunc: ... + def __enter__(self) -> Any: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> Any: ... + def __getattr__(self, attr: str) -> Any: ... + def __setattr__(self, attr: str, value: Any) -> None: ... + +class DatabaseProxy(Proxy): + def connection_context(self) -> ConnectionContext: ... + def atomic(self, *args: object, **kwargs: object) -> _atomic: ... + def manual_commit(self) -> _manual: ... + def transaction(self, *args: object, **kwargs: object) -> _transaction: ... + def savepoint(self) -> _savepoint: ... + +class ModelDescriptor(object): ... + +# SQL Generation. + +class AliasManager(object): + @property + def mapping(self) -> MutableMapping["Source", str]: ... + def add(self, source: Source) -> str: ... + def get(self, source: Source, any_depth: bool = ...) -> str: ... + def __getitem__(self, source: Source) -> str: ... + def __setitem__(self, source: Source, alias: str) -> None: ... + def push(self) -> None: ... + def pop(self) -> None: ... + +class __State(NamedTuple): + scope: int + parentheses: bool + # From the source code we know this to be a Dict and not just a MutableMapping. + settings: Dict[str, Any] + +class State(__State): + def __new__(cls, scope: int = ..., parentheses: bool = ..., **kwargs: object) -> State: ... + def __call__(self, scope: Optional[int] = ..., parentheses: Optional[int] = ..., **kwargs: object) -> State: ... + def __getattr__(self, attr_name: str) -> Any: ... + +class Context(object): + stack: List[State] + alias_manager: AliasManager + state: State + def __init__(self, **settings: Any) -> None: ... + def as_new(self) -> Context: ... + def column_sort_key(self, item: Sequence[Union[ColumnBase, Source]]) -> Tuple[str, ...]: ... + @property + def scope(self) -> int: ... + @property + def parentheses(self) -> bool: ... + @property + def subquery(self): + return self.state.subquery + def __call__(self, **overrides: Any) -> Context: ... + def scope_normal(self) -> ContextManager[Context]: ... + def scope_source(self) -> ContextManager[Context]: ... + def scope_values(self) -> ContextManager[Context]: ... + def scope_cte(self) -> ContextManager[Context]: ... + def scope_column(self) -> ContextManager[Context]: ... + def __enter__(self) -> Context: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + @contextmanager + def push_alias(self) -> Iterator[None]: ... + # TODO (dargueta): Is this right? + def sql(self, obj: Any) -> Context: ... + def literal(self, keyword: str) -> Context: ... + def value(self, value: Any, converter: Optional[__TConvFunc] = ..., add_param: bool = ...) -> Context: ... + def __sql__(self, ctx: Context) -> Context: ... + def parse(self, node: Node) -> Tuple[str, Optional[tuple]]: ... + def query(self) -> Tuple[str, Optional[tuple]]: ... + +def query_to_string(query: Node) -> str: ... + +class Node(object): + def clone(self) -> Node: ... + def __sql__(self, ctx: Context) -> Context: ... + # FIXME (dargueta): Is there a way to make this a proper decorator? + @staticmethod + def copy(method: __TFunc) -> __TFunc: + def inner(self: T, *args: object, **kwargs: object) -> T: + clone = self.clone() + method(clone, *args, **kwargs) + return clone + return inner + def coerce(self, _coerce: bool = ...) -> Node: ... + def is_alias(self) -> bool: ... + def unwrap(self) -> Node: ... + +class ColumnFactory(object): + node: Node + def __init__(self, node: Node): ... + def __getattr__(self, attr: str) -> Column: ... + +class _DynamicColumn(object): + @overload + def __get__(self, instance: None, instance_type: type) -> _DynamicColumn: ... + @overload + def __get__(self, instance: T, instance_type: Type[T]) -> ColumnFactory: ... + +class _ExplicitColumn(object): + @overload + def __get__(self, instance: None, instance_type: type) -> _ExplicitColumn: ... + @overload + def __get__(self, instance: T, instance_type: Type[T]) -> NoReturn: ... + +class Source(Node): + c: ClassVar[_DynamicColumn] + def __init__(self, alias: Optional[str] = ...): ... + def alias(self, name: str) -> Source: ... + def select(self, *columns: Field) -> Select: ... + def join(self, dest, join_type: int = ..., on: Optional[Expression] = ...) -> Join: ... + def left_outer_join(self, dest, on: Optional[Expression] = ...) -> Join: ... + def cte(self, name: str, recursive: bool = ..., columns=None, materialized=None) -> CTE: ... + def get_sort_key(self, ctx) -> Tuple[str, ...]: ... + def apply_alias(self, ctx: Context) -> Context: ... + def apply_column(self, ctx: Context) -> Context: ... + +class _HashableSource(object): + def __init__(self, *args: object, **kwargs: object): ... + def alias(self, name: str) -> _HashableSource: ... + def __hash__(self) -> int: ... + @overload + def __eq__(self, other: _HashableSource) -> bool: ... + @overload + def __eq__(self, other: Any) -> Expression: ... + @overload + def __ne__(self, other: _HashableSource) -> bool: ... + @overload + def __ne__(self, other: Any) -> Expression: ... + def __lt__(self, other: Any) -> Expression: ... + def __le__(self, other: Any) -> Expression: ... + def __gt__(self, other: Any) -> Expression: ... + def __ge__(self, other: Any) -> Expression: ... + +def __join__(join_type: int = ..., inverted: bool = ...) -> Callable[[Any, Any], Join]: ... + +class BaseTable(Source): + def __and__(self, other: Any) -> Join: ... + def __add__(self, other: Any) -> Join: ... + def __sub__(self, other: Any) -> Join: ... + def __or__(self, other: Any) -> Join: ... + def __mul__(self, other: Any) -> Join: ... + def __rand__(self, other: Any) -> Join: ... + def __radd__(self, other: Any) -> Join: ... + def __rsub__(self, other: Any) -> Join: ... + def __ror__(self, other: Any) -> Join: ... + def __rmul__(self, other: Any) -> Join: ... + +class _BoundTableContext(_callable_context_manager): + table: Table + database: Database + def __init__(self, table: Table, database: Database): ... + def __enter__(self) -> Table: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + +class Table(_HashableSource, BaseTable): + __name__: str + c: _ExplicitColumn + primary_key: Optional[Union[Field, CompositeKey]] + def __init__( + self, + name: str, + columns: Optional[Iterable[str]] = ..., + primary_key: Optional[Union[Field, CompositeKey]] = ..., + schema: Optional[str] = ..., + alias: Optional[str] = ..., + _model: Optional[Type[Model]] = ..., + _database: Optional[Database] = ..., + ): ... + def clone(self) -> Table: ... + def bind(self, database: Optional[Database] = ...) -> Table: ... + def bind_ctx(self, database: Optional[Database] = ...) -> _BoundTableContext: ... + def select(self, *columns: Column) -> Select: ... + @overload + def insert(self, insert: Optional[Select], columns: Sequence[Union[str, Field, Column]]) -> Insert: ... + @overload + def insert(self, insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], **kwargs: object): ... + @overload + def replace(self, insert: Optional[Select], columns: Sequence[Union[str, Field, Column]]) -> Insert: ... + @overload + def replace(self, insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], **kwargs: object): ... + def update(self, update: Optional[Mapping[str, object]] = ..., **kwargs: object) -> Update: ... + def delete(self) -> Delete: ... + def __sql__(self, ctx: Context) -> Context: ... + +class Join(BaseTable): + lhs: Any # TODO + rhs: Any # TODO + join_type: int + def __init__(self, lhs, rhs, join_type: int = ..., on: Optional[Expression] = ..., alias: Optional[str] = ...): ... + def on(self, predicate: Expression) -> Join: ... + def __sql__(self, ctx: Context) -> Context: ... + +class ValuesList(_HashableSource, BaseTable): + def __init__(self, values, columns=None, alias: Optional[str] = ...): ... + # FIXME (dargueta) `names` might be wrong + def columns(self, *names: str) -> ValuesList: ... + def __sql__(self, ctx: Context) -> Context: ... + +class CTE(_HashableSource, Source): + def __init__( + self, + name: str, + query: Select, + recursive: bool = ..., + columns: Optional[Iterable[Union[Column, Field, str]]] = ..., + materialized: bool = ..., + ): ... + # TODO (dargueta): Is `columns` just for column names? + def select_from(self, *columns: Union[Column, Field]) -> Select: ... + def _get_hash(self) -> int: ... + def union_all(self, rhs) -> CTE: ... + __add__ = union_all + def union(self, rhs: SelectQuery) -> CTE: ... + __or__ = union + def __sql__(self, ctx: Context) -> Context: ... + +class ColumnBase(Node): + _converter: Optional[__TConvFunc] + def converter(self, converter: Optional[__TConvFunc] = ...) -> ColumnBase: ... + @overload + def alias(self, alias: None) -> ColumnBase: ... + @overload + def alias(self, alias: str) -> Alias: ... + def unalias(self) -> ColumnBase: ... + def cast(self, as_type: str) -> Cast: ... + def asc(self, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Asc: ... + __pos__ = asc + def desc(self, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Desc: ... + __neg__ = desc + def __invert__(self) -> Negated: ... + def __and__(self, other: Any) -> Expression: ... + def __or__(self, other: Any) -> Expression: ... + def __add__(self, other: Any) -> Expression: ... + def __sub__(self, other: Any) -> Expression: ... + def __mul__(self, other: Any) -> Expression: ... + def __div__(self, other: Any) -> Expression: ... + def __truediv__(self, other: Any) -> Expression: ... + def __xor__(self, other: Any) -> Expression: ... + def __radd__(self, other: Any) -> Expression: ... + def __rsub__(self, other: Any) -> Expression: ... + def __rmul__(self, other: Any) -> Expression: ... + def __rdiv__(self, other: Any) -> Expression: ... + def __rtruediv__(self, other: Any) -> Expression: ... + def __rand__(self, other: Any) -> Expression: ... + def __ror__(self, other: Any) -> Expression: ... + def __rxor__(self, other: Any) -> Expression: ... + def __eq__(self, rhs: Optional["Node"]) -> Expression: ... + def __ne__(self, rhs: Optional[Node]) -> Expression: ... + def __lt__(self, other: Any) -> Expression: ... + def __le__(self, other: Any) -> Expression: ... + def __gt__(self, other: Any) -> Expression: ... + def __ge__(self, other: Any) -> Expression: ... + def __lshift__(self, other: Any) -> Expression: ... + def __rshift__(self, other: Any) -> Expression: ... + def __mod__(self, other: Any) -> Expression: ... + def __pow__(self, other: Any) -> Expression: ... + def bin_and(self, other: Any) -> Expression: ... + def bin_or(self, other: Any) -> Expression: ... + def in_(self, other: Any) -> Expression: ... + def not_in(self, other: Any) -> Expression: ... + def regexp(self, other: Any) -> Expression: ... + def is_null(self, is_null: bool = ...) -> Expression: ... + def contains(self, rhs: Union[Node, str]) -> Expression: ... + def startswith(self, rhs: Union[Node, str]) -> Expression: ... + def endswith(self, rhs: Union[Node, str]) -> Expression: ... + def between(self, lo: Any, hi: Any) -> Expression: ... + def concat(self, rhs: Any) -> StringExpression: ... + def iregexp(self, rhs: Any) -> Expression: ... + def __getitem__(self, item: Any) -> Expression: ... + def distinct(self) -> NodeList: ... + def collate(self, collation: str) -> NodeList: ... + def get_sort_key(self, ctx: Context) -> Tuple[str, ...]: ... + +class Column(ColumnBase): + source: Source + name: str + def __init__(self, source: Source, name: str): ... + def get_sort_key(self, ctx: Context) -> Tuple[str, ...]: ... + def __hash__(self) -> int: ... + def __sql__(self, ctx: Context) -> Context: ... + +class WrappedNode(ColumnBase, Generic[__TNode]): + node: __TNode + _coerce: bool + _converter: Optional[__TConvFunc] + def __init__(self, node: __TNode): ... + def is_alias(self) -> bool: ... + def unwrap(self) -> __TNode: ... + +class EntityFactory(object): + node: Node + def __init__(self, node: Node): ... + def __getattr__(self, attr: str) -> Entity: ... + +class _DynamicEntity(object): + __slots__ = () + @overload + def __get__(self, instance: None, instance_type: type) -> _DynamicEntity: ... + @overload + def __get__(self, instance: T, instance_type: Type[T]) -> EntityFactory: ... + +class Alias(WrappedNode): + c: ClassVar[_DynamicEntity] + def __init__(self, node: Node, alias: str): ... + def __hash__(self) -> int: ... + @overload + def alias(self, alias: None) -> Node: ... + @overload + def alias(self, alias: str) -> Alias: ... + def unalias(self) -> Node: ... + def is_alias(self) -> bool: ... + def __sql__(self, ctx: Context) -> Context: ... + +class Negated(WrappedNode): + def __invert__(self) -> Node: ... + def __sql__(self, ctx: Context) -> Context: ... + +class BitwiseMixin(object): + def __and__(self, other): + return self.bin_and(other) + def __or__(self, other): + return self.bin_or(other) + def __sub__(self, other): + return self.bin_and(other.bin_negated()) + def __invert__(self) -> BitwiseNegated: ... + +class BitwiseNegated(BitwiseMixin, WrappedNode): + def __invert__(self) -> Node: ... + def __sql__(self, ctx: Context) -> Context: ... + +class Value(ColumnBase): + value: object + converter: Optional[__TConvFunc] + multi: bool + def __init__(self, value: object, converter: Optional[__TConvFunc] = ..., unpack: bool = ...): + self.value = value + self.converter = converter + self.multi = unpack and isinstance(self.value, multi_types) + if self.multi: + self.values = [] + for item in self.value: + if isinstance(item, Node): + self.values.append(item) + else: + self.values.append(Value(item, self.converter)) + def __sql__(self, ctx): + if self.multi: + # For multi-part values (e.g. lists of IDs). + return ctx.sql(EnclosedNodeList(self.values)) + + return ctx.value(self.value, self.converter) + +def AsIs(value: object) -> Value: ... + +class Cast(WrappedNode): + def __init__(self, node: Node, cast: str): ... + def __sql__(self, ctx: Context) -> Context: ... + +class Ordering(WrappedNode): + direction: str + collation: Optional[str] + nulls: Optional[str] + def __init__(self, node: Node, direction: str, collation: Optional[str] = ..., nulls: Optional[str] = ...): ... + def collate(self, collation: Optional[str] = ...) -> Ordering: ... + def __sql__(self, ctx: Context) -> Context: ... + +def Asc(node: Node, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Ordering: ... +def Desc(node: Node, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Ordering: ... + +class Expression(ColumnBase): + lhs: Optional[Union[Node, str]] + op: int + rhs: Optional[Union[Node, str]] + flat: bool + def __init__(self, lhs: Optional[Union[Node, str]], op: int, rhs: Optional[Union[Node, str]], flat: bool = ...): ... + def __sql__(self, ctx: Context) -> Context: ... + +class StringExpression(Expression): + def __add__(self, rhs: Any) -> StringExpression: ... + def __radd__(self, lhs: Any) -> StringExpression: ... + +class Entity(ColumnBase): + def __init__(self, *path: str): ... + def __getattr__(self, attr: str) -> Entity: ... + def get_sort_key(self, ctx: Context) -> Tuple[str, ...]: ... + def __hash__(self) -> int: ... + def __sql__(self, ctx: Context) -> Context: ... + +class SQL(ColumnBase): + sql: str + params: Optional[Mapping[str, object]] + def __init__(self, sql: str, params: Mapping[str, object] = ...): ... + def __sql__(self, ctx: Context) -> Context: ... + +def Check(constraint: str) -> SQL: + return SQL("CHECK (%s)" % constraint) + +class Function(ColumnBase): + name: str + arguments: tuple + def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: Optional[__TConvFunc] = ...): ... + def __getattr__(self, attr: str) -> Callable[..., Function]: ... + # TODO (dargueta): `where` is an educated guess + def filter(self, where: Optional[Expression] = ...) -> Function: ... + def order_by(self, *ordering) -> Function: + self._order_by = ordering + def python_value(self, func: Optional[__TConvFunc] = ...) -> Function: ... + def over( + self, + partition_by: Optional[Union[Sequence[Field], Window]] = ..., + order_by: Optional[Sequence[Union[Field, Expression]]] = ..., + start: Optional[Union[str, SQL]] = ..., + end: Optional[Union[str, SQL]] = ..., + frame_type: Optional[str] = ..., + window: Optional[Window] = ..., + exclude: Optional[SQL] = ..., + ) -> NodeList: ... + def __sql__(self, ctx: Context) -> Context: ... + +fn: Function + +class Window(Node): + CURRENT_ROW: ClassVar[SQL] + GROUP: ClassVar[SQL] + TIES: ClassVar[SQL] + NO_OTHERS: ClassVar[SQL] + GROUPS: ClassVar[str] + RANGE: ClassVar[str] + ROWS: ClassVar[str] + # Instance variables + partition_by: Tuple[Union[Field, Expression], ...] + order_by: Tuple[Union[Field, Expression], ...] + start: Optional[Union[str, SQL]] + end: Optional[Union[str, SQL]] + frame_type: Optional[Any] # TODO + @overload + def __init__( + self, + partition_by: Optional[Union[Sequence[Field], Window]] = ..., + order_by: Optional[Sequence[Union[Field, Expression]]] = ..., + start: Optional[Union[str, SQL]] = ..., + end: None = ..., + frame_type: Optional[str] = ..., + extends: Optional[Union[Window, WindowAlias, str]] = ..., + exclude: Optional[SQL] = ..., + alias: Optional[str] = ..., + _inline: bool = ..., + ): ... + @overload + def __init__( + self, + partition_by: Optional[Union[Sequence[Field], Window]] = ..., + order_by: Optional[Sequence[Union[Field, Expression]]] = ..., + start: Union[str, SQL] = ..., + end: Union[str, SQL] = ..., + frame_type: Optional[str] = ..., + extends: Optional[Union[Window, WindowAlias, str]] = ..., + exclude: Optional[SQL] = ..., + alias: Optional[str] = ..., + _inline: bool = ..., + ): ... + def alias(self, alias: Optional[str] = ...) -> Window: ... + def as_range(self) -> Window: ... + def as_rows(self) -> Window: ... + def as_groups(self) -> Window: ... + def extends(self, window: Optional[Union[Window, WindowAlias, str]] = ...) -> Window: ... + def exclude(self, frame_exclusion: Optional[Union[str, SQL]] = ...) -> Window: ... + @staticmethod + def following(value: Optional[int] = ...) -> SQL: ... + @staticmethod + def preceding(value: Optional[int] = ...) -> SQL: ... + def __sql__(self, ctx: Context) -> Context: ... + +class WindowAlias(Node): + window: Window + def __init__(self, window: Window): ... + def alias(self, window_alias: str) -> WindowAlias: ... + def __sql__(self, ctx: Context) -> Context: ... + +class ForUpdate(Node): + def __init__( + self, + expr: Union[Literal[True], str], + of: Optional[Union[__TModelOrTable, List[__TModelOrTable], Set[__TModelOrTable], Tuple[__TModelOrTable, ...]]] = ..., + nowait: Optional[bool] = ..., + ): ... + def __sql__(self, ctx: Context) -> Context: ... + +def Case(predicate: Optional[Node], expression_tuples: Iterable[Tuple[Expression, Any]], default: Any = ...) -> NodeList: ... + +class NodeList(ColumnBase): + nodes: Sequence[Any] # TODO (dargueta): Narrow this type + glue: str + parens: bool + def __init__(self, nodes: Sequence[Any], glue: str = ..., parens: bool = ...): ... + def __sql__(self, ctx: Context) -> Context: ... + +def CommaNodeList(nodes: Sequence[Any]) -> NodeList: ... +def EnclosedNodeList(nodes: Sequence[Any]) -> NodeList: ... + +class _Namespace(Node): + def __init__(self, name: str): ... + def __getattr__(self, attr: str) -> NamespaceAttribute: ... + def __getitem__(self, attr: str) -> NamespaceAttribute: ... + +class NamespaceAttribute(ColumnBase): + def __init__(self, namespace: _Namespace, attribute: str): ... + def __sql__(self, ctx: Context) -> Context: ... + +EXCLUDED: _Namespace + +class DQ(ColumnBase): + query: Dict[str, Any] + + # TODO (dargueta): Narrow this down? + def __init__(self, **query: Any): ... + def __invert__(self) -> DQ: ... + def clone(self) -> DQ: ... + +class QualifiedNames(WrappedNode): + def __sql__(self, ctx: Context) -> Context: ... + +@overload +def qualify_names(node: Expression) -> Expression: ... +@overload +def qualify_names(node: ColumnBase) -> QualifiedNames: ... +@overload +def qualify_names(node: T) -> T: ... + +class OnConflict(Node): + @overload + def __init__( + self, + action: Optional[str] = ..., + update: Optional[Mapping[str, object]] = ..., + preserve: Optional[Union[Field, Iterable[Field]]] = ..., + where: Optional[Expression] = ..., + conflict_target: Optional[Union[Field, Sequence[Field]]] = ..., + conflict_where: None = ..., + conflict_constraint: Optional[str] = ..., + ): ... + @overload + def __init__( + self, + action: Optional[str] = ..., + update: Optional[Mapping[str, object]] = ..., + preserve: Optional[Union[Field, Iterable[Field]]] = ..., + where: Optional[Expression] = ..., + conflict_target: None = ..., + conflict_where: Optional[Expression] = ..., + conflict_constraint: Optional[str] = ..., + ): ... + # undocumented + def get_conflict_statement(self, ctx: Context, query: Query): + return ctx.state.conflict_statement(self, query) + def get_conflict_update(self, ctx, query): + return ctx.state.conflict_update(self, query) + def preserve(self, *columns) -> OnConflict: ... + def update(self, _data: Optional[Mapping[str, object]] = ..., **kwargs: object) -> OnConflict: ... + def where(self, *expressions: Expression) -> OnConflict: ... + def conflict_target(self, *constraints) -> OnConflict: ... + def conflict_where(self, *expressions: Expression) -> OnConflict: ... + def conflict_constraint(self, constraint: str) -> OnConflict: ... + +# BASE QUERY INTERFACE. + +class BaseQuery(Node): + default_row_type: ClassVar[int] + def __init__(self, _database: Optional[Database] = ..., **kwargs: object): ... + def bind(self, database: Optional[Database] = ...) -> BaseQuery: ... + def clone(self) -> BaseQuery: ... + def dicts(self, as_dict: bool = ...) -> BaseQuery: ... + def tuples(self, as_tuple: bool = ...) -> BaseQuery: ... + def namedtuples(self, as_namedtuple: bool = ...) -> BaseQuery: ... + def objects(self, constructor: Optional[__TConvFunc] = ...) -> BaseQuery: ... + def __sql__(self, ctx: Context) -> Context: ... + def sql(self) -> Tuple[str, Optional[tuple]]: ... + def execute(self, database: Optional[Database] = ...) -> CursorWrapper: ... + # TODO (dargueta): `Any` is too loose; list types of the cursor wrappers + def iterator(self, database: Optional[Database] = ...) -> Iterator[Any]: ... + def __iter__(self) -> Iterator[Any]: ... + @overload + def __getitem__(self, value: int) -> Any: ... + @overload + def __getitem__(self, value: slice) -> Sequence[Any]: ... + def __len__(self) -> int: ... + def __str__(self) -> str: ... + +class RawQuery(BaseQuery): + # TODO (dargueta): `tuple` may not be 100% accurate, maybe Sequence[object]? + def __init__(self, sql: Optional[str] = ..., params: Optional[tuple] = ..., **kwargs: object): ... + def __sql__(self, ctx: Context) -> Context: ... + +class Query(BaseQuery): + # TODO (dargueta): Verify type of order_by + def __init__( + self, + where: Optional[Expression] = ..., + order_by: Optional[Sequence[Node]] = ..., + limit: Optional[int] = ..., + offset: Optional[int] = ..., + **kwargs: object, + ): ... + def with_cte(self, *cte_list: CTE) -> Query: ... + def where(self, *expressions: Expression) -> Query: ... + def orwhere(self, *expressions: Expression) -> Query: ... + def order_by(self, *values: Node) -> Query: ... + def order_by_extend(self, *values: Node) -> Query: ... + def limit(self, value: Optional[int] = ...) -> Query: ... + def offset(self, value: Optional[int] = ...) -> Query: ... + def paginate(self, page: int, paginate_by: int = ...) -> Query: ... + def _apply_ordering(self, ctx: Context) -> Context: ... + def __sql__(self, ctx: Context) -> Context: ... + +def __compound_select__(operation: str, inverted: bool = ...) -> Callable[[Any, Any], CompoundSelectQuery]: ... + +class SelectQuery(Query): + def union_all(self, other: object) -> CompoundSelectQuery: ... + def union(self, other: object) -> CompoundSelectQuery: ... + def intersect(self, other: object) -> CompoundSelectQuery: ... + def except_(self, other: object) -> CompoundSelectQuery: ... + def __add__(self, other: object) -> CompoundSelectQuery: ... + def __or__(self, other: object) -> CompoundSelectQuery: ... + def __and__(self, other: object) -> CompoundSelectQuery: ... + def __sub__(self, other: object) -> CompoundSelectQuery: ... + def __radd__(self, other: object) -> CompoundSelectQuery: ... + def __ror__(self, other: object) -> CompoundSelectQuery: ... + def __rand__(self, other: object) -> CompoundSelectQuery: ... + def __rsub__(self, other: object) -> CompoundSelectQuery: ... + def select_from(self, *columns: Field) -> Select: ... + +class SelectBase(_HashableSource, Source, SelectQuery): + @overload + def peek(self, database: Optional[Database] = ..., n: Literal[1] = ...) -> Any: ... + @overload + def peek(self, database: Optional[Database] = ..., n: int = ...) -> List[Any]: ... + @overload + def first(self, database: Optional[Database] = ..., n: Literal[1] = ...) -> Any: ... + @overload + def first(self, database: Optional[Database] = ..., n: int = ...) -> List[Any]: ... + @overload + def scalar(self, database: Optional[Database] = ..., as_tuple: Literal[False] = ...) -> Any: ... + @overload + def scalar(self, database: Optional[Database] = ..., as_tuple: Literal[True] = ...) -> tuple: ... + def count(self, database: Optional[Database] = ..., clear_limit: bool = ...) -> int: ... + def exists(self, database: Optional[Database] = ...) -> bool: ... + def get(self, database: Optional[Database] = ...) -> Any: ... + +# QUERY IMPLEMENTATIONS. + +class CompoundSelectQuery(SelectBase): + lhs: Any # TODO (dargueta) + op: str + rhs: Any # TODO (dargueta) + def __init__(self, lhs: Any, op: str, rhs: Any): ... + def exists(self, database: Optional[Database] = ...) -> bool: ... + def __sql__(self, ctx: Context) -> Context: ... + +class Select(SelectBase): + def __init__( + self, + from_list: Optional[Sequence[Union[Column, Field]]] = ..., # TODO (dargueta): `Field` might be wrong + columns: Optional[Iterable[Union[Column, Field]]] = ..., # TODO (dargueta): `Field` might be wrong + # Docs say this is a "[l]ist of columns or values to group by" so we don't have + # a whole lot to restrict this to thanks to "or values" + group_by: Sequence[Any] = ..., + having: Optional[Expression] = ..., + distinct: Optional[Union[bool, Sequence[Column]]] = ..., + windows: Optional[Container[Window]] = ..., + for_update: Optional[Union[bool, str]] = ..., + for_update_of: Optional[Union[Table, Iterable[Table]]] = ..., + nowait: Optional[bool] = ..., + lateral: Optional[bool] = ..., # undocumented + **kwargs: object, + ): ... + def clone(self) -> Select: ... + def columns(self, *columns, **kwargs: object) -> Select: + self._returning = columns + select = columns + def select_extend(self, *columns) -> Select: + self._returning = tuple(self._returning) + columns + # TODO (dargueta): Is `sources` right? + def from_(self, *sources: Union[Source, Type[Model]]) -> Select: ... + def join(self, dest: Type[Model], join_type: int = ..., on: Optional[Expression] = ...) -> Select: ... + def group_by(self, *columns: Union[Table, Field]) -> Select: ... + def group_by_extend(self, *values: Union[Table, Field]) -> Select: ... + def having(self, *expressions: Expression) -> Select: ... + @overload + def distinct(self, _: bool) -> Select: ... + @overload + def distinct(self, *columns: Field) -> Select: ... + def window(self, *windows: Window) -> Select: ... + def for_update(self, for_update: bool = ..., of=None, nowait=None) -> Select: + if not for_update and (of is not None or nowait): + for_update = True + self._for_update = for_update + self._for_update_of = of + self._for_update_nowait = nowait + def lateral(self, lateral: bool = ...) -> Select: ... + +class _WriteQuery(Query): + table: Table + def __init__(self, table: Table, returning: Optional[Iterable[Union[Type[Model], Field]]] = ..., **kwargs: object): ... + def returning(self, *returning: Union[Type[Model], Field]) -> _WriteQuery: ... + def apply_returning(self, ctx: Context) -> Context: ... + def execute_returning(self, database: Database) -> CursorWrapper: ... + def handle_result(self, database: Database, cursor: __ICursor) -> Union[int, __ICursor]: ... + def __sql__(self, ctx: Context) -> Context: ... + +class Update(_WriteQuery): + def __init__(self, table: Table, update=None, **kwargs): + super(Update, self).__init__(table, **kwargs) + self._update = update + self._from = None + @Node.copy + def from_(self, *sources) -> None: + self._from = sources + def __sql__(self, ctx: Context) -> Context: ... + +class Insert(_WriteQuery): + SIMPLE: ClassVar[int] + QUERY: ClassVar[int] + MULTI: ClassVar[int] + DefaultValuesException: Type[Exception] + def __init__( + self, + table: Table, + insert: Optional[Union[Mapping[str, object], Iterable[Mapping[str, object]], SelectQuery, SQL]] = ..., + columns: Optional[Iterable[Union[str, Field]]] = ..., # FIXME: Might be `Column` not `Field` + on_conflict: Optional[OnConflict] = ..., + **kwargs: object, + ): ... + def where(self, *expressions: Expression) -> NoReturn: ... + def on_conflict_ignore(self, ignore: bool = ...) -> Insert: ... + def on_conflict_replace(self, replace: bool = ...) -> Insert: ... + def on_conflict(self, *args, **kwargs) -> Insert: ... + def get_default_data(self) -> dict: ... + def get_default_columns(self) -> Optional[List[Field]]: ... + def __sql__(self, ctx: Context) -> Context: ... + def handle_result(self, database: Database, cursor: __ICursor) -> Union[__ICursor, int]: ... + +class Delete(_WriteQuery): + def __sql__(self, ctx: Context) -> Context: ... + +class Index(Node): + def __init__( + self, + name: str, + table, + expressions, + unique: bool = ..., + safe: bool = ..., + where: Optional[Expression] = ..., + using: Optional[str] = ..., + ): ... + def safe(self, _safe: bool = ...) -> Index: ... + def where(self, *expressions: Expression) -> Index: ... + def using(self, _using: Optional[str] = ...) -> Index: ... + def __sql__(self, ctx: Context) -> Context: ... + +class ModelIndex(Index): + def __init__( + self, + model: Type[__TModel], + fields: Iterable[Union[Field, Node, str]], + unique: bool = ..., + safe: bool = ..., + where: Optional[Expression] = ..., + using: Optional[str] = ..., + name: Optional[str] = ..., + ): ... + +# DB-API 2.0 EXCEPTIONS. + +class PeeweeException(Exception): + # This attribute only exists if an exception was passed into the constructor. + # Attempting to access it otherwise will result in an AttributeError. + orig: Exception + def __init__(self, *args: object): ... + +class ImproperlyConfigured(PeeweeException): ... +class DatabaseError(PeeweeException): ... +class DataError(DatabaseError): ... +class IntegrityError(DatabaseError): ... +class InterfaceError(PeeweeException): ... +class InternalError(DatabaseError): ... +class NotSupportedError(DatabaseError): ... +class OperationalError(DatabaseError): ... +class ProgrammingError(DatabaseError): ... + +class ExceptionWrapper(object): + exceptions: Mapping[str, Type[Exception]] + def __init__(self, exceptions: Mapping[str, Type[Exception]]): ... + def __enter__(self) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: Any) -> None: ... + +EXCEPTIONS: Mapping[str, Type[Exception]] + +__exception_wrapper__: ExceptionWrapper + +class IndexMetadata(NamedTuple): + name: str + sql: str + columns: List[str] + unique: bool + table: str + +class ColumnMetadata(NamedTuple): + name: str + data_type: str + null: bool + primary_key: bool + table: str + default: object + +class ForeignKeyMetadata(NamedTuple): + column: str + dest_table: str + dest_column: str + table: str + +class ViewMetadata(NamedTuple): + name: str + sql: str + +class _ConnectionState(object): + closed: bool + conn: Optional[__IConnection] + ctx: List[ConnectionContext] + transactions: List[Union[_manual, _transaction]] + def reset(self) -> None: ... + def set_connection(self, conn: __IConnection) -> None: ... + +class _ConnectionLocal(_ConnectionState, threading.local): ... + +class ConnectionContext(_callable_context_manager): + __slots__ = ("db",) + def __init__(self, db): + self.db = db + def __enter__(self): + if self.db.is_closed(): + self.db.connect() + def __exit__(self, exc_type, exc_val, exc_tb): + self.db.close() + +class Database(_callable_context_manager): + context_class: ClassVar[Type[__TContextClass]] + field_types: ClassVar[Mapping[str, str]] + operations: ClassVar[Mapping[str, Any]] # TODO (dargueta) Verify k/v types + param: ClassVar[str] + quote: ClassVar[str] + server_version: ClassVar[Optional[Tuple[int, ...]]] + commit_select: ClassVar[bool] + compound_select_parentheses: ClassVar[int] + for_update: ClassVar[bool] + index_schema_prefix: ClassVar[bool] + index_using_precedes_table: ClassVar[bool] + limit_max: ClassVar[Optional[int]] + nulls_ordering: ClassVar[bool] + returning_clause: ClassVar[bool] + safe_create_index: ClassVar[bool] + safe_drop_index: ClassVar[bool] + sequences: ClassVar[bool] + truncate_table: ClassVar[bool] + # Instance variables + database: __IConnection + deferred: bool + autoconnect: bool + autorollback: bool + thread_safe: bool + connect_params: Mapping[str, Any] + server_version: Optional[Union[int, Tuple[int, ...]]] + def __init__( + self, + database: __IConnection, + thread_safe: bool = ..., + autorollback: bool = ..., + field_types: Optional[Mapping[str, str]] = ..., + operations: Optional[Mapping[str, str]] = ..., + autocommit: bool = ..., + autoconnect: bool = ..., + **kwargs: object, + ): ... + def init(self, database: __IConnection, **kwargs: object) -> None: ... + def __enter__(self) -> Database: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + def connection_context(self) -> ConnectionContext: ... + def connect(self, reuse_if_open: bool = ...) -> bool: ... + def close(self) -> bool: ... + def is_closed(self) -> bool: ... + def is_connection_usable(self) -> bool: ... + def connection(self) -> __IConnection: ... + def cursor(self, commit: Optional[bool] = ...) -> __ICursor: ... + def execute_sql(self, sql: str, params: Optional[tuple] = ..., commit: Union[bool, Literal[SENTINEL]] = ...) -> __ICursor: ... + def execute(self, query: Query, commit: Union[bool, Literal[SENTINEL]] = ..., **context_options: Any) -> __ICursor: ... + def get_context_options(self) -> Mapping[str, object]: ... + def get_sql_context(self, **context_options: Any): + context = self.get_context_options() + if context_options: + context.update(context_options) + return self.context_class(**context) + def conflict_statement(self, on_conflict: OnConflict, query: Query) -> Optional[SQL]: ... + def conflict_update(self, oc: OnConflict, query: Query) -> NodeList: ... + def last_insert_id(self, cursor: __ICursor, query_type: Optional[int] = ...) -> int: ... + def rows_affected(self, cursor: __ICursor) -> int: ... + def default_values_insert(self, ctx: Context) -> Context: ... + def session_start(self) -> _transaction: ... + def session_commit(self) -> bool: ... + def session_rollback(self) -> bool: ... + def in_transaction(self) -> bool: ... + def push_transaction(self, transaction) -> None: ... + def pop_transaction(self) -> Union[_manual, _transaction]: ... + def transaction_depth(self) -> int: ... + def top_transaction(self) -> Optional[Union[_manual, _transaction]]: ... + def atomic(self, *args: object, **kwargs: object) -> _atomic: ... + def manual_commit(self) -> _manual: ... + def transaction(self, *args: object, **kwargs: object) -> _transaction: ... + def savepoint(self) -> _savepoint: ... + def begin(self) -> None: ... + def commit(self) -> None: ... + def rollback(self) -> None: ... + def batch_commit(self, it: Iterable[T], n: int) -> Iterator[T]: ... + def table_exists(self, table_name: str, schema: Optional[str] = ...) -> str: ... + def get_tables(self, schema: Optional[str] = ...) -> List[str]: ... + def get_indexes(self, table: str, schema: Optional[str] = ...) -> List[IndexMetadata]: ... + def get_columns(self, table: str, schema: Optional[str] = ...) -> List[ColumnMetadata]: ... + def get_primary_keys(self, table: str, schema: Optional[str] = ...) -> List[str]: ... + def get_foreign_keys(self, table: str, schema: Optional[str] = ...) -> List[ForeignKeyMetadata]: ... + def sequence_exists(self, seq: str) -> bool: ... + def create_tables(self, models, **options): + for model in sort_models(models): + model.create_table(**options) + def drop_tables(self, models, **kwargs): + for model in reversed(sort_models(models)): + model.drop_table(**kwargs) + def extract_date(self, date_part: str, date_field: Node) -> Node: ... + def truncate_date(self, date_part: str, date_field: Node) -> Node: ... + def to_timestamp(self, date_field: str) -> Node: ... + def from_timestamp(self, date_field: str) -> Node: ... + def random(self) -> Node: ... + def bind(self, models: Iterable[Type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ...) -> None: ... + def bind_ctx( + self, models: Iterable[Type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ... + ) -> _BoundModelsContext: ... + def get_noop_select(self, ctx: Context) -> Context: ... + +def __pragma__(name): + def __get__(self): + return self.pragma(name) + def __set__(self, value): + return self.pragma(name, value) + return property(__get__, __set__) + +class SqliteDatabase(Database): + field_types: ClassVar[Mapping[str, int]] + operations: ClassVar[Mapping[str, str]] + index_schema_prefix: ClassVar[bool] + limit_max: ClassVar[int] + server_version: ClassVar[Tuple[int, ...]] + truncate_table: ClassVar[bool] + # Instance variables + timeout: int + nulls_ordering: bool + # Properties + cache_size: int + def __init__( + self, + database: str, + *args: object, + pragmas: Union[Mapping[str, object], Iterable[Tuple[str, Any]]] = ..., + **kwargs: object, + ): ... + def init( + self, + database: str, + pragmas: Optional[Union[Mapping[str, object], Iterable[Tuple[str, Any]]]] = ..., + timeout: int = ..., + **kwargs: object, + ) -> None: ... + def pragma(self, key: str, value: Union[str, bool, int] = ..., permanent: bool = ..., schema: Optional[str] = ...) -> Any: ... + foreign_keys = __pragma__("foreign_keys") + journal_mode = __pragma__("journal_mode") + journal_size_limit = __pragma__("journal_size_limit") + mmap_size = __pragma__("mmap_size") + page_size = __pragma__("page_size") + read_uncommitted = __pragma__("read_uncommitted") + synchronous = __pragma__("synchronous") + wal_autocheckpoint = __pragma__("wal_autocheckpoint") + def register_aggregate(self, klass, name=None, num_params=-1): + self._aggregates[name or klass.__name__.lower()] = (klass, num_params) + if not self.is_closed(): + self._load_aggregates(self.connection()) + def aggregate(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[__TClass], __TClass]: ... + def register_collation(self, fn: Callable, name: Optional[str] = ...) -> None: ... + def collation(self, name: Optional[str] = ...) -> Callable[[__TFunc], __TFunc]: ... + def register_function(self, fn: Callable, name: Optional[str] = ..., num_params: int = ...) -> int: ... + def func(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[__TFunc], __TFunc]: ... + def register_window_function(self, klass: type, name: Optional[str] = ..., num_params: int = ...) -> None: ... + def window_function(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[__TClass], __TClass]: ... + def register_table_function(self, klass, name=None): + if name is not None: + klass.name = name + self._table_functions.append(klass) + if not self.is_closed(): + klass.register(self.connection()) + def table_function(self, name=None): + def decorator(klass): + self.register_table_function(klass, name) + return klass + return decorator + def unregister_aggregate(self, name: str) -> None: ... + def unregister_collation(self, name: str) -> None: ... + def unregister_function(self, name: str) -> None: ... + def unregister_window_function(self, name: str) -> None: ... + def unregister_table_function(self, name: str) -> bool: ... + def load_extension(self, extension: str) -> None: ... + def unload_extension(self, extension: str) -> None: ... + def attach(self, filename: str, name: str) -> bool: ... + def detach(self, name: str) -> bool: ... + def begin(self, lock_type: Optional[str] = ...) -> None: ... + def get_views(self, schema: Optional[str] = ...) -> List[ViewMetadata]: ... + def get_binary_type(self) -> type: ... + +class PostgresqlDatabase(Database): + field_types: ClassVar[Mapping[str, str]] + operations: ClassVar[Mapping[str, str]] + param: ClassVar[str] + commit_select: ClassVar[bool] + compound_select_parentheses: ClassVar[int] + for_update: ClassVar[bool] + nulls_ordering: ClassVar[bool] + returning_clause: ClassVar[bool] + safe_create_index: ClassVar[bool] + sequences: ClassVar[bool] + # Instance variables + server_version: int + # Technically this *only* exists if we have Postgres >=9.6 and it will always be + # True in that case. + safe_create_index: bool + def init( + self, + database: __IConnection, + register_unicode: bool = ..., + encoding: Optional[str] = ..., + isolation_level: Optional[int] = ..., + **kwargs: object, + ): ... + def is_connection_usable(self) -> bool: ... + @overload + def last_insert_id(self, cursor: __ICursor, query_type: Literal[Insert.SIMPLE] = ...) -> Optional[int]: ... # I think + @overload + def last_insert_id(self, cursor: __ICursor, query_type: Optional[int] = ...) -> __ICursor: ... + def get_views(self, schema: Optional[str] = ...) -> List[ViewMetadata]: ... + def get_binary_type(self) -> type: ... + def get_noop_select(self, ctx: Context) -> SelectQuery: ... + def set_time_zone(self, timezone: str) -> None: ... + +class MySQLDatabase(Database): + field_types: ClassVar[Mapping[str, str]] + operations: ClassVar[Mapping[str, str]] + param: ClassVar[str] + quote: ClassVar[str] + commit_select: ClassVar[bool] + compound_select_parentheses: ClassVar[int] + for_update: ClassVar[bool] + index_using_precedes_table: ClassVar[bool] + limit_max = 2 ** 64 - 1 + safe_create_index: ClassVar[bool] + safe_drop_index: ClassVar[bool] + sql_mode: ClassVar[str] + # Instance variables + server_version: Tuple[int, ...] + def init(self, database: __IConnection, **kwargs: object): ... + def default_values_insert(self, ctx: Context) -> SQL: ... + def get_views(self, schema: Optional[str] = ...) -> List[ViewMetadata]: ... + def get_binary_type(self) -> type: ... + def extract_date(self, date_part, date_field): + return fn.EXTRACT(NodeList((SQL(date_part), SQL("FROM"), date_field))) + def truncate_date(self, date_part, date_field): + return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part], python_value=simple_date_time) + def to_timestamp(self, date_field): + return fn.UNIX_TIMESTAMP(date_field) + def from_timestamp(self, date_field): + return fn.FROM_UNIXTIME(date_field) + def random(self): + return fn.rand() + def get_noop_select(self, ctx): + return ctx.literal("DO 0") + +# TRANSACTION CONTROL. + +class _manual(_callable_context_manager): + db: Database + def __init__(self, db: Database): ... + def __enter__(self) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + +class _atomic(_callable_context_manager): + db: Database + def __init__(self, db: Database, *args: object, **kwargs: object): ... + def __enter__(self) -> Union[_transaction, _savepoint]: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + +class _transaction(_callable_context_manager): + db: Database + def __init__(self, db: Database, *args: object, **kwargs: object): ... + def commit(self, begin: bool = ...) -> None: ... + def rollback(self, begin: bool = ...) -> None: ... + def __enter__(self) -> _transaction: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + +class _savepoint(_callable_context_manager): + db: Database + sid: str + quoted_sid: str + def __init__(self, db: Database, sid: Optional[str] = ...): ... + def commit(self, begin: bool = ...) -> None: ... + def rollback(self) -> None: ... + def __enter__(self) -> _savepoint: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + +class CursorWrapper(Generic[T]): + cursor: __ICursor + count: int + index: int + initialized: bool + populated: bool + row_cache: List[T] + def __init__(self, cursor: __ICursor): ... + def __iter__(self) -> Union[ResultIterator[T], Iterator[T]]: ... + @overload + def __getitem__(self, item: int) -> T: ... + @overload + def __getitem__(self, item: slice) -> List[T]: ... + def __len__(self) -> int: ... + def initialize(self) -> None: ... + def iterate(self, cache=True): + row = self.cursor.fetchone() + if row is None: + self.populated = True + self.cursor.close() + raise StopIteration + elif not self.initialized: + self.initialize() # Lazy initialization. + self.initialized = True + self.count += 1 + result = self.process_row(row) + if cache: + self.row_cache.append(result) + return result + def process_row(self, row: tuple) -> T: ... + def iterator(self): + """Efficient one-pass iteration over the result set.""" + while True: + try: + yield self.iterate(False) + except StopIteration: + return + def fill_cache(self, n: int = 0) -> None: ... + +class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ... + +class NamedTupleCursorWrapper(CursorWrapper[NamedTuple]): + tuple_class: Type[NamedTuple] + +# TODO: Indicate this inherits from DictCursorWrapper but also change the return type +class ObjectCursorWrapper(DictCursorWrapper): + def __init__(self, cursor, constructor): + super(ObjectCursorWrapper, self).__init__(cursor) + self.constructor = constructor + def process_row(self, row): + row_dict = self._row_to_dict(row) + return self.constructor(**row_dict) + +class ResultIterator(Generic[T]): + cursor_wrapper: CursorWrapper[T] + index: int + def __init__(self, cursor_wrapper: CursorWrapper[T]): ... + def __iter__(self) -> Iterator[T]: ... + +# FIELDS + +class FieldAccessor(object): + model: Type[Model] + field: Field + name: str + def __init__(self, model: Type[Model], field: Field, name: str): ... + @overload + def __get__(self, instance: None, instance_type: type) -> Field: ... + @overload + def __get__(self, instance: T, instance_type: Type[T]) -> Any: ... + +class ForeignKeyAccessor(FieldAccessor): + model: Type[Model] + field: ForeignKeyField + name: str + rel_model: Type[Model] + def __init__(self, model: Type[Model], field: ForeignKeyField, name: str): ... + def get_rel_instance(self, instance: Model) -> Any: + value = instance.__data__.get(self.name) + if value is not None or self.name in instance.__rel__: + if self.name not in instance.__rel__: + obj = self.rel_model.get(self.field.rel_field == value) + instance.__rel__[self.name] = obj + return instance.__rel__[self.name] + elif not self.field.null: + raise self.rel_model.DoesNotExist + return value + def __get__(self, instance, instance_type=None): + if instance is not None: + return self.get_rel_instance(instance) + return self.field + def __set__(self, instance, obj): + if isinstance(obj, self.rel_model): + instance.__data__[self.name] = getattr(obj, self.field.rel_field.name) + instance.__rel__[self.name] = obj + else: + fk_value = instance.__data__.get(self.name) + instance.__data__[self.name] = obj + if obj != fk_value and self.name in instance.__rel__: + del instance.__rel__[self.name] + instance._dirty.add(self.name) + +class NoQueryForeignKeyAccessor(ForeignKeyAccessor): + def get_rel_instance(self, instance: Model) -> Any: + value = instance.__data__.get(self.name) + if value is not None: + return instance.__rel__.get(self.name, value) + elif not self.field.null: + raise self.rel_model.DoesNotExist + +class BackrefAccessor(object): + field: ForeignKeyField + model: Type[Model] + rel_model: Type[Model] + def __init__(self, field: ForeignKeyField): ... + @overload + def __get__(self, instance: None, instance_type: type) -> BackrefAccessor: ... + @overload + def __get__(self, instance: Field, instance_type: Type["Field"]) -> SelectQuery: ... + +class ObjectIdAccessor(object): + """Gives direct access to the underlying id""" + + field: ForeignKeyField + def __init__(self, field: ForeignKeyField): ... + @overload + def __get__(self, instance: None, instance_type: Type[Model]) -> ForeignKeyField: ... + @overload + def __get__(self, instance: __TModel, instance_type: Type[__TModel] = ...) -> Any: ... + def __set__(self, instance: Model, value: Any) -> None: ... + +class Field(ColumnBase): + accessor_class: ClassVar[Type[FieldAccessor]] + auto_increment: ClassVar[bool] + default_index_type: ClassVar[Optional[str]] + field_type: ClassVar[str] + unpack: ClassVar[bool] + # Instance variables + model: Type[Model] + null: bool + index: bool + unique: bool + column_name: str + default: Any + primary_key: bool + constraints: Optional[Iterable[Check, SQL]] + sequence: Optional[str] + collation: Optional[str] + unindexed: bool + help_text: Optional[str] + verbose_name: Optional[str] + index_type: Optional[str] + def __init__( + self, + null: bool = ..., + index: bool = ..., + unique: bool = ..., + column_name: str = ..., + default: Any = ..., + primary_key: bool = ..., + constraints: Optional[Iterable[Check, SQL]] = ..., + sequence: Optional[str] = ..., + collation: Optional[str] = ..., + unindexed: Optional[bool] = ..., + choices: Optional[Iterable[Tuple[Any, str]]] = ..., + help_text: Optional[str] = ..., + verbose_name: Optional[str] = ..., + index_type: Optional[str] = ..., + db_column: Optional[str] = ..., # Deprecated argument, undocumented + _hidden: bool = ..., + ): ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... + @property + def column(self) -> Column: ... + def adapt(self, value: T) -> T: ... + def db_value(self, value: T) -> T: ... + def python_value(self, value: T) -> T: ... + def to_value(self, value: Any) -> Value: ... + def get_sort_key(self, ctx: Context) -> Tuple[int, int]: ... + def __sql__(self, ctx: Context) -> Context: ... + def get_modifiers(self) -> None: ... + def ddl_datatype(self, ctx: Context) -> SQL: ... + def ddl(self, ctx: Context) -> NodeList: ... + +class IntegerField(Field): + @overload + def adapt(self, value: Union[int, str, float, bool]) -> int: ... + @overload + def adapt(self, value: T) -> T: ... + +class BigIntegerField(IntegerField): ... +class SmallIntegerField(IntegerField): ... + +class AutoField(IntegerField): + def __init__(self, *args: object, primary_key: bool = ..., **kwargs: object): ... + +class BigAutoField(AutoField): ... + +class IdentityField(AutoField): + def __init__(self, generate_always: bool = ..., **kwargs: object): ... + +class PrimaryKeyField(AutoField): ... + +class FloatField(Field): + @overload + def adapt(self, value: Union[str, int, float, bool]) -> float: ... + @overload + def adapt(self, value: T) -> T: ... + +class DoubleField(FloatField): ... + +class DecimalField(Field): + max_digits: int + decimal_places: int + auto_round: int + rounding: bool + def __init__( + self, + max_digits: int = ..., + decimal_places: int = ..., + auto_round: bool = ..., + rounding: bool = ..., + *args: object, + **kwargs: object, + ): ... + def get_modifiers(self) -> List[int]: ... + @overload + def db_value(self, value: None) -> None: ... + @overload + def db_value(self, value: Union[int, float, decimal.Decimal]) -> decimal.Decimal: ... + @overload + def db_value(self, value: T) -> T: ... + @overload + def python_value(self, value: None) -> None: ... + @overload + def python_value(self, value: Union[int, str, float, decimal.Decimal]) -> decimal.Decimal: ... + +class _StringField(Field): + def adapt(self, value: AnyStr) -> str: ... + def __add__(self, other: Any) -> StringExpression: ... + def __radd__(self, other: Any) -> StringExpression: ... + +class CharField(_StringField): + max_length: int + def __init__(self, max_length: int = ..., *args: object, **kwargs: object): ... + def get_modifiers(self) -> Optional[List[int]]: ... + +class FixedCharField(CharField): ... +class TextField(_StringField): ... + +class BlobField(Field): + @overload + def db_value(self, value: Union[str, bytes]) -> bytearray: ... + @overload + def db_value(self, value: T) -> T: ... + +class BitField(BitwiseMixin, BigIntegerField): + def __init__(self, *args: object, default: Optional[int] = ..., **kwargs: object): ... + # FIXME (dargueta) Return type isn't 100% accurate; function creates a new class + def flag(self, value: Optional[int] = ...) -> ColumnBase: ... + +class BigBitFieldData(object): + def __init__(self, instance, name): + self.instance = instance + self.name = name + value = self.instance.__data__.get(self.name) + if not value: + value = bytearray() + elif not isinstance(value, bytearray): + value = bytearray(value) + self._buffer = self.instance.__data__[self.name] = value + def set_bit(self, idx: int) -> None: ... + def clear_bit(self, idx: bool) -> None: ... + def toggle_bit(self, idx: int) -> bool: ... + def is_set(self, idx: int) -> bool: ... + def __repr__(self) -> str: ... + +class BigBitFieldAccessor(FieldAccessor): + def __get__(self, instance, instance_type=None): + if instance is None: + return self.field + return BigBitFieldData(instance, self.name) + def __set__(self, instance: Any, value: Union[memoryview, bytearray, BigBitFieldData, str, bytes]) -> None: ... + +class BigBitField(BlobField): + accessor_class: ClassVar[Type[BigBitFieldAccessor]] + def __init__(self, *args: object, default: type = ..., **kwargs: object): ... + @overload + def db_value(self, value: None) -> None: ... + @overload + def db_value(self, value: T) -> bytes: ... + +class UUIDField(Field): + @overload + def db_value(self, value: AnyStr) -> str: ... + @overload + def db_value(self, value: T) -> T: ... + def python_value(self, value): + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(value) if value is not None else None + +class BinaryUUIDField(BlobField): + @overload + def db_value(self, value: None) -> None: ... + @overload + def db_value(self, value: Optional[Union[bytearray, bytes, str, uuid.UUID]]) -> bytes: ... + @overload + def python_value(self, value: None) -> None: ... + @overload + def python_value(self, value: Union[bytearray, bytes, memoryview, uuid.UUID]) -> uuid.UUID: ... + +def format_date_time(value: str, formats: Iterable[str], post_process: Optional[__TConvFunc] = ...) -> str: ... +@overload +def simple_date_time(value: T) -> T: ... + +class _BaseFormattedField(Field): + # TODO (dargueta): This is a class variable that can be overridden for instances + formats: Optional[Container[str]] + def __init__(self, formats: Optional[Container[str]] = ..., *args: object, **kwargs: object): ... + +class DateTimeField(_BaseFormattedField): + @property + def year(self) -> int: ... + @property + def month(self) -> int: ... + @property + def day(self) -> int: ... + @property + def hour(self) -> int: ... + @property + def minute(self) -> int: ... + @property + def second(self) -> int: ... + @overload + def adapt(self, value: str) -> str: ... + @overload + def adapt(self, value: T) -> T: ... + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + +class DateField(_BaseFormattedField): + @property + def year(self) -> int: ... + @property + def month(self) -> int: ... + @property + def day(self) -> int: ... + @overload + def adapt(self, value: str) -> str: ... + @overload + def adapt(self, value: datetime.datetime) -> datetime.date: ... + @overload + def adapt(self, value: T) -> T: ... + def to_timestamp(self): + return self.model._meta.database.to_timestamp(self) + def truncate(self, part): + return self.model._meta.database.truncate_date(part, self) + +class TimeField(_BaseFormattedField): + @overload + def adapt(self, value: str) -> str: ... + @overload + def adapt(self, value: Union[datetime.datetime, datetime.timedelta]) -> datetime.time: ... + @overload + def adapt(self, value: T) -> T: ... + @property + def hour(self) -> int: ... + @property + def minute(self) -> int: ... + @property + def second(self) -> int: ... + +class TimestampField(BigIntegerField): + valid_resolutions: ClassVar[Container[int]] + # Instance variables + resolution: int + ticks_to_microsecond: int + utc: bool + def __init__(self, *args: object, resolution: int = ..., utc: bool = ..., **kwargs: object): ... + def local_to_utc(self, dt: datetime.datetime) -> datetime.datetime: ... + def utc_to_local(self, dt: datetime.datetime) -> datetime.datetime: ... + def get_timestamp(self, value): + if self.utc: + # If utc-mode is on, then we assume all naive datetimes are in UTC. + return calendar.timegm(value.utctimetuple()) + else: + return time.mktime(value.timetuple()) + def db_value(self, value): + if value is None: + return + + if isinstance(value, datetime.datetime): ... + elif isinstance(value, datetime.date): + value = datetime.datetime(value.year, value.month, value.day) + else: + return int(round(value * self.resolution)) + + timestamp = self.get_timestamp(value) + if self.resolution > 1: + timestamp += value.microsecond * 0.000001 + timestamp *= self.resolution + return int(round(timestamp)) + @overload + def python_value(self, value: Union[int, float]) -> datetime.datetime: ... + @overload + def python_value(self, value: T) -> T: ... + def from_timestamp(self): + expr = (self / Value(self.resolution, converter=False)) if self.resolution > 1 else self + return self.model._meta.database.from_timestamp(expr) + @property + def year(self) -> int: ... + @property + def month(self) -> int: ... + @property + def day(self) -> int: ... + @property + def hour(self) -> int: ... + @property + def minute(self) -> int: ... + @property + def second(self) -> int: ... # TODO (dargueta) Float? + +class IPField(BigIntegerField): + @overload + def db_value(self, val: str) -> int: ... + @overload + def db_value(self, val: None) -> None: ... + @overload + def python_value(self, val: int) -> str: ... + @overload + def python_value(self, val: None) -> None: ... + +class BooleanField(Field): + def adapt(self, value: Any) -> bool: ... + +class BareField(Field): + def __init__(self, adapt=None, *args, **kwargs): + super(BareField, self).__init__(*args, **kwargs) + if adapt is not None: + self.adapt = adapt + def ddl_datatype(self, ctx): + return + +class ForeignKeyField(Field): + accessor_class = ForeignKeyAccessor + def __init__( + self, + model, + field=None, + backref=None, + on_delete=None, + on_update=None, + deferrable=None, + _deferred=None, + rel_model=None, + to_field=None, + object_id_name=None, + lazy_load=True, + related_name=None, + *args, + **kwargs, + ): + kwargs.setdefault("index", True) + + # If lazy_load is disable, we use a different descriptor/accessor that + # will ensure we don't accidentally perform a query. + if not lazy_load: + self.accessor_class = NoQueryForeignKeyAccessor + + super(ForeignKeyField, self).__init__(*args, **kwargs) + + self._is_self_reference = model == "self" + self.rel_model = model + self.rel_field = field + self.declared_backref = backref + self.backref = None + self.on_delete = on_delete + self.on_update = on_update + self.deferrable = deferrable + self.deferred = _deferred + self.object_id_name = object_id_name + self.lazy_load = lazy_load + @property + def field_type(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.field_type + elif isinstance(self.rel_field, BigAutoField): + return BigIntegerField.field_type + return IntegerField.field_type + def get_modifiers(self): + if not isinstance(self.rel_field, AutoField): + return self.rel_field.get_modifiers() + return super(ForeignKeyField, self).get_modifiers() + def adapt(self, value): + return self.rel_field.adapt(value) + def db_value(self, value): + if isinstance(value, self.rel_model): + value = getattr(value, self.rel_field.name) + return self.rel_field.db_value(value) + def python_value(self, value): + if isinstance(value, self.rel_model): + return value + return self.rel_field.python_value(value) + def bind(self, model, name, set_attribute=True): + if not self.column_name: + self.column_name = name if name.endswith("_id") else name + "_id" + if not self.object_id_name: + self.object_id_name = self.column_name + if self.object_id_name == name: + self.object_id_name += "_id" + elif self.object_id_name == name: + raise ValueError( + 'ForeignKeyField "%s"."%s" specifies an ' + "object_id_name that conflicts with its field " + "name." % (model._meta.name, name) + ) + if self._is_self_reference: + self.rel_model = model + if isinstance(self.rel_field, str): + self.rel_field = getattr(self.rel_model, self.rel_field) + elif self.rel_field is None: + self.rel_field = self.rel_model._meta.primary_key + + # Bind field before assigning backref, so field is bound when + # calling declared_backref() (if callable). + super(ForeignKeyField, self).bind(model, name, set_attribute) + self.safe_name = self.object_id_name + + if callable(self.declared_backref): + self.backref = self.declared_backref(self) + else: + self.backref, self.declared_backref = self.declared_backref, None + if not self.backref: + self.backref = "%s_set" % model._meta.name + + if set_attribute: + setattr(model, self.object_id_name, ObjectIdAccessor(self)) + if self.backref not in "!+": + setattr(self.rel_model, self.backref, BackrefAccessor(self)) + def foreign_key_constraint(self): + parts = [ + SQL("FOREIGN KEY"), + EnclosedNodeList((self,)), + SQL("REFERENCES"), + self.rel_model, + EnclosedNodeList((self.rel_field,)), + ] + if self.on_delete: + parts.append(SQL("ON DELETE %s" % self.on_delete)) + if self.on_update: + parts.append(SQL("ON UPDATE %s" % self.on_update)) + if self.deferrable: + parts.append(SQL("DEFERRABLE %s" % self.deferrable)) + return NodeList(parts) + def __getattr__(self, attr): + if attr.startswith("__"): + # Prevent recursion error when deep-copying. + raise AttributeError('Cannot look-up non-existant "__" methods.') + if attr in self.rel_model._meta.fields: + return self.rel_model._meta.fields[attr] + raise AttributeError("Foreign-key has no attribute %s, nor is it a " "valid field on the related model." % attr) + +class DeferredForeignKey(Field): + _unresolved = set() + def __init__(self, rel_model_name, **kwargs): + self.field_kwargs = kwargs + self.rel_model_name = rel_model_name.lower() + DeferredForeignKey._unresolved.add(self) + super(DeferredForeignKey, self).__init__(column_name=kwargs.get("column_name"), null=kwargs.get("null")) + __hash__ = object.__hash__ + def __deepcopy__(self, memo=None): + return DeferredForeignKey(self.rel_model_name, **self.field_kwargs) + def set_model(self, rel_model): + field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) + self.model._meta.add_field(self.name, field) + @staticmethod + def resolve(model_cls): + unresolved = sorted(DeferredForeignKey._unresolved, key=operator.attrgetter("_order")) + for dr in unresolved: + if dr.rel_model_name == model_cls.__name__.lower(): + dr.set_model(model_cls) + DeferredForeignKey._unresolved.discard(dr) + +class DeferredThroughModel(object): + def __init__(self): + self._refs = [] + def set_field(self, model, field, name): + self._refs.append((model, field, name)) + def set_model(self, through_model): + for src_model, m2mfield, name in self._refs: + m2mfield.through_model = through_model + src_model._meta.add_field(name, m2mfield) + +class MetaField(Field): + column_name = default = model = name = None + primary_key = False + +class ManyToManyFieldAccessor(FieldAccessor): + model: Type[Model] + rel_model: Type[Model] + through_model: Type[Model] + src_fk: ForeignKeyField + dest_fk: ForeignKeyField + def __init__(self, model: Type[Model], field: ForeignKeyField, name: str): ... + @overload + def __get__(self, instance: None, instance_type: Type[T] = ..., force_query: bool = ...) -> Field: ... + @overload + def __get__( + self, instance: T, instance_type: Type[T] = ..., force_query: bool = ... + ) -> Union[List[str], ManyToManyQuery]: ... + def __set__(self, instance: T, value) -> None: + query = self.__get__(instance, force_query=True) + query.add(value, clear_existing=True) + +class ManyToManyField(MetaField): + accessor_class: ClassVar[Type[ManyToManyFieldAccessor]] + # Instance variables + through_model: Union[Type[Model], DeferredThroughModel] + rel_model: Type[Model] + backref: Optional[str] + def __init__( + self, + model: Type[Model], + backref: Optional[str] = ..., + through_model: Optional[Union[Type[Model], DeferredThroughModel]] = ..., + on_delete: Optional[str] = ..., + on_update: Optional[str] = ..., + _is_backref: bool = ..., + ): ... + def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... + def get_models(self) -> List[Type[Model]]: ... + def get_through_model(self) -> Union[Type[Model], DeferredThroughModel]: ... + +class VirtualField(MetaField, Generic[__TField]): + field_class: Type[__TField] + field_instance: __TField + def __init__(self, field_class: Optional[Type[__TField]] = ..., *args: object, **kwargs: object): ... + def db_value(self, value): + if self.field_instance is not None: + return self.field_instance.db_value(value) + return value + def python_value(self, value): + if self.field_instance is not None: + return self.field_instance.python_value(value) + return value + def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... + +class CompositeKey(MetaField): + sequence = None + def __init__(self, *field_names): + self.field_names = field_names + self._safe_field_names = None + @property + def safe_field_names(self): + if self._safe_field_names is None: + if self.model is None: + return self.field_names + + self._safe_field_names = [self.model._meta.fields[f].safe_name for f in self.field_names] + return self._safe_field_names + def __get__(self, instance, instance_type=None): + if instance is not None: + return tuple([getattr(instance, f) for f in self.safe_field_names]) + return self + def __set__(self, instance, value): + if not isinstance(value, (list, tuple)): + raise TypeError("A list or tuple must be used to set the value of " "a composite primary key.") + if len(value) != len(self.field_names): + raise ValueError("The length of the value must equal the number " "of columns of the composite primary key.") + for idx, field_value in enumerate(value): + setattr(instance, self.field_names[idx], field_value) + def __eq__(self, other): + expressions = [(self.model._meta.fields[field] == value) for field, value in zip(self.field_names, other)] + return reduce(operator.and_, expressions) + def __ne__(self, other): + return ~(self == other) + def __hash__(self): + return hash((self.model.__name__, self.field_names)) + def __sql__(self, ctx): + # If the composite PK is being selected, do not use parens. Elsewhere, + # such as in an expression, we want to use parentheses and treat it as + # a row value. + parens = ctx.scope != SCOPE_SOURCE + return ctx.sql(NodeList([self.model._meta.fields[field] for field in self.field_names], ", ", parens)) + def bind(self, model, name, set_attribute=True): + self.model = model + self.column_name = self.name = self.safe_name = name + setattr(model, self.name, self) + +class _SortedFieldList(object): + __slots__ = ("_keys", "_items") + def __init__(self): + self._keys = [] + self._items = [] + def __getitem__(self, i): + return self._items[i] + def __iter__(self): + return iter(self._items) + def __contains__(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + j = bisect_right(self._keys, k) + return item in self._items[i:j] + def index(self, field): + return self._keys.index(field._sort_key) + def insert(self, item): + k = item._sort_key + i = bisect_left(self._keys, k) + self._keys.insert(i, k) + self._items.insert(i, item) + def remove(self, item): + idx = self.index(item) + del self._items[idx] + del self._keys[idx] + +# MODELS + +class SchemaManager(object): + model: Type[Model] + context_options: Dict[str, Any] + def __init__(self, model: Type[Model], database: Optional[Database] = None, **context_options: Any): ... + @property + def database(self) -> Database: ... + @database.setter + def database(self, value: Optional[Database]) -> None: ... + def create_table(self, safe: bool = ..., **options: Any) -> None: ... + def create_table_as(self, table_name: str, query: SelectQuery, safe: bool = ..., **meta: Any) -> None: ... + def drop_table(self, safe: bool = ..., **options: Any) -> None: ... + def truncate_table(self, restart_identity: bool = ..., cascade: bool = ...) -> None: ... + def create_indexes(self, safe: bool = ...) -> None: ... + def drop_indexes(self, safe: bool = ...) -> None: ... + def create_sequence(self, field: Field) -> None: ... + def drop_sequence(self, field: Field) -> None: ... + def create_foreign_key(self, field: Field) -> None: ... + def create_sequences(self) -> None: ... + def create_all(self, safe: bool = ..., **table_options: Any) -> None: ... + def drop_sequences(self) -> None: ... + def drop_all(self, safe: bool = ..., drop_sequences: bool = ..., **options: Any) -> None: ... + +class Metadata(object): + model: Type[Model] + database: Optional[Database] + fields: Dict[str, Any] # TODO (dargueta) This may be Dict[str, Field] + columns: Dict[str, Any] # TODO (dargueta) Verify this + combined: Dict[str, Any] # TODO (dargueta) Same as above + sorted_fields: List[Field] + sorted_field_names: List[str] + defaults: Dict[str, Any] + name: str + table_function: Optional[Callable[[Type[Model]], str]] + legacy_table_names: bool + table_name: str + indexes: List[Union[Index, ModelIndex, SQL]] + constraints: Optional[Iterable[Union[Check, SQL]]] + primary_key: Union[Literal[False], Field, CompositeKey, None] + composite_key: Optional[bool] + auto_increment: Optional[bool] + only_save_dirty: bool + depends_on: Optional[Sequence[Type[Model]]] + table_settings: Mapping[str, object] + temporary: bool + refs: Dict[ForeignKeyField, Type[Model]] + backrefs: MutableMapping[ForeignKeyField, List[Type[Model]]] + model_refs: MutableMapping[Type[Model], List[ForeignKeyField]] + model_backrefs: MutableMapping[ForeignKeyField, List[Type[Model]]] + manytomany: Dict[str, ManyToManyField] + options: Mapping[str, object] + table: Optional[Table] + entity: Entity + def __init__( + self, + model: Type[Model], + database: Optional[Database] = ..., + table_name: Optional[str] = ..., + indexes: Optional[Iterable[Union[str, Sequence[str]]]] = ..., + primary_key: Optional[Union[Literal[False], Field, CompositeKey]] = ..., + constraints: Optional[Iterable[Union[Check, SQL]]] = ..., + schema: Optional[str] = ..., + only_save_dirty: bool = ..., + depends_on: Optional[Sequence[Type[Model]]] = ..., + options: Optional[Mapping[str, object]] = ..., + db_table: Optional[str] = ..., + table_function: Optional[Callable[[Type[Model]], str]] = ..., + table_settings: Optional[Mapping[str, object]] = ..., + without_rowid: bool = ..., + temporary: bool = ..., + legacy_table_names: bool = ..., + **kwargs: object, + ): ... + def make_table_name(self) -> str: ... + def model_graph( + self, refs: bool = ..., backrefs: bool = ..., depth_first: bool = ... + ) -> List[Tuple[ForeignKeyField, Type[Model], bool]]: ... + def add_ref(self, field: ForeignKeyField) -> None: ... + def remove_ref(self, field: ForeignKeyField) -> None: ... + def add_manytomany(self, field: ManyToManyField) -> None: ... + def remove_manytomany(self, field: ManyToManyField) -> None: ... + def get_rel_for_model(self, model: Union[Type[Model], ModelAlias]) -> Tuple[List[ForeignKeyField], List[Type[Model]]]: ... + def add_field(self, field_name: str, field: Field, set_attribute: bool = ...) -> None: ... + def remove_field(self, field_name: str) -> None: ... + def set_primary_key(self, name: str, field: Union[Field, CompositeKey]) -> None: ... + def get_primary_keys(self) -> Tuple[Field, ...]: ... + def get_default_dict(self) -> Dict[str, object]: ... + def fields_to_index(self) -> List[ModelIndex]: ... + def set_database(self, database: Database) -> None: ... + def set_table_name(self, table_name: str) -> None: ... + +class SubclassAwareMetadata(Metadata): + models: ClassVar[List[Type[Model]]] + def __init__(self, model: Type[Model], *args: object, **kwargs: object): ... + def map_models(self, fn: Callable[[Type[Model]], Any]) -> None: ... + +class DoesNotExist(Exception): ... + +class ModelBase(type): + inheritable: ClassVar[Set[str]] + def __repr__(self) -> str: ... + def __iter__(self) -> Iterator[Any]: ... + def __getitem__(self, key: object) -> Model: ... + def __setitem__(self, key: object, value: Model) -> None: ... + def __delitem__(self, key: object) -> None: ... + def __contains__(self, key: object) -> bool: ... + def __len__(self) -> int: ... + def __bool__(self) -> bool: ... + def __nonzero__(self) -> bool: ... + def __sql__(self, ctx: Context) -> Context: ... + +class _BoundModelsContext(_callable_context_manager): + models: Iterable[Type[Model]] + database: Database + bind_refs: bool + bind_backrefs: bool + def __init__(self, models: Iterable[Type[Model]], database, bind_refs: bool, bind_backrefs: bool): ... + def __enter__(self) -> Iterable[Type[Model]]: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + +class Model(Node, metaclass=ModelBase): + _meta: ClassVar[Metadata] + _schema: ClassVar[SchemaManager] + DoesNotExist: ClassVar[Type[DoesNotExist]] + __data__: MutableMapping[str, object] + __rel__: MutableMapping[str, object] + def __init__(self, *, __no_default__: Union[int, bool] = ..., **kwargs: object): ... + def __str__(self) -> str: ... + @classmethod + def validate_model(cls) -> None: ... + @classmethod + def alias(cls, alias: Optional[str] = ...) -> ModelAlias: ... + @classmethod + def select(cls, *fields: Field) -> ModelSelect: ... + @classmethod + def update(cls, __data: Optional[Iterable[Union[str, Field]]] = ..., **update: Any) -> ModelUpdate: ... + @classmethod + def insert(cls, __data: Optional[Iterable[Union[str, Field]]] = ..., **insert: Any) -> ModelInsert: ... + @overload + @classmethod + def insert_many(cls, rows: Iterable[Mapping[str, object]], fields: None) -> ModelInsert: ... + @overload + @classmethod + def insert_many(cls, rows: Iterable[tuple], fields: Sequence[Field]) -> ModelInsert: ... + @classmethod + def insert_from(cls, query: SelectQuery, fields: Iterable[Union[Field, Text]]) -> ModelInsert: ... + @classmethod + def replace(cls, __data=None, **insert): + return cls.insert(__data, **insert).on_conflict("REPLACE") + @classmethod + def replace_many(cls, rows, fields=None): + return cls.insert_many(rows=rows, fields=fields).on_conflict("REPLACE") + @classmethod + def raw(cls, sql, *params): + return ModelRaw(cls, sql, params) + @classmethod + def delete(cls) -> ModelDelete: ... + @classmethod + def create(cls: Type[T], **query) -> T: ... + @classmethod + def bulk_create(cls, model_list: Iterable[Type[Model]], batch_size: Optional[int] = ...) -> None: ... + @classmethod + def bulk_update( + cls, model_list: Iterable[Type[Model]], fields: Iterable[Union[str, Field]], batch_size: Optional[int] = ... + ) -> int: ... + @classmethod + def noop(cls) -> NoopModelSelect: ... + @classmethod + def get(cls, *query, **filters): + sq = cls.select() + if query: + # Handle simple lookup using just the primary key. + if len(query) == 1 and isinstance(query[0], int): + sq = sq.where(cls._meta.primary_key == query[0]) + else: + sq = sq.where(*query) + if filters: + sq = sq.filter(**filters) + return sq.get() + @classmethod + def get_or_none(cls, *query, **filters): + try: + return cls.get(*query, **filters) + except DoesNotExist: ... + @classmethod + def get_by_id(cls, pk): + return cls.get(cls._meta.primary_key == pk) + @classmethod + def set_by_id(cls, key, value) -> Any: # TODO (dargueta): Verify return type of .execute() + if key is None: + return cls.insert(value).execute() + else: + return cls.update(value).where(cls._meta.primary_key == key).execute() + @classmethod + def delete_by_id(cls, pk: object) -> Any: ... # TODO (dargueta): Verify return type of .execute() + @classmethod + def get_or_create(cls, *, defaults: Mapping[str, object] = ..., **kwargs: object) -> Tuple[Any, bool]: ... + @classmethod + def filter(cls, *dq_nodes: DQ, **filters: Any) -> SelectQuery: ... + def get_id(self) -> Any: ... + def save(self, force_insert: bool = ..., only: Optional[Iterable[Union[str, Field]]] = ...) -> Union[Literal[False], int]: ... + def is_dirty(self) -> bool: ... + @property + def dirty_fields(self) -> List[Field]: ... + def dependencies(self, search_nullable: bool = ...) -> Iterator[Tuple[Union[bool, Node], ForeignKeyField]]: ... + def delete_instance(self: T, recursive: bool = ..., delete_nullable: bool = ...) -> T: ... + def __hash__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __sql__(self, ctx: Context) -> Context: ... + @classmethod + def bind( + cls, + database: Database, + bind_refs: bool = ..., + bind_backrefs: bool = ..., + _exclude: Optional[MutableSet[Type[Model]]] = ..., + ) -> bool: ... + @classmethod + def bind_ctx(cls, database: Database, bind_refs: bool = ..., bind_backrefs: bool = ...) -> _BoundModelsContext: ... + @classmethod + def table_exists(cls) -> bool: ... + @classmethod + def create_table(cls, safe: bool = ..., *, fail_silently: bool = ..., **options: object) -> None: ... + @classmethod + def drop_table(cls, safe: bool = ..., drop_sequences: bool = ..., **options: object) -> None: ... + @classmethod + def truncate_table(cls, **options: object) -> None: ... + @classmethod + def index(cls, *fields, **kwargs): + return ModelIndex(cls, fields, **kwargs) + @classmethod + def add_index(cls, *fields: Union[str, SQL, Index], **kwargs: object) -> None: ... + +class ModelAlias(Node): + """Provide a separate reference to a model in a query.""" + + model: Type[Model] + alias: Optional[str] + def __init__(self, model: Type[Model], alias: Optional[str] = ...): ... + def __getattr__(self, attr: str): + # Hack to work-around the fact that properties or other objects + # implementing the descriptor protocol (on the model being aliased), + # will not work correctly when we use getattr(). So we explicitly pass + # the model alias to the descriptor's getter. + try: + obj = self.model.__dict__[attr] + except KeyError: ... + else: + if isinstance(obj, ModelDescriptor): + return obj.__get__(None, self) + + model_attr = getattr(self.model, attr) + if isinstance(model_attr, Field): + self.__dict__[attr] = FieldAlias.create(self, model_attr) + return self.__dict__[attr] + return model_attr + def __setattr__(self, attr: str, value: object) -> NoReturn: ... + def get_field_aliases(self) -> List[Field]: ... + def select(self, *selection: Field) -> ModelSelect: ... + def __call__(self, **kwargs): + return self.model(**kwargs) + def __sql__(self, ctx: Context) -> Context: ... + +class FieldAlias(Field): + source: Node + model: Type[Model] + field: Field + # TODO (dargueta): Making an educated guess about `source`; might be `Node` + def __init__(self, source: MetaField, field: Field): ... + @classmethod + def create(cls, source: ModelAlias, field: str): + class _FieldAlias(cls, type(field)): ... + return _FieldAlias(source, field) + def clone(self) -> FieldAlias: ... + def adapt(self, value): + return self.field.adapt(value) + def python_value(self, value): + return self.field.python_value(value) + def db_value(self, value): + return self.field.db_value(value) + def __getattr__(self, attr): + return self.source if attr == "model" else getattr(self.field, attr) + def __sql__(self, ctx: Context) -> Context: ... + +def sort_models(models: Iterable[Type[Model]]) -> List[Type[Model]]: ... + +class _ModelQueryHelper(object): + default_row_type: ClassVar[int] + def objects(self, constructor: Optional[Callable[..., Any]] = ...) -> _ModelQueryHelper: ... + +class ModelRaw(_ModelQueryHelper, RawQuery, Generic[__TModel]): + model: Type[__TModel] + def __init__(self, model: Type[__TModel], sql: str, params: tuple, **kwargs: object): ... + def get(self) -> __TModel: ... + +class BaseModelSelect(_ModelQueryHelper): + def union_all(self, rhs): + return ModelCompoundSelectQuery(self.model, self, "UNION ALL", rhs) + __add__ = union_all + def union(self, rhs): + return ModelCompoundSelectQuery(self.model, self, "UNION", rhs) + __or__ = union + def intersect(self, rhs): + return ModelCompoundSelectQuery(self.model, self, "INTERSECT", rhs) + __and__ = intersect + def except_(self, rhs): + return ModelCompoundSelectQuery(self.model, self, "EXCEPT", rhs) + __sub__ = except_ + def __iter__(self) -> Iterator[Any]: + if not self._cursor_wrapper: + self.execute() + return iter(self._cursor_wrapper) + def prefetch(self, *subqueries: __TSubquery) -> List[Any]: ... + def get(self, database: Optional[Database] = ...) -> Any: ... + def group_by(self, *columns: Union[Type[Model], Table, Field]) -> BaseModelSelect: ... + +class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): + model: Type[__TModel] + def __init__(self, model: Type[__TModel], *args: object, **kwargs: object): ... + +class ModelSelect(BaseModelSelect, Select): + model: Type[Model] + def __init__(self, model: Type[Model], fields_or_models: Iterable[__TFieldOrModel], is_default: bool = ...): ... + def clone(self) -> ModelSelect: ... + def select(self, *fields_or_models: __TFieldOrModel) -> ModelSelect: ... + def switch(self, ctx: Optional[Type[Model]] = ...) -> ModelSelect: ... + def join( + self, + dest: Union[Type[Model], Table, ModelAlias, ModelSelect], + join_type: int = ..., + on: Union[Column, Expression, Field, None] = ..., + src: Union[Type[Model], Table, ModelAlias, ModelSelect, None] = ..., + attr: Optional[str] = ..., + ) -> ModelSelect: ... + def join_from( + self, + src: Union[Type[Model], Table, ModelAlias, ModelSelect], + dest: Union[Type[Model], Table, ModelAlias, ModelSelect], + join_type: int = ..., + on: Union[Column, Expression, Field, None] = ..., + attr: Optional[str] = ..., + ) -> ModelSelect: ... + def ensure_join( + self, lm: Type[Model], rm: Type[Model], on: Union[Column, Expression, Field, None] = ..., **join_kwargs: Any + ) -> ModelSelect: ... + # TODO (dargueta): 85% sure about the return value + def convert_dict_to_node(self, qdict: Mapping[str, object]) -> Tuple[List[Expression], List[Field]]: ... + def filter(self, *args: Node, **kwargs: object) -> ModelSelect: ... + def create_table(self, name: str, safe: bool = ..., **meta: Any) -> None: ... + def __sql_selection__(self, ctx: Context, is_subquery: bool = ...) -> Context: ... + +class NoopModelSelect(ModelSelect): + def __sql__(self, ctx: Context) -> Context: ... + +class _ModelWriteQueryHelper(_ModelQueryHelper): + model: Type[Model] + def __init__(self, model: Type[Model], *args: object, **kwargs: object): ... + def returning(self, *returning: Union[Type[Model], Field]) -> _ModelWriteQueryHelper: ... + +class ModelUpdate(_ModelWriteQueryHelper, Update): ... + +class ModelInsert(_ModelWriteQueryHelper, Insert): + default_row_type: ClassVar[int] + def returning(self, *returning: Union[Type[Model], Field]) -> ModelInsert: ... + def get_default_data(self): + return self.model._meta.defaults + def get_default_columns(self) -> Sequence[Field]: ... + +class ModelDelete(_ModelWriteQueryHelper, Delete): ... + +class ManyToManyQuery(ModelSelect): + def __init__( + self, instance: Model, accessor: ManyToManyFieldAccessor, rel: __TFieldOrModel, *args: object, **kwargs: object + ): ... + def add(self, value: Union[SelectQuery, Type[Model], Iterable[str]], clear_existing: bool = ...) -> None: ... + def remove(self, value: Union[SelectQuery, Type[Model], Iterable[str]]) -> Optional[int]: ... + def clear(self) -> int: ... + +class BaseModelCursorWrapper(DictCursorWrapper, Generic[__TModel]): + ncols: int + columns: List[str] + converters: List[__TConvFunc] + fields: List[Field] + model: Type[__TModel] + select: Sequence[str] + def __init__(self, cursor: __ICursor, model: Type[__TModel], columns: Optional[Sequence[str]]): ... + def process_row(self, row: tuple) -> Mapping[str, object]: ... + +class ModelDictCursorWrapper(BaseModelCursorWrapper[__TModel]): + def process_row(self, row: tuple) -> Dict[str, Any]: ... + +class ModelTupleCursorWrapper(ModelDictCursorWrapper[__TModel]): + constructor: ClassVar[Callable[[Sequence[Any]], tuple]] + def process_row(self, row: tuple) -> tuple: ... + +class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper[__TModel]): ... + +class ModelObjectCursorWrapper(ModelDictCursorWrapper[__TModel]): + constructor: Union[Type[__TModel], Callable[[Any], __TModel]] + is_model: bool + # TODO (dargueta): `select` is some kind of Sequence + def __init__( + self, cursor: __ICursor, model: __TModel, select, constructor: Union[Type[__TModel], Callable[[Any], __TModel]] + ): ... + def process_row(self, row: tuple) -> __TModel: ... + +class ModelCursorWrapper(BaseModelCursorWrapper[__TModel]): + from_list: Any # TODO (dargueta) -- Iterable[Union[Join, ...]] + joins: Any # TODO (dargueta) -- Mapping[, Tuple[?, ?, Callable[..., __TModel], int?]] + key_to_constructor: Dict[Type[__TModel], Callable[..., __TModel]] + src_is_dest: Dict[Type[Model], bool] + src_to_dest: List[tuple] # TODO -- Tuple[, join_type[1], join_type[0], bool, join_type[3]] + column_keys: List # TODO + def __init__(self, cursor: __ICursor, model: Type[__TModel], select, from_list, joins): + super(ModelCursorWrapper, self).__init__(cursor, model, select) + self.from_list = from_list + self.joins = joins + def initialize(self) -> None: + self._initialize_columns() + selected_src = set([field.model for field in self.fields if field is not None]) + select, columns = self.select, self.columns + + self.key_to_constructor = {self.model: self.model} + self.src_is_dest = {} + self.src_to_dest = [] + accum = collections.deque(self.from_list) + dests = set() + + while accum: + curr = accum.popleft() + if isinstance(curr, Join): + accum.append(curr.lhs) + accum.append(curr.rhs) + continue + + if curr not in self.joins: + continue + + is_dict = isinstance(curr, dict) + for key, attr, constructor, join_type in self.joins[curr]: + if key not in self.key_to_constructor: + self.key_to_constructor[key] = constructor + + # (src, attr, dest, is_dict, join_type). + self.src_to_dest.append((curr, attr, key, is_dict, join_type)) + dests.add(key) + accum.append(key) + + # Ensure that we accommodate everything selected. + for src in selected_src: + if src not in self.key_to_constructor: + if is_model(src): + self.key_to_constructor[src] = src + elif isinstance(src, ModelAlias): + self.key_to_constructor[src] = src.model + + # Indicate which sources are also dests. + for src, _, dest, _, _ in self.src_to_dest: + self.src_is_dest[src] = src in dests and (dest in selected_src or src in selected_src) + + self.column_keys = [] + for idx, node in enumerate(select): + key = self.model + field = self.fields[idx] + if field is not None: + if isinstance(field, FieldAlias): + key = field.source + else: + key = field.model + else: + if isinstance(node, Node): + node = node.unwrap() + if isinstance(node, Column): + key = node.source + + self.column_keys.append(key) + def process_row(self, row: tuple) -> __TModel: ... + +class __PrefetchQuery(NamedTuple): + query: Query # TODO (dargueta): Verify + fields: Optional[Sequence[Field]] + is_backref: Optional[bool] + rel_models: Optional[List[Type[Model]]] + field_to_name: Optional[List[Tuple[Field, str]]] + model: Type[Model] + +class PrefetchQuery(__PrefetchQuery): + # TODO (dargueta): The key is a two-tuple but not completely sure what + def populate_instance(self, instance: Model, id_map: Mapping[tuple, Any]): + if self.is_backref: + for field in self.fields: + identifier = instance.__data__[field.name] + key = (field, identifier) + if key in id_map: + setattr(instance, field.name, id_map[key]) + else: + for field, attname in self.field_to_name: + identifier = instance.__data__[field.rel_field.name] + key = (field, identifier) + rel_instances = id_map.get(key, []) + for inst in rel_instances: + setattr(inst, attname, instance) + inst._dirty.clear() + setattr(instance, field.backref, rel_instances) + # TODO (dargueta): Same question here about the key tuple + def store_instance(self, instance: Model, id_map: MutableMapping[tuple, List[Model]]) -> None: ... + +def prefetch_add_subquery(sq: Query, subqueries: Iterable[__TSubquery]) -> List[PrefetchQuery]: ... +def prefetch(sq: Query, *subqueries: __TSubquery) -> List[Any]: ... From b7eebcda73bc04aa46e74f9f4d0bed320e94cddd Mon Sep 17 00:00:00 2001 From: Diego Argueta Date: Mon, 30 Nov 2020 19:12:52 -0800 Subject: [PATCH 02/22] More work on annotations --- third_party/2and3/peewee.pyi | 1176 +++++++++++++--------------------- 1 file changed, 438 insertions(+), 738 deletions(-) diff --git a/third_party/2and3/peewee.pyi b/third_party/2and3/peewee.pyi index 7f3bc7969d7d..5a169604c6a0 100644 --- a/third_party/2and3/peewee.pyi +++ b/third_party/2and3/peewee.pyi @@ -1,56 +1,49 @@ -from bisect import bisect_left -from bisect import bisect_right -from contextlib import contextmanager -from functools import wraps -from typing import Any -from typing import AnyStr -from typing import Callable -from typing import ClassVar -from typing import Container -from typing import ContextManager -from typing import Dict -from typing import Generic -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Mapping -from typing import MutableMapping -from typing import MutableSet -from typing import NamedTuple -from typing import NoReturn -from typing import Optional -from typing import overload -from typing import Set -from typing import Sequence -from typing import Text -from typing import Tuple -from typing import Type -from typing import TypeVar -from typing import Union -import calendar -import collections import datetime import decimal +import enum import operator import re import threading -import time import uuid - -from typing_extensions import Literal -from typing_extensions import Protocol +from bisect import bisect_left, bisect_right +from contextlib import contextmanager +from typing import ( + Any, + AnyStr, + Callable, + ClassVar, + Container, + ContextManager, + Dict, + Generic, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + MutableSet, + NamedTuple, + NoReturn, + Optional, + Sequence, + Set, + Text, + Tuple, + Type, + TypeVar, + Union, + overload, +) +from typing_extensions import Literal, Protocol T = TypeVar("T") -__TModel = TypeVar("__TModel", bound="Model") -__TConvFunc = Callable[[Any], Any] -__TFunc = TypeVar("__TFunc", bound=Callable) -__TClass = TypeVar("__TClass", bound=type) -__TModelOrTable = Union[Type["Model"], "ModelAlias", "Table"] -__TSubquery = Union[Tuple["Query", Type["Model"]], Type["Model"], "ModelAlias"] -__TContextClass = TypeVar("__TContextClass", bound="Context") -__TField = TypeVar("__TField", bound="Field") -__TFieldOrModel = Union[__TModelOrTable, "Field"] -__TNode = TypeVar("__TNode", bound="Node") +_TModel = TypeVar("_TModel", bound="Model") +_TConvFunc = Callable[[Any], Any] +_TFunc = TypeVar("_TFunc", bound=Callable) +_TClass = TypeVar("_TClass", bound=type) +_TContextClass = TypeVar("_TContextClass", bound="Context") +_TField = TypeVar("_TField", bound="Field") +_TNode = TypeVar("_TNode", bound="Node") __version__: str __all__: List[str] @@ -68,16 +61,34 @@ class __IConnection(Protocol): def commit(self) -> Any: ... def rollback(self) -> Any: ... +class __IAggregate(Protocol): + def step(self, *value: object) -> None: ... + def finalize(self) -> Any: ... + +class __ITableFunction(Protocol): + columns: Sequence[str] + params: Sequence[str] + name: str + print_tracebacks: bool + def initialize(self, **parameters: object) -> None: ... + def iterate(self, idx: int) -> tuple: ... + @classmethod + def register(cls, conn: __IConnection) -> None: ... + def _sqlite_date_part(lookup_type: str, datetime_string: str) -> Optional[str]: ... def _sqlite_date_trunc(lookup_type: str, datetime_string: str) -> Optional[str]: ... class attrdict(dict): def __getattr__(self, attr: str) -> Any: ... - def __setattr__(self, attr: str, value: Any) -> None: ... + def __setattr__(self, attr: str, value: object) -> None: ... def __iadd__(self, rhs: Mapping[str, object]) -> attrdict: ... def __add__(self, rhs: Mapping[str, object]) -> Mapping[str, object]: ... -SENTINEL: object +class _TSentinel(enum.Enum): ... + +# HACK (dargueta): This is a regular object but we need it to annotate the sentinel in +# type arguments. +SENTINEL: _TSentinel OP: attrdict @@ -105,23 +116,18 @@ SNAKE_CASE_STEP2: re.Pattern MODEL_BASE: str # TODO (dargueta) -class _callable_context_manager(object): - def __call__(self, fn): - @wraps(fn) - def inner(*args, **kwargs): - with self: - return fn(*args, **kwargs) - return inner - -class Proxy(object): +class _callable_context_manager: + def __call__(self, fn: _TFunc) -> _TFunc: ... + +class Proxy: obj: Any - def initialize(self, obj: Any) -> None: ... - def attach_callback(self, callback: __TConvFunc) -> __TConvFunc: ... - def passthrough(method: __TFunc) -> __TFunc: ... + def initialize(self, obj: object) -> None: ... + def attach_callback(self, callback: _TConvFunc) -> _TConvFunc: ... + def passthrough(method: _TFunc) -> _TFunc: ... def __enter__(self) -> Any: ... def __exit__(self, exc_type, exc_val, exc_tb) -> Any: ... def __getattr__(self, attr: str) -> Any: ... - def __setattr__(self, attr: str, value: Any) -> None: ... + def __setattr__(self, attr: str, value: object) -> None: ... class DatabaseProxy(Proxy): def connection_context(self) -> ConnectionContext: ... @@ -130,13 +136,13 @@ class DatabaseProxy(Proxy): def transaction(self, *args: object, **kwargs: object) -> _transaction: ... def savepoint(self) -> _savepoint: ... -class ModelDescriptor(object): ... +class ModelDescriptor: ... # SQL Generation. -class AliasManager(object): +class AliasManager: @property - def mapping(self) -> MutableMapping["Source", str]: ... + def mapping(self) -> MutableMapping[Source, str]: ... def add(self, source: Source) -> str: ... def get(self, source: Source, any_depth: bool = ...) -> str: ... def __getitem__(self, source: Source) -> str: ... @@ -155,11 +161,11 @@ class State(__State): def __call__(self, scope: Optional[int] = ..., parentheses: Optional[int] = ..., **kwargs: object) -> State: ... def __getattr__(self, attr_name: str) -> Any: ... -class Context(object): +class Context: stack: List[State] alias_manager: AliasManager state: State - def __init__(self, **settings: Any) -> None: ... + def __init__(self, **settings: object) -> None: ... def as_new(self) -> Context: ... def column_sort_key(self, item: Sequence[Union[ColumnBase, Source]]) -> Tuple[str, ...]: ... @property @@ -167,55 +173,49 @@ class Context(object): @property def parentheses(self) -> bool: ... @property - def subquery(self): - return self.state.subquery - def __call__(self, **overrides: Any) -> Context: ... + def subquery(self) -> Any: ... # TODO (dargueta): Figure out type of "self.state.subquery" + def __call__(self, **overrides: object) -> Context: ... def scope_normal(self) -> ContextManager[Context]: ... def scope_source(self) -> ContextManager[Context]: ... def scope_values(self) -> ContextManager[Context]: ... def scope_cte(self) -> ContextManager[Context]: ... def scope_column(self) -> ContextManager[Context]: ... def __enter__(self) -> Context: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... - @contextmanager + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... + # @contextmanager def push_alias(self) -> Iterator[None]: ... # TODO (dargueta): Is this right? - def sql(self, obj: Any) -> Context: ... + def sql(self, obj: object) -> Context: ... def literal(self, keyword: str) -> Context: ... - def value(self, value: Any, converter: Optional[__TConvFunc] = ..., add_param: bool = ...) -> Context: ... + def value(self, value: object, converter: Optional[_TConvFunc] = ..., add_param: bool = ...) -> Context: ... def __sql__(self, ctx: Context) -> Context: ... def parse(self, node: Node) -> Tuple[str, Optional[tuple]]: ... def query(self) -> Tuple[str, Optional[tuple]]: ... def query_to_string(query: Node) -> str: ... -class Node(object): +class Node: def clone(self) -> Node: ... def __sql__(self, ctx: Context) -> Context: ... # FIXME (dargueta): Is there a way to make this a proper decorator? @staticmethod - def copy(method: __TFunc) -> __TFunc: - def inner(self: T, *args: object, **kwargs: object) -> T: - clone = self.clone() - method(clone, *args, **kwargs) - return clone - return inner + def copy(method: _TFunc) -> _TFunc: ... def coerce(self, _coerce: bool = ...) -> Node: ... def is_alias(self) -> bool: ... def unwrap(self) -> Node: ... -class ColumnFactory(object): +class ColumnFactory: node: Node def __init__(self, node: Node): ... def __getattr__(self, attr: str) -> Column: ... -class _DynamicColumn(object): +class _DynamicColumn: @overload def __get__(self, instance: None, instance_type: type) -> _DynamicColumn: ... @overload def __get__(self, instance: T, instance_type: Type[T]) -> ColumnFactory: ... -class _ExplicitColumn(object): +class _ExplicitColumn: @overload def __get__(self, instance: None, instance_type: type) -> _ExplicitColumn: ... @overload @@ -233,43 +233,43 @@ class Source(Node): def apply_alias(self, ctx: Context) -> Context: ... def apply_column(self, ctx: Context) -> Context: ... -class _HashableSource(object): +class _HashableSource: def __init__(self, *args: object, **kwargs: object): ... def alias(self, name: str) -> _HashableSource: ... def __hash__(self) -> int: ... @overload def __eq__(self, other: _HashableSource) -> bool: ... @overload - def __eq__(self, other: Any) -> Expression: ... + def __eq__(self, other: object) -> Expression: ... @overload def __ne__(self, other: _HashableSource) -> bool: ... @overload - def __ne__(self, other: Any) -> Expression: ... - def __lt__(self, other: Any) -> Expression: ... - def __le__(self, other: Any) -> Expression: ... - def __gt__(self, other: Any) -> Expression: ... - def __ge__(self, other: Any) -> Expression: ... + def __ne__(self, other: object) -> Expression: ... + def __lt__(self, other: object) -> Expression: ... + def __le__(self, other: object) -> Expression: ... + def __gt__(self, other: object) -> Expression: ... + def __ge__(self, other: object) -> Expression: ... def __join__(join_type: int = ..., inverted: bool = ...) -> Callable[[Any, Any], Join]: ... class BaseTable(Source): - def __and__(self, other: Any) -> Join: ... - def __add__(self, other: Any) -> Join: ... - def __sub__(self, other: Any) -> Join: ... - def __or__(self, other: Any) -> Join: ... - def __mul__(self, other: Any) -> Join: ... - def __rand__(self, other: Any) -> Join: ... - def __radd__(self, other: Any) -> Join: ... - def __rsub__(self, other: Any) -> Join: ... - def __ror__(self, other: Any) -> Join: ... - def __rmul__(self, other: Any) -> Join: ... + def __and__(self, other: object) -> Join: ... + def __add__(self, other: object) -> Join: ... + def __sub__(self, other: object) -> Join: ... + def __or__(self, other: object) -> Join: ... + def __mul__(self, other: object) -> Join: ... + def __rand__(self, other: object) -> Join: ... + def __radd__(self, other: object) -> Join: ... + def __rsub__(self, other: object) -> Join: ... + def __ror__(self, other: object) -> Join: ... + def __rmul__(self, other: object) -> Join: ... class _BoundTableContext(_callable_context_manager): table: Table database: Database def __init__(self, table: Table, database: Database): ... def __enter__(self) -> Table: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... class Table(_HashableSource, BaseTable): __name__: str @@ -334,8 +334,8 @@ class CTE(_HashableSource, Source): def __sql__(self, ctx: Context) -> Context: ... class ColumnBase(Node): - _converter: Optional[__TConvFunc] - def converter(self, converter: Optional[__TConvFunc] = ...) -> ColumnBase: ... + _converter: Optional[_TConvFunc] + def converter(self, converter: Optional[_TConvFunc] = ...) -> ColumnBase: ... @overload def alias(self, alias: None) -> ColumnBase: ... @overload @@ -347,45 +347,45 @@ class ColumnBase(Node): def desc(self, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Desc: ... __neg__ = desc def __invert__(self) -> Negated: ... - def __and__(self, other: Any) -> Expression: ... - def __or__(self, other: Any) -> Expression: ... - def __add__(self, other: Any) -> Expression: ... - def __sub__(self, other: Any) -> Expression: ... - def __mul__(self, other: Any) -> Expression: ... - def __div__(self, other: Any) -> Expression: ... - def __truediv__(self, other: Any) -> Expression: ... - def __xor__(self, other: Any) -> Expression: ... - def __radd__(self, other: Any) -> Expression: ... - def __rsub__(self, other: Any) -> Expression: ... - def __rmul__(self, other: Any) -> Expression: ... - def __rdiv__(self, other: Any) -> Expression: ... - def __rtruediv__(self, other: Any) -> Expression: ... - def __rand__(self, other: Any) -> Expression: ... - def __ror__(self, other: Any) -> Expression: ... - def __rxor__(self, other: Any) -> Expression: ... - def __eq__(self, rhs: Optional["Node"]) -> Expression: ... + def __and__(self, other: object) -> Expression: ... + def __or__(self, other: object) -> Expression: ... + def __add__(self, other: object) -> Expression: ... + def __sub__(self, other: object) -> Expression: ... + def __mul__(self, other: object) -> Expression: ... + def __div__(self, other: object) -> Expression: ... + def __truediv__(self, other: object) -> Expression: ... + def __xor__(self, other: object) -> Expression: ... + def __radd__(self, other: object) -> Expression: ... + def __rsub__(self, other: object) -> Expression: ... + def __rmul__(self, other: object) -> Expression: ... + def __rdiv__(self, other: object) -> Expression: ... + def __rtruediv__(self, other: object) -> Expression: ... + def __rand__(self, other: object) -> Expression: ... + def __ror__(self, other: object) -> Expression: ... + def __rxor__(self, other: object) -> Expression: ... + def __eq__(self, rhs: Optional[Node]) -> Expression: ... def __ne__(self, rhs: Optional[Node]) -> Expression: ... - def __lt__(self, other: Any) -> Expression: ... - def __le__(self, other: Any) -> Expression: ... - def __gt__(self, other: Any) -> Expression: ... - def __ge__(self, other: Any) -> Expression: ... - def __lshift__(self, other: Any) -> Expression: ... - def __rshift__(self, other: Any) -> Expression: ... - def __mod__(self, other: Any) -> Expression: ... - def __pow__(self, other: Any) -> Expression: ... - def bin_and(self, other: Any) -> Expression: ... - def bin_or(self, other: Any) -> Expression: ... - def in_(self, other: Any) -> Expression: ... - def not_in(self, other: Any) -> Expression: ... - def regexp(self, other: Any) -> Expression: ... + def __lt__(self, other: object) -> Expression: ... + def __le__(self, other: object) -> Expression: ... + def __gt__(self, other: object) -> Expression: ... + def __ge__(self, other: object) -> Expression: ... + def __lshift__(self, other: object) -> Expression: ... + def __rshift__(self, other: object) -> Expression: ... + def __mod__(self, other: object) -> Expression: ... + def __pow__(self, other: object) -> Expression: ... + def bin_and(self, other: object) -> Expression: ... + def bin_or(self, other: object) -> Expression: ... + def in_(self, other: object) -> Expression: ... + def not_in(self, other: object) -> Expression: ... + def regexp(self, other: object) -> Expression: ... def is_null(self, is_null: bool = ...) -> Expression: ... def contains(self, rhs: Union[Node, str]) -> Expression: ... def startswith(self, rhs: Union[Node, str]) -> Expression: ... def endswith(self, rhs: Union[Node, str]) -> Expression: ... - def between(self, lo: Any, hi: Any) -> Expression: ... - def concat(self, rhs: Any) -> StringExpression: ... - def iregexp(self, rhs: Any) -> Expression: ... - def __getitem__(self, item: Any) -> Expression: ... + def between(self, lo: object, hi: object) -> Expression: ... + def concat(self, rhs: object) -> StringExpression: ... + def iregexp(self, rhs: object) -> Expression: ... + def __getitem__(self, item: object) -> Expression: ... def distinct(self) -> NodeList: ... def collate(self, collation: str) -> NodeList: ... def get_sort_key(self, ctx: Context) -> Tuple[str, ...]: ... @@ -398,21 +398,20 @@ class Column(ColumnBase): def __hash__(self) -> int: ... def __sql__(self, ctx: Context) -> Context: ... -class WrappedNode(ColumnBase, Generic[__TNode]): - node: __TNode +class WrappedNode(ColumnBase, Generic[_TNode]): + node: _TNode _coerce: bool - _converter: Optional[__TConvFunc] - def __init__(self, node: __TNode): ... + _converter: Optional[_TConvFunc] + def __init__(self, node: _TNode): ... def is_alias(self) -> bool: ... - def unwrap(self) -> __TNode: ... + def unwrap(self) -> _TNode: ... -class EntityFactory(object): +class EntityFactory: node: Node def __init__(self, node: Node): ... def __getattr__(self, attr: str) -> Entity: ... -class _DynamicEntity(object): - __slots__ = () +class _DynamicEntity: @overload def __get__(self, instance: None, instance_type: type) -> _DynamicEntity: ... @overload @@ -434,13 +433,10 @@ class Negated(WrappedNode): def __invert__(self) -> Node: ... def __sql__(self, ctx: Context) -> Context: ... -class BitwiseMixin(object): - def __and__(self, other): - return self.bin_and(other) - def __or__(self, other): - return self.bin_or(other) - def __sub__(self, other): - return self.bin_and(other.bin_negated()) +class BitwiseMixin: + def __and__(self, other: object) -> Expression: ... + def __or__(self, other: object) -> Expression: ... + def __sub__(self, other: object) -> Expression: ... def __invert__(self) -> BitwiseNegated: ... class BitwiseNegated(BitwiseMixin, WrappedNode): @@ -449,25 +445,10 @@ class BitwiseNegated(BitwiseMixin, WrappedNode): class Value(ColumnBase): value: object - converter: Optional[__TConvFunc] + converter: Optional[_TConvFunc] multi: bool - def __init__(self, value: object, converter: Optional[__TConvFunc] = ..., unpack: bool = ...): - self.value = value - self.converter = converter - self.multi = unpack and isinstance(self.value, multi_types) - if self.multi: - self.values = [] - for item in self.value: - if isinstance(item, Node): - self.values.append(item) - else: - self.values.append(Value(item, self.converter)) - def __sql__(self, ctx): - if self.multi: - # For multi-part values (e.g. lists of IDs). - return ctx.sql(EnclosedNodeList(self.values)) - - return ctx.value(self.value, self.converter) + def __init__(self, value: object, converter: Optional[_TConvFunc] = ..., unpack: bool = ...): ... + def __sql__(self, ctx: Context) -> Context: ... def AsIs(value: object) -> Value: ... @@ -495,8 +476,8 @@ class Expression(ColumnBase): def __sql__(self, ctx: Context) -> Context: ... class StringExpression(Expression): - def __add__(self, rhs: Any) -> StringExpression: ... - def __radd__(self, lhs: Any) -> StringExpression: ... + def __add__(self, rhs: object) -> StringExpression: ... + def __radd__(self, lhs: object) -> StringExpression: ... class Entity(ColumnBase): def __init__(self, *path: str): ... @@ -511,19 +492,17 @@ class SQL(ColumnBase): def __init__(self, sql: str, params: Mapping[str, object] = ...): ... def __sql__(self, ctx: Context) -> Context: ... -def Check(constraint: str) -> SQL: - return SQL("CHECK (%s)" % constraint) +def Check(constraint: str) -> SQL: ... class Function(ColumnBase): name: str arguments: tuple - def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: Optional[__TConvFunc] = ...): ... + def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: Optional[_TConvFunc] = ...): ... def __getattr__(self, attr: str) -> Callable[..., Function]: ... # TODO (dargueta): `where` is an educated guess def filter(self, where: Optional[Expression] = ...) -> Function: ... - def order_by(self, *ordering) -> Function: - self._order_by = ordering - def python_value(self, func: Optional[__TConvFunc] = ...) -> Function: ... + def order_by(self, *ordering: Union[Field, Expression]) -> Function: ... + def python_value(self, func: Optional[_TConvFunc] = ...) -> Function: ... def over( self, partition_by: Optional[Union[Sequence[Field], Window]] = ..., @@ -600,12 +579,12 @@ class ForUpdate(Node): def __init__( self, expr: Union[Literal[True], str], - of: Optional[Union[__TModelOrTable, List[__TModelOrTable], Set[__TModelOrTable], Tuple[__TModelOrTable, ...]]] = ..., + of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...]]] = ..., nowait: Optional[bool] = ..., ): ... def __sql__(self, ctx: Context) -> Context: ... -def Case(predicate: Optional[Node], expression_tuples: Iterable[Tuple[Expression, Any]], default: Any = ...) -> NodeList: ... +def Case(predicate: Optional[Node], expression_tuples: Iterable[Tuple[Expression, Any]], default: object = ...) -> NodeList: ... class NodeList(ColumnBase): nodes: Sequence[Any] # TODO (dargueta): Narrow this type @@ -632,7 +611,7 @@ class DQ(ColumnBase): query: Dict[str, Any] # TODO (dargueta): Narrow this down? - def __init__(self, **query: Any): ... + def __init__(self, **query: object): ... def __invert__(self) -> DQ: ... def clone(self) -> DQ: ... @@ -670,19 +649,16 @@ class OnConflict(Node): conflict_constraint: Optional[str] = ..., ): ... # undocumented - def get_conflict_statement(self, ctx: Context, query: Query): - return ctx.state.conflict_statement(self, query) - def get_conflict_update(self, ctx, query): - return ctx.state.conflict_update(self, query) - def preserve(self, *columns) -> OnConflict: ... + def get_conflict_statement(self, ctx: Context, query: Query) -> Optional[SQL]: ... + def get_conflict_update(self, ctx: Context, query: Query) -> NodeList: ... + def preserve(self, *columns: Column) -> OnConflict: ... + # Despite the argument name `_data` is documented def update(self, _data: Optional[Mapping[str, object]] = ..., **kwargs: object) -> OnConflict: ... def where(self, *expressions: Expression) -> OnConflict: ... - def conflict_target(self, *constraints) -> OnConflict: ... + def conflict_target(self, *constraints: Column) -> OnConflict: ... def conflict_where(self, *expressions: Expression) -> OnConflict: ... def conflict_constraint(self, constraint: str) -> OnConflict: ... -# BASE QUERY INTERFACE. - class BaseQuery(Node): default_row_type: ClassVar[int] def __init__(self, _database: Optional[Database] = ..., **kwargs: object): ... @@ -691,7 +667,7 @@ class BaseQuery(Node): def dicts(self, as_dict: bool = ...) -> BaseQuery: ... def tuples(self, as_tuple: bool = ...) -> BaseQuery: ... def namedtuples(self, as_namedtuple: bool = ...) -> BaseQuery: ... - def objects(self, constructor: Optional[__TConvFunc] = ...) -> BaseQuery: ... + def objects(self, constructor: Optional[_TConvFunc] = ...) -> BaseQuery: ... def __sql__(self, ctx: Context) -> Context: ... def sql(self) -> Tuple[str, Optional[tuple]]: ... def execute(self, database: Optional[Database] = ...) -> CursorWrapper: ... @@ -771,7 +747,7 @@ class CompoundSelectQuery(SelectBase): lhs: Any # TODO (dargueta) op: str rhs: Any # TODO (dargueta) - def __init__(self, lhs: Any, op: str, rhs: Any): ... + def __init__(self, lhs: object, op: str, rhs: object): ... def exists(self, database: Optional[Database] = ...) -> bool: ... def __sql__(self, ctx: Context) -> Context: ... @@ -793,11 +769,10 @@ class Select(SelectBase): **kwargs: object, ): ... def clone(self) -> Select: ... - def columns(self, *columns, **kwargs: object) -> Select: - self._returning = columns - select = columns - def select_extend(self, *columns) -> Select: - self._returning = tuple(self._returning) + columns + # TODO (dargueta) `Field` might be wrong in this union + def columns(self, *columns: Union[Column, Field], **kwargs: object) -> Select: ... + def select(self, *columns: Union[Column, Field], **kwargs: object) -> Select: ... + def select_extend(self, *columns) -> Select: ... # TODO (dargueta): Is `sources` right? def from_(self, *sources: Union[Source, Type[Model]]) -> Select: ... def join(self, dest: Type[Model], join_type: int = ..., on: Optional[Expression] = ...) -> Select: ... @@ -809,12 +784,9 @@ class Select(SelectBase): @overload def distinct(self, *columns: Field) -> Select: ... def window(self, *windows: Window) -> Select: ... - def for_update(self, for_update: bool = ..., of=None, nowait=None) -> Select: - if not for_update and (of is not None or nowait): - for_update = True - self._for_update = for_update - self._for_update_of = of - self._for_update_nowait = nowait + def for_update( + self, for_update: bool = ..., of: Optional[Union[Table, Iterable[Table]]] = ..., nowait: Optional[bool] = ... + ) -> Select: ... def lateral(self, lateral: bool = ...) -> Select: ... class _WriteQuery(Query): @@ -827,13 +799,9 @@ class _WriteQuery(Query): def __sql__(self, ctx: Context) -> Context: ... class Update(_WriteQuery): - def __init__(self, table: Table, update=None, **kwargs): - super(Update, self).__init__(table, **kwargs) - self._update = update - self._from = None - @Node.copy - def from_(self, *sources) -> None: - self._from = sources + # TODO (dargueta): `update` + def __init__(self, table: Table, update: Optional[Any] = ..., **kwargs: object): ... + def from_(self, *sources) -> Update: ... def __sql__(self, ctx: Context) -> Context: ... class Insert(_WriteQuery): @@ -880,7 +848,7 @@ class Index(Node): class ModelIndex(Index): def __init__( self, - model: Type[__TModel], + model: Type[_TModel], fields: Iterable[Union[Field, Node, str]], unique: bool = ..., safe: bool = ..., @@ -889,8 +857,6 @@ class ModelIndex(Index): name: Optional[str] = ..., ): ... -# DB-API 2.0 EXCEPTIONS. - class PeeweeException(Exception): # This attribute only exists if an exception was passed into the constructor. # Attempting to access it otherwise will result in an AttributeError. @@ -907,16 +873,14 @@ class NotSupportedError(DatabaseError): ... class OperationalError(DatabaseError): ... class ProgrammingError(DatabaseError): ... -class ExceptionWrapper(object): +class ExceptionWrapper: exceptions: Mapping[str, Type[Exception]] def __init__(self, exceptions: Mapping[str, Type[Exception]]): ... def __enter__(self) -> None: ... - def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: Any) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: object) -> None: ... EXCEPTIONS: Mapping[str, Type[Exception]] -__exception_wrapper__: ExceptionWrapper - class IndexMetadata(NamedTuple): name: str sql: str @@ -942,7 +906,7 @@ class ViewMetadata(NamedTuple): name: str sql: str -class _ConnectionState(object): +class _ConnectionState: closed: bool conn: Optional[__IConnection] ctx: List[ConnectionContext] @@ -953,22 +917,17 @@ class _ConnectionState(object): class _ConnectionLocal(_ConnectionState, threading.local): ... class ConnectionContext(_callable_context_manager): - __slots__ = ("db",) - def __init__(self, db): - self.db = db - def __enter__(self): - if self.db.is_closed(): - self.db.connect() - def __exit__(self, exc_type, exc_val, exc_tb): - self.db.close() + db: Database + def __enter__(self) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... class Database(_callable_context_manager): - context_class: ClassVar[Type[__TContextClass]] + context_class: ClassVar[Type[_TContextClass]] field_types: ClassVar[Mapping[str, str]] operations: ClassVar[Mapping[str, Any]] # TODO (dargueta) Verify k/v types param: ClassVar[str] quote: ClassVar[str] - server_version: ClassVar[Optional[Tuple[int, ...]]] + server_version: ClassVar[Optional[Union[int, Tuple[int, ...]]]] commit_select: ClassVar[bool] compound_select_parentheses: ClassVar[int] for_update: ClassVar[bool] @@ -988,7 +947,6 @@ class Database(_callable_context_manager): autorollback: bool thread_safe: bool connect_params: Mapping[str, Any] - server_version: Optional[Union[int, Tuple[int, ...]]] def __init__( self, database: __IConnection, @@ -1002,7 +960,7 @@ class Database(_callable_context_manager): ): ... def init(self, database: __IConnection, **kwargs: object) -> None: ... def __enter__(self) -> Database: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... def connection_context(self) -> ConnectionContext: ... def connect(self, reuse_if_open: bool = ...) -> bool: ... def close(self) -> bool: ... @@ -1010,14 +968,10 @@ class Database(_callable_context_manager): def is_connection_usable(self) -> bool: ... def connection(self) -> __IConnection: ... def cursor(self, commit: Optional[bool] = ...) -> __ICursor: ... - def execute_sql(self, sql: str, params: Optional[tuple] = ..., commit: Union[bool, Literal[SENTINEL]] = ...) -> __ICursor: ... - def execute(self, query: Query, commit: Union[bool, Literal[SENTINEL]] = ..., **context_options: Any) -> __ICursor: ... + def execute_sql(self, sql: str, params: Optional[tuple] = ..., commit: Union[bool, _TSentinel] = ...) -> __ICursor: ... + def execute(self, query: Query, commit: Union[bool, _TSentinel] = ..., **context_options: object) -> __ICursor: ... def get_context_options(self) -> Mapping[str, object]: ... - def get_sql_context(self, **context_options: Any): - context = self.get_context_options() - if context_options: - context.update(context_options) - return self.context_class(**context) + def get_sql_context(self, **context_options: object) -> _TContextClass: ... def conflict_statement(self, on_conflict: OnConflict, query: Query) -> Optional[SQL]: ... def conflict_update(self, oc: OnConflict, query: Query) -> NodeList: ... def last_insert_id(self, cursor: __ICursor, query_type: Optional[int] = ...) -> int: ... @@ -1046,16 +1000,12 @@ class Database(_callable_context_manager): def get_primary_keys(self, table: str, schema: Optional[str] = ...) -> List[str]: ... def get_foreign_keys(self, table: str, schema: Optional[str] = ...) -> List[ForeignKeyMetadata]: ... def sequence_exists(self, seq: str) -> bool: ... - def create_tables(self, models, **options): - for model in sort_models(models): - model.create_table(**options) - def drop_tables(self, models, **kwargs): - for model in reversed(sort_models(models)): - model.drop_table(**kwargs) - def extract_date(self, date_part: str, date_field: Node) -> Node: ... - def truncate_date(self, date_part: str, date_field: Node) -> Node: ... - def to_timestamp(self, date_field: str) -> Node: ... - def from_timestamp(self, date_field: str) -> Node: ... + def create_tables(self, models: Iterable[Type[Model]], **options: object) -> None: ... + def drop_tables(self, models: Iterable[Type[Model]], **kwargs: object) -> None: ... + def extract_date(self, date_part: str, date_field: Node) -> Function: ... + def truncate_date(self, date_part: str, date_field: Node) -> Function: ... + def to_timestamp(self, date_field: str) -> Function: ... + def from_timestamp(self, date_field: str) -> Function: ... def random(self) -> Node: ... def bind(self, models: Iterable[Type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ...) -> None: ... def bind_ctx( @@ -1063,13 +1013,6 @@ class Database(_callable_context_manager): ) -> _BoundModelsContext: ... def get_noop_select(self, ctx: Context) -> Context: ... -def __pragma__(name): - def __get__(self): - return self.pragma(name) - def __set__(self, value): - return self.pragma(name, value) - return property(__get__, __set__) - class SqliteDatabase(Database): field_types: ClassVar[Mapping[str, int]] operations: ClassVar[Mapping[str, str]] @@ -1097,36 +1040,48 @@ class SqliteDatabase(Database): **kwargs: object, ) -> None: ... def pragma(self, key: str, value: Union[str, bool, int] = ..., permanent: bool = ..., schema: Optional[str] = ...) -> Any: ... - foreign_keys = __pragma__("foreign_keys") - journal_mode = __pragma__("journal_mode") - journal_size_limit = __pragma__("journal_size_limit") - mmap_size = __pragma__("mmap_size") - page_size = __pragma__("page_size") - read_uncommitted = __pragma__("read_uncommitted") - synchronous = __pragma__("synchronous") - wal_autocheckpoint = __pragma__("wal_autocheckpoint") - def register_aggregate(self, klass, name=None, num_params=-1): - self._aggregates[name or klass.__name__.lower()] = (klass, num_params) - if not self.is_closed(): - self._load_aggregates(self.connection()) - def aggregate(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[__TClass], __TClass]: ... + @property + def foreign_keys(self) -> Any: ... + @foreign_keys.setter + def foreign_keys(self, value: object) -> Any: ... + @property + def journal_mode(self) -> Any: ... + @journal_mode.setter + def journal_mode(self, value: object) -> Any: ... + @property + def journal_size_limit(self) -> Any: ... + @journal_size_limit.setter + def journal_size_limit(self, value: object) -> Any: ... + @property + def mmap_size(self) -> Any: ... + @mmap_size.setter + def mmap_size(self, value: object) -> Any: ... + @property + def page_size(self) -> Any: ... + @page_size.setter + def page_size(self, value: object) -> Any: ... + @property + def read_uncommitted(self) -> Any: ... + @read_uncommitted.setter + def read_uncommitted(self, value: object) -> Any: ... + @property + def synchronous(self) -> Any: ... + @synchronous.setter + def synchronous(self, value: object) -> Any: ... + @property + def wal_autocheckpoint(self) -> Any: ... + @wal_autocheckpoint.setter + def wal_autocheckpoint(self, value: object) -> Any: ... + def register_aggregate(self, klass: Type[__IAggregate], name: Optional[str] = ..., num_params: int = ...): ... + def aggregate(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... def register_collation(self, fn: Callable, name: Optional[str] = ...) -> None: ... - def collation(self, name: Optional[str] = ...) -> Callable[[__TFunc], __TFunc]: ... + def collation(self, name: Optional[str] = ...) -> Callable[[_TFunc], _TFunc]: ... def register_function(self, fn: Callable, name: Optional[str] = ..., num_params: int = ...) -> int: ... - def func(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[__TFunc], __TFunc]: ... + def func(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[_TFunc], _TFunc]: ... def register_window_function(self, klass: type, name: Optional[str] = ..., num_params: int = ...) -> None: ... - def window_function(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[__TClass], __TClass]: ... - def register_table_function(self, klass, name=None): - if name is not None: - klass.name = name - self._table_functions.append(klass) - if not self.is_closed(): - klass.register(self.connection()) - def table_function(self, name=None): - def decorator(klass): - self.register_table_function(klass, name) - return klass - return decorator + def window_function(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... + def register_table_function(self, klass: Type[__ITableFunction], name: Optional[str] = ...) -> None: ... + def table_function(self, name: Optional[str] = ...) -> Callable[[Type[__ITableFunction]], Type[__ITableFunction]]: ... def unregister_aggregate(self, name: str) -> None: ... def unregister_collation(self, name: str) -> None: ... def unregister_function(self, name: str) -> None: ... @@ -1153,9 +1108,6 @@ class PostgresqlDatabase(Database): sequences: ClassVar[bool] # Instance variables server_version: int - # Technically this *only* exists if we have Postgres >=9.6 and it will always be - # True in that case. - safe_create_index: bool def init( self, database: __IConnection, @@ -1165,10 +1117,7 @@ class PostgresqlDatabase(Database): **kwargs: object, ): ... def is_connection_usable(self) -> bool: ... - @overload - def last_insert_id(self, cursor: __ICursor, query_type: Literal[Insert.SIMPLE] = ...) -> Optional[int]: ... # I think - @overload - def last_insert_id(self, cursor: __ICursor, query_type: Optional[int] = ...) -> __ICursor: ... + def last_insert_id(self, cursor: __ICursor, query_type: Optional[int] = ...) -> Union[Optional[int], __ICursor]: ... def get_views(self, schema: Optional[str] = ...) -> List[ViewMetadata]: ... def get_binary_type(self) -> type: ... def get_noop_select(self, ctx: Context) -> SelectQuery: ... @@ -1183,7 +1132,7 @@ class MySQLDatabase(Database): compound_select_parentheses: ClassVar[int] for_update: ClassVar[bool] index_using_precedes_table: ClassVar[bool] - limit_max = 2 ** 64 - 1 + limit_max: ClassVar[int] safe_create_index: ClassVar[bool] safe_drop_index: ClassVar[bool] sql_mode: ClassVar[str] @@ -1193,18 +1142,13 @@ class MySQLDatabase(Database): def default_values_insert(self, ctx: Context) -> SQL: ... def get_views(self, schema: Optional[str] = ...) -> List[ViewMetadata]: ... def get_binary_type(self) -> type: ... - def extract_date(self, date_part, date_field): - return fn.EXTRACT(NodeList((SQL(date_part), SQL("FROM"), date_field))) - def truncate_date(self, date_part, date_field): - return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part], python_value=simple_date_time) - def to_timestamp(self, date_field): - return fn.UNIX_TIMESTAMP(date_field) - def from_timestamp(self, date_field): - return fn.FROM_UNIXTIME(date_field) - def random(self): - return fn.rand() - def get_noop_select(self, ctx): - return ctx.literal("DO 0") + # TODO (dargueta) Verify return type on these function calls + def extract_date(self, date_part: str, date_field: str) -> Function: ... + def truncate_date(self, date_part: str, date_field: str) -> Function: ... + def to_timestamp(self, date_field: str) -> Function: ... + def from_timestamp(self, date_field: str) -> Function: ... + def random(self) -> Function: ... + def get_noop_select(self, ctx: Context) -> Context: ... # TRANSACTION CONTROL. @@ -1212,13 +1156,13 @@ class _manual(_callable_context_manager): db: Database def __init__(self, db: Database): ... def __enter__(self) -> None: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... class _atomic(_callable_context_manager): db: Database def __init__(self, db: Database, *args: object, **kwargs: object): ... def __enter__(self) -> Union[_transaction, _savepoint]: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... class _transaction(_callable_context_manager): db: Database @@ -1226,7 +1170,7 @@ class _transaction(_callable_context_manager): def commit(self, begin: bool = ...) -> None: ... def rollback(self, begin: bool = ...) -> None: ... def __enter__(self) -> _transaction: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... class _savepoint(_callable_context_manager): db: Database @@ -1236,7 +1180,7 @@ class _savepoint(_callable_context_manager): def commit(self, begin: bool = ...) -> None: ... def rollback(self) -> None: ... def __enter__(self) -> _savepoint: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... class CursorWrapper(Generic[T]): cursor: __ICursor @@ -1253,43 +1197,21 @@ class CursorWrapper(Generic[T]): def __getitem__(self, item: slice) -> List[T]: ... def __len__(self) -> int: ... def initialize(self) -> None: ... - def iterate(self, cache=True): - row = self.cursor.fetchone() - if row is None: - self.populated = True - self.cursor.close() - raise StopIteration - elif not self.initialized: - self.initialize() # Lazy initialization. - self.initialized = True - self.count += 1 - result = self.process_row(row) - if cache: - self.row_cache.append(result) - return result + def iterate(self, cache: bool = ...) -> T: ... def process_row(self, row: tuple) -> T: ... - def iterator(self): - """Efficient one-pass iteration over the result set.""" - while True: - try: - yield self.iterate(False) - except StopIteration: - return + def iterator(self) -> Iterator[T]: ... def fill_cache(self, n: int = 0) -> None: ... class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ... -class NamedTupleCursorWrapper(CursorWrapper[NamedTuple]): - tuple_class: Type[NamedTuple] +# FIXME (dargueta): Somehow figure out how to make this a NamedTuple sorta deal +class NamedTupleCursorWrapper(CursorWrapper[tuple]): + tuple_class: Type[tuple] -# TODO: Indicate this inherits from DictCursorWrapper but also change the return type -class ObjectCursorWrapper(DictCursorWrapper): - def __init__(self, cursor, constructor): - super(ObjectCursorWrapper, self).__init__(cursor) - self.constructor = constructor - def process_row(self, row): - row_dict = self._row_to_dict(row) - return self.constructor(**row_dict) +class ObjectCursorWrapper(DictCursorWrapper, Generic[T]): + constructor: Callable[..., T] + def __init__(self, cursor: __ICursor, constructor: Callable[..., T]): ... + def process_row(self, row: tuple) -> T: ... class ResultIterator(Generic[T]): cursor_wrapper: CursorWrapper[T] @@ -1299,7 +1221,7 @@ class ResultIterator(Generic[T]): # FIELDS -class FieldAccessor(object): +class FieldAccessor: model: Type[Model] field: Field name: str @@ -1315,40 +1237,17 @@ class ForeignKeyAccessor(FieldAccessor): name: str rel_model: Type[Model] def __init__(self, model: Type[Model], field: ForeignKeyField, name: str): ... - def get_rel_instance(self, instance: Model) -> Any: - value = instance.__data__.get(self.name) - if value is not None or self.name in instance.__rel__: - if self.name not in instance.__rel__: - obj = self.rel_model.get(self.field.rel_field == value) - instance.__rel__[self.name] = obj - return instance.__rel__[self.name] - elif not self.field.null: - raise self.rel_model.DoesNotExist - return value - def __get__(self, instance, instance_type=None): - if instance is not None: - return self.get_rel_instance(instance) - return self.field - def __set__(self, instance, obj): - if isinstance(obj, self.rel_model): - instance.__data__[self.name] = getattr(obj, self.field.rel_field.name) - instance.__rel__[self.name] = obj - else: - fk_value = instance.__data__.get(self.name) - instance.__data__[self.name] = obj - if obj != fk_value and self.name in instance.__rel__: - del instance.__rel__[self.name] - instance._dirty.add(self.name) + def get_rel_instance(self, instance: Model) -> Any: ... + @overload + def __get__(self, instance: None, instance_type: type) -> Any: ... + @overload + def __get__(self, instance: _TModel, instance_type: Type[_TModel]) -> ForeignKeyField: ... + def __set__(self, instance: _TModel, obj: object) -> None: ... class NoQueryForeignKeyAccessor(ForeignKeyAccessor): - def get_rel_instance(self, instance: Model) -> Any: - value = instance.__data__.get(self.name) - if value is not None: - return instance.__rel__.get(self.name, value) - elif not self.field.null: - raise self.rel_model.DoesNotExist - -class BackrefAccessor(object): + def get_rel_instance(self, instance: Model) -> Any: ... + +class BackrefAccessor: field: ForeignKeyField model: Type[Model] rel_model: Type[Model] @@ -1356,9 +1255,9 @@ class BackrefAccessor(object): @overload def __get__(self, instance: None, instance_type: type) -> BackrefAccessor: ... @overload - def __get__(self, instance: Field, instance_type: Type["Field"]) -> SelectQuery: ... + def __get__(self, instance: Field, instance_type: Type[Field]) -> SelectQuery: ... -class ObjectIdAccessor(object): +class ObjectIdAccessor: """Gives direct access to the underlying id""" field: ForeignKeyField @@ -1366,8 +1265,8 @@ class ObjectIdAccessor(object): @overload def __get__(self, instance: None, instance_type: Type[Model]) -> ForeignKeyField: ... @overload - def __get__(self, instance: __TModel, instance_type: Type[__TModel] = ...) -> Any: ... - def __set__(self, instance: Model, value: Any) -> None: ... + def __get__(self, instance: _TModel, instance_type: Type[_TModel] = ...) -> Any: ... + def __set__(self, instance: Model, value: object) -> None: ... class Field(ColumnBase): accessor_class: ClassVar[Type[FieldAccessor]] @@ -1501,16 +1400,10 @@ class BitField(BitwiseMixin, BigIntegerField): # FIXME (dargueta) Return type isn't 100% accurate; function creates a new class def flag(self, value: Optional[int] = ...) -> ColumnBase: ... -class BigBitFieldData(object): - def __init__(self, instance, name): - self.instance = instance - self.name = name - value = self.instance.__data__.get(self.name) - if not value: - value = bytearray() - elif not isinstance(value, bytearray): - value = bytearray(value) - self._buffer = self.instance.__data__[self.name] = value +class BigBitFieldData: + name: str + instance: Model + def __init__(self, instance: Model, name: str): ... def set_bit(self, idx: int) -> None: ... def clear_bit(self, idx: bool) -> None: ... def toggle_bit(self, idx: int) -> bool: ... @@ -1518,10 +1411,10 @@ class BigBitFieldData(object): def __repr__(self) -> str: ... class BigBitFieldAccessor(FieldAccessor): - def __get__(self, instance, instance_type=None): - if instance is None: - return self.field - return BigBitFieldData(instance, self.name) + @overload + def __get__(self, instance: None, instance_type: Type[_TModel]) -> Field: ... + @overload + def __get__(self, instance: _TModel, instance_type: Type[_TModel]) -> BigBitFieldData: ... def __set__(self, instance: Any, value: Union[memoryview, bytearray, BigBitFieldData, str, bytes]) -> None: ... class BigBitField(BlobField): @@ -1537,10 +1430,10 @@ class UUIDField(Field): def db_value(self, value: AnyStr) -> str: ... @overload def db_value(self, value: T) -> T: ... - def python_value(self, value): - if isinstance(value, uuid.UUID): - return value - return uuid.UUID(value) if value is not None else None + @overload + def python_value(self, value: Union[uuid.UUID, AnyStr]) -> uuid.UUID: ... + @overload + def python_value(self, value: None) -> None: ... class BinaryUUIDField(BlobField): @overload @@ -1552,7 +1445,7 @@ class BinaryUUIDField(BlobField): @overload def python_value(self, value: Union[bytearray, bytes, memoryview, uuid.UUID]) -> uuid.UUID: ... -def format_date_time(value: str, formats: Iterable[str], post_process: Optional[__TConvFunc] = ...) -> str: ... +def format_date_time(value: str, formats: Iterable[str], post_process: Optional[_TConvFunc] = ...) -> str: ... @overload def simple_date_time(value: T) -> T: ... @@ -1578,10 +1471,8 @@ class DateTimeField(_BaseFormattedField): def adapt(self, value: str) -> str: ... @overload def adapt(self, value: T) -> T: ... - def to_timestamp(self): - return self.model._meta.database.to_timestamp(self) - def truncate(self, part): - return self.model._meta.database.truncate_date(part, self) + def to_timestamp(self) -> Function: ... + def truncate(self, part: str) -> Function: ... class DateField(_BaseFormattedField): @property @@ -1596,10 +1487,8 @@ class DateField(_BaseFormattedField): def adapt(self, value: datetime.datetime) -> datetime.date: ... @overload def adapt(self, value: T) -> T: ... - def to_timestamp(self): - return self.model._meta.database.to_timestamp(self) - def truncate(self, part): - return self.model._meta.database.truncate_date(part, self) + def to_timestamp(self) -> Function: ... + def truncate(self, part: str) -> Function: ... class TimeField(_BaseFormattedField): @overload @@ -1624,34 +1513,16 @@ class TimestampField(BigIntegerField): def __init__(self, *args: object, resolution: int = ..., utc: bool = ..., **kwargs: object): ... def local_to_utc(self, dt: datetime.datetime) -> datetime.datetime: ... def utc_to_local(self, dt: datetime.datetime) -> datetime.datetime: ... - def get_timestamp(self, value): - if self.utc: - # If utc-mode is on, then we assume all naive datetimes are in UTC. - return calendar.timegm(value.utctimetuple()) - else: - return time.mktime(value.timetuple()) - def db_value(self, value): - if value is None: - return - - if isinstance(value, datetime.datetime): ... - elif isinstance(value, datetime.date): - value = datetime.datetime(value.year, value.month, value.day) - else: - return int(round(value * self.resolution)) - - timestamp = self.get_timestamp(value) - if self.resolution > 1: - timestamp += value.microsecond * 0.000001 - timestamp *= self.resolution - return int(round(timestamp)) + def get_timestamp(self, value: datetime.datetime) -> float: ... + @overload + def db_value(self, value: None) -> None: ... + @overload + def db_value(self, value: Union[datetime.datetime, datetime.date, float]) -> int: ... @overload def python_value(self, value: Union[int, float]) -> datetime.datetime: ... @overload def python_value(self, value: T) -> T: ... - def from_timestamp(self): - expr = (self / Value(self.resolution, converter=False)) if self.resolution > 1 else self - return self.model._meta.database.from_timestamp(expr) + def from_timestamp(self) -> float: ... @property def year(self) -> int: ... @property @@ -1679,110 +1550,52 @@ class BooleanField(Field): def adapt(self, value: Any) -> bool: ... class BareField(Field): - def __init__(self, adapt=None, *args, **kwargs): - super(BareField, self).__init__(*args, **kwargs) - if adapt is not None: - self.adapt = adapt - def ddl_datatype(self, ctx): - return + # If `adapt` was omitted from the constructor or None, this attribute won't exist. + adapt: Optional[_TConvFunc] + def __init__(self, adapt: Optional[_TConvFunc] = ..., *args: object, **kwargs: object): ... + def ddl_datatype(self, ctx: Context) -> None: ... class ForeignKeyField(Field): accessor_class = ForeignKeyAccessor + rel_model: Union[Type[Model], Literal["self"]] + rel_field: Field + declared_backref: Optional[str] + backref: Optional[str] # TODO (dargueta): Verify + on_delete: Optional[str] + on_update: Optional[str] + deferrable: Optional[str] + deferred: Optional[bool] + object_id_name: Optional[str] + lazy_load: bool + safe_name: str def __init__( self, - model, - field=None, - backref=None, - on_delete=None, - on_update=None, - deferrable=None, - _deferred=None, - rel_model=None, - to_field=None, - object_id_name=None, - lazy_load=True, - related_name=None, - *args, - **kwargs, - ): - kwargs.setdefault("index", True) - - # If lazy_load is disable, we use a different descriptor/accessor that - # will ensure we don't accidentally perform a query. - if not lazy_load: - self.accessor_class = NoQueryForeignKeyAccessor - - super(ForeignKeyField, self).__init__(*args, **kwargs) - - self._is_self_reference = model == "self" - self.rel_model = model - self.rel_field = field - self.declared_backref = backref - self.backref = None - self.on_delete = on_delete - self.on_update = on_update - self.deferrable = deferrable - self.deferred = _deferred - self.object_id_name = object_id_name - self.lazy_load = lazy_load + model: Union[Type[Model], Literal["self"]], + field: Optional[Field] = ..., + # TODO (dargueta): Documentation says this is only a string but code accepts a callable too + backref: Optional[str] = ..., + on_delete: Optional[str] = ..., + on_update: Optional[str] = ..., + deferrable: Optional[str] = ..., + _deferred: Optional[bool] = ..., # undocumented + rel_model: object = ..., # undocumented + to_field: object = ..., # undocumented + object_id_name: Optional[str] = ..., + lazy_load: bool = ..., + # type for related_name is a guess + related_name: Optional[str] = ..., # undocumented + *args: object, + index: bool = ..., + **kwargs: object, + ): ... @property - def field_type(self): - if not isinstance(self.rel_field, AutoField): - return self.rel_field.field_type - elif isinstance(self.rel_field, BigAutoField): - return BigIntegerField.field_type - return IntegerField.field_type - def get_modifiers(self): - if not isinstance(self.rel_field, AutoField): - return self.rel_field.get_modifiers() - return super(ForeignKeyField, self).get_modifiers() - def adapt(self, value): - return self.rel_field.adapt(value) - def db_value(self, value): - if isinstance(value, self.rel_model): - value = getattr(value, self.rel_field.name) - return self.rel_field.db_value(value) - def python_value(self, value): - if isinstance(value, self.rel_model): - return value - return self.rel_field.python_value(value) - def bind(self, model, name, set_attribute=True): - if not self.column_name: - self.column_name = name if name.endswith("_id") else name + "_id" - if not self.object_id_name: - self.object_id_name = self.column_name - if self.object_id_name == name: - self.object_id_name += "_id" - elif self.object_id_name == name: - raise ValueError( - 'ForeignKeyField "%s"."%s" specifies an ' - "object_id_name that conflicts with its field " - "name." % (model._meta.name, name) - ) - if self._is_self_reference: - self.rel_model = model - if isinstance(self.rel_field, str): - self.rel_field = getattr(self.rel_model, self.rel_field) - elif self.rel_field is None: - self.rel_field = self.rel_model._meta.primary_key - - # Bind field before assigning backref, so field is bound when - # calling declared_backref() (if callable). - super(ForeignKeyField, self).bind(model, name, set_attribute) - self.safe_name = self.object_id_name - - if callable(self.declared_backref): - self.backref = self.declared_backref(self) - else: - self.backref, self.declared_backref = self.declared_backref, None - if not self.backref: - self.backref = "%s_set" % model._meta.name - - if set_attribute: - setattr(model, self.object_id_name, ObjectIdAccessor(self)) - if self.backref not in "!+": - setattr(self.rel_model, self.backref, BackrefAccessor(self)) - def foreign_key_constraint(self): + def field_type(self) -> str: ... + def get_modifiers(self) -> Optional[Iterable[object]]: ... + def adapt(self, value: object) -> Any: ... + def db_value(self, value: object) -> Any: ... + def python_value(self, value: object) -> Any: ... + def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... + def foreign_key_constraint(self) -> NodeList: parts = [ SQL("FOREIGN KEY"), EnclosedNodeList((self,)), @@ -1797,13 +1610,7 @@ class ForeignKeyField(Field): if self.deferrable: parts.append(SQL("DEFERRABLE %s" % self.deferrable)) return NodeList(parts) - def __getattr__(self, attr): - if attr.startswith("__"): - # Prevent recursion error when deep-copying. - raise AttributeError('Cannot look-up non-existant "__" methods.') - if attr in self.rel_model._meta.fields: - return self.rel_model._meta.fields[attr] - raise AttributeError("Foreign-key has no attribute %s, nor is it a " "valid field on the related model." % attr) + def __getattr__(self, attr: str) -> Field: ... class DeferredForeignKey(Field): _unresolved = set() @@ -1826,19 +1633,18 @@ class DeferredForeignKey(Field): dr.set_model(model_cls) DeferredForeignKey._unresolved.discard(dr) -class DeferredThroughModel(object): - def __init__(self): - self._refs = [] - def set_field(self, model, field, name): - self._refs.append((model, field, name)) - def set_model(self, through_model): - for src_model, m2mfield, name in self._refs: - m2mfield.through_model = through_model - src_model._meta.add_field(name, m2mfield) +class DeferredThroughModel: + def set_field(self, model: Type[Model], field: Type[Field], name: str) -> None: ... + def set_model(self, through_model: Type[Model]) -> None: ... class MetaField(Field): - column_name = default = model = name = None - primary_key = False + # These are declared as class variables in the source code but are used like local + # variables + column_name: Optional[str] + default: Any + model: Type[Model] + name: Optional[str] + primary_key: bool class ManyToManyFieldAccessor(FieldAccessor): model: Type[Model] @@ -1876,10 +1682,10 @@ class ManyToManyField(MetaField): def get_models(self) -> List[Type[Model]]: ... def get_through_model(self) -> Union[Type[Model], DeferredThroughModel]: ... -class VirtualField(MetaField, Generic[__TField]): - field_class: Type[__TField] - field_instance: __TField - def __init__(self, field_class: Optional[Type[__TField]] = ..., *args: object, **kwargs: object): ... +class VirtualField(MetaField, Generic[_TField]): + field_class: Type[_TField] + field_instance: _TField + def __init__(self, field_class: Optional[Type[_TField]] = ..., *args: object, **kwargs: object): ... def db_value(self, value): if self.field_instance is not None: return self.field_instance.db_value(value) @@ -1892,36 +1698,24 @@ class VirtualField(MetaField, Generic[__TField]): class CompositeKey(MetaField): sequence = None - def __init__(self, *field_names): + field_names: Tuple[str, ...] + def __init__(self, *field_names: str): self.field_names = field_names self._safe_field_names = None @property - def safe_field_names(self): - if self._safe_field_names is None: - if self.model is None: - return self.field_names - - self._safe_field_names = [self.model._meta.fields[f].safe_name for f in self.field_names] - return self._safe_field_names - def __get__(self, instance, instance_type=None): - if instance is not None: - return tuple([getattr(instance, f) for f in self.safe_field_names]) - return self - def __set__(self, instance, value): - if not isinstance(value, (list, tuple)): - raise TypeError("A list or tuple must be used to set the value of " "a composite primary key.") - if len(value) != len(self.field_names): - raise ValueError("The length of the value must equal the number " "of columns of the composite primary key.") - for idx, field_value in enumerate(value): - setattr(instance, self.field_names[idx], field_value) + def safe_field_names(self) -> Union[List[str], Tuple[str, ...]]: ... + @overload + def __get__(self, instance: None, instance_type: type) -> CompositeKey: ... + @overload + def __get__(self, instance: T, instance_type: Type[T]) -> tuple: ... + def __set__(self, instance: Model, value: Union[list, tuple]) -> None: ... def __eq__(self, other): expressions = [(self.model._meta.fields[field] == value) for field, value in zip(self.field_names, other)] return reduce(operator.and_, expressions) def __ne__(self, other): return ~(self == other) - def __hash__(self): - return hash((self.model.__name__, self.field_names)) - def __sql__(self, ctx): + def __hash__(self) -> int: ... + def __sql__(self, ctx: Context) -> Context: # If the composite PK is being selected, do not use parens. Elsewhere, # such as in an expression, we want to use parentheses and treat it as # a row value. @@ -1932,7 +1726,7 @@ class CompositeKey(MetaField): self.column_name = self.name = self.safe_name = name setattr(model, self.name, self) -class _SortedFieldList(object): +class _SortedFieldList: __slots__ = ("_keys", "_items") def __init__(self): self._keys = [] @@ -1960,17 +1754,17 @@ class _SortedFieldList(object): # MODELS -class SchemaManager(object): +class SchemaManager: model: Type[Model] - context_options: Dict[str, Any] - def __init__(self, model: Type[Model], database: Optional[Database] = None, **context_options: Any): ... + context_options: Dict[str, object] + def __init__(self, model: Type[Model], database: Optional[Database] = None, **context_options: object): ... @property def database(self) -> Database: ... @database.setter def database(self, value: Optional[Database]) -> None: ... - def create_table(self, safe: bool = ..., **options: Any) -> None: ... - def create_table_as(self, table_name: str, query: SelectQuery, safe: bool = ..., **meta: Any) -> None: ... - def drop_table(self, safe: bool = ..., **options: Any) -> None: ... + def create_table(self, safe: bool = ..., **options: object) -> None: ... + def create_table_as(self, table_name: str, query: SelectQuery, safe: bool = ..., **meta: object) -> None: ... + def drop_table(self, safe: bool = ..., **options: object) -> None: ... def truncate_table(self, restart_identity: bool = ..., cascade: bool = ...) -> None: ... def create_indexes(self, safe: bool = ...) -> None: ... def drop_indexes(self, safe: bool = ...) -> None: ... @@ -1978,19 +1772,19 @@ class SchemaManager(object): def drop_sequence(self, field: Field) -> None: ... def create_foreign_key(self, field: Field) -> None: ... def create_sequences(self) -> None: ... - def create_all(self, safe: bool = ..., **table_options: Any) -> None: ... + def create_all(self, safe: bool = ..., **table_options: object) -> None: ... def drop_sequences(self) -> None: ... - def drop_all(self, safe: bool = ..., drop_sequences: bool = ..., **options: Any) -> None: ... + def drop_all(self, safe: bool = ..., drop_sequences: bool = ..., **options: object) -> None: ... -class Metadata(object): +class Metadata: model: Type[Model] database: Optional[Database] - fields: Dict[str, Any] # TODO (dargueta) This may be Dict[str, Field] - columns: Dict[str, Any] # TODO (dargueta) Verify this - combined: Dict[str, Any] # TODO (dargueta) Same as above + fields: Dict[str, object] # TODO (dargueta) This may be Dict[str, Field] + columns: Dict[str, object] # TODO (dargueta) Verify this + combined: Dict[str, object] # TODO (dargueta) Same as above sorted_fields: List[Field] sorted_field_names: List[str] - defaults: Dict[str, Any] + defaults: Dict[str, object] name: str table_function: Optional[Callable[[Type[Model]], str]] legacy_table_names: bool @@ -2193,36 +1987,23 @@ class Model(Node, metaclass=ModelBase): @classmethod def add_index(cls, *fields: Union[str, SQL, Index], **kwargs: object) -> None: ... -class ModelAlias(Node): +class ModelAlias(Node, Generic[_TModel]): """Provide a separate reference to a model in a query.""" - model: Type[Model] + model: Type[_TModel] alias: Optional[str] - def __init__(self, model: Type[Model], alias: Optional[str] = ...): ... - def __getattr__(self, attr: str): - # Hack to work-around the fact that properties or other objects - # implementing the descriptor protocol (on the model being aliased), - # will not work correctly when we use getattr(). So we explicitly pass - # the model alias to the descriptor's getter. - try: - obj = self.model.__dict__[attr] - except KeyError: ... - else: - if isinstance(obj, ModelDescriptor): - return obj.__get__(None, self) - - model_attr = getattr(self.model, attr) - if isinstance(model_attr, Field): - self.__dict__[attr] = FieldAlias.create(self, model_attr) - return self.__dict__[attr] - return model_attr + def __init__(self, model: Type[_TModel], alias: Optional[str] = ...): ... + def __getattr__(self, attr: str) -> Any: ... def __setattr__(self, attr: str, value: object) -> NoReturn: ... def get_field_aliases(self) -> List[Field]: ... def select(self, *selection: Field) -> ModelSelect: ... - def __call__(self, **kwargs): - return self.model(**kwargs) + def __call__(self, **kwargs) -> _TModel: ... def __sql__(self, ctx: Context) -> Context: ... +_TModelOrTable = Union[Type[Model], ModelAlias, Table] +_TSubquery = Union[Tuple[Query, Type[Model]], Type[Model], ModelAlias] +_TFieldOrModel = Union[_TModelOrTable, Field] + class FieldAlias(Field): source: Node model: Type[Model] @@ -2230,68 +2011,58 @@ class FieldAlias(Field): # TODO (dargueta): Making an educated guess about `source`; might be `Node` def __init__(self, source: MetaField, field: Field): ... @classmethod - def create(cls, source: ModelAlias, field: str): - class _FieldAlias(cls, type(field)): ... - return _FieldAlias(source, field) + def create(cls, source: ModelAlias, field: str) -> FieldAlias: ... def clone(self) -> FieldAlias: ... - def adapt(self, value): - return self.field.adapt(value) - def python_value(self, value): - return self.field.python_value(value) - def db_value(self, value): - return self.field.db_value(value) - def __getattr__(self, attr): - return self.source if attr == "model" else getattr(self.field, attr) + def adapt(self, value: object) -> Any: ... + def python_value(self, value: object) -> Any: ... + def db_value(self, value: object) -> Any: ... + @overload + def __getattr__(self, attr: Literal["model"]) -> Node: ... + @overload + def __getattr__(self, attr: str) -> Any: ... def __sql__(self, ctx: Context) -> Context: ... def sort_models(models: Iterable[Type[Model]]) -> List[Type[Model]]: ... -class _ModelQueryHelper(object): +class _ModelQueryHelper: default_row_type: ClassVar[int] def objects(self, constructor: Optional[Callable[..., Any]] = ...) -> _ModelQueryHelper: ... -class ModelRaw(_ModelQueryHelper, RawQuery, Generic[__TModel]): - model: Type[__TModel] - def __init__(self, model: Type[__TModel], sql: str, params: tuple, **kwargs: object): ... - def get(self) -> __TModel: ... +class ModelRaw(_ModelQueryHelper, RawQuery, Generic[_TModel]): + model: Type[_TModel] + def __init__(self, model: Type[_TModel], sql: str, params: tuple, **kwargs: object): ... + def get(self) -> _TModel: ... class BaseModelSelect(_ModelQueryHelper): - def union_all(self, rhs): - return ModelCompoundSelectQuery(self.model, self, "UNION ALL", rhs) + def union_all(self, rhs: object) -> ModelCompoundSelectQuery: ... __add__ = union_all - def union(self, rhs): - return ModelCompoundSelectQuery(self.model, self, "UNION", rhs) + def union(self, rhs: object) -> ModelCompoundSelectQuery: ... __or__ = union - def intersect(self, rhs): - return ModelCompoundSelectQuery(self.model, self, "INTERSECT", rhs) + def intersect(self, rhs: object) -> ModelCompoundSelectQuery: ... __and__ = intersect - def except_(self, rhs): - return ModelCompoundSelectQuery(self.model, self, "EXCEPT", rhs) + def except_(self, rhs: object) -> ModelCompoundSelectQuery: ... __sub__ = except_ - def __iter__(self) -> Iterator[Any]: - if not self._cursor_wrapper: - self.execute() - return iter(self._cursor_wrapper) - def prefetch(self, *subqueries: __TSubquery) -> List[Any]: ... + def __iter__(self) -> Iterator[Any]: ... + def prefetch(self, *subqueries: _TSubquery) -> List[Any]: ... def get(self, database: Optional[Database] = ...) -> Any: ... def group_by(self, *columns: Union[Type[Model], Table, Field]) -> BaseModelSelect: ... class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): - model: Type[__TModel] - def __init__(self, model: Type[__TModel], *args: object, **kwargs: object): ... + model: Type[_TModel] + def __init__(self, model: Type[_TModel], *args: object, **kwargs: object): ... -class ModelSelect(BaseModelSelect, Select): - model: Type[Model] - def __init__(self, model: Type[Model], fields_or_models: Iterable[__TFieldOrModel], is_default: bool = ...): ... +class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): + model: Type[_TModel] + def __init__(self, model: Type[_TModel], fields_or_models: Iterable[_TFieldOrModel], is_default: bool = ...): ... def clone(self) -> ModelSelect: ... - def select(self, *fields_or_models: __TFieldOrModel) -> ModelSelect: ... + def select(self, *fields_or_models: _TFieldOrModel) -> ModelSelect: ... def switch(self, ctx: Optional[Type[Model]] = ...) -> ModelSelect: ... def join( self, dest: Union[Type[Model], Table, ModelAlias, ModelSelect], join_type: int = ..., - on: Union[Column, Expression, Field, None] = ..., - src: Union[Type[Model], Table, ModelAlias, ModelSelect, None] = ..., + on: Optional[Union[Column, Expression, Field]] = ..., + src: Optional[Union[Type[Model], Table, ModelAlias, ModelSelect]] = ..., attr: Optional[str] = ..., ) -> ModelSelect: ... def join_from( @@ -2299,16 +2070,16 @@ class ModelSelect(BaseModelSelect, Select): src: Union[Type[Model], Table, ModelAlias, ModelSelect], dest: Union[Type[Model], Table, ModelAlias, ModelSelect], join_type: int = ..., - on: Union[Column, Expression, Field, None] = ..., + on: Optional[Union[Column, Expression, Field]] = ..., attr: Optional[str] = ..., ) -> ModelSelect: ... def ensure_join( - self, lm: Type[Model], rm: Type[Model], on: Union[Column, Expression, Field, None] = ..., **join_kwargs: Any + self, lm: Type[Model], rm: Type[Model], on: Optional[Union[Column, Expression, Field]] = ..., **join_kwargs: Any ) -> ModelSelect: ... # TODO (dargueta): 85% sure about the return value def convert_dict_to_node(self, qdict: Mapping[str, object]) -> Tuple[List[Expression], List[Field]]: ... def filter(self, *args: Node, **kwargs: object) -> ModelSelect: ... - def create_table(self, name: str, safe: bool = ..., **meta: Any) -> None: ... + def create_table(self, name: str, safe: bool = ..., **meta: object) -> None: ... def __sql_selection__(self, ctx: Context, is_subquery: bool = ...) -> Context: ... class NoopModelSelect(ModelSelect): @@ -2332,111 +2103,57 @@ class ModelDelete(_ModelWriteQueryHelper, Delete): ... class ManyToManyQuery(ModelSelect): def __init__( - self, instance: Model, accessor: ManyToManyFieldAccessor, rel: __TFieldOrModel, *args: object, **kwargs: object + self, instance: Model, accessor: ManyToManyFieldAccessor, rel: _TFieldOrModel, *args: object, **kwargs: object ): ... def add(self, value: Union[SelectQuery, Type[Model], Iterable[str]], clear_existing: bool = ...) -> None: ... def remove(self, value: Union[SelectQuery, Type[Model], Iterable[str]]) -> Optional[int]: ... def clear(self) -> int: ... -class BaseModelCursorWrapper(DictCursorWrapper, Generic[__TModel]): +class BaseModelCursorWrapper(DictCursorWrapper, Generic[_TModel]): ncols: int columns: List[str] - converters: List[__TConvFunc] + converters: List[_TConvFunc] fields: List[Field] - model: Type[__TModel] + model: Type[_TModel] select: Sequence[str] - def __init__(self, cursor: __ICursor, model: Type[__TModel], columns: Optional[Sequence[str]]): ... + def __init__(self, cursor: __ICursor, model: Type[_TModel], columns: Optional[Sequence[str]]): ... def process_row(self, row: tuple) -> Mapping[str, object]: ... -class ModelDictCursorWrapper(BaseModelCursorWrapper[__TModel]): +class ModelDictCursorWrapper(BaseModelCursorWrapper[_TModel]): def process_row(self, row: tuple) -> Dict[str, Any]: ... -class ModelTupleCursorWrapper(ModelDictCursorWrapper[__TModel]): +class ModelTupleCursorWrapper(ModelDictCursorWrapper[_TModel]): constructor: ClassVar[Callable[[Sequence[Any]], tuple]] def process_row(self, row: tuple) -> tuple: ... -class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper[__TModel]): ... +class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper[_TModel]): ... -class ModelObjectCursorWrapper(ModelDictCursorWrapper[__TModel]): - constructor: Union[Type[__TModel], Callable[[Any], __TModel]] +class ModelObjectCursorWrapper(ModelDictCursorWrapper[_TModel]): + constructor: Union[Type[_TModel], Callable[[Any], _TModel]] is_model: bool # TODO (dargueta): `select` is some kind of Sequence def __init__( - self, cursor: __ICursor, model: __TModel, select, constructor: Union[Type[__TModel], Callable[[Any], __TModel]] + self, + cursor: __ICursor, + model: _TModel, + select: Sequence[object], + constructor: Union[Type[_TModel], Callable[[Any], _TModel]], ): ... - def process_row(self, row: tuple) -> __TModel: ... + def process_row(self, row: tuple) -> _TModel: ... -class ModelCursorWrapper(BaseModelCursorWrapper[__TModel]): +class ModelCursorWrapper(BaseModelCursorWrapper[_TModel]): from_list: Any # TODO (dargueta) -- Iterable[Union[Join, ...]] - joins: Any # TODO (dargueta) -- Mapping[, Tuple[?, ?, Callable[..., __TModel], int?]] - key_to_constructor: Dict[Type[__TModel], Callable[..., __TModel]] + joins: Any # TODO (dargueta) -- Mapping[, Tuple[?, ?, Callable[..., _TModel], int?]] + key_to_constructor: Dict[Type[_TModel], Callable[..., _TModel]] src_is_dest: Dict[Type[Model], bool] src_to_dest: List[tuple] # TODO -- Tuple[, join_type[1], join_type[0], bool, join_type[3]] column_keys: List # TODO - def __init__(self, cursor: __ICursor, model: Type[__TModel], select, from_list, joins): + def __init__(self, cursor: __ICursor, model: Type[_TModel], select, from_list, joins): super(ModelCursorWrapper, self).__init__(cursor, model, select) self.from_list = from_list self.joins = joins - def initialize(self) -> None: - self._initialize_columns() - selected_src = set([field.model for field in self.fields if field is not None]) - select, columns = self.select, self.columns - - self.key_to_constructor = {self.model: self.model} - self.src_is_dest = {} - self.src_to_dest = [] - accum = collections.deque(self.from_list) - dests = set() - - while accum: - curr = accum.popleft() - if isinstance(curr, Join): - accum.append(curr.lhs) - accum.append(curr.rhs) - continue - - if curr not in self.joins: - continue - - is_dict = isinstance(curr, dict) - for key, attr, constructor, join_type in self.joins[curr]: - if key not in self.key_to_constructor: - self.key_to_constructor[key] = constructor - - # (src, attr, dest, is_dict, join_type). - self.src_to_dest.append((curr, attr, key, is_dict, join_type)) - dests.add(key) - accum.append(key) - - # Ensure that we accommodate everything selected. - for src in selected_src: - if src not in self.key_to_constructor: - if is_model(src): - self.key_to_constructor[src] = src - elif isinstance(src, ModelAlias): - self.key_to_constructor[src] = src.model - - # Indicate which sources are also dests. - for src, _, dest, _, _ in self.src_to_dest: - self.src_is_dest[src] = src in dests and (dest in selected_src or src in selected_src) - - self.column_keys = [] - for idx, node in enumerate(select): - key = self.model - field = self.fields[idx] - if field is not None: - if isinstance(field, FieldAlias): - key = field.source - else: - key = field.model - else: - if isinstance(node, Node): - node = node.unwrap() - if isinstance(node, Column): - key = node.source - - self.column_keys.append(key) - def process_row(self, row: tuple) -> __TModel: ... + def initialize(self) -> None: ... + def process_row(self, row: tuple) -> _TModel: ... class __PrefetchQuery(NamedTuple): query: Query # TODO (dargueta): Verify @@ -2447,25 +2164,8 @@ class __PrefetchQuery(NamedTuple): model: Type[Model] class PrefetchQuery(__PrefetchQuery): - # TODO (dargueta): The key is a two-tuple but not completely sure what - def populate_instance(self, instance: Model, id_map: Mapping[tuple, Any]): - if self.is_backref: - for field in self.fields: - identifier = instance.__data__[field.name] - key = (field, identifier) - if key in id_map: - setattr(instance, field.name, id_map[key]) - else: - for field, attname in self.field_to_name: - identifier = instance.__data__[field.rel_field.name] - key = (field, identifier) - rel_instances = id_map.get(key, []) - for inst in rel_instances: - setattr(inst, attname, instance) - inst._dirty.clear() - setattr(instance, field.backref, rel_instances) - # TODO (dargueta): Same question here about the key tuple - def store_instance(self, instance: Model, id_map: MutableMapping[tuple, List[Model]]) -> None: ... - -def prefetch_add_subquery(sq: Query, subqueries: Iterable[__TSubquery]) -> List[PrefetchQuery]: ... -def prefetch(sq: Query, *subqueries: __TSubquery) -> List[Any]: ... + def populate_instance(self, instance: Model, id_map: Mapping[Tuple[Any, Any], object]): ... + def store_instance(self, instance: Model, id_map: MutableMapping[Tuple[Any, Any], List[Model]]) -> None: ... + +def prefetch_add_subquery(sq: Query, subqueries: Iterable[_TSubquery]) -> List[PrefetchQuery]: ... +def prefetch(sq: Query, *subqueries: _TSubquery) -> List[Any]: ... From 1f7af9fd47edd7ff5a81828d913e7d7e91209560 Mon Sep 17 00:00:00 2001 From: Diego Argueta <620513-dargueta@users.noreply.github.com> Date: Sat, 16 Jan 2021 14:34:13 -0800 Subject: [PATCH 03/22] Commit final parts of first draft of the PR. --- third_party/2and3/peewee.pyi | 186 ++++++++++------------------------- 1 file changed, 50 insertions(+), 136 deletions(-) diff --git a/third_party/2and3/peewee.pyi b/third_party/2and3/peewee.pyi index 5a169604c6a0..ef2ba0f9bf3f 100644 --- a/third_party/2and3/peewee.pyi +++ b/third_party/2and3/peewee.pyi @@ -1,12 +1,9 @@ import datetime import decimal import enum -import operator import re import threading import uuid -from bisect import bisect_left, bisect_right -from contextlib import contextmanager from typing import ( Any, AnyStr, @@ -16,6 +13,7 @@ from typing import ( ContextManager, Dict, Generic, + Hashable, Iterable, Iterator, List, @@ -302,8 +300,8 @@ class Table(_HashableSource, BaseTable): def __sql__(self, ctx: Context) -> Context: ... class Join(BaseTable): - lhs: Any # TODO - rhs: Any # TODO + lhs: Any # TODO (dargueta) + rhs: Any # TODO (dargueta) join_type: int def __init__(self, lhs, rhs, join_type: int = ..., on: Optional[Expression] = ..., alias: Optional[str] = ...): ... def on(self, predicate: Expression) -> Join: ... @@ -1595,43 +1593,17 @@ class ForeignKeyField(Field): def db_value(self, value: object) -> Any: ... def python_value(self, value: object) -> Any: ... def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... - def foreign_key_constraint(self) -> NodeList: - parts = [ - SQL("FOREIGN KEY"), - EnclosedNodeList((self,)), - SQL("REFERENCES"), - self.rel_model, - EnclosedNodeList((self.rel_field,)), - ] - if self.on_delete: - parts.append(SQL("ON DELETE %s" % self.on_delete)) - if self.on_update: - parts.append(SQL("ON UPDATE %s" % self.on_update)) - if self.deferrable: - parts.append(SQL("DEFERRABLE %s" % self.deferrable)) - return NodeList(parts) + def foreign_key_constraint(self) -> NodeList: ... def __getattr__(self, attr: str) -> Field: ... class DeferredForeignKey(Field): - _unresolved = set() - def __init__(self, rel_model_name, **kwargs): - self.field_kwargs = kwargs - self.rel_model_name = rel_model_name.lower() - DeferredForeignKey._unresolved.add(self) - super(DeferredForeignKey, self).__init__(column_name=kwargs.get("column_name"), null=kwargs.get("null")) - __hash__ = object.__hash__ - def __deepcopy__(self, memo=None): - return DeferredForeignKey(self.rel_model_name, **self.field_kwargs) - def set_model(self, rel_model): - field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) - self.model._meta.add_field(self.name, field) + field_kwargs: Dict[str, object] + rel_model_name: str + def __init__(self, rel_model_name: str, *, column_name: Optional[str] = ..., null: Optional[str] = ..., **kwargs: object): ... + def set_model(self, rel_model: Type[Model]) -> None: ... @staticmethod - def resolve(model_cls): - unresolved = sorted(DeferredForeignKey._unresolved, key=operator.attrgetter("_order")) - for dr in unresolved: - if dr.rel_model_name == model_cls.__name__.lower(): - dr.set_model(model_cls) - DeferredForeignKey._unresolved.discard(dr) + def resolve(model_cls: Type[Model]) -> None: ... + def __hash__(self) -> int: ... class DeferredThroughModel: def set_field(self, model: Type[Model], field: Type[Field], name: str) -> None: ... @@ -1659,9 +1631,7 @@ class ManyToManyFieldAccessor(FieldAccessor): def __get__( self, instance: T, instance_type: Type[T] = ..., force_query: bool = ... ) -> Union[List[str], ManyToManyQuery]: ... - def __set__(self, instance: T, value) -> None: - query = self.__get__(instance, force_query=True) - query.add(value, clear_existing=True) + def __set__(self, instance: T, value) -> None: ... class ManyToManyField(MetaField): accessor_class: ClassVar[Type[ManyToManyFieldAccessor]] @@ -1684,24 +1654,20 @@ class ManyToManyField(MetaField): class VirtualField(MetaField, Generic[_TField]): field_class: Type[_TField] - field_instance: _TField + field_instance: Optional[_TField] def __init__(self, field_class: Optional[Type[_TField]] = ..., *args: object, **kwargs: object): ... - def db_value(self, value): - if self.field_instance is not None: - return self.field_instance.db_value(value) - return value - def python_value(self, value): - if self.field_instance is not None: - return self.field_instance.python_value(value) - return value + def db_value(self, value: object) -> Any: ... + def python_value(self, value: object) -> Any: ... def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... class CompositeKey(MetaField): sequence = None field_names: Tuple[str, ...] - def __init__(self, *field_names: str): - self.field_names = field_names - self._safe_field_names = None + # The following attributes are not set in the constructor an so may not always be + # present. + model: Type["Model"] + column_name: str + def __init__(self, *field_names: str): ... @property def safe_field_names(self) -> Union[List[str], Tuple[str, ...]]: ... @overload @@ -1709,48 +1675,11 @@ class CompositeKey(MetaField): @overload def __get__(self, instance: T, instance_type: Type[T]) -> tuple: ... def __set__(self, instance: Model, value: Union[list, tuple]) -> None: ... - def __eq__(self, other): - expressions = [(self.model._meta.fields[field] == value) for field, value in zip(self.field_names, other)] - return reduce(operator.and_, expressions) - def __ne__(self, other): - return ~(self == other) + def __eq__(self, other: Expression) -> Expression: ... + def __ne__(self, other: Expression) -> Expression: ... def __hash__(self) -> int: ... - def __sql__(self, ctx: Context) -> Context: - # If the composite PK is being selected, do not use parens. Elsewhere, - # such as in an expression, we want to use parentheses and treat it as - # a row value. - parens = ctx.scope != SCOPE_SOURCE - return ctx.sql(NodeList([self.model._meta.fields[field] for field in self.field_names], ", ", parens)) - def bind(self, model, name, set_attribute=True): - self.model = model - self.column_name = self.name = self.safe_name = name - setattr(model, self.name, self) - -class _SortedFieldList: - __slots__ = ("_keys", "_items") - def __init__(self): - self._keys = [] - self._items = [] - def __getitem__(self, i): - return self._items[i] - def __iter__(self): - return iter(self._items) - def __contains__(self, item): - k = item._sort_key - i = bisect_left(self._keys, k) - j = bisect_right(self._keys, k) - return item in self._items[i:j] - def index(self, field): - return self._keys.index(field._sort_key) - def insert(self, item): - k = item._sort_key - i = bisect_left(self._keys, k) - self._keys.insert(i, k) - self._items.insert(i, item) - def remove(self, item): - idx = self.index(item) - del self._items[idx] - del self._keys[idx] + def __sql__(self, ctx: Context) -> Context: ... + def bind(self, model: Type["Model"], name: str, set_attribute: bool = ...) -> None: ... # MODELS @@ -1900,14 +1829,11 @@ class Model(Node, metaclass=ModelBase): @classmethod def insert_from(cls, query: SelectQuery, fields: Iterable[Union[Field, Text]]) -> ModelInsert: ... @classmethod - def replace(cls, __data=None, **insert): - return cls.insert(__data, **insert).on_conflict("REPLACE") + def replace(cls, __data: Optional[Iterable[Union[str, Field]]] = ..., **insert: object) -> OnConflict: ... @classmethod - def replace_many(cls, rows, fields=None): - return cls.insert_many(rows=rows, fields=fields).on_conflict("REPLACE") + def replace_many(cls, rows: Iterable[tuple], fields: Optional[Sequence[Field]] = ...) -> OnConflict: ... @classmethod - def raw(cls, sql, *params): - return ModelRaw(cls, sql, params) + def raw(cls, sql: str, *params: object) -> ModelRaw: ... @classmethod def delete(cls) -> ModelDelete: ... @classmethod @@ -1921,33 +1847,17 @@ class Model(Node, metaclass=ModelBase): @classmethod def noop(cls) -> NoopModelSelect: ... @classmethod - def get(cls, *query, **filters): - sq = cls.select() - if query: - # Handle simple lookup using just the primary key. - if len(query) == 1 and isinstance(query[0], int): - sq = sq.where(cls._meta.primary_key == query[0]) - else: - sq = sq.where(*query) - if filters: - sq = sq.filter(**filters) - return sq.get() + def get(cls, *query: object, **filters: object) -> ModelSelect: ... @classmethod - def get_or_none(cls, *query, **filters): - try: - return cls.get(*query, **filters) - except DoesNotExist: ... + def get_or_none(cls, *query: object, **filters: object) -> Optional[ModelSelect]: ... @classmethod - def get_by_id(cls, pk): - return cls.get(cls._meta.primary_key == pk) + def get_by_id(cls, pk: object) -> ModelSelect: ... + # TODO (dargueta) I'm 99% sure of return value for this one @classmethod - def set_by_id(cls, key, value) -> Any: # TODO (dargueta): Verify return type of .execute() - if key is None: - return cls.insert(value).execute() - else: - return cls.update(value).where(cls._meta.primary_key == key).execute() + def set_by_id(cls, key, value) -> CursorWrapper: ... + # TODO (dargueta) I'm also not 100% about this one's return value. @classmethod - def delete_by_id(cls, pk: object) -> Any: ... # TODO (dargueta): Verify return type of .execute() + def delete_by_id(cls, pk: object) -> CursorWrapper: ... @classmethod def get_or_create(cls, *, defaults: Mapping[str, object] = ..., **kwargs: object) -> Tuple[Any, bool]: ... @classmethod @@ -1960,8 +1870,8 @@ class Model(Node, metaclass=ModelBase): def dependencies(self, search_nullable: bool = ...) -> Iterator[Tuple[Union[bool, Node], ForeignKeyField]]: ... def delete_instance(self: T, recursive: bool = ..., delete_nullable: bool = ...) -> T: ... def __hash__(self) -> int: ... - def __eq__(self, other: Any) -> bool: ... - def __ne__(self, other: Any) -> bool: ... + def __eq__(self, other: object) -> bool: ... + def __ne__(self, other: object) -> bool: ... def __sql__(self, ctx: Context) -> Context: ... @classmethod def bind( @@ -2095,8 +2005,7 @@ class ModelUpdate(_ModelWriteQueryHelper, Update): ... class ModelInsert(_ModelWriteQueryHelper, Insert): default_row_type: ClassVar[int] def returning(self, *returning: Union[Type[Model], Field]) -> ModelInsert: ... - def get_default_data(self): - return self.model._meta.defaults + def get_default_data(self) -> Mapping[str, object]: ... def get_default_columns(self) -> Sequence[Field]: ... class ModelDelete(_ModelWriteQueryHelper, Delete): ... @@ -2142,16 +2051,21 @@ class ModelObjectCursorWrapper(ModelDictCursorWrapper[_TModel]): def process_row(self, row: tuple) -> _TModel: ... class ModelCursorWrapper(BaseModelCursorWrapper[_TModel]): - from_list: Any # TODO (dargueta) -- Iterable[Union[Join, ...]] - joins: Any # TODO (dargueta) -- Mapping[, Tuple[?, ?, Callable[..., _TModel], int?]] + from_list: Iterable[Any] # TODO (dargueta) -- Iterable[Union[Join, ...]] + # TODO (dargueta) -- Mapping[, Tuple[?, ?, Callable[..., _TModel], int?]] + joins: Mapping[Hashable, Tuple[object, object, Callable[..., _TModel], int]] key_to_constructor: Dict[Type[_TModel], Callable[..., _TModel]] src_is_dest: Dict[Type[Model], bool] src_to_dest: List[tuple] # TODO -- Tuple[, join_type[1], join_type[0], bool, join_type[3]] column_keys: List # TODO - def __init__(self, cursor: __ICursor, model: Type[_TModel], select, from_list, joins): - super(ModelCursorWrapper, self).__init__(cursor, model, select) - self.from_list = from_list - self.joins = joins + def __init__( + self, + cursor: __ICursor, + model: Type[_TModel], + select, + from_list: Iterable[object], + joins: Mapping[Hashable, Tuple[object, object, Callable[..., _TModel], int]], + ): ... def initialize(self) -> None: ... def process_row(self, row: tuple) -> _TModel: ... @@ -2164,8 +2078,8 @@ class __PrefetchQuery(NamedTuple): model: Type[Model] class PrefetchQuery(__PrefetchQuery): - def populate_instance(self, instance: Model, id_map: Mapping[Tuple[Any, Any], object]): ... - def store_instance(self, instance: Model, id_map: MutableMapping[Tuple[Any, Any], List[Model]]) -> None: ... + def populate_instance(self, instance: Model, id_map: Mapping[Tuple[object, object], object]): ... + def store_instance(self, instance: Model, id_map: MutableMapping[Tuple[object, object], List[Model]]) -> None: ... def prefetch_add_subquery(sq: Query, subqueries: Iterable[_TSubquery]) -> List[PrefetchQuery]: ... -def prefetch(sq: Query, *subqueries: _TSubquery) -> List[Any]: ... +def prefetch(sq: Query, *subqueries: _TSubquery) -> List[object]: ... From 92dff01a69dc1382793ad2938294de0f707d8c19 Mon Sep 17 00:00:00 2001 From: Diego Argueta Date: Mon, 25 Jan 2021 19:23:33 -0800 Subject: [PATCH 04/22] Incremental fixes --- tests/mypy_test.py | 1 + third_party/2and3/peewee.pyi | 122 ++++++++++++++++++----------------- 2 files changed, 65 insertions(+), 58 deletions(-) diff --git a/tests/mypy_test.py b/tests/mypy_test.py index a495d65e49e6..a12da57c7446 100755 --- a/tests/mypy_test.py +++ b/tests/mypy_test.py @@ -132,6 +132,7 @@ def main(): flags.append("--no-implicit-optional") flags.append("--disallow-any-generics") flags.append("--disallow-subclassing-any") + flags.append("--show-error-codes") if args.warn_unused_ignores: flags.append("--warn-unused-ignores") if args.platform: diff --git a/third_party/2and3/peewee.pyi b/third_party/2and3/peewee.pyi index ef2ba0f9bf3f..ad2b4499ddfb 100644 --- a/third_party/2and3/peewee.pyi +++ b/third_party/2and3/peewee.pyi @@ -121,6 +121,7 @@ class Proxy: obj: Any def initialize(self, obj: object) -> None: ... def attach_callback(self, callback: _TConvFunc) -> _TConvFunc: ... + @staticmethod # This is technically inaccurate but that's how it's used def passthrough(method: _TFunc) -> _TFunc: ... def __enter__(self) -> Any: ... def __exit__(self, exc_type, exc_val, exc_tb) -> Any: ... @@ -219,30 +220,34 @@ class _ExplicitColumn: @overload def __get__(self, instance: T, instance_type: Type[T]) -> NoReturn: ... -class Source(Node): +class _SupportsAlias(Protocol): + def alias(self: T, name: str) -> T: ... + +class Source(_SupportsAlias, Node): c: ClassVar[_DynamicColumn] def __init__(self, alias: Optional[str] = ...): ... - def alias(self, name: str) -> Source: ... def select(self, *columns: Field) -> Select: ... def join(self, dest, join_type: int = ..., on: Optional[Expression] = ...) -> Join: ... def left_outer_join(self, dest, on: Optional[Expression] = ...) -> Join: ... - def cte(self, name: str, recursive: bool = ..., columns=None, materialized=None) -> CTE: ... + def cte(self, name: str, recursive: bool = ..., columns=..., materialized=...) -> CTE: ... # incomplete def get_sort_key(self, ctx) -> Tuple[str, ...]: ... def apply_alias(self, ctx: Context) -> Context: ... def apply_column(self, ctx: Context) -> Context: ... -class _HashableSource: +class _HashableSource(_SupportsAlias): def __init__(self, *args: object, **kwargs: object): ... - def alias(self, name: str) -> _HashableSource: ... def __hash__(self) -> int: ... + # The overrides here are unfortunately a necessary evil. The __eq__/__ne__ methods + # return different types depending on the type of the argument, and both differ + # from `object`'s signature. + @overload # type: ignore + def __eq__(self, other: _HashableSource) -> bool: ... # type: ignore @overload - def __eq__(self, other: _HashableSource) -> bool: ... - @overload - def __eq__(self, other: object) -> Expression: ... + def __eq__(self, other: object) -> Expression: ... # type: ignore + @overload # type: ignore + def __ne__(self, other: _HashableSource) -> bool: ... # type: ignore @overload - def __ne__(self, other: _HashableSource) -> bool: ... - @overload - def __ne__(self, other: object) -> Expression: ... + def __ne__(self, other: object) -> Expression: ... # type: ignore def __lt__(self, other: object) -> Expression: ... def __le__(self, other: object) -> Expression: ... def __gt__(self, other: object) -> Expression: ... @@ -308,7 +313,7 @@ class Join(BaseTable): def __sql__(self, ctx: Context) -> Context: ... class ValuesList(_HashableSource, BaseTable): - def __init__(self, values, columns=None, alias: Optional[str] = ...): ... + def __init__(self, values, columns=..., alias: Optional[str] = ...): ... # incomplete # FIXME (dargueta) `names` might be wrong def columns(self, *names: str) -> ValuesList: ... def __sql__(self, ctx: Context) -> Context: ... @@ -340,11 +345,12 @@ class ColumnBase(Node): def alias(self, alias: str) -> Alias: ... def unalias(self) -> ColumnBase: ... def cast(self, as_type: str) -> Cast: ... - def asc(self, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Asc: ... + def asc(self, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> _SupportsSQLOrdering: ... __pos__ = asc - def desc(self, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Desc: ... + def desc(self, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> _SupportsSQLOrdering: ... __neg__ = desc - def __invert__(self) -> Negated: ... + # TODO (dargueta): This always returns Negated but subclasses can return something else + def __invert__(self) -> WrappedNode: ... def __and__(self, other: object) -> Expression: ... def __or__(self, other: object) -> Expression: ... def __add__(self, other: object) -> Expression: ... @@ -361,8 +367,8 @@ class ColumnBase(Node): def __rand__(self, other: object) -> Expression: ... def __ror__(self, other: object) -> Expression: ... def __rxor__(self, other: object) -> Expression: ... - def __eq__(self, rhs: Optional[Node]) -> Expression: ... - def __ne__(self, rhs: Optional[Node]) -> Expression: ... + def __eq__(self, rhs: object) -> Expression: ... + def __ne__(self, rhs: object) -> Expression: ... def __lt__(self, other: object) -> Expression: ... def __le__(self, other: object) -> Expression: ... def __gt__(self, other: object) -> Expression: ... @@ -462,6 +468,9 @@ class Ordering(WrappedNode): def collate(self, collation: Optional[str] = ...) -> Ordering: ... def __sql__(self, ctx: Context) -> Context: ... +class _SupportsSQLOrdering(Protocol): + def __call__(node: Node, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Ordering: ... + def Asc(node: Node, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Ordering: ... def Desc(node: Node, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Ordering: ... @@ -528,7 +537,7 @@ class Window(Node): order_by: Tuple[Union[Field, Expression], ...] start: Optional[Union[str, SQL]] end: Optional[Union[str, SQL]] - frame_type: Optional[Any] # TODO + frame_type: Optional[Any] # incomplete @overload def __init__( self, @@ -585,14 +594,15 @@ class ForUpdate(Node): def Case(predicate: Optional[Node], expression_tuples: Iterable[Tuple[Expression, Any]], default: object = ...) -> NodeList: ... class NodeList(ColumnBase): - nodes: Sequence[Any] # TODO (dargueta): Narrow this type + # TODO (dargueta): Narrow this type + nodes: Sequence[Any] # incomplete glue: str parens: bool - def __init__(self, nodes: Sequence[Any], glue: str = ..., parens: bool = ...): ... + def __init__(self, nodes: Sequence[Any], glue: str = ..., parens: bool = ...): ... # incomplete def __sql__(self, ctx: Context) -> Context: ... -def CommaNodeList(nodes: Sequence[Any]) -> NodeList: ... -def EnclosedNodeList(nodes: Sequence[Any]) -> NodeList: ... +def CommaNodeList(nodes: Sequence[Any]) -> NodeList: ... # incomplete +def EnclosedNodeList(nodes: Sequence[Any]) -> NodeList: ... # incomplete class _Namespace(Node): def __init__(self, name: str): ... @@ -724,20 +734,20 @@ class SelectQuery(Query): class SelectBase(_HashableSource, Source, SelectQuery): @overload - def peek(self, database: Optional[Database] = ..., n: Literal[1] = ...) -> Any: ... + def peek(self, database: Optional[Database] = ..., n: Literal[1] = ...) -> object: ... @overload - def peek(self, database: Optional[Database] = ..., n: int = ...) -> List[Any]: ... + def peek(self, database: Optional[Database] = ..., n: int = ...) -> List[object]: ... @overload - def first(self, database: Optional[Database] = ..., n: Literal[1] = ...) -> Any: ... + def first(self, database: Optional[Database] = ..., n: Literal[1] = ...) -> object: ... @overload - def first(self, database: Optional[Database] = ..., n: int = ...) -> List[Any]: ... + def first(self, database: Optional[Database] = ..., n: int = ...) -> List[object]: ... @overload - def scalar(self, database: Optional[Database] = ..., as_tuple: Literal[False] = ...) -> Any: ... + def scalar(self, database: Optional[Database] = ..., as_tuple: Literal[False] = ...) -> object: ... @overload def scalar(self, database: Optional[Database] = ..., as_tuple: Literal[True] = ...) -> tuple: ... def count(self, database: Optional[Database] = ..., clear_limit: bool = ...) -> int: ... def exists(self, database: Optional[Database] = ...) -> bool: ... - def get(self, database: Optional[Database] = ...) -> Any: ... + def get(self, database: Optional[Database] = ...) -> object: ... # QUERY IMPLEMENTATIONS. @@ -756,7 +766,7 @@ class Select(SelectBase): columns: Optional[Iterable[Union[Column, Field]]] = ..., # TODO (dargueta): `Field` might be wrong # Docs say this is a "[l]ist of columns or values to group by" so we don't have # a whole lot to restrict this to thanks to "or values" - group_by: Sequence[Any] = ..., + group_by: Sequence[object] = ..., having: Optional[Expression] = ..., distinct: Optional[Union[bool, Sequence[Column]]] = ..., windows: Optional[Container[Window]] = ..., @@ -819,7 +829,7 @@ class Insert(_WriteQuery): def on_conflict_ignore(self, ignore: bool = ...) -> Insert: ... def on_conflict_replace(self, replace: bool = ...) -> Insert: ... def on_conflict(self, *args, **kwargs) -> Insert: ... - def get_default_data(self) -> dict: ... + def get_default_data(self) -> Mapping[str, object]: ... def get_default_columns(self) -> Optional[List[Field]]: ... def __sql__(self, ctx: Context) -> Context: ... def handle_result(self, database: Database, cursor: __ICursor) -> Union[__ICursor, int]: ... @@ -1037,7 +1047,9 @@ class SqliteDatabase(Database): timeout: int = ..., **kwargs: object, ) -> None: ... - def pragma(self, key: str, value: Union[str, bool, int] = ..., permanent: bool = ..., schema: Optional[str] = ...) -> Any: ... + def pragma( + self, key: str, value: Union[str, bool, int] = ..., permanent: bool = ..., schema: Optional[str] = ... + ) -> object: ... @property def foreign_keys(self) -> Any: ... @foreign_keys.setter @@ -1198,7 +1210,7 @@ class CursorWrapper(Generic[T]): def iterate(self, cache: bool = ...) -> T: ... def process_row(self, row: tuple) -> T: ... def iterator(self) -> Iterator[T]: ... - def fill_cache(self, n: int = 0) -> None: ... + def fill_cache(self, n: int = ...) -> None: ... class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ... @@ -1206,10 +1218,10 @@ class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ... class NamedTupleCursorWrapper(CursorWrapper[tuple]): tuple_class: Type[tuple] -class ObjectCursorWrapper(DictCursorWrapper, Generic[T]): +class ObjectCursorWrapper(DictCursorWrapper[T]): constructor: Callable[..., T] def __init__(self, cursor: __ICursor, constructor: Callable[..., T]): ... - def process_row(self, row: tuple) -> T: ... + def process_row(self, row: tuple) -> T: ... # type: ignore class ResultIterator(Generic[T]): cursor_wrapper: CursorWrapper[T] @@ -1280,7 +1292,7 @@ class Field(ColumnBase): column_name: str default: Any primary_key: bool - constraints: Optional[Iterable[Check, SQL]] + constraints: Optional[Iterable[Union[Callable[[str], SQL], SQL]]] sequence: Optional[str] collation: Optional[str] unindexed: bool @@ -1317,7 +1329,7 @@ class Field(ColumnBase): def to_value(self, value: Any) -> Value: ... def get_sort_key(self, ctx: Context) -> Tuple[int, int]: ... def __sql__(self, ctx: Context) -> Context: ... - def get_modifiers(self) -> None: ... + def get_modifiers(self) -> Any: ... def ddl_datatype(self, ctx: Context) -> SQL: ... def ddl(self, ctx: Context) -> NodeList: ... @@ -1465,9 +1477,6 @@ class DateTimeField(_BaseFormattedField): def minute(self) -> int: ... @property def second(self) -> int: ... - @overload - def adapt(self, value: str) -> str: ... - @overload def adapt(self, value: T) -> T: ... def to_timestamp(self) -> Function: ... def truncate(self, part: str) -> Function: ... @@ -1480,8 +1489,6 @@ class DateField(_BaseFormattedField): @property def day(self) -> int: ... @overload - def adapt(self, value: str) -> str: ... - @overload def adapt(self, value: datetime.datetime) -> datetime.date: ... @overload def adapt(self, value: T) -> T: ... @@ -1489,8 +1496,6 @@ class DateField(_BaseFormattedField): def truncate(self, part: str) -> Function: ... class TimeField(_BaseFormattedField): - @overload - def adapt(self, value: str) -> str: ... @overload def adapt(self, value: Union[datetime.datetime, datetime.timedelta]) -> datetime.time: ... @overload @@ -1665,7 +1670,7 @@ class CompositeKey(MetaField): field_names: Tuple[str, ...] # The following attributes are not set in the constructor an so may not always be # present. - model: Type["Model"] + model: Type[Model] column_name: str def __init__(self, *field_names: str): ... @property @@ -1679,14 +1684,14 @@ class CompositeKey(MetaField): def __ne__(self, other: Expression) -> Expression: ... def __hash__(self) -> int: ... def __sql__(self, ctx: Context) -> Context: ... - def bind(self, model: Type["Model"], name: str, set_attribute: bool = ...) -> None: ... + def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... # MODELS class SchemaManager: model: Type[Model] context_options: Dict[str, object] - def __init__(self, model: Type[Model], database: Optional[Database] = None, **context_options: object): ... + def __init__(self, model: Type[Model], database: Optional[Database] = ..., **context_options: object): ... @property def database(self) -> Database: ... @database.setter @@ -1719,7 +1724,7 @@ class Metadata: legacy_table_names: bool table_name: str indexes: List[Union[Index, ModelIndex, SQL]] - constraints: Optional[Iterable[Union[Check, SQL]]] + constraints: Optional[Iterable[Union[Callable[[str], SQL], SQL]]] primary_key: Union[Literal[False], Field, CompositeKey, None] composite_key: Optional[bool] auto_increment: Optional[bool] @@ -1892,7 +1897,7 @@ class Model(Node, metaclass=ModelBase): @classmethod def truncate_table(cls, **options: object) -> None: ... @classmethod - def index(cls, *fields, **kwargs): + def index(cls, *fields: Union[Field, Node, str], **kwargs: object) -> ModelIndex: return ModelIndex(cls, fields, **kwargs) @classmethod def add_index(cls, *fields: Union[str, SQL, Index], **kwargs: object) -> None: ... @@ -1957,7 +1962,7 @@ class BaseModelSelect(_ModelQueryHelper): def get(self, database: Optional[Database] = ...) -> Any: ... def group_by(self, *columns: Union[Type[Model], Table, Field]) -> BaseModelSelect: ... -class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): +class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery, Generic[_TModel]): model: Type[_TModel] def __init__(self, model: Type[_TModel], *args: object, **kwargs: object): ... @@ -1965,9 +1970,9 @@ class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): model: Type[_TModel] def __init__(self, model: Type[_TModel], fields_or_models: Iterable[_TFieldOrModel], is_default: bool = ...): ... def clone(self) -> ModelSelect: ... - def select(self, *fields_or_models: _TFieldOrModel) -> ModelSelect: ... + def select(self, *fields_or_models: _TFieldOrModel) -> ModelSelect: ... # type: ignore def switch(self, ctx: Optional[Type[Model]] = ...) -> ModelSelect: ... - def join( + def join( # type: ignore self, dest: Union[Type[Model], Table, ModelAlias, ModelSelect], join_type: int = ..., @@ -2026,14 +2031,14 @@ class BaseModelCursorWrapper(DictCursorWrapper, Generic[_TModel]): model: Type[_TModel] select: Sequence[str] def __init__(self, cursor: __ICursor, model: Type[_TModel], columns: Optional[Sequence[str]]): ... - def process_row(self, row: tuple) -> Mapping[str, object]: ... + def process_row(self, row: tuple) -> Mapping[str, object]: ... # type: ignore class ModelDictCursorWrapper(BaseModelCursorWrapper[_TModel]): def process_row(self, row: tuple) -> Dict[str, Any]: ... class ModelTupleCursorWrapper(ModelDictCursorWrapper[_TModel]): constructor: ClassVar[Callable[[Sequence[Any]], tuple]] - def process_row(self, row: tuple) -> tuple: ... + def process_row(self, row: tuple) -> tuple: ... # type: ignore class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper[_TModel]): ... @@ -2045,19 +2050,20 @@ class ModelObjectCursorWrapper(ModelDictCursorWrapper[_TModel]): self, cursor: __ICursor, model: _TModel, - select: Sequence[object], + select: Sequence[Any], # incomplete constructor: Union[Type[_TModel], Callable[[Any], _TModel]], ): ... - def process_row(self, row: tuple) -> _TModel: ... + def process_row(self, row: tuple) -> _TModel: ... # type: ignore class ModelCursorWrapper(BaseModelCursorWrapper[_TModel]): - from_list: Iterable[Any] # TODO (dargueta) -- Iterable[Union[Join, ...]] + # TODO (dargueta) -- Iterable[Union[Join, ...]] + from_list: Iterable[Any] # incomplete # TODO (dargueta) -- Mapping[, Tuple[?, ?, Callable[..., _TModel], int?]] joins: Mapping[Hashable, Tuple[object, object, Callable[..., _TModel], int]] key_to_constructor: Dict[Type[_TModel], Callable[..., _TModel]] src_is_dest: Dict[Type[Model], bool] src_to_dest: List[tuple] # TODO -- Tuple[, join_type[1], join_type[0], bool, join_type[3]] - column_keys: List # TODO + column_keys: List # incomplete def __init__( self, cursor: __ICursor, @@ -2067,7 +2073,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper[_TModel]): joins: Mapping[Hashable, Tuple[object, object, Callable[..., _TModel], int]], ): ... def initialize(self) -> None: ... - def process_row(self, row: tuple) -> _TModel: ... + def process_row(self, row: tuple) -> _TModel: ... # type: ignore class __PrefetchQuery(NamedTuple): query: Query # TODO (dargueta): Verify From 2fa729271bb98521b4dd38aadab4f995a2645998 Mon Sep 17 00:00:00 2001 From: Diego Argueta Date: Wed, 17 Nov 2021 11:18:46 -0800 Subject: [PATCH 05/22] Add stuff, reformat according to Black --- third_party/2and3/peewee.pyi | 219 +++++++++++++++++++++++++++++------ 1 file changed, 183 insertions(+), 36 deletions(-) diff --git a/third_party/2and3/peewee.pyi b/third_party/2and3/peewee.pyi index ad2b4499ddfb..ca5fcfc9c2a9 100644 --- a/third_party/2and3/peewee.pyi +++ b/third_party/2and3/peewee.pyi @@ -157,7 +157,12 @@ class __State(NamedTuple): class State(__State): def __new__(cls, scope: int = ..., parentheses: bool = ..., **kwargs: object) -> State: ... - def __call__(self, scope: Optional[int] = ..., parentheses: Optional[int] = ..., **kwargs: object) -> State: ... + def __call__( + self, + scope: Optional[int] = ..., + parentheses: Optional[int] = ..., + **kwargs: object, + ) -> State: ... def __getattr__(self, attr_name: str) -> Any: ... class Context: @@ -172,7 +177,9 @@ class Context: @property def parentheses(self) -> bool: ... @property - def subquery(self) -> Any: ... # TODO (dargueta): Figure out type of "self.state.subquery" + def subquery( + self, + ) -> Any: ... # TODO (dargueta): Figure out type of "self.state.subquery" def __call__(self, **overrides: object) -> Context: ... def scope_normal(self) -> ContextManager[Context]: ... def scope_source(self) -> ContextManager[Context]: ... @@ -186,7 +193,12 @@ class Context: # TODO (dargueta): Is this right? def sql(self, obj: object) -> Context: ... def literal(self, keyword: str) -> Context: ... - def value(self, value: object, converter: Optional[_TConvFunc] = ..., add_param: bool = ...) -> Context: ... + def value( + self, + value: object, + converter: Optional[_TConvFunc] = ..., + add_param: bool = ..., + ) -> Context: ... def __sql__(self, ctx: Context) -> Context: ... def parse(self, node: Node) -> Tuple[str, Optional[tuple]]: ... def query(self) -> Tuple[str, Optional[tuple]]: ... @@ -295,11 +307,19 @@ class Table(_HashableSource, BaseTable): @overload def insert(self, insert: Optional[Select], columns: Sequence[Union[str, Field, Column]]) -> Insert: ... @overload - def insert(self, insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], **kwargs: object): ... + def insert( + self, + insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], + **kwargs: object, + ): ... @overload def replace(self, insert: Optional[Select], columns: Sequence[Union[str, Field, Column]]) -> Insert: ... @overload - def replace(self, insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], **kwargs: object): ... + def replace( + self, + insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], + **kwargs: object, + ): ... def update(self, update: Optional[Mapping[str, object]] = ..., **kwargs: object) -> Update: ... def delete(self) -> Delete: ... def __sql__(self, ctx: Context) -> Context: ... @@ -308,7 +328,14 @@ class Join(BaseTable): lhs: Any # TODO (dargueta) rhs: Any # TODO (dargueta) join_type: int - def __init__(self, lhs, rhs, join_type: int = ..., on: Optional[Expression] = ..., alias: Optional[str] = ...): ... + def __init__( + self, + lhs, + rhs, + join_type: int = ..., + on: Optional[Expression] = ..., + alias: Optional[str] = ..., + ): ... def on(self, predicate: Expression) -> Join: ... def __sql__(self, ctx: Context) -> Context: ... @@ -464,7 +491,13 @@ class Ordering(WrappedNode): direction: str collation: Optional[str] nulls: Optional[str] - def __init__(self, node: Node, direction: str, collation: Optional[str] = ..., nulls: Optional[str] = ...): ... + def __init__( + self, + node: Node, + direction: str, + collation: Optional[str] = ..., + nulls: Optional[str] = ..., + ): ... def collate(self, collation: Optional[str] = ...) -> Ordering: ... def __sql__(self, ctx: Context) -> Context: ... @@ -479,7 +512,13 @@ class Expression(ColumnBase): op: int rhs: Optional[Union[Node, str]] flat: bool - def __init__(self, lhs: Optional[Union[Node, str]], op: int, rhs: Optional[Union[Node, str]], flat: bool = ...): ... + def __init__( + self, + lhs: Optional[Union[Node, str]], + op: int, + rhs: Optional[Union[Node, str]], + flat: bool = ..., + ): ... def __sql__(self, ctx: Context) -> Context: ... class StringExpression(Expression): @@ -504,7 +543,13 @@ def Check(constraint: str) -> SQL: ... class Function(ColumnBase): name: str arguments: tuple - def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: Optional[_TConvFunc] = ...): ... + def __init__( + self, + name: str, + arguments: tuple, + coerce: bool = ..., + python_value: Optional[_TConvFunc] = ..., + ): ... def __getattr__(self, attr: str) -> Callable[..., Function]: ... # TODO (dargueta): `where` is an educated guess def filter(self, where: Optional[Expression] = ...) -> Function: ... @@ -586,12 +631,23 @@ class ForUpdate(Node): def __init__( self, expr: Union[Literal[True], str], - of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...]]] = ..., + of: Optional[ + Union[ + _TModelOrTable, + List[_TModelOrTable], + Set[_TModelOrTable], + Tuple[_TModelOrTable, ...], + ] + ] = ..., nowait: Optional[bool] = ..., ): ... def __sql__(self, ctx: Context) -> Context: ... -def Case(predicate: Optional[Node], expression_tuples: Iterable[Tuple[Expression, Any]], default: object = ...) -> NodeList: ... +def Case( + predicate: Optional[Node], + expression_tuples: Iterable[Tuple[Expression, Any]], + default: object = ..., +) -> NodeList: ... class NodeList(ColumnBase): # TODO (dargueta): Narrow this type @@ -793,13 +849,21 @@ class Select(SelectBase): def distinct(self, *columns: Field) -> Select: ... def window(self, *windows: Window) -> Select: ... def for_update( - self, for_update: bool = ..., of: Optional[Union[Table, Iterable[Table]]] = ..., nowait: Optional[bool] = ... + self, + for_update: bool = ..., + of: Optional[Union[Table, Iterable[Table]]] = ..., + nowait: Optional[bool] = ..., ) -> Select: ... def lateral(self, lateral: bool = ...) -> Select: ... class _WriteQuery(Query): table: Table - def __init__(self, table: Table, returning: Optional[Iterable[Union[Type[Model], Field]]] = ..., **kwargs: object): ... + def __init__( + self, + table: Table, + returning: Optional[Iterable[Union[Type[Model], Field]]] = ..., + **kwargs: object, + ): ... def returning(self, *returning: Union[Type[Model], Field]) -> _WriteQuery: ... def apply_returning(self, ctx: Context) -> Context: ... def execute_returning(self, database: Database) -> CursorWrapper: ... @@ -976,8 +1040,18 @@ class Database(_callable_context_manager): def is_connection_usable(self) -> bool: ... def connection(self) -> __IConnection: ... def cursor(self, commit: Optional[bool] = ...) -> __ICursor: ... - def execute_sql(self, sql: str, params: Optional[tuple] = ..., commit: Union[bool, _TSentinel] = ...) -> __ICursor: ... - def execute(self, query: Query, commit: Union[bool, _TSentinel] = ..., **context_options: object) -> __ICursor: ... + def execute_sql( + self, + sql: str, + params: Optional[tuple] = ..., + commit: Union[bool, _TSentinel] = ..., + ) -> __ICursor: ... + def execute( + self, + query: Query, + commit: Union[bool, _TSentinel] = ..., + **context_options: object, + ) -> __ICursor: ... def get_context_options(self) -> Mapping[str, object]: ... def get_sql_context(self, **context_options: object) -> _TContextClass: ... def conflict_statement(self, on_conflict: OnConflict, query: Query) -> Optional[SQL]: ... @@ -1015,9 +1089,17 @@ class Database(_callable_context_manager): def to_timestamp(self, date_field: str) -> Function: ... def from_timestamp(self, date_field: str) -> Function: ... def random(self) -> Node: ... - def bind(self, models: Iterable[Type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ...) -> None: ... + def bind( + self, + models: Iterable[Type[Model]], + bind_refs: bool = ..., + bind_backrefs: bool = ..., + ) -> None: ... def bind_ctx( - self, models: Iterable[Type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ... + self, + models: Iterable[Type[Model]], + bind_refs: bool = ..., + bind_backrefs: bool = ..., ) -> _BoundModelsContext: ... def get_noop_select(self, ctx: Context) -> Context: ... @@ -1048,7 +1130,11 @@ class SqliteDatabase(Database): **kwargs: object, ) -> None: ... def pragma( - self, key: str, value: Union[str, bool, int] = ..., permanent: bool = ..., schema: Optional[str] = ... + self, + key: str, + value: Union[str, bool, int] = ..., + permanent: bool = ..., + schema: Optional[str] = ..., ) -> object: ... @property def foreign_keys(self) -> Any: ... @@ -1082,7 +1168,12 @@ class SqliteDatabase(Database): def wal_autocheckpoint(self) -> Any: ... @wal_autocheckpoint.setter def wal_autocheckpoint(self, value: object) -> Any: ... - def register_aggregate(self, klass: Type[__IAggregate], name: Optional[str] = ..., num_params: int = ...): ... + def register_aggregate( + self, + klass: Type[__IAggregate], + name: Optional[str] = ..., + num_params: int = ..., + ): ... def aggregate(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... def register_collation(self, fn: Callable, name: Optional[str] = ...) -> None: ... def collation(self, name: Optional[str] = ...) -> Callable[[_TFunc], _TFunc]: ... @@ -1307,7 +1398,7 @@ class Field(ColumnBase): column_name: str = ..., default: Any = ..., primary_key: bool = ..., - constraints: Optional[Iterable[Check, SQL]] = ..., + constraints: Optional[Iterable[Union[Callable[[str], SQL], SQL]]] = ..., sequence: Optional[str] = ..., collation: Optional[str] = ..., unindexed: Optional[bool] = ..., @@ -1335,7 +1426,7 @@ class Field(ColumnBase): class IntegerField(Field): @overload - def adapt(self, value: Union[int, str, float, bool]) -> int: ... + def adapt(self, value: Union[int, str, float, bool]) -> int: ... # type: ignore @overload def adapt(self, value: T) -> T: ... @@ -1354,7 +1445,7 @@ class PrimaryKeyField(AutoField): ... class FloatField(Field): @overload - def adapt(self, value: Union[str, int, float, bool]) -> float: ... + def adapt(self, value: Union[str, int, float, bool]) -> float: ... # type: ignore @overload def adapt(self, value: T) -> T: ... @@ -1378,7 +1469,7 @@ class DecimalField(Field): @overload def db_value(self, value: None) -> None: ... @overload - def db_value(self, value: Union[int, float, decimal.Decimal]) -> decimal.Decimal: ... + def db_value(self, value: Union[int, float, decimal.Decimal]) -> decimal.Decimal: ... # type: ignore @overload def db_value(self, value: T) -> T: ... @overload @@ -1425,7 +1516,11 @@ class BigBitFieldAccessor(FieldAccessor): def __get__(self, instance: None, instance_type: Type[_TModel]) -> Field: ... @overload def __get__(self, instance: _TModel, instance_type: Type[_TModel]) -> BigBitFieldData: ... - def __set__(self, instance: Any, value: Union[memoryview, bytearray, BigBitFieldData, str, bytes]) -> None: ... + def __set__( + self, + instance: Any, + value: Union[memoryview, bytearray, BigBitFieldData, str, bytes], + ) -> None: ... class BigBitField(BlobField): accessor_class: ClassVar[Type[BigBitFieldAccessor]] @@ -1604,7 +1699,14 @@ class ForeignKeyField(Field): class DeferredForeignKey(Field): field_kwargs: Dict[str, object] rel_model_name: str - def __init__(self, rel_model_name: str, *, column_name: Optional[str] = ..., null: Optional[str] = ..., **kwargs: object): ... + def __init__( + self, + rel_model_name: str, + *, + column_name: Optional[str] = ..., + null: Optional[str] = ..., + **kwargs: object, + ): ... def set_model(self, rel_model: Type[Model]) -> None: ... @staticmethod def resolve(model_cls: Type[Model]) -> None: ... @@ -1660,7 +1762,12 @@ class ManyToManyField(MetaField): class VirtualField(MetaField, Generic[_TField]): field_class: Type[_TField] field_instance: Optional[_TField] - def __init__(self, field_class: Optional[Type[_TField]] = ..., *args: object, **kwargs: object): ... + def __init__( + self, + field_class: Optional[Type[_TField]] = ..., + *args: object, + **kwargs: object, + ): ... def db_value(self, value: object) -> Any: ... def python_value(self, value: object) -> Any: ... def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... @@ -1691,7 +1798,12 @@ class CompositeKey(MetaField): class SchemaManager: model: Type[Model] context_options: Dict[str, object] - def __init__(self, model: Type[Model], database: Optional[Database] = ..., **context_options: object): ... + def __init__( + self, + model: Type[Model], + database: Optional[Database] = ..., + **context_options: object, + ): ... @property def database(self) -> Database: ... @database.setter @@ -1803,7 +1915,13 @@ class _BoundModelsContext(_callable_context_manager): database: Database bind_refs: bool bind_backrefs: bool - def __init__(self, models: Iterable[Type[Model]], database, bind_refs: bool, bind_backrefs: bool): ... + def __init__( + self, + models: Iterable[Type[Model]], + database, + bind_refs: bool, + bind_backrefs: bool, + ): ... def __enter__(self) -> Iterable[Type[Model]]: ... def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... @@ -1847,7 +1965,10 @@ class Model(Node, metaclass=ModelBase): def bulk_create(cls, model_list: Iterable[Type[Model]], batch_size: Optional[int] = ...) -> None: ... @classmethod def bulk_update( - cls, model_list: Iterable[Type[Model]], fields: Iterable[Union[str, Field]], batch_size: Optional[int] = ... + cls, + model_list: Iterable[Type[Model]], + fields: Iterable[Union[str, Field]], + batch_size: Optional[int] = ..., ) -> int: ... @classmethod def noop(cls) -> NoopModelSelect: ... @@ -1868,7 +1989,11 @@ class Model(Node, metaclass=ModelBase): @classmethod def filter(cls, *dq_nodes: DQ, **filters: Any) -> SelectQuery: ... def get_id(self) -> Any: ... - def save(self, force_insert: bool = ..., only: Optional[Iterable[Union[str, Field]]] = ...) -> Union[Literal[False], int]: ... + def save( + self, + force_insert: bool = ..., + only: Optional[Iterable[Union[str, Field]]] = ..., + ) -> Union[Literal[False], int]: ... def is_dirty(self) -> bool: ... @property def dirty_fields(self) -> List[Field]: ... @@ -1968,7 +2093,12 @@ class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery, Generic[_TM class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): model: Type[_TModel] - def __init__(self, model: Type[_TModel], fields_or_models: Iterable[_TFieldOrModel], is_default: bool = ...): ... + def __init__( + self, + model: Type[_TModel], + fields_or_models: Iterable[_TFieldOrModel], + is_default: bool = ..., + ): ... def clone(self) -> ModelSelect: ... def select(self, *fields_or_models: _TFieldOrModel) -> ModelSelect: ... # type: ignore def switch(self, ctx: Optional[Type[Model]] = ...) -> ModelSelect: ... @@ -1989,7 +2119,11 @@ class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): attr: Optional[str] = ..., ) -> ModelSelect: ... def ensure_join( - self, lm: Type[Model], rm: Type[Model], on: Optional[Union[Column, Expression, Field]] = ..., **join_kwargs: Any + self, + lm: Type[Model], + rm: Type[Model], + on: Optional[Union[Column, Expression, Field]] = ..., + **join_kwargs: Any, ) -> ModelSelect: ... # TODO (dargueta): 85% sure about the return value def convert_dict_to_node(self, qdict: Mapping[str, object]) -> Tuple[List[Expression], List[Field]]: ... @@ -2017,9 +2151,18 @@ class ModelDelete(_ModelWriteQueryHelper, Delete): ... class ManyToManyQuery(ModelSelect): def __init__( - self, instance: Model, accessor: ManyToManyFieldAccessor, rel: _TFieldOrModel, *args: object, **kwargs: object + self, + instance: Model, + accessor: ManyToManyFieldAccessor, + rel: _TFieldOrModel, + *args: object, + **kwargs: object, ): ... - def add(self, value: Union[SelectQuery, Type[Model], Iterable[str]], clear_existing: bool = ...) -> None: ... + def add( + self, + value: Union[SelectQuery, Type[Model], Iterable[str]], + clear_existing: bool = ..., + ) -> None: ... def remove(self, value: Union[SelectQuery, Type[Model], Iterable[str]]) -> Optional[int]: ... def clear(self) -> int: ... @@ -2031,7 +2174,7 @@ class BaseModelCursorWrapper(DictCursorWrapper, Generic[_TModel]): model: Type[_TModel] select: Sequence[str] def __init__(self, cursor: __ICursor, model: Type[_TModel], columns: Optional[Sequence[str]]): ... - def process_row(self, row: tuple) -> Mapping[str, object]: ... # type: ignore + def process_row(self, row: tuple) -> Mapping[str, object]: ... # type: ignore class ModelDictCursorWrapper(BaseModelCursorWrapper[_TModel]): def process_row(self, row: tuple) -> Dict[str, Any]: ... @@ -2085,7 +2228,11 @@ class __PrefetchQuery(NamedTuple): class PrefetchQuery(__PrefetchQuery): def populate_instance(self, instance: Model, id_map: Mapping[Tuple[object, object], object]): ... - def store_instance(self, instance: Model, id_map: MutableMapping[Tuple[object, object], List[Model]]) -> None: ... + def store_instance( + self, + instance: Model, + id_map: MutableMapping[Tuple[object, object], List[Model]], + ) -> None: ... def prefetch_add_subquery(sq: Query, subqueries: Iterable[_TSubquery]) -> List[PrefetchQuery]: ... def prefetch(sq: Query, *subqueries: _TSubquery) -> List[object]: ... From fa81e4863dd31ac2f422fb9fdb41fc876778bc25 Mon Sep 17 00:00:00 2001 From: Diego Argueta Date: Wed, 17 Nov 2021 13:03:38 -0800 Subject: [PATCH 06/22] A few more tweaks to int/float confusion --- third_party/2and3/peewee.pyi | 222 ++++++----------------------------- 1 file changed, 38 insertions(+), 184 deletions(-) diff --git a/third_party/2and3/peewee.pyi b/third_party/2and3/peewee.pyi index ca5fcfc9c2a9..0d2adef2686a 100644 --- a/third_party/2and3/peewee.pyi +++ b/third_party/2and3/peewee.pyi @@ -156,13 +156,9 @@ class __State(NamedTuple): settings: Dict[str, Any] class State(__State): + subquery: object # TODO (dargueta) def __new__(cls, scope: int = ..., parentheses: bool = ..., **kwargs: object) -> State: ... - def __call__( - self, - scope: Optional[int] = ..., - parentheses: Optional[int] = ..., - **kwargs: object, - ) -> State: ... + def __call__(self, scope: Optional[int] = ..., parentheses: Optional[int] = ..., **kwargs: object) -> State: ... def __getattr__(self, attr_name: str) -> Any: ... class Context: @@ -177,9 +173,7 @@ class Context: @property def parentheses(self) -> bool: ... @property - def subquery( - self, - ) -> Any: ... # TODO (dargueta): Figure out type of "self.state.subquery" + def subquery(self) -> Any: ... # TODO (dargueta): Figure out type of "self.state.subquery" def __call__(self, **overrides: object) -> Context: ... def scope_normal(self) -> ContextManager[Context]: ... def scope_source(self) -> ContextManager[Context]: ... @@ -193,12 +187,7 @@ class Context: # TODO (dargueta): Is this right? def sql(self, obj: object) -> Context: ... def literal(self, keyword: str) -> Context: ... - def value( - self, - value: object, - converter: Optional[_TConvFunc] = ..., - add_param: bool = ..., - ) -> Context: ... + def value(self, value: object, converter: Optional[_TConvFunc] = ..., add_param: bool = ...) -> Context: ... def __sql__(self, ctx: Context) -> Context: ... def parse(self, node: Node) -> Tuple[str, Optional[tuple]]: ... def query(self) -> Tuple[str, Optional[tuple]]: ... @@ -307,19 +296,11 @@ class Table(_HashableSource, BaseTable): @overload def insert(self, insert: Optional[Select], columns: Sequence[Union[str, Field, Column]]) -> Insert: ... @overload - def insert( - self, - insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], - **kwargs: object, - ): ... + def insert(self, insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], **kwargs: object): ... @overload def replace(self, insert: Optional[Select], columns: Sequence[Union[str, Field, Column]]) -> Insert: ... @overload - def replace( - self, - insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], - **kwargs: object, - ): ... + def replace(self, insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], **kwargs: object): ... def update(self, update: Optional[Mapping[str, object]] = ..., **kwargs: object) -> Update: ... def delete(self) -> Delete: ... def __sql__(self, ctx: Context) -> Context: ... @@ -328,14 +309,7 @@ class Join(BaseTable): lhs: Any # TODO (dargueta) rhs: Any # TODO (dargueta) join_type: int - def __init__( - self, - lhs, - rhs, - join_type: int = ..., - on: Optional[Expression] = ..., - alias: Optional[str] = ..., - ): ... + def __init__(self, lhs, rhs, join_type: int = ..., on: Optional[Expression] = ..., alias: Optional[str] = ...): ... def on(self, predicate: Expression) -> Join: ... def __sql__(self, ctx: Context) -> Context: ... @@ -491,13 +465,7 @@ class Ordering(WrappedNode): direction: str collation: Optional[str] nulls: Optional[str] - def __init__( - self, - node: Node, - direction: str, - collation: Optional[str] = ..., - nulls: Optional[str] = ..., - ): ... + def __init__(self, node: Node, direction: str, collation: Optional[str] = ..., nulls: Optional[str] = ...): ... def collate(self, collation: Optional[str] = ...) -> Ordering: ... def __sql__(self, ctx: Context) -> Context: ... @@ -512,13 +480,7 @@ class Expression(ColumnBase): op: int rhs: Optional[Union[Node, str]] flat: bool - def __init__( - self, - lhs: Optional[Union[Node, str]], - op: int, - rhs: Optional[Union[Node, str]], - flat: bool = ..., - ): ... + def __init__(self, lhs: Optional[Union[Node, str]], op: int, rhs: Optional[Union[Node, str]], flat: bool = ...): ... def __sql__(self, ctx: Context) -> Context: ... class StringExpression(Expression): @@ -543,13 +505,7 @@ def Check(constraint: str) -> SQL: ... class Function(ColumnBase): name: str arguments: tuple - def __init__( - self, - name: str, - arguments: tuple, - coerce: bool = ..., - python_value: Optional[_TConvFunc] = ..., - ): ... + def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: Optional[_TConvFunc] = ...): ... def __getattr__(self, attr: str) -> Callable[..., Function]: ... # TODO (dargueta): `where` is an educated guess def filter(self, where: Optional[Expression] = ...) -> Function: ... @@ -631,23 +587,12 @@ class ForUpdate(Node): def __init__( self, expr: Union[Literal[True], str], - of: Optional[ - Union[ - _TModelOrTable, - List[_TModelOrTable], - Set[_TModelOrTable], - Tuple[_TModelOrTable, ...], - ] - ] = ..., + of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...],]] = ..., nowait: Optional[bool] = ..., ): ... def __sql__(self, ctx: Context) -> Context: ... -def Case( - predicate: Optional[Node], - expression_tuples: Iterable[Tuple[Expression, Any]], - default: object = ..., -) -> NodeList: ... +def Case(predicate: Optional[Node], expression_tuples: Iterable[Tuple[Expression, Any]], default: object = ...) -> NodeList: ... class NodeList(ColumnBase): # TODO (dargueta): Narrow this type @@ -849,21 +794,13 @@ class Select(SelectBase): def distinct(self, *columns: Field) -> Select: ... def window(self, *windows: Window) -> Select: ... def for_update( - self, - for_update: bool = ..., - of: Optional[Union[Table, Iterable[Table]]] = ..., - nowait: Optional[bool] = ..., + self, for_update: bool = ..., of: Optional[Union[Table, Iterable[Table]]] = ..., nowait: Optional[bool] = ... ) -> Select: ... def lateral(self, lateral: bool = ...) -> Select: ... class _WriteQuery(Query): table: Table - def __init__( - self, - table: Table, - returning: Optional[Iterable[Union[Type[Model], Field]]] = ..., - **kwargs: object, - ): ... + def __init__(self, table: Table, returning: Optional[Iterable[Union[Type[Model], Field]]] = ..., **kwargs: object): ... def returning(self, *returning: Union[Type[Model], Field]) -> _WriteQuery: ... def apply_returning(self, ctx: Context) -> Context: ... def execute_returning(self, database: Database) -> CursorWrapper: ... @@ -1040,18 +977,8 @@ class Database(_callable_context_manager): def is_connection_usable(self) -> bool: ... def connection(self) -> __IConnection: ... def cursor(self, commit: Optional[bool] = ...) -> __ICursor: ... - def execute_sql( - self, - sql: str, - params: Optional[tuple] = ..., - commit: Union[bool, _TSentinel] = ..., - ) -> __ICursor: ... - def execute( - self, - query: Query, - commit: Union[bool, _TSentinel] = ..., - **context_options: object, - ) -> __ICursor: ... + def execute_sql(self, sql: str, params: Optional[tuple] = ..., commit: Union[bool, _TSentinel] = ...) -> __ICursor: ... + def execute(self, query: Query, commit: Union[bool, _TSentinel] = ..., **context_options: object) -> __ICursor: ... def get_context_options(self) -> Mapping[str, object]: ... def get_sql_context(self, **context_options: object) -> _TContextClass: ... def conflict_statement(self, on_conflict: OnConflict, query: Query) -> Optional[SQL]: ... @@ -1089,17 +1016,9 @@ class Database(_callable_context_manager): def to_timestamp(self, date_field: str) -> Function: ... def from_timestamp(self, date_field: str) -> Function: ... def random(self) -> Node: ... - def bind( - self, - models: Iterable[Type[Model]], - bind_refs: bool = ..., - bind_backrefs: bool = ..., - ) -> None: ... + def bind(self, models: Iterable[Type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ...) -> None: ... def bind_ctx( - self, - models: Iterable[Type[Model]], - bind_refs: bool = ..., - bind_backrefs: bool = ..., + self, models: Iterable[Type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ... ) -> _BoundModelsContext: ... def get_noop_select(self, ctx: Context) -> Context: ... @@ -1130,11 +1049,7 @@ class SqliteDatabase(Database): **kwargs: object, ) -> None: ... def pragma( - self, - key: str, - value: Union[str, bool, int] = ..., - permanent: bool = ..., - schema: Optional[str] = ..., + self, key: str, value: Union[str, bool, int] = ..., permanent: bool = ..., schema: Optional[str] = ... ) -> object: ... @property def foreign_keys(self) -> Any: ... @@ -1168,12 +1083,7 @@ class SqliteDatabase(Database): def wal_autocheckpoint(self) -> Any: ... @wal_autocheckpoint.setter def wal_autocheckpoint(self, value: object) -> Any: ... - def register_aggregate( - self, - klass: Type[__IAggregate], - name: Optional[str] = ..., - num_params: int = ..., - ): ... + def register_aggregate(self, klass: Type[__IAggregate], name: Optional[str] = ..., num_params: int = ...): ... def aggregate(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... def register_collation(self, fn: Callable, name: Optional[str] = ...) -> None: ... def collation(self, name: Optional[str] = ...) -> Callable[[_TFunc], _TFunc]: ... @@ -1426,7 +1336,7 @@ class Field(ColumnBase): class IntegerField(Field): @overload - def adapt(self, value: Union[int, str, float, bool]) -> int: ... # type: ignore + def adapt(self, value: Union[str, float, bool]) -> int: ... # type: ignore @overload def adapt(self, value: T) -> T: ... @@ -1445,7 +1355,7 @@ class PrimaryKeyField(AutoField): ... class FloatField(Field): @overload - def adapt(self, value: Union[str, int, float, bool]) -> float: ... # type: ignore + def adapt(self, value: Union[str, float, bool]) -> float: ... # type: ignore @overload def adapt(self, value: T) -> T: ... @@ -1469,13 +1379,13 @@ class DecimalField(Field): @overload def db_value(self, value: None) -> None: ... @overload - def db_value(self, value: Union[int, float, decimal.Decimal]) -> decimal.Decimal: ... # type: ignore + def db_value(self, value: Union[float, decimal.Decimal]) -> decimal.Decimal: ... # type: ignore @overload def db_value(self, value: T) -> T: ... @overload def python_value(self, value: None) -> None: ... @overload - def python_value(self, value: Union[int, str, float, decimal.Decimal]) -> decimal.Decimal: ... + def python_value(self, value: Union[str, float, decimal.Decimal]) -> decimal.Decimal: ... class _StringField(Field): def adapt(self, value: AnyStr) -> str: ... @@ -1516,11 +1426,7 @@ class BigBitFieldAccessor(FieldAccessor): def __get__(self, instance: None, instance_type: Type[_TModel]) -> Field: ... @overload def __get__(self, instance: _TModel, instance_type: Type[_TModel]) -> BigBitFieldData: ... - def __set__( - self, - instance: Any, - value: Union[memoryview, bytearray, BigBitFieldData, str, bytes], - ) -> None: ... + def __set__(self, instance: Any, value: Union[memoryview, bytearray, BigBitFieldData, str, bytes]) -> None: ... class BigBitField(BlobField): accessor_class: ClassVar[Type[BigBitFieldAccessor]] @@ -1632,7 +1538,7 @@ class TimestampField(BigIntegerField): @property def minute(self) -> int: ... @property - def second(self) -> int: ... # TODO (dargueta) Float? + def second(self) -> float: ... # TODO (dargueta) Int? class IPField(BigIntegerField): @overload @@ -1654,7 +1560,7 @@ class BareField(Field): def ddl_datatype(self, ctx: Context) -> None: ... class ForeignKeyField(Field): - accessor_class = ForeignKeyAccessor + accessor_class: ClassVar[Type[ForeignKeyAccessor]] rel_model: Union[Type[Model], Literal["self"]] rel_field: Field declared_backref: Optional[str] @@ -1699,14 +1605,7 @@ class ForeignKeyField(Field): class DeferredForeignKey(Field): field_kwargs: Dict[str, object] rel_model_name: str - def __init__( - self, - rel_model_name: str, - *, - column_name: Optional[str] = ..., - null: Optional[str] = ..., - **kwargs: object, - ): ... + def __init__(self, rel_model_name: str, *, column_name: Optional[str] = ..., null: Optional[str] = ..., **kwargs: object): ... def set_model(self, rel_model: Type[Model]) -> None: ... @staticmethod def resolve(model_cls: Type[Model]) -> None: ... @@ -1762,12 +1661,7 @@ class ManyToManyField(MetaField): class VirtualField(MetaField, Generic[_TField]): field_class: Type[_TField] field_instance: Optional[_TField] - def __init__( - self, - field_class: Optional[Type[_TField]] = ..., - *args: object, - **kwargs: object, - ): ... + def __init__(self, field_class: Optional[Type[_TField]] = ..., *args: object, **kwargs: object): ... def db_value(self, value: object) -> Any: ... def python_value(self, value: object) -> Any: ... def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... @@ -1798,12 +1692,7 @@ class CompositeKey(MetaField): class SchemaManager: model: Type[Model] context_options: Dict[str, object] - def __init__( - self, - model: Type[Model], - database: Optional[Database] = ..., - **context_options: object, - ): ... + def __init__(self, model: Type[Model], database: Optional[Database] = ..., **context_options: object): ... @property def database(self) -> Database: ... @database.setter @@ -1915,13 +1804,7 @@ class _BoundModelsContext(_callable_context_manager): database: Database bind_refs: bool bind_backrefs: bool - def __init__( - self, - models: Iterable[Type[Model]], - database, - bind_refs: bool, - bind_backrefs: bool, - ): ... + def __init__(self, models: Iterable[Type[Model]], database, bind_refs: bool, bind_backrefs: bool): ... def __enter__(self) -> Iterable[Type[Model]]: ... def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... @@ -1965,10 +1848,7 @@ class Model(Node, metaclass=ModelBase): def bulk_create(cls, model_list: Iterable[Type[Model]], batch_size: Optional[int] = ...) -> None: ... @classmethod def bulk_update( - cls, - model_list: Iterable[Type[Model]], - fields: Iterable[Union[str, Field]], - batch_size: Optional[int] = ..., + cls, model_list: Iterable[Type[Model]], fields: Iterable[Union[str, Field]], batch_size: Optional[int] = ... ) -> int: ... @classmethod def noop(cls) -> NoopModelSelect: ... @@ -1989,11 +1869,7 @@ class Model(Node, metaclass=ModelBase): @classmethod def filter(cls, *dq_nodes: DQ, **filters: Any) -> SelectQuery: ... def get_id(self) -> Any: ... - def save( - self, - force_insert: bool = ..., - only: Optional[Iterable[Union[str, Field]]] = ..., - ) -> Union[Literal[False], int]: ... + def save(self, force_insert: bool = ..., only: Optional[Iterable[Union[str, Field]]] = ...) -> Union[Literal[False], int]: ... def is_dirty(self) -> bool: ... @property def dirty_fields(self) -> List[Field]: ... @@ -2093,12 +1969,7 @@ class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery, Generic[_TM class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): model: Type[_TModel] - def __init__( - self, - model: Type[_TModel], - fields_or_models: Iterable[_TFieldOrModel], - is_default: bool = ..., - ): ... + def __init__(self, model: Type[_TModel], fields_or_models: Iterable[_TFieldOrModel], is_default: bool = ...): ... def clone(self) -> ModelSelect: ... def select(self, *fields_or_models: _TFieldOrModel) -> ModelSelect: ... # type: ignore def switch(self, ctx: Optional[Type[Model]] = ...) -> ModelSelect: ... @@ -2119,11 +1990,7 @@ class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): attr: Optional[str] = ..., ) -> ModelSelect: ... def ensure_join( - self, - lm: Type[Model], - rm: Type[Model], - on: Optional[Union[Column, Expression, Field]] = ..., - **join_kwargs: Any, + self, lm: Type[Model], rm: Type[Model], on: Optional[Union[Column, Expression, Field]] = ..., **join_kwargs: Any ) -> ModelSelect: ... # TODO (dargueta): 85% sure about the return value def convert_dict_to_node(self, qdict: Mapping[str, object]) -> Tuple[List[Expression], List[Field]]: ... @@ -2151,18 +2018,9 @@ class ModelDelete(_ModelWriteQueryHelper, Delete): ... class ManyToManyQuery(ModelSelect): def __init__( - self, - instance: Model, - accessor: ManyToManyFieldAccessor, - rel: _TFieldOrModel, - *args: object, - **kwargs: object, + self, instance: Model, accessor: ManyToManyFieldAccessor, rel: _TFieldOrModel, *args: object, **kwargs: object ): ... - def add( - self, - value: Union[SelectQuery, Type[Model], Iterable[str]], - clear_existing: bool = ..., - ) -> None: ... + def add(self, value: Union[SelectQuery, Type[Model], Iterable[str]], clear_existing: bool = ...) -> None: ... def remove(self, value: Union[SelectQuery, Type[Model], Iterable[str]]) -> Optional[int]: ... def clear(self) -> int: ... @@ -2228,11 +2086,7 @@ class __PrefetchQuery(NamedTuple): class PrefetchQuery(__PrefetchQuery): def populate_instance(self, instance: Model, id_map: Mapping[Tuple[object, object], object]): ... - def store_instance( - self, - instance: Model, - id_map: MutableMapping[Tuple[object, object], List[Model]], - ) -> None: ... + def store_instance(self, instance: Model, id_map: MutableMapping[Tuple[object, object], List[Model]]) -> None: ... def prefetch_add_subquery(sq: Query, subqueries: Iterable[_TSubquery]) -> List[PrefetchQuery]: ... def prefetch(sq: Query, *subqueries: _TSubquery) -> List[object]: ... From 9175536a5c38d5c52875df940083913d8b2e1d3c Mon Sep 17 00:00:00 2001 From: Diego Argueta <620513-dargueta@users.noreply.github.com> Date: Sun, 20 Feb 2022 08:48:19 -0800 Subject: [PATCH 07/22] Hopefully the last of the linter fixes, thanks to @pylipp --- third_party/2and3/peewee.pyi | 95 ++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/third_party/2and3/peewee.pyi b/third_party/2and3/peewee.pyi index 0d2adef2686a..934083167b46 100644 --- a/third_party/2and3/peewee.pyi +++ b/third_party/2and3/peewee.pyi @@ -34,7 +34,7 @@ from typing import ( ) from typing_extensions import Literal, Protocol -T = TypeVar("T") +_T = TypeVar("_T") _TModel = TypeVar("_TModel", bound="Model") _TConvFunc = Callable[[Any], Any] _TFunc = TypeVar("_TFunc", bound=Callable) @@ -213,16 +213,16 @@ class _DynamicColumn: @overload def __get__(self, instance: None, instance_type: type) -> _DynamicColumn: ... @overload - def __get__(self, instance: T, instance_type: Type[T]) -> ColumnFactory: ... + def __get__(self, instance: _T, instance_type: Type[_T]) -> ColumnFactory: ... class _ExplicitColumn: @overload def __get__(self, instance: None, instance_type: type) -> _ExplicitColumn: ... @overload - def __get__(self, instance: T, instance_type: Type[T]) -> NoReturn: ... + def __get__(self, instance: _T, instance_type: Type[_T]) -> NoReturn: ... class _SupportsAlias(Protocol): - def alias(self: T, name: str) -> T: ... + def alias(self: _T, name: str) -> _T: ... class Source(_SupportsAlias, Node): c: ClassVar[_DynamicColumn] @@ -420,7 +420,7 @@ class _DynamicEntity: @overload def __get__(self, instance: None, instance_type: type) -> _DynamicEntity: ... @overload - def __get__(self, instance: T, instance_type: Type[T]) -> EntityFactory: ... + def __get__(self, instance: _T, instance_type: Type[_T]) -> EntityFactory: ... class Alias(WrappedNode): c: ClassVar[_DynamicEntity] @@ -587,7 +587,7 @@ class ForUpdate(Node): def __init__( self, expr: Union[Literal[True], str], - of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...],]] = ..., + of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...]]] = ..., nowait: Optional[bool] = ..., ): ... def __sql__(self, ctx: Context) -> Context: ... @@ -632,7 +632,7 @@ def qualify_names(node: Expression) -> Expression: ... @overload def qualify_names(node: ColumnBase) -> QualifiedNames: ... @overload -def qualify_names(node: T) -> T: ... +def qualify_names(node: _T) -> _T: ... class OnConflict(Node): @overload @@ -1001,7 +1001,7 @@ class Database(_callable_context_manager): def begin(self) -> None: ... def commit(self) -> None: ... def rollback(self) -> None: ... - def batch_commit(self, it: Iterable[T], n: int) -> Iterator[T]: ... + def batch_commit(self, it: Iterable[_T], n: int) -> Iterator[_T]: ... def table_exists(self, table_name: str, schema: Optional[str] = ...) -> str: ... def get_tables(self, schema: Optional[str] = ...) -> List[str]: ... def get_indexes(self, table: str, schema: Optional[str] = ...) -> List[IndexMetadata]: ... @@ -1193,24 +1193,24 @@ class _savepoint(_callable_context_manager): def __enter__(self) -> _savepoint: ... def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... -class CursorWrapper(Generic[T]): +class CursorWrapper(Generic[_T]): cursor: __ICursor count: int index: int initialized: bool populated: bool - row_cache: List[T] + row_cache: List[_T] def __init__(self, cursor: __ICursor): ... - def __iter__(self) -> Union[ResultIterator[T], Iterator[T]]: ... + def __iter__(self) -> Union[ResultIterator[_T], Iterator[_T]]: ... @overload - def __getitem__(self, item: int) -> T: ... + def __getitem__(self, item: int) -> _T: ... @overload - def __getitem__(self, item: slice) -> List[T]: ... + def __getitem__(self, item: slice) -> List[_T]: ... def __len__(self) -> int: ... def initialize(self) -> None: ... - def iterate(self, cache: bool = ...) -> T: ... - def process_row(self, row: tuple) -> T: ... - def iterator(self) -> Iterator[T]: ... + def iterate(self, cache: bool = ...) -> _T: ... + def process_row(self, row: tuple) -> _T: ... + def iterator(self) -> Iterator[_T]: ... def fill_cache(self, n: int = ...) -> None: ... class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ... @@ -1219,16 +1219,16 @@ class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ... class NamedTupleCursorWrapper(CursorWrapper[tuple]): tuple_class: Type[tuple] -class ObjectCursorWrapper(DictCursorWrapper[T]): - constructor: Callable[..., T] - def __init__(self, cursor: __ICursor, constructor: Callable[..., T]): ... - def process_row(self, row: tuple) -> T: ... # type: ignore +class ObjectCursorWrapper(DictCursorWrapper[_T]): + constructor: Callable[..., _T] + def __init__(self, cursor: __ICursor, constructor: Callable[..., _T]): ... + def process_row(self, row: tuple) -> _T: ... # type: ignore -class ResultIterator(Generic[T]): - cursor_wrapper: CursorWrapper[T] +class ResultIterator(Generic[_T]): + cursor_wrapper: CursorWrapper[_T] index: int - def __init__(self, cursor_wrapper: CursorWrapper[T]): ... - def __iter__(self) -> Iterator[T]: ... + def __init__(self, cursor_wrapper: CursorWrapper[_T]): ... + def __iter__(self) -> Iterator[_T]: ... # FIELDS @@ -1240,7 +1240,7 @@ class FieldAccessor: @overload def __get__(self, instance: None, instance_type: type) -> Field: ... @overload - def __get__(self, instance: T, instance_type: Type[T]) -> Any: ... + def __get__(self, instance: _T, instance_type: Type[_T]) -> Any: ... class ForeignKeyAccessor(FieldAccessor): model: Type[Model] @@ -1324,9 +1324,9 @@ class Field(ColumnBase): def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... @property def column(self) -> Column: ... - def adapt(self, value: T) -> T: ... - def db_value(self, value: T) -> T: ... - def python_value(self, value: T) -> T: ... + def adapt(self, value: _T) -> _T: ... + def db_value(self, value: _T) -> _T: ... + def python_value(self, value: _T) -> _T: ... def to_value(self, value: Any) -> Value: ... def get_sort_key(self, ctx: Context) -> Tuple[int, int]: ... def __sql__(self, ctx: Context) -> Context: ... @@ -1338,7 +1338,7 @@ class IntegerField(Field): @overload def adapt(self, value: Union[str, float, bool]) -> int: ... # type: ignore @overload - def adapt(self, value: T) -> T: ... + def adapt(self, value: _T) -> _T: ... class BigIntegerField(IntegerField): ... class SmallIntegerField(IntegerField): ... @@ -1357,7 +1357,7 @@ class FloatField(Field): @overload def adapt(self, value: Union[str, float, bool]) -> float: ... # type: ignore @overload - def adapt(self, value: T) -> T: ... + def adapt(self, value: _T) -> _T: ... class DoubleField(FloatField): ... @@ -1381,7 +1381,7 @@ class DecimalField(Field): @overload def db_value(self, value: Union[float, decimal.Decimal]) -> decimal.Decimal: ... # type: ignore @overload - def db_value(self, value: T) -> T: ... + def db_value(self, value: _T) -> _T: ... @overload def python_value(self, value: None) -> None: ... @overload @@ -1404,7 +1404,7 @@ class BlobField(Field): @overload def db_value(self, value: Union[str, bytes]) -> bytearray: ... @overload - def db_value(self, value: T) -> T: ... + def db_value(self, value: _T) -> _T: ... class BitField(BitwiseMixin, BigIntegerField): def __init__(self, *args: object, default: Optional[int] = ..., **kwargs: object): ... @@ -1434,13 +1434,13 @@ class BigBitField(BlobField): @overload def db_value(self, value: None) -> None: ... @overload - def db_value(self, value: T) -> bytes: ... + def db_value(self, value: _T) -> bytes: ... class UUIDField(Field): @overload def db_value(self, value: AnyStr) -> str: ... @overload - def db_value(self, value: T) -> T: ... + def db_value(self, value: _T) -> _T: ... @overload def python_value(self, value: Union[uuid.UUID, AnyStr]) -> uuid.UUID: ... @overload @@ -1458,7 +1458,7 @@ class BinaryUUIDField(BlobField): def format_date_time(value: str, formats: Iterable[str], post_process: Optional[_TConvFunc] = ...) -> str: ... @overload -def simple_date_time(value: T) -> T: ... +def simple_date_time(value: _T) -> _T: ... class _BaseFormattedField(Field): # TODO (dargueta): This is a class variable that can be overridden for instances @@ -1478,7 +1478,7 @@ class DateTimeField(_BaseFormattedField): def minute(self) -> int: ... @property def second(self) -> int: ... - def adapt(self, value: T) -> T: ... + def adapt(self, value: _T) -> _T: ... def to_timestamp(self) -> Function: ... def truncate(self, part: str) -> Function: ... @@ -1492,7 +1492,7 @@ class DateField(_BaseFormattedField): @overload def adapt(self, value: datetime.datetime) -> datetime.date: ... @overload - def adapt(self, value: T) -> T: ... + def adapt(self, value: _T) -> _T: ... def to_timestamp(self) -> Function: ... def truncate(self, part: str) -> Function: ... @@ -1500,7 +1500,7 @@ class TimeField(_BaseFormattedField): @overload def adapt(self, value: Union[datetime.datetime, datetime.timedelta]) -> datetime.time: ... @overload - def adapt(self, value: T) -> T: ... + def adapt(self, value: _T) -> _T: ... @property def hour(self) -> int: ... @property @@ -1525,7 +1525,7 @@ class TimestampField(BigIntegerField): @overload def python_value(self, value: Union[int, float]) -> datetime.datetime: ... @overload - def python_value(self, value: T) -> T: ... + def python_value(self, value: _T) -> _T: ... def from_timestamp(self) -> float: ... @property def year(self) -> int: ... @@ -1632,12 +1632,12 @@ class ManyToManyFieldAccessor(FieldAccessor): dest_fk: ForeignKeyField def __init__(self, model: Type[Model], field: ForeignKeyField, name: str): ... @overload - def __get__(self, instance: None, instance_type: Type[T] = ..., force_query: bool = ...) -> Field: ... + def __get__(self, instance: None, instance_type: Type[_T] = ..., force_query: bool = ...) -> Field: ... @overload def __get__( - self, instance: T, instance_type: Type[T] = ..., force_query: bool = ... + self, instance: _T, instance_type: Type[_T] = ..., force_query: bool = ... ) -> Union[List[str], ManyToManyQuery]: ... - def __set__(self, instance: T, value) -> None: ... + def __set__(self, instance: _T, value) -> None: ... class ManyToManyField(MetaField): accessor_class: ClassVar[Type[ManyToManyFieldAccessor]] @@ -1679,7 +1679,7 @@ class CompositeKey(MetaField): @overload def __get__(self, instance: None, instance_type: type) -> CompositeKey: ... @overload - def __get__(self, instance: T, instance_type: Type[T]) -> tuple: ... + def __get__(self, instance: _T, instance_type: Type[_T]) -> tuple: ... def __set__(self, instance: Model, value: Union[list, tuple]) -> None: ... def __eq__(self, other: Expression) -> Expression: ... def __ne__(self, other: Expression) -> Expression: ... @@ -1843,7 +1843,7 @@ class Model(Node, metaclass=ModelBase): @classmethod def delete(cls) -> ModelDelete: ... @classmethod - def create(cls: Type[T], **query) -> T: ... + def create(cls: Type[_T], **query) -> _T: ... @classmethod def bulk_create(cls, model_list: Iterable[Type[Model]], batch_size: Optional[int] = ...) -> None: ... @classmethod @@ -1874,7 +1874,7 @@ class Model(Node, metaclass=ModelBase): @property def dirty_fields(self) -> List[Field]: ... def dependencies(self, search_nullable: bool = ...) -> Iterator[Tuple[Union[bool, Node], ForeignKeyField]]: ... - def delete_instance(self: T, recursive: bool = ..., delete_nullable: bool = ...) -> T: ... + def delete_instance(self: _T, recursive: bool = ..., delete_nullable: bool = ...) -> _T: ... def __hash__(self) -> int: ... def __eq__(self, other: object) -> bool: ... def __ne__(self, other: object) -> bool: ... @@ -1898,8 +1898,7 @@ class Model(Node, metaclass=ModelBase): @classmethod def truncate_table(cls, **options: object) -> None: ... @classmethod - def index(cls, *fields: Union[Field, Node, str], **kwargs: object) -> ModelIndex: - return ModelIndex(cls, fields, **kwargs) + def index(cls, *fields: Union[Field, Node, str], **kwargs: object) -> ModelIndex: ... @classmethod def add_index(cls, *fields: Union[str, SQL, Index], **kwargs: object) -> None: ... @@ -2063,7 +2062,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper[_TModel]): joins: Mapping[Hashable, Tuple[object, object, Callable[..., _TModel], int]] key_to_constructor: Dict[Type[_TModel], Callable[..., _TModel]] src_is_dest: Dict[Type[Model], bool] - src_to_dest: List[tuple] # TODO -- Tuple[, join_type[1], join_type[0], bool, join_type[3]] + src_to_dest: List[tuple] # TODO -- Tuple[, join_type[1], join_type[0], bool, join_type[3]] column_keys: List # incomplete def __init__( self, From 315359479de838f817a572dea4eb7772bb0e2d5d Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Mon, 18 Apr 2022 19:50:23 +0100 Subject: [PATCH 08/22] Fix location of new stub file --- stubs/peewee/METADATA.toml | 1 + {third_party/2and3 => stubs/peewee}/peewee.pyi | 0 2 files changed, 1 insertion(+) create mode 100644 stubs/peewee/METADATA.toml rename {third_party/2and3 => stubs/peewee}/peewee.pyi (100%) diff --git a/stubs/peewee/METADATA.toml b/stubs/peewee/METADATA.toml new file mode 100644 index 000000000000..719172f4da80 --- /dev/null +++ b/stubs/peewee/METADATA.toml @@ -0,0 +1 @@ +version = "3.14.10.*" diff --git a/third_party/2and3/peewee.pyi b/stubs/peewee/peewee.pyi similarity index 100% rename from third_party/2and3/peewee.pyi rename to stubs/peewee/peewee.pyi From 30d62898fc3fe4fecb31a3215e728b7eae737e20 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Mon, 18 Apr 2022 21:35:48 +0100 Subject: [PATCH 09/22] Fix many lint errors --- stubs/peewee/METADATA.toml | 1 + stubs/peewee/peewee.pyi | 1101 +++++++++++++++++++----------------- 2 files changed, 578 insertions(+), 524 deletions(-) diff --git a/stubs/peewee/METADATA.toml b/stubs/peewee/METADATA.toml index 719172f4da80..2ce4d86a8441 100644 --- a/stubs/peewee/METADATA.toml +++ b/stubs/peewee/METADATA.toml @@ -1 +1,2 @@ version = "3.14.10.*" +python2 = true diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 934083167b46..937e16545b55 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -4,6 +4,7 @@ import enum import re import threading import uuid +from _typeshed import Self from typing import ( Any, AnyStr, @@ -11,45 +12,116 @@ from typing import ( ClassVar, Container, ContextManager, - Dict, Generic, Hashable, Iterable, Iterator, - List, Mapping, MutableMapping, MutableSet, NamedTuple, NoReturn, - Optional, + Protocol, Sequence, - Set, Text, - Tuple, - Type, TypeVar, Union, overload, ) -from typing_extensions import Literal, Protocol +from typing_extensions import Literal, TypeAlias _T = TypeVar("_T") -_TModel = TypeVar("_TModel", bound="Model") +_TModel = TypeVar("_TModel", bound=Model) _TConvFunc = Callable[[Any], Any] _TFunc = TypeVar("_TFunc", bound=Callable) _TClass = TypeVar("_TClass", bound=type) -_TContextClass = TypeVar("_TContextClass", bound="Context") -_TField = TypeVar("_TField", bound="Field") -_TNode = TypeVar("_TNode", bound="Node") +_TContextClass = TypeVar("_TContextClass", bound=Context) +_TField = TypeVar("_TField", bound=Field) +_TNode = TypeVar("_TNode", bound=Node) __version__: str -__all__: List[str] + +__all__ = [ + "AnyField", + "AsIs", + "AutoField", + "BareField", + "BigAutoField", + "BigBitField", + "BigIntegerField", + "BinaryUUIDField", + "BitField", + "BlobField", + "BooleanField", + "Case", + "Cast", + "CharField", + "Check", + "chunked", + "Column", + "CompositeKey", + "Context", + "Database", + "DatabaseError", + "DatabaseProxy", + "DataError", + "DateField", + "DateTimeField", + "DecimalField", + "DeferredForeignKey", + "DeferredThroughModel", + "DJANGO_MAP", + "DoesNotExist", + "DoubleField", + "DQ", + "EXCLUDED", + "Field", + "FixedCharField", + "FloatField", + "fn", + "ForeignKeyField", + "IdentityField", + "ImproperlyConfigured", + "Index", + "IntegerField", + "IntegrityError", + "InterfaceError", + "InternalError", + "IPField", + "JOIN", + "ManyToManyField", + "Model", + "ModelIndex", + "MySQLDatabase", + "NotSupportedError", + "OP", + "OperationalError", + "PostgresqlDatabase", + "PrimaryKeyField", # XXX: Deprecated, change to AutoField. + "prefetch", + "ProgrammingError", + "Proxy", + "QualifiedNames", + "SchemaManager", + "SmallIntegerField", + "Select", + "SQL", + "SqliteDatabase", + "Table", + "TextField", + "TimeField", + "TimestampField", + "Tuple", + "UUIDField", + "Value", + "ValuesList", + "Window", +] class __ICursor(Protocol): - description: Tuple[str, Any, Any, Any, Any, Any, Any] + description: tuple[str, Any, Any, Any, Any, Any, Any] rowcount: int - def fetchone(self) -> Optional[tuple]: ... + def fetchone(self) -> tuple | None: ... def fetchmany(self, size: int = ...) -> Iterable[tuple]: ... def fetchall(self) -> Iterable[tuple]: ... @@ -73,13 +145,13 @@ class __ITableFunction(Protocol): @classmethod def register(cls, conn: __IConnection) -> None: ... -def _sqlite_date_part(lookup_type: str, datetime_string: str) -> Optional[str]: ... -def _sqlite_date_trunc(lookup_type: str, datetime_string: str) -> Optional[str]: ... +def _sqlite_date_part(lookup_type: str, datetime_string: str) -> str | None: ... +def _sqlite_date_trunc(lookup_type: str, datetime_string: str) -> str | None: ... class attrdict(dict): def __getattr__(self, attr: str) -> Any: ... def __setattr__(self, attr: str, value: object) -> None: ... - def __iadd__(self, rhs: Mapping[str, object]) -> attrdict: ... + def __iadd__(self: Self, rhs: Mapping[str, object]) -> Self: ... def __add__(self, rhs: Mapping[str, object]) -> Mapping[str, object]: ... class _TSentinel(enum.Enum): ... @@ -152,22 +224,22 @@ class AliasManager: class __State(NamedTuple): scope: int parentheses: bool - # From the source code we know this to be a Dict and not just a MutableMapping. - settings: Dict[str, Any] + # From the source code we know this to be a dict and not just a MutableMapping. + settings: dict[str, Any] class State(__State): subquery: object # TODO (dargueta) - def __new__(cls, scope: int = ..., parentheses: bool = ..., **kwargs: object) -> State: ... - def __call__(self, scope: Optional[int] = ..., parentheses: Optional[int] = ..., **kwargs: object) -> State: ... + def __new__(cls: type[Self], scope: int = ..., parentheses: bool = ..., **kwargs: object) -> Self: ... + def __call__(self, scope: int | None = ..., parentheses: int | None = ..., **kwargs: object) -> State: ... def __getattr__(self, attr_name: str) -> Any: ... class Context: - stack: List[State] + stack: list[State] alias_manager: AliasManager state: State def __init__(self, **settings: object) -> None: ... def as_new(self) -> Context: ... - def column_sort_key(self, item: Sequence[Union[ColumnBase, Source]]) -> Tuple[str, ...]: ... + def column_sort_key(self, item: Sequence[ColumnBase | Source]) -> tuple[str, ...]: ... @property def scope(self) -> int: ... @property @@ -180,17 +252,17 @@ class Context: def scope_values(self) -> ContextManager[Context]: ... def scope_cte(self) -> ContextManager[Context]: ... def scope_column(self) -> ContextManager[Context]: ... - def __enter__(self) -> Context: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... + def __enter__(self: Self) -> Self: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... # @contextmanager def push_alias(self) -> Iterator[None]: ... # TODO (dargueta): Is this right? def sql(self, obj: object) -> Context: ... def literal(self, keyword: str) -> Context: ... - def value(self, value: object, converter: Optional[_TConvFunc] = ..., add_param: bool = ...) -> Context: ... + def value(self, value: object, converter: _TConvFunc | None = ..., add_param: bool = ...) -> Context: ... def __sql__(self, ctx: Context) -> Context: ... - def parse(self, node: Node) -> Tuple[str, Optional[tuple]]: ... - def query(self) -> Tuple[str, Optional[tuple]]: ... + def parse(self, node: Node) -> tuple[str, tuple | None]: ... + def query(self) -> tuple[str, tuple | None]: ... def query_to_string(query: Node) -> str: ... @@ -213,25 +285,25 @@ class _DynamicColumn: @overload def __get__(self, instance: None, instance_type: type) -> _DynamicColumn: ... @overload - def __get__(self, instance: _T, instance_type: Type[_T]) -> ColumnFactory: ... + def __get__(self, instance: _T, instance_type: type[_T]) -> ColumnFactory: ... class _ExplicitColumn: @overload def __get__(self, instance: None, instance_type: type) -> _ExplicitColumn: ... @overload - def __get__(self, instance: _T, instance_type: Type[_T]) -> NoReturn: ... + def __get__(self, instance: _T, instance_type: type[_T]) -> NoReturn: ... class _SupportsAlias(Protocol): - def alias(self: _T, name: str) -> _T: ... + def alias(self: Self, name: str) -> Self: ... class Source(_SupportsAlias, Node): c: ClassVar[_DynamicColumn] - def __init__(self, alias: Optional[str] = ...): ... + def __init__(self, alias: str | None = ...): ... def select(self, *columns: Field) -> Select: ... - def join(self, dest, join_type: int = ..., on: Optional[Expression] = ...) -> Join: ... - def left_outer_join(self, dest, on: Optional[Expression] = ...) -> Join: ... + def join(self, dest, join_type: int = ..., on: Expression | None = ...) -> Join: ... + def left_outer_join(self, dest, on: Expression | None = ...) -> Join: ... def cte(self, name: str, recursive: bool = ..., columns=..., materialized=...) -> CTE: ... # incomplete - def get_sort_key(self, ctx) -> Tuple[str, ...]: ... + def get_sort_key(self, ctx) -> tuple[str, ...]: ... def apply_alias(self, ctx: Context) -> Context: ... def apply_column(self, ctx: Context) -> Context: ... @@ -273,35 +345,35 @@ class _BoundTableContext(_callable_context_manager): database: Database def __init__(self, table: Table, database: Database): ... def __enter__(self) -> Table: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... class Table(_HashableSource, BaseTable): __name__: str c: _ExplicitColumn - primary_key: Optional[Union[Field, CompositeKey]] + primary_key: Field | CompositeKey | None def __init__( self, name: str, - columns: Optional[Iterable[str]] = ..., - primary_key: Optional[Union[Field, CompositeKey]] = ..., - schema: Optional[str] = ..., - alias: Optional[str] = ..., - _model: Optional[Type[Model]] = ..., - _database: Optional[Database] = ..., + columns: Iterable[str] | None = ..., + primary_key: Field | CompositeKey | None = ..., + schema: str | None = ..., + alias: str | None = ..., + _model: type[Model] | None = ..., + _database: Database | None = ..., ): ... def clone(self) -> Table: ... - def bind(self, database: Optional[Database] = ...) -> Table: ... - def bind_ctx(self, database: Optional[Database] = ...) -> _BoundTableContext: ... + def bind(self, database: Database | None = ...) -> Table: ... + def bind_ctx(self, database: Database | None = ...) -> _BoundTableContext: ... def select(self, *columns: Column) -> Select: ... @overload - def insert(self, insert: Optional[Select], columns: Sequence[Union[str, Field, Column]]) -> Insert: ... + def insert(self, insert: Select | None, columns: Sequence[str | Field | Column]) -> Insert: ... @overload - def insert(self, insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], **kwargs: object): ... + def insert(self, insert: Mapping[str, object] | Iterable[Mapping[str, object]], **kwargs: object): ... @overload - def replace(self, insert: Optional[Select], columns: Sequence[Union[str, Field, Column]]) -> Insert: ... + def replace(self, insert: Select | None, columns: Sequence[str | Field | Column]) -> Insert: ... @overload - def replace(self, insert: Union[Mapping[str, object], Iterable[Mapping[str, object]]], **kwargs: object): ... - def update(self, update: Optional[Mapping[str, object]] = ..., **kwargs: object) -> Update: ... + def replace(self, insert: Mapping[str, object] | Iterable[Mapping[str, object]], **kwargs: object): ... + def update(self, update: Mapping[str, object] | None = ..., **kwargs: object) -> Update: ... def delete(self) -> Delete: ... def __sql__(self, ctx: Context) -> Context: ... @@ -309,12 +381,12 @@ class Join(BaseTable): lhs: Any # TODO (dargueta) rhs: Any # TODO (dargueta) join_type: int - def __init__(self, lhs, rhs, join_type: int = ..., on: Optional[Expression] = ..., alias: Optional[str] = ...): ... + def __init__(self, lhs, rhs, join_type: int = ..., on: Expression | None = ..., alias: str | None = ...): ... def on(self, predicate: Expression) -> Join: ... def __sql__(self, ctx: Context) -> Context: ... class ValuesList(_HashableSource, BaseTable): - def __init__(self, values, columns=..., alias: Optional[str] = ...): ... # incomplete + def __init__(self, values, columns=..., alias: str | None = ...): ... # incomplete # FIXME (dargueta) `names` might be wrong def columns(self, *names: str) -> ValuesList: ... def __sql__(self, ctx: Context) -> Context: ... @@ -325,11 +397,11 @@ class CTE(_HashableSource, Source): name: str, query: Select, recursive: bool = ..., - columns: Optional[Iterable[Union[Column, Field, str]]] = ..., + columns: Iterable[Column | Field | str] | None = ..., materialized: bool = ..., ): ... # TODO (dargueta): Is `columns` just for column names? - def select_from(self, *columns: Union[Column, Field]) -> Select: ... + def select_from(self, *columns: Column | Field) -> Select: ... def _get_hash(self) -> int: ... def union_all(self, rhs) -> CTE: ... __add__ = union_all @@ -338,17 +410,17 @@ class CTE(_HashableSource, Source): def __sql__(self, ctx: Context) -> Context: ... class ColumnBase(Node): - _converter: Optional[_TConvFunc] - def converter(self, converter: Optional[_TConvFunc] = ...) -> ColumnBase: ... + _converter: _TConvFunc | None + def converter(self, converter: _TConvFunc | None = ...) -> ColumnBase: ... @overload def alias(self, alias: None) -> ColumnBase: ... @overload def alias(self, alias: str) -> Alias: ... def unalias(self) -> ColumnBase: ... def cast(self, as_type: str) -> Cast: ... - def asc(self, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> _SupportsSQLOrdering: ... + def asc(self, collation: str | None = ..., nulls: str | None = ...) -> _SupportsSQLOrdering: ... __pos__ = asc - def desc(self, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> _SupportsSQLOrdering: ... + def desc(self, collation: str | None = ..., nulls: str | None = ...) -> _SupportsSQLOrdering: ... __neg__ = desc # TODO (dargueta): This always returns Negated but subclasses can return something else def __invert__(self) -> WrappedNode: ... @@ -384,29 +456,29 @@ class ColumnBase(Node): def not_in(self, other: object) -> Expression: ... def regexp(self, other: object) -> Expression: ... def is_null(self, is_null: bool = ...) -> Expression: ... - def contains(self, rhs: Union[Node, str]) -> Expression: ... - def startswith(self, rhs: Union[Node, str]) -> Expression: ... - def endswith(self, rhs: Union[Node, str]) -> Expression: ... + def contains(self, rhs: Node | str) -> Expression: ... + def startswith(self, rhs: Node | str) -> Expression: ... + def endswith(self, rhs: Node | str) -> Expression: ... def between(self, lo: object, hi: object) -> Expression: ... def concat(self, rhs: object) -> StringExpression: ... def iregexp(self, rhs: object) -> Expression: ... def __getitem__(self, item: object) -> Expression: ... def distinct(self) -> NodeList: ... def collate(self, collation: str) -> NodeList: ... - def get_sort_key(self, ctx: Context) -> Tuple[str, ...]: ... + def get_sort_key(self, ctx: Context) -> tuple[str, ...]: ... class Column(ColumnBase): source: Source name: str def __init__(self, source: Source, name: str): ... - def get_sort_key(self, ctx: Context) -> Tuple[str, ...]: ... + def get_sort_key(self, ctx: Context) -> tuple[str, ...]: ... def __hash__(self) -> int: ... def __sql__(self, ctx: Context) -> Context: ... class WrappedNode(ColumnBase, Generic[_TNode]): node: _TNode _coerce: bool - _converter: Optional[_TConvFunc] + _converter: _TConvFunc | None def __init__(self, node: _TNode): ... def is_alias(self) -> bool: ... def unwrap(self) -> _TNode: ... @@ -420,7 +492,7 @@ class _DynamicEntity: @overload def __get__(self, instance: None, instance_type: type) -> _DynamicEntity: ... @overload - def __get__(self, instance: _T, instance_type: Type[_T]) -> EntityFactory: ... + def __get__(self, instance: _T, instance_type: type[_T]) -> EntityFactory: ... class Alias(WrappedNode): c: ClassVar[_DynamicEntity] @@ -450,9 +522,9 @@ class BitwiseNegated(BitwiseMixin, WrappedNode): class Value(ColumnBase): value: object - converter: Optional[_TConvFunc] + converter: _TConvFunc | None multi: bool - def __init__(self, value: object, converter: Optional[_TConvFunc] = ..., unpack: bool = ...): ... + def __init__(self, value: object, converter: _TConvFunc | None = ..., unpack: bool = ...): ... def __sql__(self, ctx: Context) -> Context: ... def AsIs(value: object) -> Value: ... @@ -463,24 +535,24 @@ class Cast(WrappedNode): class Ordering(WrappedNode): direction: str - collation: Optional[str] - nulls: Optional[str] - def __init__(self, node: Node, direction: str, collation: Optional[str] = ..., nulls: Optional[str] = ...): ... - def collate(self, collation: Optional[str] = ...) -> Ordering: ... + collation: str | None + nulls: str | None + def __init__(self, node: Node, direction: str, collation: str | None = ..., nulls: str | None = ...): ... + def collate(self, collation: str | None = ...) -> Ordering: ... def __sql__(self, ctx: Context) -> Context: ... class _SupportsSQLOrdering(Protocol): - def __call__(node: Node, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Ordering: ... + def __call__(node: Node, collation: str | None = ..., nulls: str | None = ...) -> Ordering: ... -def Asc(node: Node, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Ordering: ... -def Desc(node: Node, collation: Optional[str] = ..., nulls: Optional[str] = ...) -> Ordering: ... +def Asc(node: Node, collation: str | None = ..., nulls: str | None = ...) -> Ordering: ... +def Desc(node: Node, collation: str | None = ..., nulls: str | None = ...) -> Ordering: ... class Expression(ColumnBase): - lhs: Optional[Union[Node, str]] + lhs: Node | str | None op: int - rhs: Optional[Union[Node, str]] + rhs: Node | str | None flat: bool - def __init__(self, lhs: Optional[Union[Node, str]], op: int, rhs: Optional[Union[Node, str]], flat: bool = ...): ... + def __init__(self, lhs: Node | str | None, op: int, rhs: Node | str | None, flat: bool = ...): ... def __sql__(self, ctx: Context) -> Context: ... class StringExpression(Expression): @@ -490,13 +562,13 @@ class StringExpression(Expression): class Entity(ColumnBase): def __init__(self, *path: str): ... def __getattr__(self, attr: str) -> Entity: ... - def get_sort_key(self, ctx: Context) -> Tuple[str, ...]: ... + def get_sort_key(self, ctx: Context) -> tuple[str, ...]: ... def __hash__(self) -> int: ... def __sql__(self, ctx: Context) -> Context: ... class SQL(ColumnBase): sql: str - params: Optional[Mapping[str, object]] + params: Mapping[str, object] | None def __init__(self, sql: str, params: Mapping[str, object] = ...): ... def __sql__(self, ctx: Context) -> Context: ... @@ -505,21 +577,21 @@ def Check(constraint: str) -> SQL: ... class Function(ColumnBase): name: str arguments: tuple - def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: Optional[_TConvFunc] = ...): ... + def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: _TConvFunc | None = ...): ... def __getattr__(self, attr: str) -> Callable[..., Function]: ... # TODO (dargueta): `where` is an educated guess - def filter(self, where: Optional[Expression] = ...) -> Function: ... - def order_by(self, *ordering: Union[Field, Expression]) -> Function: ... - def python_value(self, func: Optional[_TConvFunc] = ...) -> Function: ... + def filter(self, where: Expression | None = ...) -> Function: ... + def order_by(self, *ordering: Field | Expression) -> Function: ... + def python_value(self, func: _TConvFunc | None = ...) -> Function: ... def over( self, - partition_by: Optional[Union[Sequence[Field], Window]] = ..., - order_by: Optional[Sequence[Union[Field, Expression]]] = ..., - start: Optional[Union[str, SQL]] = ..., - end: Optional[Union[str, SQL]] = ..., - frame_type: Optional[str] = ..., - window: Optional[Window] = ..., - exclude: Optional[SQL] = ..., + partition_by: Sequence[Field] | Window | None = ..., + order_by: Sequence[Field | Expression] | None = ..., + start: str | SQL | None = ..., + end: str | SQL | None = ..., + frame_type: str | None = ..., + window: Window | None = ..., + exclude: SQL | None = ..., ) -> NodeList: ... def __sql__(self, ctx: Context) -> Context: ... @@ -534,47 +606,47 @@ class Window(Node): RANGE: ClassVar[str] ROWS: ClassVar[str] # Instance variables - partition_by: Tuple[Union[Field, Expression], ...] - order_by: Tuple[Union[Field, Expression], ...] - start: Optional[Union[str, SQL]] - end: Optional[Union[str, SQL]] - frame_type: Optional[Any] # incomplete + partition_by: tuple[Field | Expression, ...] + order_by: tuple[Field | Expression, ...] + start: str | SQL | None + end: str | SQL | None + frame_type: Any | None # incomplete @overload def __init__( self, - partition_by: Optional[Union[Sequence[Field], Window]] = ..., - order_by: Optional[Sequence[Union[Field, Expression]]] = ..., - start: Optional[Union[str, SQL]] = ..., + partition_by: Sequence[Field] | Window | None = ..., + order_by: Sequence[Field | Expression] | None = ..., + start: str | SQL | None = ..., end: None = ..., - frame_type: Optional[str] = ..., - extends: Optional[Union[Window, WindowAlias, str]] = ..., - exclude: Optional[SQL] = ..., - alias: Optional[str] = ..., + frame_type: str | None = ..., + extends: Window | WindowAlias | str | None = ..., + exclude: SQL | None = ..., + alias: str | None = ..., _inline: bool = ..., ): ... @overload def __init__( self, - partition_by: Optional[Union[Sequence[Field], Window]] = ..., - order_by: Optional[Sequence[Union[Field, Expression]]] = ..., - start: Union[str, SQL] = ..., - end: Union[str, SQL] = ..., - frame_type: Optional[str] = ..., - extends: Optional[Union[Window, WindowAlias, str]] = ..., - exclude: Optional[SQL] = ..., - alias: Optional[str] = ..., + partition_by: Sequence[Field] | Window | None = ..., + order_by: Sequence[Field | Expression] | None = ..., + start: str | SQL = ..., + end: str | SQL = ..., + frame_type: str | None = ..., + extends: Window | WindowAlias | str | None = ..., + exclude: SQL | None = ..., + alias: str | None = ..., _inline: bool = ..., ): ... - def alias(self, alias: Optional[str] = ...) -> Window: ... + def alias(self, alias: str | None = ...) -> Window: ... def as_range(self) -> Window: ... def as_rows(self) -> Window: ... def as_groups(self) -> Window: ... - def extends(self, window: Optional[Union[Window, WindowAlias, str]] = ...) -> Window: ... - def exclude(self, frame_exclusion: Optional[Union[str, SQL]] = ...) -> Window: ... + def extends(self, window: Window | WindowAlias | str | None = ...) -> Window: ... + def exclude(self, frame_exclusion: str | SQL | None = ...) -> Window: ... @staticmethod - def following(value: Optional[int] = ...) -> SQL: ... + def following(value: int | None = ...) -> SQL: ... @staticmethod - def preceding(value: Optional[int] = ...) -> SQL: ... + def preceding(value: int | None = ...) -> SQL: ... def __sql__(self, ctx: Context) -> Context: ... class WindowAlias(Node): @@ -586,13 +658,13 @@ class WindowAlias(Node): class ForUpdate(Node): def __init__( self, - expr: Union[Literal[True], str], - of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...]]] = ..., - nowait: Optional[bool] = ..., + expr: Literal[True] | str, + of: _TModelOrTable | list[_TModelOrTable] | set[_TModelOrTable] | tuple[_TModelOrTable, ...] | None = ..., + nowait: bool | None = ..., ): ... def __sql__(self, ctx: Context) -> Context: ... -def Case(predicate: Optional[Node], expression_tuples: Iterable[Tuple[Expression, Any]], default: object = ...) -> NodeList: ... +def Case(predicate: Node | None, expression_tuples: Iterable[tuple[Expression, Any]], default: object = ...) -> NodeList: ... class NodeList(ColumnBase): # TODO (dargueta): Narrow this type @@ -617,7 +689,7 @@ class NamespaceAttribute(ColumnBase): EXCLUDED: _Namespace class DQ(ColumnBase): - query: Dict[str, Any] + query: dict[str, Any] # TODO (dargueta): Narrow this down? def __init__(self, **query: object): ... @@ -638,31 +710,31 @@ class OnConflict(Node): @overload def __init__( self, - action: Optional[str] = ..., - update: Optional[Mapping[str, object]] = ..., - preserve: Optional[Union[Field, Iterable[Field]]] = ..., - where: Optional[Expression] = ..., - conflict_target: Optional[Union[Field, Sequence[Field]]] = ..., + action: str | None = ..., + update: Mapping[str, object] | None = ..., + preserve: Field | Iterable[Field] | None = ..., + where: Expression | None = ..., + conflict_target: Field | Sequence[Field] | None = ..., conflict_where: None = ..., - conflict_constraint: Optional[str] = ..., + conflict_constraint: str | None = ..., ): ... @overload def __init__( self, - action: Optional[str] = ..., - update: Optional[Mapping[str, object]] = ..., - preserve: Optional[Union[Field, Iterable[Field]]] = ..., - where: Optional[Expression] = ..., + action: str | None = ..., + update: Mapping[str, object] | None = ..., + preserve: Field | Iterable[Field] | None = ..., + where: Expression | None = ..., conflict_target: None = ..., - conflict_where: Optional[Expression] = ..., - conflict_constraint: Optional[str] = ..., + conflict_where: Expression | None = ..., + conflict_constraint: str | None = ..., ): ... # undocumented - def get_conflict_statement(self, ctx: Context, query: Query) -> Optional[SQL]: ... + def get_conflict_statement(self, ctx: Context, query: Query) -> SQL | None: ... def get_conflict_update(self, ctx: Context, query: Query) -> NodeList: ... def preserve(self, *columns: Column) -> OnConflict: ... # Despite the argument name `_data` is documented - def update(self, _data: Optional[Mapping[str, object]] = ..., **kwargs: object) -> OnConflict: ... + def update(self, _data: Mapping[str, object] | None = ..., **kwargs: object) -> OnConflict: ... def where(self, *expressions: Expression) -> OnConflict: ... def conflict_target(self, *constraints: Column) -> OnConflict: ... def conflict_where(self, *expressions: Expression) -> OnConflict: ... @@ -670,39 +742,38 @@ class OnConflict(Node): class BaseQuery(Node): default_row_type: ClassVar[int] - def __init__(self, _database: Optional[Database] = ..., **kwargs: object): ... - def bind(self, database: Optional[Database] = ...) -> BaseQuery: ... + def __init__(self, _database: Database | None = ..., **kwargs: object): ... + def bind(self, database: Database | None = ...) -> BaseQuery: ... def clone(self) -> BaseQuery: ... def dicts(self, as_dict: bool = ...) -> BaseQuery: ... def tuples(self, as_tuple: bool = ...) -> BaseQuery: ... def namedtuples(self, as_namedtuple: bool = ...) -> BaseQuery: ... - def objects(self, constructor: Optional[_TConvFunc] = ...) -> BaseQuery: ... + def objects(self, constructor: _TConvFunc | None = ...) -> BaseQuery: ... def __sql__(self, ctx: Context) -> Context: ... - def sql(self) -> Tuple[str, Optional[tuple]]: ... - def execute(self, database: Optional[Database] = ...) -> CursorWrapper: ... + def sql(self) -> tuple[str, tuple | None]: ... + def execute(self, database: Database | None = ...) -> CursorWrapper: ... # TODO (dargueta): `Any` is too loose; list types of the cursor wrappers - def iterator(self, database: Optional[Database] = ...) -> Iterator[Any]: ... + def iterator(self, database: Database | None = ...) -> Iterator[Any]: ... def __iter__(self) -> Iterator[Any]: ... @overload def __getitem__(self, value: int) -> Any: ... @overload def __getitem__(self, value: slice) -> Sequence[Any]: ... def __len__(self) -> int: ... - def __str__(self) -> str: ... class RawQuery(BaseQuery): # TODO (dargueta): `tuple` may not be 100% accurate, maybe Sequence[object]? - def __init__(self, sql: Optional[str] = ..., params: Optional[tuple] = ..., **kwargs: object): ... + def __init__(self, sql: str | None = ..., params: tuple | None = ..., **kwargs: object): ... def __sql__(self, ctx: Context) -> Context: ... class Query(BaseQuery): # TODO (dargueta): Verify type of order_by def __init__( self, - where: Optional[Expression] = ..., - order_by: Optional[Sequence[Node]] = ..., - limit: Optional[int] = ..., - offset: Optional[int] = ..., + where: Expression | None = ..., + order_by: Sequence[Node] | None = ..., + limit: int | None = ..., + offset: int | None = ..., **kwargs: object, ): ... def with_cte(self, *cte_list: CTE) -> Query: ... @@ -710,8 +781,8 @@ class Query(BaseQuery): def orwhere(self, *expressions: Expression) -> Query: ... def order_by(self, *values: Node) -> Query: ... def order_by_extend(self, *values: Node) -> Query: ... - def limit(self, value: Optional[int] = ...) -> Query: ... - def offset(self, value: Optional[int] = ...) -> Query: ... + def limit(self, value: int | None = ...) -> Query: ... + def offset(self, value: int | None = ...) -> Query: ... def paginate(self, page: int, paginate_by: int = ...) -> Query: ... def _apply_ordering(self, ctx: Context) -> Context: ... def __sql__(self, ctx: Context) -> Context: ... @@ -735,20 +806,20 @@ class SelectQuery(Query): class SelectBase(_HashableSource, Source, SelectQuery): @overload - def peek(self, database: Optional[Database] = ..., n: Literal[1] = ...) -> object: ... + def peek(self, database: Database | None = ..., n: Literal[1] = ...) -> object: ... @overload - def peek(self, database: Optional[Database] = ..., n: int = ...) -> List[object]: ... + def peek(self, database: Database | None = ..., n: int = ...) -> list[object]: ... @overload - def first(self, database: Optional[Database] = ..., n: Literal[1] = ...) -> object: ... + def first(self, database: Database | None = ..., n: Literal[1] = ...) -> object: ... @overload - def first(self, database: Optional[Database] = ..., n: int = ...) -> List[object]: ... + def first(self, database: Database | None = ..., n: int = ...) -> list[object]: ... @overload - def scalar(self, database: Optional[Database] = ..., as_tuple: Literal[False] = ...) -> object: ... + def scalar(self, database: Database | None = ..., as_tuple: Literal[False] = ...) -> object: ... @overload - def scalar(self, database: Optional[Database] = ..., as_tuple: Literal[True] = ...) -> tuple: ... - def count(self, database: Optional[Database] = ..., clear_limit: bool = ...) -> int: ... - def exists(self, database: Optional[Database] = ...) -> bool: ... - def get(self, database: Optional[Database] = ...) -> object: ... + def scalar(self, database: Database | None = ..., as_tuple: Literal[True] = ...) -> tuple: ... + def count(self, database: Database | None = ..., clear_limit: bool = ...) -> int: ... + def exists(self, database: Database | None = ...) -> bool: ... + def get(self, database: Database | None = ...) -> object: ... # QUERY IMPLEMENTATIONS. @@ -757,36 +828,36 @@ class CompoundSelectQuery(SelectBase): op: str rhs: Any # TODO (dargueta) def __init__(self, lhs: object, op: str, rhs: object): ... - def exists(self, database: Optional[Database] = ...) -> bool: ... + def exists(self, database: Database | None = ...) -> bool: ... def __sql__(self, ctx: Context) -> Context: ... class Select(SelectBase): def __init__( self, - from_list: Optional[Sequence[Union[Column, Field]]] = ..., # TODO (dargueta): `Field` might be wrong - columns: Optional[Iterable[Union[Column, Field]]] = ..., # TODO (dargueta): `Field` might be wrong + from_list: Sequence[Column | Field] | None = ..., # TODO (dargueta): `Field` might be wrong + columns: Iterable[Column | Field] | None = ..., # TODO (dargueta): `Field` might be wrong # Docs say this is a "[l]ist of columns or values to group by" so we don't have # a whole lot to restrict this to thanks to "or values" group_by: Sequence[object] = ..., - having: Optional[Expression] = ..., - distinct: Optional[Union[bool, Sequence[Column]]] = ..., - windows: Optional[Container[Window]] = ..., - for_update: Optional[Union[bool, str]] = ..., - for_update_of: Optional[Union[Table, Iterable[Table]]] = ..., - nowait: Optional[bool] = ..., - lateral: Optional[bool] = ..., # undocumented + having: Expression | None = ..., + distinct: bool | Sequence[Column] | None = ..., + windows: Container[Window] | None = ..., + for_update: bool | str | None = ..., + for_update_of: Table | Iterable[Table] | None = ..., + nowait: bool | None = ..., + lateral: bool | None = ..., # undocumented **kwargs: object, ): ... def clone(self) -> Select: ... # TODO (dargueta) `Field` might be wrong in this union - def columns(self, *columns: Union[Column, Field], **kwargs: object) -> Select: ... - def select(self, *columns: Union[Column, Field], **kwargs: object) -> Select: ... + def columns(self, *columns: Column | Field, **kwargs: object) -> Select: ... + def select(self, *columns: Column | Field, **kwargs: object) -> Select: ... def select_extend(self, *columns) -> Select: ... # TODO (dargueta): Is `sources` right? - def from_(self, *sources: Union[Source, Type[Model]]) -> Select: ... - def join(self, dest: Type[Model], join_type: int = ..., on: Optional[Expression] = ...) -> Select: ... - def group_by(self, *columns: Union[Table, Field]) -> Select: ... - def group_by_extend(self, *values: Union[Table, Field]) -> Select: ... + def from_(self, *sources: Source | type[Model]) -> Select: ... + def join(self, dest: type[Model], join_type: int = ..., on: Expression | None = ...) -> Select: ... + def group_by(self, *columns: Table | Field) -> Select: ... + def group_by_extend(self, *values: Table | Field) -> Select: ... def having(self, *expressions: Expression) -> Select: ... @overload def distinct(self, _: bool) -> Select: ... @@ -794,22 +865,22 @@ class Select(SelectBase): def distinct(self, *columns: Field) -> Select: ... def window(self, *windows: Window) -> Select: ... def for_update( - self, for_update: bool = ..., of: Optional[Union[Table, Iterable[Table]]] = ..., nowait: Optional[bool] = ... + self, for_update: bool = ..., of: Table | Iterable[Table] | None = ..., nowait: bool | None = ... ) -> Select: ... def lateral(self, lateral: bool = ...) -> Select: ... class _WriteQuery(Query): table: Table - def __init__(self, table: Table, returning: Optional[Iterable[Union[Type[Model], Field]]] = ..., **kwargs: object): ... - def returning(self, *returning: Union[Type[Model], Field]) -> _WriteQuery: ... + def __init__(self, table: Table, returning: Iterable[type[Model] | Field] | None = ..., **kwargs: object): ... + def returning(self, *returning: type[Model] | Field) -> _WriteQuery: ... def apply_returning(self, ctx: Context) -> Context: ... def execute_returning(self, database: Database) -> CursorWrapper: ... - def handle_result(self, database: Database, cursor: __ICursor) -> Union[int, __ICursor]: ... + def handle_result(self, database: Database, cursor: __ICursor) -> int | __ICursor: ... def __sql__(self, ctx: Context) -> Context: ... class Update(_WriteQuery): # TODO (dargueta): `update` - def __init__(self, table: Table, update: Optional[Any] = ..., **kwargs: object): ... + def __init__(self, table: Table, update: Any | None = ..., **kwargs: object): ... def from_(self, *sources) -> Update: ... def __sql__(self, ctx: Context) -> Context: ... @@ -817,13 +888,13 @@ class Insert(_WriteQuery): SIMPLE: ClassVar[int] QUERY: ClassVar[int] MULTI: ClassVar[int] - DefaultValuesException: Type[Exception] + DefaultValuesException: type[Exception] def __init__( self, table: Table, - insert: Optional[Union[Mapping[str, object], Iterable[Mapping[str, object]], SelectQuery, SQL]] = ..., - columns: Optional[Iterable[Union[str, Field]]] = ..., # FIXME: Might be `Column` not `Field` - on_conflict: Optional[OnConflict] = ..., + insert: Mapping[str, object] | Iterable[Mapping[str, object]] | SelectQuery | SQL | None = ..., + columns: Iterable[str | Field] | None = ..., # FIXME: Might be `Column` not `Field` + on_conflict: OnConflict | None = ..., **kwargs: object, ): ... def where(self, *expressions: Expression) -> NoReturn: ... @@ -831,9 +902,9 @@ class Insert(_WriteQuery): def on_conflict_replace(self, replace: bool = ...) -> Insert: ... def on_conflict(self, *args, **kwargs) -> Insert: ... def get_default_data(self) -> Mapping[str, object]: ... - def get_default_columns(self) -> Optional[List[Field]]: ... + def get_default_columns(self) -> list[Field] | None: ... def __sql__(self, ctx: Context) -> Context: ... - def handle_result(self, database: Database, cursor: __ICursor) -> Union[__ICursor, int]: ... + def handle_result(self, database: Database, cursor: __ICursor) -> __ICursor | int: ... class Delete(_WriteQuery): def __sql__(self, ctx: Context) -> Context: ... @@ -846,24 +917,24 @@ class Index(Node): expressions, unique: bool = ..., safe: bool = ..., - where: Optional[Expression] = ..., - using: Optional[str] = ..., + where: Expression | None = ..., + using: str | None = ..., ): ... def safe(self, _safe: bool = ...) -> Index: ... def where(self, *expressions: Expression) -> Index: ... - def using(self, _using: Optional[str] = ...) -> Index: ... + def using(self, _using: str | None = ...) -> Index: ... def __sql__(self, ctx: Context) -> Context: ... class ModelIndex(Index): def __init__( self, - model: Type[_TModel], - fields: Iterable[Union[Field, Node, str]], + model: type[_TModel], + fields: Iterable[Field | Node | str], unique: bool = ..., safe: bool = ..., - where: Optional[Expression] = ..., - using: Optional[str] = ..., - name: Optional[str] = ..., + where: Expression | None = ..., + using: str | None = ..., + name: str | None = ..., ): ... class PeeweeException(Exception): @@ -883,17 +954,17 @@ class OperationalError(DatabaseError): ... class ProgrammingError(DatabaseError): ... class ExceptionWrapper: - exceptions: Mapping[str, Type[Exception]] - def __init__(self, exceptions: Mapping[str, Type[Exception]]): ... + exceptions: Mapping[str, type[Exception]] + def __init__(self, exceptions: Mapping[str, type[Exception]]): ... def __enter__(self) -> None: ... - def __exit__(self, exc_type: Type[Exception], exc_value: Exception, traceback: object) -> None: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: object) -> None: ... -EXCEPTIONS: Mapping[str, Type[Exception]] +EXCEPTIONS: Mapping[str, type[Exception]] class IndexMetadata(NamedTuple): name: str sql: str - columns: List[str] + columns: list[str] unique: bool table: str @@ -917,9 +988,9 @@ class ViewMetadata(NamedTuple): class _ConnectionState: closed: bool - conn: Optional[__IConnection] - ctx: List[ConnectionContext] - transactions: List[Union[_manual, _transaction]] + conn: __IConnection | None + ctx: list[ConnectionContext] + transactions: list[_manual | _transaction] def reset(self) -> None: ... def set_connection(self, conn: __IConnection) -> None: ... @@ -928,21 +999,21 @@ class _ConnectionLocal(_ConnectionState, threading.local): ... class ConnectionContext(_callable_context_manager): db: Database def __enter__(self) -> None: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... class Database(_callable_context_manager): - context_class: ClassVar[Type[_TContextClass]] + context_class: ClassVar[type[_TContextClass]] field_types: ClassVar[Mapping[str, str]] operations: ClassVar[Mapping[str, Any]] # TODO (dargueta) Verify k/v types param: ClassVar[str] quote: ClassVar[str] - server_version: ClassVar[Optional[Union[int, Tuple[int, ...]]]] + server_version: ClassVar[int | tuple[int, ...] | None] commit_select: ClassVar[bool] compound_select_parentheses: ClassVar[int] for_update: ClassVar[bool] index_schema_prefix: ClassVar[bool] index_using_precedes_table: ClassVar[bool] - limit_max: ClassVar[Optional[int]] + limit_max: ClassVar[int | None] nulls_ordering: ClassVar[bool] returning_clause: ClassVar[bool] safe_create_index: ClassVar[bool] @@ -961,29 +1032,29 @@ class Database(_callable_context_manager): database: __IConnection, thread_safe: bool = ..., autorollback: bool = ..., - field_types: Optional[Mapping[str, str]] = ..., - operations: Optional[Mapping[str, str]] = ..., + field_types: Mapping[str, str] | None = ..., + operations: Mapping[str, str] | None = ..., autocommit: bool = ..., autoconnect: bool = ..., **kwargs: object, ): ... def init(self, database: __IConnection, **kwargs: object) -> None: ... - def __enter__(self) -> Database: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... + def __enter__(self: Self) -> Self: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... def connection_context(self) -> ConnectionContext: ... def connect(self, reuse_if_open: bool = ...) -> bool: ... def close(self) -> bool: ... def is_closed(self) -> bool: ... def is_connection_usable(self) -> bool: ... def connection(self) -> __IConnection: ... - def cursor(self, commit: Optional[bool] = ...) -> __ICursor: ... - def execute_sql(self, sql: str, params: Optional[tuple] = ..., commit: Union[bool, _TSentinel] = ...) -> __ICursor: ... - def execute(self, query: Query, commit: Union[bool, _TSentinel] = ..., **context_options: object) -> __ICursor: ... + def cursor(self, commit: bool | None = ...) -> __ICursor: ... + def execute_sql(self, sql: str, params: tuple | None = ..., commit: bool | _TSentinel = ...) -> __ICursor: ... + def execute(self, query: Query, commit: bool | _TSentinel = ..., **context_options: object) -> __ICursor: ... def get_context_options(self) -> Mapping[str, object]: ... def get_sql_context(self, **context_options: object) -> _TContextClass: ... - def conflict_statement(self, on_conflict: OnConflict, query: Query) -> Optional[SQL]: ... + def conflict_statement(self, on_conflict: OnConflict, query: Query) -> SQL | None: ... def conflict_update(self, oc: OnConflict, query: Query) -> NodeList: ... - def last_insert_id(self, cursor: __ICursor, query_type: Optional[int] = ...) -> int: ... + def last_insert_id(self, cursor: __ICursor, query_type: int | None = ...) -> int: ... def rows_affected(self, cursor: __ICursor) -> int: ... def default_values_insert(self, ctx: Context) -> Context: ... def session_start(self) -> _transaction: ... @@ -991,9 +1062,9 @@ class Database(_callable_context_manager): def session_rollback(self) -> bool: ... def in_transaction(self) -> bool: ... def push_transaction(self, transaction) -> None: ... - def pop_transaction(self) -> Union[_manual, _transaction]: ... + def pop_transaction(self) -> _manual | _transaction: ... def transaction_depth(self) -> int: ... - def top_transaction(self) -> Optional[Union[_manual, _transaction]]: ... + def top_transaction(self) -> _manual | _transaction | None: ... def atomic(self, *args: object, **kwargs: object) -> _atomic: ... def manual_commit(self) -> _manual: ... def transaction(self, *args: object, **kwargs: object) -> _transaction: ... @@ -1002,23 +1073,23 @@ class Database(_callable_context_manager): def commit(self) -> None: ... def rollback(self) -> None: ... def batch_commit(self, it: Iterable[_T], n: int) -> Iterator[_T]: ... - def table_exists(self, table_name: str, schema: Optional[str] = ...) -> str: ... - def get_tables(self, schema: Optional[str] = ...) -> List[str]: ... - def get_indexes(self, table: str, schema: Optional[str] = ...) -> List[IndexMetadata]: ... - def get_columns(self, table: str, schema: Optional[str] = ...) -> List[ColumnMetadata]: ... - def get_primary_keys(self, table: str, schema: Optional[str] = ...) -> List[str]: ... - def get_foreign_keys(self, table: str, schema: Optional[str] = ...) -> List[ForeignKeyMetadata]: ... + def table_exists(self, table_name: str, schema: str | None = ...) -> str: ... + def get_tables(self, schema: str | None = ...) -> list[str]: ... + def get_indexes(self, table: str, schema: str | None = ...) -> list[IndexMetadata]: ... + def get_columns(self, table: str, schema: str | None = ...) -> list[ColumnMetadata]: ... + def get_primary_keys(self, table: str, schema: str | None = ...) -> list[str]: ... + def get_foreign_keys(self, table: str, schema: str | None = ...) -> list[ForeignKeyMetadata]: ... def sequence_exists(self, seq: str) -> bool: ... - def create_tables(self, models: Iterable[Type[Model]], **options: object) -> None: ... - def drop_tables(self, models: Iterable[Type[Model]], **kwargs: object) -> None: ... + def create_tables(self, models: Iterable[type[Model]], **options: object) -> None: ... + def drop_tables(self, models: Iterable[type[Model]], **kwargs: object) -> None: ... def extract_date(self, date_part: str, date_field: Node) -> Function: ... def truncate_date(self, date_part: str, date_field: Node) -> Function: ... def to_timestamp(self, date_field: str) -> Function: ... def from_timestamp(self, date_field: str) -> Function: ... def random(self) -> Node: ... - def bind(self, models: Iterable[Type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ...) -> None: ... + def bind(self, models: Iterable[type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ...) -> None: ... def bind_ctx( - self, models: Iterable[Type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ... + self, models: Iterable[type[Model]], bind_refs: bool = ..., bind_backrefs: bool = ... ) -> _BoundModelsContext: ... def get_noop_select(self, ctx: Context) -> Context: ... @@ -1027,7 +1098,7 @@ class SqliteDatabase(Database): operations: ClassVar[Mapping[str, str]] index_schema_prefix: ClassVar[bool] limit_max: ClassVar[int] - server_version: ClassVar[Tuple[int, ...]] + server_version: ClassVar[tuple[int, ...]] truncate_table: ClassVar[bool] # Instance variables timeout: int @@ -1035,22 +1106,16 @@ class SqliteDatabase(Database): # Properties cache_size: int def __init__( - self, - database: str, - *args: object, - pragmas: Union[Mapping[str, object], Iterable[Tuple[str, Any]]] = ..., - **kwargs: object, + self, database: str, *args: object, pragmas: Mapping[str, object] | Iterable[tuple[str, Any]] = ..., **kwargs: object ): ... def init( self, database: str, - pragmas: Optional[Union[Mapping[str, object], Iterable[Tuple[str, Any]]]] = ..., + pragmas: Mapping[str, object] | Iterable[tuple[str, Any]] | None = ..., timeout: int = ..., **kwargs: object, ) -> None: ... - def pragma( - self, key: str, value: Union[str, bool, int] = ..., permanent: bool = ..., schema: Optional[str] = ... - ) -> object: ... + def pragma(self, key: str, value: str | bool | int = ..., permanent: bool = ..., schema: str | None = ...) -> object: ... @property def foreign_keys(self) -> Any: ... @foreign_keys.setter @@ -1083,16 +1148,16 @@ class SqliteDatabase(Database): def wal_autocheckpoint(self) -> Any: ... @wal_autocheckpoint.setter def wal_autocheckpoint(self, value: object) -> Any: ... - def register_aggregate(self, klass: Type[__IAggregate], name: Optional[str] = ..., num_params: int = ...): ... - def aggregate(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... - def register_collation(self, fn: Callable, name: Optional[str] = ...) -> None: ... - def collation(self, name: Optional[str] = ...) -> Callable[[_TFunc], _TFunc]: ... - def register_function(self, fn: Callable, name: Optional[str] = ..., num_params: int = ...) -> int: ... - def func(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[_TFunc], _TFunc]: ... - def register_window_function(self, klass: type, name: Optional[str] = ..., num_params: int = ...) -> None: ... - def window_function(self, name: Optional[str] = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... - def register_table_function(self, klass: Type[__ITableFunction], name: Optional[str] = ...) -> None: ... - def table_function(self, name: Optional[str] = ...) -> Callable[[Type[__ITableFunction]], Type[__ITableFunction]]: ... + def register_aggregate(self, klass: type[__IAggregate], name: str | None = ..., num_params: int = ...): ... + def aggregate(self, name: str | None = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... + def register_collation(self, fn: Callable, name: str | None = ...) -> None: ... + def collation(self, name: str | None = ...) -> Callable[[_TFunc], _TFunc]: ... + def register_function(self, fn: Callable, name: str | None = ..., num_params: int = ...) -> int: ... + def func(self, name: str | None = ..., num_params: int = ...) -> Callable[[_TFunc], _TFunc]: ... + def register_window_function(self, klass: type, name: str | None = ..., num_params: int = ...) -> None: ... + def window_function(self, name: str | None = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... + def register_table_function(self, klass: type[__ITableFunction], name: str | None = ...) -> None: ... + def table_function(self, name: str | None = ...) -> Callable[[type[__ITableFunction]], type[__ITableFunction]]: ... def unregister_aggregate(self, name: str) -> None: ... def unregister_collation(self, name: str) -> None: ... def unregister_function(self, name: str) -> None: ... @@ -1102,8 +1167,8 @@ class SqliteDatabase(Database): def unload_extension(self, extension: str) -> None: ... def attach(self, filename: str, name: str) -> bool: ... def detach(self, name: str) -> bool: ... - def begin(self, lock_type: Optional[str] = ...) -> None: ... - def get_views(self, schema: Optional[str] = ...) -> List[ViewMetadata]: ... + def begin(self, lock_type: str | None = ...) -> None: ... + def get_views(self, schema: str | None = ...) -> list[ViewMetadata]: ... def get_binary_type(self) -> type: ... class PostgresqlDatabase(Database): @@ -1123,13 +1188,13 @@ class PostgresqlDatabase(Database): self, database: __IConnection, register_unicode: bool = ..., - encoding: Optional[str] = ..., - isolation_level: Optional[int] = ..., + encoding: str | None = ..., + isolation_level: int | None = ..., **kwargs: object, ): ... def is_connection_usable(self) -> bool: ... - def last_insert_id(self, cursor: __ICursor, query_type: Optional[int] = ...) -> Union[Optional[int], __ICursor]: ... - def get_views(self, schema: Optional[str] = ...) -> List[ViewMetadata]: ... + def last_insert_id(self, cursor: __ICursor, query_type: int | None = ...) -> int | None | __ICursor: ... + def get_views(self, schema: str | None = ...) -> list[ViewMetadata]: ... def get_binary_type(self) -> type: ... def get_noop_select(self, ctx: Context) -> SelectQuery: ... def set_time_zone(self, timezone: str) -> None: ... @@ -1148,10 +1213,10 @@ class MySQLDatabase(Database): safe_drop_index: ClassVar[bool] sql_mode: ClassVar[str] # Instance variables - server_version: Tuple[int, ...] + server_version: tuple[int, ...] def init(self, database: __IConnection, **kwargs: object): ... def default_values_insert(self, ctx: Context) -> SQL: ... - def get_views(self, schema: Optional[str] = ...) -> List[ViewMetadata]: ... + def get_views(self, schema: str | None = ...) -> list[ViewMetadata]: ... def get_binary_type(self) -> type: ... # TODO (dargueta) Verify return type on these function calls def extract_date(self, date_part: str, date_field: str) -> Function: ... @@ -1167,31 +1232,31 @@ class _manual(_callable_context_manager): db: Database def __init__(self, db: Database): ... def __enter__(self) -> None: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... class _atomic(_callable_context_manager): db: Database def __init__(self, db: Database, *args: object, **kwargs: object): ... - def __enter__(self) -> Union[_transaction, _savepoint]: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... + def __enter__(self) -> _transaction | _savepoint: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... class _transaction(_callable_context_manager): db: Database def __init__(self, db: Database, *args: object, **kwargs: object): ... def commit(self, begin: bool = ...) -> None: ... def rollback(self, begin: bool = ...) -> None: ... - def __enter__(self) -> _transaction: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... + def __enter__(self: Self) -> Self: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... class _savepoint(_callable_context_manager): db: Database sid: str quoted_sid: str - def __init__(self, db: Database, sid: Optional[str] = ...): ... + def __init__(self, db: Database, sid: str | None = ...): ... def commit(self, begin: bool = ...) -> None: ... def rollback(self) -> None: ... - def __enter__(self) -> _savepoint: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ... + def __enter__(self: Self) -> Self: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... class CursorWrapper(Generic[_T]): cursor: __ICursor @@ -1199,13 +1264,13 @@ class CursorWrapper(Generic[_T]): index: int initialized: bool populated: bool - row_cache: List[_T] + row_cache: list[_T] def __init__(self, cursor: __ICursor): ... - def __iter__(self) -> Union[ResultIterator[_T], Iterator[_T]]: ... + def __iter__(self) -> ResultIterator[_T] | Iterator[_T]: ... @overload def __getitem__(self, item: int) -> _T: ... @overload - def __getitem__(self, item: slice) -> List[_T]: ... + def __getitem__(self, item: slice) -> list[_T]: ... def __len__(self) -> int: ... def initialize(self) -> None: ... def iterate(self, cache: bool = ...) -> _T: ... @@ -1217,7 +1282,7 @@ class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ... # FIXME (dargueta): Somehow figure out how to make this a NamedTuple sorta deal class NamedTupleCursorWrapper(CursorWrapper[tuple]): - tuple_class: Type[tuple] + tuple_class: type[tuple] class ObjectCursorWrapper(DictCursorWrapper[_T]): constructor: Callable[..., _T] @@ -1233,26 +1298,26 @@ class ResultIterator(Generic[_T]): # FIELDS class FieldAccessor: - model: Type[Model] + model: type[Model] field: Field name: str - def __init__(self, model: Type[Model], field: Field, name: str): ... + def __init__(self, model: type[Model], field: Field, name: str): ... @overload def __get__(self, instance: None, instance_type: type) -> Field: ... @overload - def __get__(self, instance: _T, instance_type: Type[_T]) -> Any: ... + def __get__(self, instance: _T, instance_type: type[_T]) -> Any: ... class ForeignKeyAccessor(FieldAccessor): - model: Type[Model] + model: type[Model] field: ForeignKeyField name: str - rel_model: Type[Model] - def __init__(self, model: Type[Model], field: ForeignKeyField, name: str): ... + rel_model: type[Model] + def __init__(self, model: type[Model], field: ForeignKeyField, name: str): ... def get_rel_instance(self, instance: Model) -> Any: ... @overload def __get__(self, instance: None, instance_type: type) -> Any: ... @overload - def __get__(self, instance: _TModel, instance_type: Type[_TModel]) -> ForeignKeyField: ... + def __get__(self, instance: _TModel, instance_type: type[_TModel]) -> ForeignKeyField: ... def __set__(self, instance: _TModel, obj: object) -> None: ... class NoQueryForeignKeyAccessor(ForeignKeyAccessor): @@ -1260,46 +1325,45 @@ class NoQueryForeignKeyAccessor(ForeignKeyAccessor): class BackrefAccessor: field: ForeignKeyField - model: Type[Model] - rel_model: Type[Model] + model: type[Model] + rel_model: type[Model] def __init__(self, field: ForeignKeyField): ... @overload def __get__(self, instance: None, instance_type: type) -> BackrefAccessor: ... @overload - def __get__(self, instance: Field, instance_type: Type[Field]) -> SelectQuery: ... + def __get__(self, instance: Field, instance_type: type[Field]) -> SelectQuery: ... +# "Gives direct access to the underlying id" class ObjectIdAccessor: - """Gives direct access to the underlying id""" - field: ForeignKeyField def __init__(self, field: ForeignKeyField): ... @overload - def __get__(self, instance: None, instance_type: Type[Model]) -> ForeignKeyField: ... + def __get__(self, instance: None, instance_type: type[Model]) -> ForeignKeyField: ... @overload - def __get__(self, instance: _TModel, instance_type: Type[_TModel] = ...) -> Any: ... + def __get__(self, instance: _TModel, instance_type: type[_TModel] = ...) -> Any: ... def __set__(self, instance: Model, value: object) -> None: ... class Field(ColumnBase): - accessor_class: ClassVar[Type[FieldAccessor]] + accessor_class: ClassVar[type[FieldAccessor]] auto_increment: ClassVar[bool] - default_index_type: ClassVar[Optional[str]] + default_index_type: ClassVar[str | None] field_type: ClassVar[str] unpack: ClassVar[bool] # Instance variables - model: Type[Model] + model: type[Model] null: bool index: bool unique: bool column_name: str default: Any primary_key: bool - constraints: Optional[Iterable[Union[Callable[[str], SQL], SQL]]] - sequence: Optional[str] - collation: Optional[str] + constraints: Iterable[Callable[[str], SQL] | SQL] | None + sequence: str | None + collation: str | None unindexed: bool - help_text: Optional[str] - verbose_name: Optional[str] - index_type: Optional[str] + help_text: str | None + verbose_name: str | None + index_type: str | None def __init__( self, null: bool = ..., @@ -1308,27 +1372,26 @@ class Field(ColumnBase): column_name: str = ..., default: Any = ..., primary_key: bool = ..., - constraints: Optional[Iterable[Union[Callable[[str], SQL], SQL]]] = ..., - sequence: Optional[str] = ..., - collation: Optional[str] = ..., - unindexed: Optional[bool] = ..., - choices: Optional[Iterable[Tuple[Any, str]]] = ..., - help_text: Optional[str] = ..., - verbose_name: Optional[str] = ..., - index_type: Optional[str] = ..., - db_column: Optional[str] = ..., # Deprecated argument, undocumented + constraints: Iterable[Callable[[str], SQL] | SQL] | None = ..., + sequence: str | None = ..., + collation: str | None = ..., + unindexed: bool | None = ..., + choices: Iterable[tuple[Any, str]] | None = ..., + help_text: str | None = ..., + verbose_name: str | None = ..., + index_type: str | None = ..., + db_column: str | None = ..., # Deprecated argument, undocumented _hidden: bool = ..., ): ... def __hash__(self) -> int: ... - def __repr__(self) -> str: ... - def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... + def bind(self, model: type[Model], name: str, set_attribute: bool = ...) -> None: ... @property def column(self) -> Column: ... def adapt(self, value: _T) -> _T: ... def db_value(self, value: _T) -> _T: ... def python_value(self, value: _T) -> _T: ... def to_value(self, value: Any) -> Value: ... - def get_sort_key(self, ctx: Context) -> Tuple[int, int]: ... + def get_sort_key(self, ctx: Context) -> tuple[int, int]: ... def __sql__(self, ctx: Context) -> Context: ... def get_modifiers(self) -> Any: ... def ddl_datatype(self, ctx: Context) -> SQL: ... @@ -1336,7 +1399,7 @@ class Field(ColumnBase): class IntegerField(Field): @overload - def adapt(self, value: Union[str, float, bool]) -> int: ... # type: ignore + def adapt(self, value: str | float | bool) -> int: ... # type: ignore @overload def adapt(self, value: _T) -> _T: ... @@ -1355,7 +1418,7 @@ class PrimaryKeyField(AutoField): ... class FloatField(Field): @overload - def adapt(self, value: Union[str, float, bool]) -> float: ... # type: ignore + def adapt(self, value: str | float | bool) -> float: ... # type: ignore @overload def adapt(self, value: _T) -> _T: ... @@ -1375,17 +1438,17 @@ class DecimalField(Field): *args: object, **kwargs: object, ): ... - def get_modifiers(self) -> List[int]: ... + def get_modifiers(self) -> list[int]: ... @overload def db_value(self, value: None) -> None: ... @overload - def db_value(self, value: Union[float, decimal.Decimal]) -> decimal.Decimal: ... # type: ignore + def db_value(self, value: float | decimal.Decimal) -> decimal.Decimal: ... # type: ignore @overload def db_value(self, value: _T) -> _T: ... @overload def python_value(self, value: None) -> None: ... @overload - def python_value(self, value: Union[str, float, decimal.Decimal]) -> decimal.Decimal: ... + def python_value(self, value: str | float | decimal.Decimal) -> decimal.Decimal: ... class _StringField(Field): def adapt(self, value: AnyStr) -> str: ... @@ -1395,21 +1458,21 @@ class _StringField(Field): class CharField(_StringField): max_length: int def __init__(self, max_length: int = ..., *args: object, **kwargs: object): ... - def get_modifiers(self) -> Optional[List[int]]: ... + def get_modifiers(self) -> list[int] | None: ... class FixedCharField(CharField): ... class TextField(_StringField): ... class BlobField(Field): @overload - def db_value(self, value: Union[str, bytes]) -> bytearray: ... + def db_value(self, value: str | bytes) -> bytearray: ... @overload def db_value(self, value: _T) -> _T: ... class BitField(BitwiseMixin, BigIntegerField): - def __init__(self, *args: object, default: Optional[int] = ..., **kwargs: object): ... + def __init__(self, *args: object, default: int | None = ..., **kwargs: object): ... # FIXME (dargueta) Return type isn't 100% accurate; function creates a new class - def flag(self, value: Optional[int] = ...) -> ColumnBase: ... + def flag(self, value: int | None = ...) -> ColumnBase: ... class BigBitFieldData: name: str @@ -1419,17 +1482,16 @@ class BigBitFieldData: def clear_bit(self, idx: bool) -> None: ... def toggle_bit(self, idx: int) -> bool: ... def is_set(self, idx: int) -> bool: ... - def __repr__(self) -> str: ... class BigBitFieldAccessor(FieldAccessor): @overload - def __get__(self, instance: None, instance_type: Type[_TModel]) -> Field: ... + def __get__(self, instance: None, instance_type: type[_TModel]) -> Field: ... @overload - def __get__(self, instance: _TModel, instance_type: Type[_TModel]) -> BigBitFieldData: ... - def __set__(self, instance: Any, value: Union[memoryview, bytearray, BigBitFieldData, str, bytes]) -> None: ... + def __get__(self, instance: _TModel, instance_type: type[_TModel]) -> BigBitFieldData: ... + def __set__(self, instance: Any, value: memoryview | bytearray | BigBitFieldData | str | bytes) -> None: ... class BigBitField(BlobField): - accessor_class: ClassVar[Type[BigBitFieldAccessor]] + accessor_class: ClassVar[type[BigBitFieldAccessor]] def __init__(self, *args: object, default: type = ..., **kwargs: object): ... @overload def db_value(self, value: None) -> None: ... @@ -1442,7 +1504,7 @@ class UUIDField(Field): @overload def db_value(self, value: _T) -> _T: ... @overload - def python_value(self, value: Union[uuid.UUID, AnyStr]) -> uuid.UUID: ... + def python_value(self, value: uuid.UUID | AnyStr) -> uuid.UUID: ... @overload def python_value(self, value: None) -> None: ... @@ -1450,20 +1512,20 @@ class BinaryUUIDField(BlobField): @overload def db_value(self, value: None) -> None: ... @overload - def db_value(self, value: Optional[Union[bytearray, bytes, str, uuid.UUID]]) -> bytes: ... + def db_value(self, value: bytearray | bytes | str | uuid.UUID | None) -> bytes: ... @overload def python_value(self, value: None) -> None: ... @overload - def python_value(self, value: Union[bytearray, bytes, memoryview, uuid.UUID]) -> uuid.UUID: ... + def python_value(self, value: bytearray | bytes | memoryview | uuid.UUID) -> uuid.UUID: ... -def format_date_time(value: str, formats: Iterable[str], post_process: Optional[_TConvFunc] = ...) -> str: ... +def format_date_time(value: str, formats: Iterable[str], post_process: _TConvFunc | None = ...) -> str: ... @overload def simple_date_time(value: _T) -> _T: ... class _BaseFormattedField(Field): # TODO (dargueta): This is a class variable that can be overridden for instances - formats: Optional[Container[str]] - def __init__(self, formats: Optional[Container[str]] = ..., *args: object, **kwargs: object): ... + formats: Container[str] | None + def __init__(self, formats: Container[str] | None = ..., *args: object, **kwargs: object): ... class DateTimeField(_BaseFormattedField): @property @@ -1498,7 +1560,7 @@ class DateField(_BaseFormattedField): class TimeField(_BaseFormattedField): @overload - def adapt(self, value: Union[datetime.datetime, datetime.timedelta]) -> datetime.time: ... + def adapt(self, value: datetime.datetime | datetime.timedelta) -> datetime.time: ... @overload def adapt(self, value: _T) -> _T: ... @property @@ -1521,9 +1583,9 @@ class TimestampField(BigIntegerField): @overload def db_value(self, value: None) -> None: ... @overload - def db_value(self, value: Union[datetime.datetime, datetime.date, float]) -> int: ... + def db_value(self, value: datetime.datetime | datetime.date | float) -> int: ... @overload - def python_value(self, value: Union[int, float]) -> datetime.datetime: ... + def python_value(self, value: int | float) -> datetime.datetime: ... @overload def python_value(self, value: _T) -> _T: ... def from_timestamp(self) -> float: ... @@ -1555,148 +1617,146 @@ class BooleanField(Field): class BareField(Field): # If `adapt` was omitted from the constructor or None, this attribute won't exist. - adapt: Optional[_TConvFunc] - def __init__(self, adapt: Optional[_TConvFunc] = ..., *args: object, **kwargs: object): ... + adapt: _TConvFunc | None + def __init__(self, adapt: _TConvFunc | None = ..., *args: object, **kwargs: object): ... def ddl_datatype(self, ctx: Context) -> None: ... class ForeignKeyField(Field): - accessor_class: ClassVar[Type[ForeignKeyAccessor]] - rel_model: Union[Type[Model], Literal["self"]] + accessor_class: ClassVar[type[ForeignKeyAccessor]] + rel_model: type[Model] | Literal["self"] rel_field: Field - declared_backref: Optional[str] - backref: Optional[str] # TODO (dargueta): Verify - on_delete: Optional[str] - on_update: Optional[str] - deferrable: Optional[str] - deferred: Optional[bool] - object_id_name: Optional[str] + declared_backref: str | None + backref: str | None # TODO (dargueta): Verify + on_delete: str | None + on_update: str | None + deferrable: str | None + deferred: bool | None + object_id_name: str | None lazy_load: bool safe_name: str def __init__( self, - model: Union[Type[Model], Literal["self"]], - field: Optional[Field] = ..., + model: type[Model] | Literal["self"], + field: Field | None = ..., # TODO (dargueta): Documentation says this is only a string but code accepts a callable too - backref: Optional[str] = ..., - on_delete: Optional[str] = ..., - on_update: Optional[str] = ..., - deferrable: Optional[str] = ..., - _deferred: Optional[bool] = ..., # undocumented + backref: str | None = ..., + on_delete: str | None = ..., + on_update: str | None = ..., + deferrable: str | None = ..., + _deferred: bool | None = ..., # undocumented rel_model: object = ..., # undocumented to_field: object = ..., # undocumented - object_id_name: Optional[str] = ..., + object_id_name: str | None = ..., lazy_load: bool = ..., # type for related_name is a guess - related_name: Optional[str] = ..., # undocumented + related_name: str | None = ..., # undocumented *args: object, index: bool = ..., **kwargs: object, ): ... @property def field_type(self) -> str: ... - def get_modifiers(self) -> Optional[Iterable[object]]: ... + def get_modifiers(self) -> Iterable[object] | None: ... def adapt(self, value: object) -> Any: ... def db_value(self, value: object) -> Any: ... def python_value(self, value: object) -> Any: ... - def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... + def bind(self, model: type[Model], name: str, set_attribute: bool = ...) -> None: ... def foreign_key_constraint(self) -> NodeList: ... def __getattr__(self, attr: str) -> Field: ... class DeferredForeignKey(Field): - field_kwargs: Dict[str, object] + field_kwargs: dict[str, object] rel_model_name: str - def __init__(self, rel_model_name: str, *, column_name: Optional[str] = ..., null: Optional[str] = ..., **kwargs: object): ... - def set_model(self, rel_model: Type[Model]) -> None: ... + def __init__(self, rel_model_name: str, *, column_name: str | None = ..., null: str | None = ..., **kwargs: object): ... + def set_model(self, rel_model: type[Model]) -> None: ... @staticmethod - def resolve(model_cls: Type[Model]) -> None: ... + def resolve(model_cls: type[Model]) -> None: ... def __hash__(self) -> int: ... class DeferredThroughModel: - def set_field(self, model: Type[Model], field: Type[Field], name: str) -> None: ... - def set_model(self, through_model: Type[Model]) -> None: ... + def set_field(self, model: type[Model], field: type[Field], name: str) -> None: ... + def set_model(self, through_model: type[Model]) -> None: ... class MetaField(Field): # These are declared as class variables in the source code but are used like local # variables - column_name: Optional[str] + column_name: str | None default: Any - model: Type[Model] - name: Optional[str] + model: type[Model] + name: str | None primary_key: bool class ManyToManyFieldAccessor(FieldAccessor): - model: Type[Model] - rel_model: Type[Model] - through_model: Type[Model] + model: type[Model] + rel_model: type[Model] + through_model: type[Model] src_fk: ForeignKeyField dest_fk: ForeignKeyField - def __init__(self, model: Type[Model], field: ForeignKeyField, name: str): ... + def __init__(self, model: type[Model], field: ForeignKeyField, name: str): ... @overload - def __get__(self, instance: None, instance_type: Type[_T] = ..., force_query: bool = ...) -> Field: ... + def __get__(self, instance: None, instance_type: type[_T] = ..., force_query: bool = ...) -> Field: ... @overload - def __get__( - self, instance: _T, instance_type: Type[_T] = ..., force_query: bool = ... - ) -> Union[List[str], ManyToManyQuery]: ... + def __get__(self, instance: _T, instance_type: type[_T] = ..., force_query: bool = ...) -> list[str] | ManyToManyQuery: ... def __set__(self, instance: _T, value) -> None: ... class ManyToManyField(MetaField): - accessor_class: ClassVar[Type[ManyToManyFieldAccessor]] + accessor_class: ClassVar[type[ManyToManyFieldAccessor]] # Instance variables - through_model: Union[Type[Model], DeferredThroughModel] - rel_model: Type[Model] - backref: Optional[str] + through_model: type[Model] | DeferredThroughModel + rel_model: type[Model] + backref: str | None def __init__( self, - model: Type[Model], - backref: Optional[str] = ..., - through_model: Optional[Union[Type[Model], DeferredThroughModel]] = ..., - on_delete: Optional[str] = ..., - on_update: Optional[str] = ..., + model: type[Model], + backref: str | None = ..., + through_model: type[Model] | DeferredThroughModel | None = ..., + on_delete: str | None = ..., + on_update: str | None = ..., _is_backref: bool = ..., ): ... - def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... - def get_models(self) -> List[Type[Model]]: ... - def get_through_model(self) -> Union[Type[Model], DeferredThroughModel]: ... + def bind(self, model: type[Model], name: str, set_attribute: bool = ...) -> None: ... + def get_models(self) -> list[type[Model]]: ... + def get_through_model(self) -> type[Model] | DeferredThroughModel: ... class VirtualField(MetaField, Generic[_TField]): - field_class: Type[_TField] - field_instance: Optional[_TField] - def __init__(self, field_class: Optional[Type[_TField]] = ..., *args: object, **kwargs: object): ... + field_class: type[_TField] + field_instance: _TField | None + def __init__(self, field_class: type[_TField] | None = ..., *args: object, **kwargs: object): ... def db_value(self, value: object) -> Any: ... def python_value(self, value: object) -> Any: ... - def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... + def bind(self, model: type[Model], name: str, set_attribute: bool = ...) -> None: ... class CompositeKey(MetaField): sequence = None - field_names: Tuple[str, ...] + field_names: tuple[str, ...] # The following attributes are not set in the constructor an so may not always be # present. - model: Type[Model] + model: type[Model] column_name: str def __init__(self, *field_names: str): ... @property - def safe_field_names(self) -> Union[List[str], Tuple[str, ...]]: ... + def safe_field_names(self) -> list[str] | tuple[str, ...]: ... @overload def __get__(self, instance: None, instance_type: type) -> CompositeKey: ... @overload - def __get__(self, instance: _T, instance_type: Type[_T]) -> tuple: ... - def __set__(self, instance: Model, value: Union[list, tuple]) -> None: ... + def __get__(self, instance: _T, instance_type: type[_T]) -> tuple: ... + def __set__(self, instance: Model, value: list | tuple) -> None: ... def __eq__(self, other: Expression) -> Expression: ... def __ne__(self, other: Expression) -> Expression: ... def __hash__(self) -> int: ... def __sql__(self, ctx: Context) -> Context: ... - def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ... + def bind(self, model: type[Model], name: str, set_attribute: bool = ...) -> None: ... # MODELS class SchemaManager: - model: Type[Model] - context_options: Dict[str, object] - def __init__(self, model: Type[Model], database: Optional[Database] = ..., **context_options: object): ... + model: type[Model] + context_options: dict[str, object] + def __init__(self, model: type[Model], database: Database | None = ..., **context_options: object): ... @property def database(self) -> Database: ... @database.setter - def database(self, value: Optional[Database]) -> None: ... + def database(self, value: Database | None) -> None: ... def create_table(self, safe: bool = ..., **options: object) -> None: ... def create_table_as(self, table_name: str, query: SelectQuery, safe: bool = ..., **meta: object) -> None: ... def drop_table(self, safe: bool = ..., **options: object) -> None: ... @@ -1712,50 +1772,50 @@ class SchemaManager: def drop_all(self, safe: bool = ..., drop_sequences: bool = ..., **options: object) -> None: ... class Metadata: - model: Type[Model] - database: Optional[Database] - fields: Dict[str, object] # TODO (dargueta) This may be Dict[str, Field] - columns: Dict[str, object] # TODO (dargueta) Verify this - combined: Dict[str, object] # TODO (dargueta) Same as above - sorted_fields: List[Field] - sorted_field_names: List[str] - defaults: Dict[str, object] + model: type[Model] + database: Database | None + fields: dict[str, object] # TODO (dargueta) This may be dict[str, Field] + columns: dict[str, object] # TODO (dargueta) Verify this + combined: dict[str, object] # TODO (dargueta) Same as above + sorted_fields: list[Field] + sorted_field_names: list[str] + defaults: dict[str, object] name: str - table_function: Optional[Callable[[Type[Model]], str]] + table_function: Callable[[type[Model]], str] | None legacy_table_names: bool table_name: str - indexes: List[Union[Index, ModelIndex, SQL]] - constraints: Optional[Iterable[Union[Callable[[str], SQL], SQL]]] - primary_key: Union[Literal[False], Field, CompositeKey, None] - composite_key: Optional[bool] - auto_increment: Optional[bool] + indexes: list[Index | ModelIndex | SQL] + constraints: Iterable[Callable[[str], SQL] | SQL] | None + primary_key: Literal[False] | Field | CompositeKey | None + composite_key: bool | None + auto_increment: bool | None only_save_dirty: bool - depends_on: Optional[Sequence[Type[Model]]] + depends_on: Sequence[type[Model]] | None table_settings: Mapping[str, object] temporary: bool - refs: Dict[ForeignKeyField, Type[Model]] - backrefs: MutableMapping[ForeignKeyField, List[Type[Model]]] - model_refs: MutableMapping[Type[Model], List[ForeignKeyField]] - model_backrefs: MutableMapping[ForeignKeyField, List[Type[Model]]] - manytomany: Dict[str, ManyToManyField] + refs: dict[ForeignKeyField, type[Model]] + backrefs: MutableMapping[ForeignKeyField, list[type[Model]]] + model_refs: MutableMapping[type[Model], list[ForeignKeyField]] + model_backrefs: MutableMapping[ForeignKeyField, list[type[Model]]] + manytomany: dict[str, ManyToManyField] options: Mapping[str, object] - table: Optional[Table] + table: Table | None entity: Entity def __init__( self, - model: Type[Model], - database: Optional[Database] = ..., - table_name: Optional[str] = ..., - indexes: Optional[Iterable[Union[str, Sequence[str]]]] = ..., - primary_key: Optional[Union[Literal[False], Field, CompositeKey]] = ..., - constraints: Optional[Iterable[Union[Check, SQL]]] = ..., - schema: Optional[str] = ..., + model: type[Model], + database: Database | None = ..., + table_name: str | None = ..., + indexes: Iterable[str | Sequence[str]] | None = ..., + primary_key: Literal[False] | Field | CompositeKey | None = ..., + constraints: Iterable[Check | SQL] | None = ..., + schema: str | None = ..., only_save_dirty: bool = ..., - depends_on: Optional[Sequence[Type[Model]]] = ..., - options: Optional[Mapping[str, object]] = ..., - db_table: Optional[str] = ..., - table_function: Optional[Callable[[Type[Model]], str]] = ..., - table_settings: Optional[Mapping[str, object]] = ..., + depends_on: Sequence[type[Model]] | None = ..., + options: Mapping[str, object] | None = ..., + db_table: str | None = ..., + table_function: Callable[[type[Model]], str] | None = ..., + table_settings: Mapping[str, object] | None = ..., without_rowid: bool = ..., temporary: bool = ..., legacy_table_names: bool = ..., @@ -1764,31 +1824,30 @@ class Metadata: def make_table_name(self) -> str: ... def model_graph( self, refs: bool = ..., backrefs: bool = ..., depth_first: bool = ... - ) -> List[Tuple[ForeignKeyField, Type[Model], bool]]: ... + ) -> list[tuple[ForeignKeyField, type[Model], bool]]: ... def add_ref(self, field: ForeignKeyField) -> None: ... def remove_ref(self, field: ForeignKeyField) -> None: ... def add_manytomany(self, field: ManyToManyField) -> None: ... def remove_manytomany(self, field: ManyToManyField) -> None: ... - def get_rel_for_model(self, model: Union[Type[Model], ModelAlias]) -> Tuple[List[ForeignKeyField], List[Type[Model]]]: ... + def get_rel_for_model(self, model: type[Model] | ModelAlias) -> tuple[list[ForeignKeyField], list[type[Model]]]: ... def add_field(self, field_name: str, field: Field, set_attribute: bool = ...) -> None: ... def remove_field(self, field_name: str) -> None: ... - def set_primary_key(self, name: str, field: Union[Field, CompositeKey]) -> None: ... - def get_primary_keys(self) -> Tuple[Field, ...]: ... - def get_default_dict(self) -> Dict[str, object]: ... - def fields_to_index(self) -> List[ModelIndex]: ... + def set_primary_key(self, name: str, field: Field | CompositeKey) -> None: ... + def get_primary_keys(self) -> tuple[Field, ...]: ... + def get_default_dict(self) -> dict[str, object]: ... + def fields_to_index(self) -> list[ModelIndex]: ... def set_database(self, database: Database) -> None: ... def set_table_name(self, table_name: str) -> None: ... class SubclassAwareMetadata(Metadata): - models: ClassVar[List[Type[Model]]] - def __init__(self, model: Type[Model], *args: object, **kwargs: object): ... - def map_models(self, fn: Callable[[Type[Model]], Any]) -> None: ... + models: ClassVar[list[type[Model]]] + def __init__(self, model: type[Model], *args: object, **kwargs: object): ... + def map_models(self, fn: Callable[[type[Model]], Any]) -> None: ... class DoesNotExist(Exception): ... class ModelBase(type): - inheritable: ClassVar[Set[str]] - def __repr__(self) -> str: ... + inheritable: ClassVar[set[str]] def __iter__(self) -> Iterator[Any]: ... def __getitem__(self, key: object) -> Model: ... def __setitem__(self, key: object, value: Model) -> None: ... @@ -1800,32 +1859,31 @@ class ModelBase(type): def __sql__(self, ctx: Context) -> Context: ... class _BoundModelsContext(_callable_context_manager): - models: Iterable[Type[Model]] + models: Iterable[type[Model]] database: Database bind_refs: bool bind_backrefs: bool - def __init__(self, models: Iterable[Type[Model]], database, bind_refs: bool, bind_backrefs: bool): ... - def __enter__(self) -> Iterable[Type[Model]]: ... - def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None: ... + def __init__(self, models: Iterable[type[Model]], database, bind_refs: bool, bind_backrefs: bool): ... + def __enter__(self) -> Iterable[type[Model]]: ... + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... class Model(Node, metaclass=ModelBase): _meta: ClassVar[Metadata] _schema: ClassVar[SchemaManager] - DoesNotExist: ClassVar[Type[DoesNotExist]] + DoesNotExist: ClassVar[type[DoesNotExist]] __data__: MutableMapping[str, object] __rel__: MutableMapping[str, object] - def __init__(self, *, __no_default__: Union[int, bool] = ..., **kwargs: object): ... - def __str__(self) -> str: ... + def __init__(self, *, __no_default__: int | bool = ..., **kwargs: object): ... @classmethod def validate_model(cls) -> None: ... @classmethod - def alias(cls, alias: Optional[str] = ...) -> ModelAlias: ... + def alias(cls, alias: str | None = ...) -> ModelAlias: ... @classmethod def select(cls, *fields: Field) -> ModelSelect: ... @classmethod - def update(cls, __data: Optional[Iterable[Union[str, Field]]] = ..., **update: Any) -> ModelUpdate: ... + def update(cls, __data: Iterable[str | Field] | None = ..., **update: Any) -> ModelUpdate: ... @classmethod - def insert(cls, __data: Optional[Iterable[Union[str, Field]]] = ..., **insert: Any) -> ModelInsert: ... + def insert(cls, __data: Iterable[str | Field] | None = ..., **insert: Any) -> ModelInsert: ... @overload @classmethod def insert_many(cls, rows: Iterable[Mapping[str, object]], fields: None) -> ModelInsert: ... @@ -1833,29 +1891,29 @@ class Model(Node, metaclass=ModelBase): @classmethod def insert_many(cls, rows: Iterable[tuple], fields: Sequence[Field]) -> ModelInsert: ... @classmethod - def insert_from(cls, query: SelectQuery, fields: Iterable[Union[Field, Text]]) -> ModelInsert: ... + def insert_from(cls, query: SelectQuery, fields: Iterable[Field | Text]) -> ModelInsert: ... @classmethod - def replace(cls, __data: Optional[Iterable[Union[str, Field]]] = ..., **insert: object) -> OnConflict: ... + def replace(cls, __data: Iterable[str | Field] | None = ..., **insert: object) -> OnConflict: ... @classmethod - def replace_many(cls, rows: Iterable[tuple], fields: Optional[Sequence[Field]] = ...) -> OnConflict: ... + def replace_many(cls, rows: Iterable[tuple], fields: Sequence[Field] | None = ...) -> OnConflict: ... @classmethod def raw(cls, sql: str, *params: object) -> ModelRaw: ... @classmethod def delete(cls) -> ModelDelete: ... @classmethod - def create(cls: Type[_T], **query) -> _T: ... + def create(cls: type[Self], **query) -> Self: ... @classmethod - def bulk_create(cls, model_list: Iterable[Type[Model]], batch_size: Optional[int] = ...) -> None: ... + def bulk_create(cls, model_list: Iterable[type[Model]], batch_size: int | None = ...) -> None: ... @classmethod def bulk_update( - cls, model_list: Iterable[Type[Model]], fields: Iterable[Union[str, Field]], batch_size: Optional[int] = ... + cls, model_list: Iterable[type[Model]], fields: Iterable[str | Field], batch_size: int | None = ... ) -> int: ... @classmethod def noop(cls) -> NoopModelSelect: ... @classmethod def get(cls, *query: object, **filters: object) -> ModelSelect: ... @classmethod - def get_or_none(cls, *query: object, **filters: object) -> Optional[ModelSelect]: ... + def get_or_none(cls, *query: object, **filters: object) -> ModelSelect | None: ... @classmethod def get_by_id(cls, pk: object) -> ModelSelect: ... # TODO (dargueta) I'm 99% sure of return value for this one @@ -1865,27 +1923,23 @@ class Model(Node, metaclass=ModelBase): @classmethod def delete_by_id(cls, pk: object) -> CursorWrapper: ... @classmethod - def get_or_create(cls, *, defaults: Mapping[str, object] = ..., **kwargs: object) -> Tuple[Any, bool]: ... + def get_or_create(cls, *, defaults: Mapping[str, object] = ..., **kwargs: object) -> tuple[Any, bool]: ... @classmethod def filter(cls, *dq_nodes: DQ, **filters: Any) -> SelectQuery: ... def get_id(self) -> Any: ... - def save(self, force_insert: bool = ..., only: Optional[Iterable[Union[str, Field]]] = ...) -> Union[Literal[False], int]: ... + def save(self, force_insert: bool = ..., only: Iterable[str | Field] | None = ...) -> Literal[False] | int: ... def is_dirty(self) -> bool: ... @property - def dirty_fields(self) -> List[Field]: ... - def dependencies(self, search_nullable: bool = ...) -> Iterator[Tuple[Union[bool, Node], ForeignKeyField]]: ... - def delete_instance(self: _T, recursive: bool = ..., delete_nullable: bool = ...) -> _T: ... + def dirty_fields(self) -> list[Field]: ... + def dependencies(self, search_nullable: bool = ...) -> Iterator[tuple[bool | Node, ForeignKeyField]]: ... + def delete_instance(self: Self, recursive: bool = ..., delete_nullable: bool = ...) -> Self: ... def __hash__(self) -> int: ... def __eq__(self, other: object) -> bool: ... def __ne__(self, other: object) -> bool: ... def __sql__(self, ctx: Context) -> Context: ... @classmethod def bind( - cls, - database: Database, - bind_refs: bool = ..., - bind_backrefs: bool = ..., - _exclude: Optional[MutableSet[Type[Model]]] = ..., + cls, database: Database, bind_refs: bool = ..., bind_backrefs: bool = ..., _exclude: MutableSet[type[Model]] | None = ... ) -> bool: ... @classmethod def bind_ctx(cls, database: Database, bind_refs: bool = ..., bind_backrefs: bool = ...) -> _BoundModelsContext: ... @@ -1898,30 +1952,29 @@ class Model(Node, metaclass=ModelBase): @classmethod def truncate_table(cls, **options: object) -> None: ... @classmethod - def index(cls, *fields: Union[Field, Node, str], **kwargs: object) -> ModelIndex: ... + def index(cls, *fields: Field | Node | str, **kwargs: object) -> ModelIndex: ... @classmethod - def add_index(cls, *fields: Union[str, SQL, Index], **kwargs: object) -> None: ... + def add_index(cls, *fields: str | SQL | Index, **kwargs: object) -> None: ... +# "Provide a separate reference to a model in a query." class ModelAlias(Node, Generic[_TModel]): - """Provide a separate reference to a model in a query.""" - - model: Type[_TModel] - alias: Optional[str] - def __init__(self, model: Type[_TModel], alias: Optional[str] = ...): ... + model: type[_TModel] + alias: str | None + def __init__(self, model: type[_TModel], alias: str | None = ...): ... def __getattr__(self, attr: str) -> Any: ... def __setattr__(self, attr: str, value: object) -> NoReturn: ... - def get_field_aliases(self) -> List[Field]: ... + def get_field_aliases(self) -> list[Field]: ... def select(self, *selection: Field) -> ModelSelect: ... def __call__(self, **kwargs) -> _TModel: ... def __sql__(self, ctx: Context) -> Context: ... -_TModelOrTable = Union[Type[Model], ModelAlias, Table] -_TSubquery = Union[Tuple[Query, Type[Model]], Type[Model], ModelAlias] -_TFieldOrModel = Union[_TModelOrTable, Field] +_TModelOrTable: TypeAlias = type[Model] | ModelAlias | Table +_TSubquery: TypeAlias = Union[tuple[Query, type[Model]], type[Model], ModelAlias] +_TFieldOrModel: TypeAlias = _TModelOrTable | Field class FieldAlias(Field): source: Node - model: Type[Model] + model: type[Model] field: Field # TODO (dargueta): Making an educated guess about `source`; might be `Node` def __init__(self, source: MetaField, field: Field): ... @@ -1937,15 +1990,15 @@ class FieldAlias(Field): def __getattr__(self, attr: str) -> Any: ... def __sql__(self, ctx: Context) -> Context: ... -def sort_models(models: Iterable[Type[Model]]) -> List[Type[Model]]: ... +def sort_models(models: Iterable[type[Model]]) -> list[type[Model]]: ... class _ModelQueryHelper: default_row_type: ClassVar[int] - def objects(self, constructor: Optional[Callable[..., Any]] = ...) -> _ModelQueryHelper: ... + def objects(self, constructor: Callable[..., Any] | None = ...) -> _ModelQueryHelper: ... class ModelRaw(_ModelQueryHelper, RawQuery, Generic[_TModel]): - model: Type[_TModel] - def __init__(self, model: Type[_TModel], sql: str, params: tuple, **kwargs: object): ... + model: type[_TModel] + def __init__(self, model: type[_TModel], sql: str, params: tuple, **kwargs: object): ... def get(self) -> _TModel: ... class BaseModelSelect(_ModelQueryHelper): @@ -1958,41 +2011,41 @@ class BaseModelSelect(_ModelQueryHelper): def except_(self, rhs: object) -> ModelCompoundSelectQuery: ... __sub__ = except_ def __iter__(self) -> Iterator[Any]: ... - def prefetch(self, *subqueries: _TSubquery) -> List[Any]: ... - def get(self, database: Optional[Database] = ...) -> Any: ... - def group_by(self, *columns: Union[Type[Model], Table, Field]) -> BaseModelSelect: ... + def prefetch(self, *subqueries: _TSubquery) -> list[Any]: ... + def get(self, database: Database | None = ...) -> Any: ... + def group_by(self, *columns: type[Model] | Table | Field) -> BaseModelSelect: ... class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery, Generic[_TModel]): - model: Type[_TModel] - def __init__(self, model: Type[_TModel], *args: object, **kwargs: object): ... + model: type[_TModel] + def __init__(self, model: type[_TModel], *args: object, **kwargs: object): ... class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): - model: Type[_TModel] - def __init__(self, model: Type[_TModel], fields_or_models: Iterable[_TFieldOrModel], is_default: bool = ...): ... + model: type[_TModel] + def __init__(self, model: type[_TModel], fields_or_models: Iterable[_TFieldOrModel], is_default: bool = ...): ... def clone(self) -> ModelSelect: ... def select(self, *fields_or_models: _TFieldOrModel) -> ModelSelect: ... # type: ignore - def switch(self, ctx: Optional[Type[Model]] = ...) -> ModelSelect: ... + def switch(self, ctx: type[Model] | None = ...) -> ModelSelect: ... def join( # type: ignore self, - dest: Union[Type[Model], Table, ModelAlias, ModelSelect], + dest: type[Model] | Table | ModelAlias | ModelSelect, join_type: int = ..., - on: Optional[Union[Column, Expression, Field]] = ..., - src: Optional[Union[Type[Model], Table, ModelAlias, ModelSelect]] = ..., - attr: Optional[str] = ..., + on: Column | Expression | Field | None = ..., + src: type[Model] | Table | ModelAlias | ModelSelect | None = ..., + attr: str | None = ..., ) -> ModelSelect: ... def join_from( self, - src: Union[Type[Model], Table, ModelAlias, ModelSelect], - dest: Union[Type[Model], Table, ModelAlias, ModelSelect], + src: type[Model] | Table | ModelAlias | ModelSelect, + dest: type[Model] | Table | ModelAlias | ModelSelect, join_type: int = ..., - on: Optional[Union[Column, Expression, Field]] = ..., - attr: Optional[str] = ..., + on: Column | Expression | Field | None = ..., + attr: str | None = ..., ) -> ModelSelect: ... def ensure_join( - self, lm: Type[Model], rm: Type[Model], on: Optional[Union[Column, Expression, Field]] = ..., **join_kwargs: Any + self, lm: type[Model], rm: type[Model], on: Column | Expression | Field | None = ..., **join_kwargs: Any ) -> ModelSelect: ... # TODO (dargueta): 85% sure about the return value - def convert_dict_to_node(self, qdict: Mapping[str, object]) -> Tuple[List[Expression], List[Field]]: ... + def convert_dict_to_node(self, qdict: Mapping[str, object]) -> tuple[list[Expression], list[Field]]: ... def filter(self, *args: Node, **kwargs: object) -> ModelSelect: ... def create_table(self, name: str, safe: bool = ..., **meta: object) -> None: ... def __sql_selection__(self, ctx: Context, is_subquery: bool = ...) -> Context: ... @@ -2001,15 +2054,15 @@ class NoopModelSelect(ModelSelect): def __sql__(self, ctx: Context) -> Context: ... class _ModelWriteQueryHelper(_ModelQueryHelper): - model: Type[Model] - def __init__(self, model: Type[Model], *args: object, **kwargs: object): ... - def returning(self, *returning: Union[Type[Model], Field]) -> _ModelWriteQueryHelper: ... + model: type[Model] + def __init__(self, model: type[Model], *args: object, **kwargs: object): ... + def returning(self, *returning: type[Model] | Field) -> _ModelWriteQueryHelper: ... class ModelUpdate(_ModelWriteQueryHelper, Update): ... class ModelInsert(_ModelWriteQueryHelper, Insert): default_row_type: ClassVar[int] - def returning(self, *returning: Union[Type[Model], Field]) -> ModelInsert: ... + def returning(self, *returning: type[Model] | Field) -> ModelInsert: ... def get_default_data(self) -> Mapping[str, object]: ... def get_default_columns(self) -> Sequence[Field]: ... @@ -2019,22 +2072,22 @@ class ManyToManyQuery(ModelSelect): def __init__( self, instance: Model, accessor: ManyToManyFieldAccessor, rel: _TFieldOrModel, *args: object, **kwargs: object ): ... - def add(self, value: Union[SelectQuery, Type[Model], Iterable[str]], clear_existing: bool = ...) -> None: ... - def remove(self, value: Union[SelectQuery, Type[Model], Iterable[str]]) -> Optional[int]: ... + def add(self, value: SelectQuery | type[Model] | Iterable[str], clear_existing: bool = ...) -> None: ... + def remove(self, value: SelectQuery | type[Model] | Iterable[str]) -> int | None: ... def clear(self) -> int: ... class BaseModelCursorWrapper(DictCursorWrapper, Generic[_TModel]): ncols: int - columns: List[str] - converters: List[_TConvFunc] - fields: List[Field] - model: Type[_TModel] + columns: list[str] + converters: list[_TConvFunc] + fields: list[Field] + model: type[_TModel] select: Sequence[str] - def __init__(self, cursor: __ICursor, model: Type[_TModel], columns: Optional[Sequence[str]]): ... + def __init__(self, cursor: __ICursor, model: type[_TModel], columns: Sequence[str] | None): ... def process_row(self, row: tuple) -> Mapping[str, object]: ... # type: ignore class ModelDictCursorWrapper(BaseModelCursorWrapper[_TModel]): - def process_row(self, row: tuple) -> Dict[str, Any]: ... + def process_row(self, row: tuple) -> dict[str, Any]: ... class ModelTupleCursorWrapper(ModelDictCursorWrapper[_TModel]): constructor: ClassVar[Callable[[Sequence[Any]], tuple]] @@ -2043,7 +2096,7 @@ class ModelTupleCursorWrapper(ModelDictCursorWrapper[_TModel]): class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper[_TModel]): ... class ModelObjectCursorWrapper(ModelDictCursorWrapper[_TModel]): - constructor: Union[Type[_TModel], Callable[[Any], _TModel]] + constructor: type[_TModel] | Callable[[Any], _TModel] is_model: bool # TODO (dargueta): `select` is some kind of Sequence def __init__( @@ -2051,41 +2104,41 @@ class ModelObjectCursorWrapper(ModelDictCursorWrapper[_TModel]): cursor: __ICursor, model: _TModel, select: Sequence[Any], # incomplete - constructor: Union[Type[_TModel], Callable[[Any], _TModel]], + constructor: type[_TModel] | Callable[[Any], _TModel], ): ... def process_row(self, row: tuple) -> _TModel: ... # type: ignore class ModelCursorWrapper(BaseModelCursorWrapper[_TModel]): # TODO (dargueta) -- Iterable[Union[Join, ...]] from_list: Iterable[Any] # incomplete - # TODO (dargueta) -- Mapping[, Tuple[?, ?, Callable[..., _TModel], int?]] - joins: Mapping[Hashable, Tuple[object, object, Callable[..., _TModel], int]] - key_to_constructor: Dict[Type[_TModel], Callable[..., _TModel]] - src_is_dest: Dict[Type[Model], bool] - src_to_dest: List[tuple] # TODO -- Tuple[, join_type[1], join_type[0], bool, join_type[3]] - column_keys: List # incomplete + # TODO (dargueta) -- Mapping[, tuple[?, ?, Callable[..., _TModel], int?]] + joins: Mapping[Hashable, tuple[object, object, Callable[..., _TModel], int]] + key_to_constructor: dict[type[_TModel], Callable[..., _TModel]] + src_is_dest: dict[type[Model], bool] + src_to_dest: list[tuple] # TODO -- tuple[, join_type[1], join_type[0], bool, join_type[3]] + column_keys: list # incomplete def __init__( self, cursor: __ICursor, - model: Type[_TModel], + model: type[_TModel], select, from_list: Iterable[object], - joins: Mapping[Hashable, Tuple[object, object, Callable[..., _TModel], int]], + joins: Mapping[Hashable, tuple[object, object, Callable[..., _TModel], int]], ): ... def initialize(self) -> None: ... def process_row(self, row: tuple) -> _TModel: ... # type: ignore class __PrefetchQuery(NamedTuple): query: Query # TODO (dargueta): Verify - fields: Optional[Sequence[Field]] - is_backref: Optional[bool] - rel_models: Optional[List[Type[Model]]] - field_to_name: Optional[List[Tuple[Field, str]]] - model: Type[Model] + fields: Sequence[Field] | None + is_backref: bool | None + rel_models: list[type[Model]] | None + field_to_name: list[tuple[Field, str]] | None + model: type[Model] class PrefetchQuery(__PrefetchQuery): - def populate_instance(self, instance: Model, id_map: Mapping[Tuple[object, object], object]): ... - def store_instance(self, instance: Model, id_map: MutableMapping[Tuple[object, object], List[Model]]) -> None: ... + def populate_instance(self, instance: Model, id_map: Mapping[tuple[object, object], object]): ... + def store_instance(self, instance: Model, id_map: MutableMapping[tuple[object, object], list[Model]]) -> None: ... -def prefetch_add_subquery(sq: Query, subqueries: Iterable[_TSubquery]) -> List[PrefetchQuery]: ... -def prefetch(sq: Query, *subqueries: _TSubquery) -> List[object]: ... +def prefetch_add_subquery(sq: Query, subqueries: Iterable[_TSubquery]) -> list[PrefetchQuery]: ... +def prefetch(sq: Query, *subqueries: _TSubquery) -> list[object]: ... From d2a95564fdbb3ac4a1fb34ac391f527148641dc8 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Mon, 18 Apr 2022 21:39:57 +0100 Subject: [PATCH 10/22] Add to stricter-pyright ignorelist --- pyrightconfig.stricter.json | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrightconfig.stricter.json b/pyrightconfig.stricter.json index 55eea79bf374..5ac27e5b5ff2 100644 --- a/pyrightconfig.stricter.json +++ b/pyrightconfig.stricter.json @@ -46,6 +46,7 @@ "stubs/openpyxl", "stubs/Pillow", "stubs/paramiko", + "stubs/peewee", "stubs/prettytable", "stubs/protobuf", "stubs/google-cloud-ndb", From 8adc6b0f18c8780784d416b07e75265934bf18a6 Mon Sep 17 00:00:00 2001 From: Akuli Date: Mon, 25 Apr 2022 17:10:45 +0300 Subject: [PATCH 11/22] Fix a few obvious-ish things --- stubs/peewee/peewee.pyi | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 937e16545b55..db3a1ae5b8f3 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -1,7 +1,6 @@ import datetime import decimal import enum -import re import threading import uuid from _typeshed import Self @@ -21,6 +20,7 @@ from typing import ( MutableSet, NamedTuple, NoReturn, + Pattern, Protocol, Sequence, Text, @@ -180,8 +180,8 @@ CSQ_PARENTHESES_NEVER: int CSQ_PARENTHESES_ALWAYS: int CSQ_PARENTHESES_UNNESTED: int -SNAKE_CASE_STEP1: re.Pattern -SNAKE_CASE_STEP2: re.Pattern +SNAKE_CASE_STEP1: Pattern +SNAKE_CASE_STEP2: Pattern MODEL_BASE: str @@ -349,7 +349,8 @@ class _BoundTableContext(_callable_context_manager): class Table(_HashableSource, BaseTable): __name__: str - c: _ExplicitColumn + # mypy doesn't seem to understand very well how descriptors work + c: _ExplicitColumn # type: ignore[misc, assignment] primary_key: Field | CompositeKey | None def __init__( self, @@ -440,8 +441,8 @@ class ColumnBase(Node): def __rand__(self, other: object) -> Expression: ... def __ror__(self, other: object) -> Expression: ... def __rxor__(self, other: object) -> Expression: ... - def __eq__(self, rhs: object) -> Expression: ... - def __ne__(self, rhs: object) -> Expression: ... + def __eq__(self, rhs: object) -> Expression: ... # type: ignore[override] + def __ne__(self, rhs: object) -> Expression: ... # type: ignore[override] def __lt__(self, other: object) -> Expression: ... def __le__(self, other: object) -> Expression: ... def __gt__(self, other: object) -> Expression: ... @@ -538,7 +539,7 @@ class Ordering(WrappedNode): collation: str | None nulls: str | None def __init__(self, node: Node, direction: str, collation: str | None = ..., nulls: str | None = ...): ... - def collate(self, collation: str | None = ...) -> Ordering: ... + def collate(self, collation: str | None = ...) -> Ordering: ... # type: ignore[override] def __sql__(self, ctx: Context) -> Context: ... class _SupportsSQLOrdering(Protocol): @@ -1038,7 +1039,7 @@ class Database(_callable_context_manager): autoconnect: bool = ..., **kwargs: object, ): ... - def init(self, database: __IConnection, **kwargs: object) -> None: ... + def init(self, database: __IConnection, **kwargs: Any) -> None: ... def __enter__(self: Self) -> Self: ... def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... def connection_context(self) -> ConnectionContext: ... @@ -1184,7 +1185,7 @@ class PostgresqlDatabase(Database): sequences: ClassVar[bool] # Instance variables server_version: int - def init( + def init( # type: ignore[override] self, database: __IConnection, register_unicode: bool = ..., From 12f0e6f62f5e11a3066836ee3585b5b98a0500cd Mon Sep 17 00:00:00 2001 From: Akuli Date: Mon, 25 Apr 2022 17:46:28 +0300 Subject: [PATCH 12/22] use _typeshed.Self when appropriate --- stubs/peewee/peewee.pyi | 148 ++++++++++++++++++++-------------------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index db3a1ae5b8f3..09802edfb78c 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -246,7 +246,7 @@ class Context: def parentheses(self) -> bool: ... @property def subquery(self) -> Any: ... # TODO (dargueta): Figure out type of "self.state.subquery" - def __call__(self, **overrides: object) -> Context: ... + def __call__(self: Self, **overrides: object) -> Self: ... def scope_normal(self) -> ContextManager[Context]: ... def scope_source(self) -> ContextManager[Context]: ... def scope_values(self) -> ContextManager[Context]: ... @@ -258,7 +258,7 @@ class Context: def push_alias(self) -> Iterator[None]: ... # TODO (dargueta): Is this right? def sql(self, obj: object) -> Context: ... - def literal(self, keyword: str) -> Context: ... + def literal(self: Self, keyword: str) -> Self: ... def value(self, value: object, converter: _TConvFunc | None = ..., add_param: bool = ...) -> Context: ... def __sql__(self, ctx: Context) -> Context: ... def parse(self, node: Node) -> tuple[str, tuple | None]: ... @@ -272,9 +272,9 @@ class Node: # FIXME (dargueta): Is there a way to make this a proper decorator? @staticmethod def copy(method: _TFunc) -> _TFunc: ... - def coerce(self, _coerce: bool = ...) -> Node: ... + def coerce(self: Self, _coerce: bool = ...) -> Self: ... def is_alias(self) -> bool: ... - def unwrap(self) -> Node: ... + def unwrap(self: Self) -> Self: ... class ColumnFactory: node: Node @@ -362,8 +362,8 @@ class Table(_HashableSource, BaseTable): _model: type[Model] | None = ..., _database: Database | None = ..., ): ... - def clone(self) -> Table: ... - def bind(self, database: Database | None = ...) -> Table: ... + def clone(self: Self) -> Self: ... + def bind(self: Self, database: Database | None = ...) -> Self: ... def bind_ctx(self, database: Database | None = ...) -> _BoundTableContext: ... def select(self, *columns: Column) -> Select: ... @overload @@ -383,13 +383,13 @@ class Join(BaseTable): rhs: Any # TODO (dargueta) join_type: int def __init__(self, lhs, rhs, join_type: int = ..., on: Expression | None = ..., alias: str | None = ...): ... - def on(self, predicate: Expression) -> Join: ... + def on(self: Self, predicate: Expression) -> Self: ... def __sql__(self, ctx: Context) -> Context: ... class ValuesList(_HashableSource, BaseTable): def __init__(self, values, columns=..., alias: str | None = ...): ... # incomplete # FIXME (dargueta) `names` might be wrong - def columns(self, *names: str) -> ValuesList: ... + def columns(self: Self, *names: str) -> Self: ... def __sql__(self, ctx: Context) -> Context: ... class CTE(_HashableSource, Source): @@ -412,12 +412,12 @@ class CTE(_HashableSource, Source): class ColumnBase(Node): _converter: _TConvFunc | None - def converter(self, converter: _TConvFunc | None = ...) -> ColumnBase: ... + def converter(self: Self, converter: _TConvFunc | None = ...) -> Self: ... @overload - def alias(self, alias: None) -> ColumnBase: ... + def alias(self: Self, alias: None) -> Self: ... @overload def alias(self, alias: str) -> Alias: ... - def unalias(self) -> ColumnBase: ... + def unalias(self: Self) -> Self: ... def cast(self, as_type: str) -> Cast: ... def asc(self, collation: str | None = ..., nulls: str | None = ...) -> _SupportsSQLOrdering: ... __pos__ = asc @@ -504,7 +504,7 @@ class Alias(WrappedNode): @overload def alias(self, alias: str) -> Alias: ... def unalias(self) -> Node: ... - def is_alias(self) -> bool: ... + def is_alias(self) -> Literal[True]: ... def __sql__(self, ctx: Context) -> Context: ... class Negated(WrappedNode): @@ -581,9 +581,9 @@ class Function(ColumnBase): def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: _TConvFunc | None = ...): ... def __getattr__(self, attr: str) -> Callable[..., Function]: ... # TODO (dargueta): `where` is an educated guess - def filter(self, where: Expression | None = ...) -> Function: ... - def order_by(self, *ordering: Field | Expression) -> Function: ... - def python_value(self, func: _TConvFunc | None = ...) -> Function: ... + def filter(self: Self, where: Expression | None = ...) -> Self: ... + def order_by(self: Self, *ordering: Field | Expression) -> Self: ... + def python_value(self: Self, func: _TConvFunc | None = ...) -> Self: ... def over( self, partition_by: Sequence[Field] | Window | None = ..., @@ -638,12 +638,12 @@ class Window(Node): alias: str | None = ..., _inline: bool = ..., ): ... - def alias(self, alias: str | None = ...) -> Window: ... - def as_range(self) -> Window: ... - def as_rows(self) -> Window: ... - def as_groups(self) -> Window: ... - def extends(self, window: Window | WindowAlias | str | None = ...) -> Window: ... - def exclude(self, frame_exclusion: str | SQL | None = ...) -> Window: ... + def alias(self: Self, alias: str | None = ...) -> Self: ... + def as_range(self: Self) -> Self: ... + def as_rows(self: Self) -> Self: ... + def as_groups(self: Self) -> Self: ... + def extends(self: Self, window: Window | WindowAlias | str | None = ...) -> Self: ... + def exclude(self: Self, frame_exclusion: str | SQL | None = ...) -> Self: ... @staticmethod def following(value: int | None = ...) -> SQL: ... @staticmethod @@ -653,7 +653,7 @@ class Window(Node): class WindowAlias(Node): window: Window def __init__(self, window: Window): ... - def alias(self, window_alias: str) -> WindowAlias: ... + def alias(self: Self, window_alias: str) -> Self: ... def __sql__(self, ctx: Context) -> Context: ... class ForUpdate(Node): @@ -694,8 +694,8 @@ class DQ(ColumnBase): # TODO (dargueta): Narrow this down? def __init__(self, **query: object): ... - def __invert__(self) -> DQ: ... - def clone(self) -> DQ: ... + def __invert__(self: Self) -> Self: ... + def clone(self: Self) -> Self: ... class QualifiedNames(WrappedNode): def __sql__(self, ctx: Context) -> Context: ... @@ -733,23 +733,23 @@ class OnConflict(Node): # undocumented def get_conflict_statement(self, ctx: Context, query: Query) -> SQL | None: ... def get_conflict_update(self, ctx: Context, query: Query) -> NodeList: ... - def preserve(self, *columns: Column) -> OnConflict: ... + def preserve(self: Self, *columns: Column) -> Self: ... # Despite the argument name `_data` is documented - def update(self, _data: Mapping[str, object] | None = ..., **kwargs: object) -> OnConflict: ... - def where(self, *expressions: Expression) -> OnConflict: ... - def conflict_target(self, *constraints: Column) -> OnConflict: ... - def conflict_where(self, *expressions: Expression) -> OnConflict: ... - def conflict_constraint(self, constraint: str) -> OnConflict: ... + def update(self: Self, _data: Mapping[str, object] | None = ..., **kwargs: object) -> Self: ... + def where(self: Self, *expressions: Expression) -> Self: ... + def conflict_target(self: Self, *constraints: Column) -> Self: ... + def conflict_where(self: Self, *expressions: Expression) -> Self: ... + def conflict_constraint(self: Self, constraint: str) -> Self: ... class BaseQuery(Node): default_row_type: ClassVar[int] def __init__(self, _database: Database | None = ..., **kwargs: object): ... - def bind(self, database: Database | None = ...) -> BaseQuery: ... - def clone(self) -> BaseQuery: ... - def dicts(self, as_dict: bool = ...) -> BaseQuery: ... - def tuples(self, as_tuple: bool = ...) -> BaseQuery: ... - def namedtuples(self, as_namedtuple: bool = ...) -> BaseQuery: ... - def objects(self, constructor: _TConvFunc | None = ...) -> BaseQuery: ... + def bind(self: Self, database: Database | None = ...) -> Self: ... + def clone(self: Self) -> Self: ... + def dicts(self: Self, as_dict: bool = ...) -> Self: ... + def tuples(self: Self, as_tuple: bool = ...) -> Self: ... + def namedtuples(self: Self, as_namedtuple: bool = ...) -> Self: ... + def objects(self: Self, constructor: _TConvFunc | None = ...) -> Self: ... def __sql__(self, ctx: Context) -> Context: ... def sql(self) -> tuple[str, tuple | None]: ... def execute(self, database: Database | None = ...) -> CursorWrapper: ... @@ -777,14 +777,14 @@ class Query(BaseQuery): offset: int | None = ..., **kwargs: object, ): ... - def with_cte(self, *cte_list: CTE) -> Query: ... - def where(self, *expressions: Expression) -> Query: ... - def orwhere(self, *expressions: Expression) -> Query: ... - def order_by(self, *values: Node) -> Query: ... - def order_by_extend(self, *values: Node) -> Query: ... - def limit(self, value: int | None = ...) -> Query: ... - def offset(self, value: int | None = ...) -> Query: ... - def paginate(self, page: int, paginate_by: int = ...) -> Query: ... + def with_cte(self: Self, *cte_list: CTE) -> Self: ... + def where(self: Self, *expressions: Expression) -> Self: ... + def orwhere(self: Self, *expressions: Expression) -> Self: ... + def order_by(self: Self, *values: Node) -> Self: ... + def order_by_extend(self: Self, *values: Node) -> Self: ... + def limit(self: Self, value: int | None = ...) -> Self: ... + def offset(self: Self, value: int | None = ...) -> Self: ... + def paginate(self: Self, page: int, paginate_by: int = ...) -> Self: ... def _apply_ordering(self, ctx: Context) -> Context: ... def __sql__(self, ctx: Context) -> Context: ... @@ -849,31 +849,31 @@ class Select(SelectBase): lateral: bool | None = ..., # undocumented **kwargs: object, ): ... - def clone(self) -> Select: ... + def clone(self: Self) -> Self: ... # TODO (dargueta) `Field` might be wrong in this union - def columns(self, *columns: Column | Field, **kwargs: object) -> Select: ... - def select(self, *columns: Column | Field, **kwargs: object) -> Select: ... - def select_extend(self, *columns) -> Select: ... + def columns(self: Self, *columns: Column | Field, **kwargs: object) -> Self: ... + select = columns + def select_extend(self: Self, *columns) -> Self: ... # TODO (dargueta): Is `sources` right? - def from_(self, *sources: Source | type[Model]) -> Select: ... - def join(self, dest: type[Model], join_type: int = ..., on: Expression | None = ...) -> Select: ... - def group_by(self, *columns: Table | Field) -> Select: ... - def group_by_extend(self, *values: Table | Field) -> Select: ... - def having(self, *expressions: Expression) -> Select: ... + def from_(self: Self, *sources: Source | type[Model]) -> Self: ... + def join(self: Self, dest: type[Model], join_type: int = ..., on: Expression | None = ...) -> Self: ... + def group_by(self: Self, *columns: Table | Field) -> Self: ... + def group_by_extend(self: Self, *values: Table | Field) -> Self: ... + def having(self: Self, *expressions: Expression) -> Self: ... @overload - def distinct(self, _: bool) -> Select: ... + def distinct(self: Self, _: bool) -> Self: ... @overload - def distinct(self, *columns: Field) -> Select: ... - def window(self, *windows: Window) -> Select: ... + def distinct(self: Self, *columns: Field) -> Self: ... + def window(self: Self, *windows: Window) -> Self: ... def for_update( - self, for_update: bool = ..., of: Table | Iterable[Table] | None = ..., nowait: bool | None = ... - ) -> Select: ... - def lateral(self, lateral: bool = ...) -> Select: ... + self: Self, for_update: bool = ..., of: Table | Iterable[Table] | None = ..., nowait: bool | None = ... + ) -> Self: ... + def lateral(self: Self, lateral: bool = ...) -> Self: ... class _WriteQuery(Query): table: Table def __init__(self, table: Table, returning: Iterable[type[Model] | Field] | None = ..., **kwargs: object): ... - def returning(self, *returning: type[Model] | Field) -> _WriteQuery: ... + def returning(self: Self, *returning: type[Model] | Field) -> Self: ... def apply_returning(self, ctx: Context) -> Context: ... def execute_returning(self, database: Database) -> CursorWrapper: ... def handle_result(self, database: Database, cursor: __ICursor) -> int | __ICursor: ... @@ -882,7 +882,7 @@ class _WriteQuery(Query): class Update(_WriteQuery): # TODO (dargueta): `update` def __init__(self, table: Table, update: Any | None = ..., **kwargs: object): ... - def from_(self, *sources) -> Update: ... + def from_(self: Self, *sources) -> Self: ... def __sql__(self, ctx: Context) -> Context: ... class Insert(_WriteQuery): @@ -899,9 +899,9 @@ class Insert(_WriteQuery): **kwargs: object, ): ... def where(self, *expressions: Expression) -> NoReturn: ... - def on_conflict_ignore(self, ignore: bool = ...) -> Insert: ... - def on_conflict_replace(self, replace: bool = ...) -> Insert: ... - def on_conflict(self, *args, **kwargs) -> Insert: ... + def on_conflict_ignore(self: Self, ignore: bool = ...) -> Self: ... + def on_conflict_replace(self: Self, replace: bool = ...) -> Self: ... + def on_conflict(self: Self, *args, **kwargs) -> Self: ... def get_default_data(self) -> Mapping[str, object]: ... def get_default_columns(self) -> list[Field] | None: ... def __sql__(self, ctx: Context) -> Context: ... @@ -921,9 +921,9 @@ class Index(Node): where: Expression | None = ..., using: str | None = ..., ): ... - def safe(self, _safe: bool = ...) -> Index: ... - def where(self, *expressions: Expression) -> Index: ... - def using(self, _using: str | None = ...) -> Index: ... + def safe(self: Self, _safe: bool = ...) -> Self: ... + def where(self: Self, *expressions: Expression) -> Self: ... + def using(self: Self, _using: str | None = ...) -> Self: ... def __sql__(self, ctx: Context) -> Context: ... class ModelIndex(Index): @@ -1995,7 +1995,7 @@ def sort_models(models: Iterable[type[Model]]) -> list[type[Model]]: ... class _ModelQueryHelper: default_row_type: ClassVar[int] - def objects(self, constructor: Callable[..., Any] | None = ...) -> _ModelQueryHelper: ... + def objects(self: Self, constructor: Callable[..., Any] | None = ...) -> Self: ... class ModelRaw(_ModelQueryHelper, RawQuery, Generic[_TModel]): model: type[_TModel] @@ -2014,7 +2014,7 @@ class BaseModelSelect(_ModelQueryHelper): def __iter__(self) -> Iterator[Any]: ... def prefetch(self, *subqueries: _TSubquery) -> list[Any]: ... def get(self, database: Database | None = ...) -> Any: ... - def group_by(self, *columns: type[Model] | Table | Field) -> BaseModelSelect: ... + def group_by(self: Self, *columns: type[Model] | Table | Field) -> Self: ... class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery, Generic[_TModel]): model: type[_TModel] @@ -2023,17 +2023,17 @@ class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery, Generic[_TM class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): model: type[_TModel] def __init__(self, model: type[_TModel], fields_or_models: Iterable[_TFieldOrModel], is_default: bool = ...): ... - def clone(self) -> ModelSelect: ... - def select(self, *fields_or_models: _TFieldOrModel) -> ModelSelect: ... # type: ignore - def switch(self, ctx: type[Model] | None = ...) -> ModelSelect: ... + def clone(self: Self) -> Self: ... + def select(self: Self, *fields_or_models: _TFieldOrModel) -> Self: ... # type: ignore + def switch(self: Self, ctx: type[Model] | None = ...) -> Self: ... def join( # type: ignore - self, + self: Self, dest: type[Model] | Table | ModelAlias | ModelSelect, join_type: int = ..., on: Column | Expression | Field | None = ..., src: type[Model] | Table | ModelAlias | ModelSelect | None = ..., attr: str | None = ..., - ) -> ModelSelect: ... + ) -> Self: ... def join_from( self, src: type[Model] | Table | ModelAlias | ModelSelect, From 9d44ac6c150d31e2a57cead2f92210c0043f6a3d Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Thu, 19 May 2022 11:20:32 +0100 Subject: [PATCH 13/22] More lint fixes --- stubs/peewee/METADATA.toml | 1 - stubs/peewee/peewee.pyi | 39 +++++++++----------------------------- 2 files changed, 9 insertions(+), 31 deletions(-) diff --git a/stubs/peewee/METADATA.toml b/stubs/peewee/METADATA.toml index 2ce4d86a8441..719172f4da80 100644 --- a/stubs/peewee/METADATA.toml +++ b/stubs/peewee/METADATA.toml @@ -1,2 +1 @@ version = "3.14.10.*" -python2 = true diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 09802edfb78c..5fb4c698c54f 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -4,30 +4,9 @@ import enum import threading import uuid from _typeshed import Self -from typing import ( - Any, - AnyStr, - Callable, - ClassVar, - Container, - ContextManager, - Generic, - Hashable, - Iterable, - Iterator, - Mapping, - MutableMapping, - MutableSet, - NamedTuple, - NoReturn, - Pattern, - Protocol, - Sequence, - Text, - TypeVar, - Union, - overload, -) +from collections.abc import Callable, Container, Hashable, Iterable, Iterator, Mapping, MutableMapping, MutableSet, Sequence +from contextlib import AbstractContextManager +from typing import Any, AnyStr, ClassVar, Generic, NamedTuple, NoReturn, Pattern, Protocol, TypeVar, Union, overload from typing_extensions import Literal, TypeAlias _T = TypeVar("_T") @@ -247,11 +226,11 @@ class Context: @property def subquery(self) -> Any: ... # TODO (dargueta): Figure out type of "self.state.subquery" def __call__(self: Self, **overrides: object) -> Self: ... - def scope_normal(self) -> ContextManager[Context]: ... - def scope_source(self) -> ContextManager[Context]: ... - def scope_values(self) -> ContextManager[Context]: ... - def scope_cte(self) -> ContextManager[Context]: ... - def scope_column(self) -> ContextManager[Context]: ... + def scope_normal(self) -> AbstractContextManager[Context]: ... + def scope_source(self) -> AbstractContextManager[Context]: ... + def scope_values(self) -> AbstractContextManager[Context]: ... + def scope_cte(self) -> AbstractContextManager[Context]: ... + def scope_column(self) -> AbstractContextManager[Context]: ... def __enter__(self: Self) -> Self: ... def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... # @contextmanager @@ -1892,7 +1871,7 @@ class Model(Node, metaclass=ModelBase): @classmethod def insert_many(cls, rows: Iterable[tuple], fields: Sequence[Field]) -> ModelInsert: ... @classmethod - def insert_from(cls, query: SelectQuery, fields: Iterable[Field | Text]) -> ModelInsert: ... + def insert_from(cls, query: SelectQuery, fields: Iterable[Field | str]) -> ModelInsert: ... @classmethod def replace(cls, __data: Iterable[str | Field] | None = ..., **insert: object) -> OnConflict: ... @classmethod From 9a7c5f967207e26d86f82354600c696c24524709 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 15 Jun 2022 18:12:09 +0100 Subject: [PATCH 14/22] Swap a few 'incomplete' comments for `_typeshed.Incomplete` --- stubs/peewee/peewee.pyi | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 5fb4c698c54f..88c809f751d8 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -3,7 +3,7 @@ import decimal import enum import threading import uuid -from _typeshed import Self +from _typeshed import Incomplete, Self from collections.abc import Callable, Container, Hashable, Iterable, Iterator, Mapping, MutableMapping, MutableSet, Sequence from contextlib import AbstractContextManager from typing import Any, AnyStr, ClassVar, Generic, NamedTuple, NoReturn, Pattern, Protocol, TypeVar, Union, overload @@ -590,7 +590,7 @@ class Window(Node): order_by: tuple[Field | Expression, ...] start: str | SQL | None end: str | SQL | None - frame_type: Any | None # incomplete + frame_type: Incomplete | None @overload def __init__( self, @@ -648,14 +648,14 @@ def Case(predicate: Node | None, expression_tuples: Iterable[tuple[Expression, A class NodeList(ColumnBase): # TODO (dargueta): Narrow this type - nodes: Sequence[Any] # incomplete + nodes: Sequence[Incomplete] glue: str parens: bool - def __init__(self, nodes: Sequence[Any], glue: str = ..., parens: bool = ...): ... # incomplete + def __init__(self, nodes: Sequence[Incomplete], glue: str = ..., parens: bool = ...): ... def __sql__(self, ctx: Context) -> Context: ... -def CommaNodeList(nodes: Sequence[Any]) -> NodeList: ... # incomplete -def EnclosedNodeList(nodes: Sequence[Any]) -> NodeList: ... # incomplete +def CommaNodeList(nodes: Sequence[Incomplete]) -> NodeList: ... +def EnclosedNodeList(nodes: Sequence[Incomplete]) -> NodeList: ... class _Namespace(Node): def __init__(self, name: str): ... @@ -2083,20 +2083,20 @@ class ModelObjectCursorWrapper(ModelDictCursorWrapper[_TModel]): self, cursor: __ICursor, model: _TModel, - select: Sequence[Any], # incomplete + select: Sequence[Incomplete], constructor: type[_TModel] | Callable[[Any], _TModel], ): ... def process_row(self, row: tuple) -> _TModel: ... # type: ignore class ModelCursorWrapper(BaseModelCursorWrapper[_TModel]): # TODO (dargueta) -- Iterable[Union[Join, ...]] - from_list: Iterable[Any] # incomplete + from_list: Iterable[Incomplete] # TODO (dargueta) -- Mapping[, tuple[?, ?, Callable[..., _TModel], int?]] joins: Mapping[Hashable, tuple[object, object, Callable[..., _TModel], int]] key_to_constructor: dict[type[_TModel], Callable[..., _TModel]] src_is_dest: dict[type[Model], bool] src_to_dest: list[tuple] # TODO -- tuple[, join_type[1], join_type[0], bool, join_type[3]] - column_keys: list # incomplete + column_keys: list[Incomplete] def __init__( self, cursor: __ICursor, From 5011f56c4ab9b967ae057bfff02e3b46cd1fb8b6 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 15 Jun 2022 18:20:46 +0100 Subject: [PATCH 15/22] Fix `TFunc`-related issues --- stubs/peewee/peewee.pyi | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 88c809f751d8..63db15e19875 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -7,16 +7,17 @@ from _typeshed import Incomplete, Self from collections.abc import Callable, Container, Hashable, Iterable, Iterator, Mapping, MutableMapping, MutableSet, Sequence from contextlib import AbstractContextManager from typing import Any, AnyStr, ClassVar, Generic, NamedTuple, NoReturn, Pattern, Protocol, TypeVar, Union, overload -from typing_extensions import Literal, TypeAlias +from typing_extensions import Concatenate, Literal, ParamSpec, TypeAlias _T = TypeVar("_T") _TModel = TypeVar("_TModel", bound=Model) _TConvFunc = Callable[[Any], Any] -_TFunc = TypeVar("_TFunc", bound=Callable) +_TFunc = TypeVar("_TFunc", bound=Callable[..., Any]) _TClass = TypeVar("_TClass", bound=type) _TContextClass = TypeVar("_TContextClass", bound=Context) _TField = TypeVar("_TField", bound=Field) _TNode = TypeVar("_TNode", bound=Node) +_P = ParamSpec("_P") __version__: str @@ -250,7 +251,7 @@ class Node: def __sql__(self, ctx: Context) -> Context: ... # FIXME (dargueta): Is there a way to make this a proper decorator? @staticmethod - def copy(method: _TFunc) -> _TFunc: ... + def copy(method: Callable[_P, Any]) -> Callable[Concatenate[_T, _P], _T]: ... def coerce(self: Self, _coerce: bool = ...) -> Self: ... def is_alias(self) -> bool: ... def unwrap(self: Self) -> Self: ... From 719dd627dc513e3e5c7d78bd14a6ba459be78efb Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 15 Jun 2022 18:25:53 +0100 Subject: [PATCH 16/22] Nit: Rename several `TypeVar`s to use suffixes instead of prefixes --- stubs/peewee/peewee.pyi | 151 ++++++++++++++++++++-------------------- 1 file changed, 76 insertions(+), 75 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 63db15e19875..192495995fcf 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -10,15 +10,16 @@ from typing import Any, AnyStr, ClassVar, Generic, NamedTuple, NoReturn, Pattern from typing_extensions import Concatenate, Literal, ParamSpec, TypeAlias _T = TypeVar("_T") -_TModel = TypeVar("_TModel", bound=Model) -_TConvFunc = Callable[[Any], Any] -_TFunc = TypeVar("_TFunc", bound=Callable[..., Any]) -_TClass = TypeVar("_TClass", bound=type) -_TContextClass = TypeVar("_TContextClass", bound=Context) -_TField = TypeVar("_TField", bound=Field) -_TNode = TypeVar("_TNode", bound=Node) +_ModelT = TypeVar("_ModelT", bound=Model) +_FuncT = TypeVar("_FuncT", bound=Callable[..., Any]) +_ClassT = TypeVar("_ClassT", bound=type) +_ContextClassT = TypeVar("_ContextClassT", bound=Context) +_FieldT = TypeVar("_FieldT", bound=Field) +_NodeT = TypeVar("_NodeT", bound=Node) _P = ParamSpec("_P") +_ConvFunc: TypeAlias = Callable[[Any], Any] + __version__: str __all__ = [ @@ -167,14 +168,14 @@ MODEL_BASE: str # TODO (dargueta) class _callable_context_manager: - def __call__(self, fn: _TFunc) -> _TFunc: ... + def __call__(self, fn: _FuncT) -> _FuncT: ... class Proxy: obj: Any def initialize(self, obj: object) -> None: ... - def attach_callback(self, callback: _TConvFunc) -> _TConvFunc: ... + def attach_callback(self, callback: _ConvFunc) -> _ConvFunc: ... @staticmethod # This is technically inaccurate but that's how it's used - def passthrough(method: _TFunc) -> _TFunc: ... + def passthrough(method: _FuncT) -> _FuncT: ... def __enter__(self) -> Any: ... def __exit__(self, exc_type, exc_val, exc_tb) -> Any: ... def __getattr__(self, attr: str) -> Any: ... @@ -239,7 +240,7 @@ class Context: # TODO (dargueta): Is this right? def sql(self, obj: object) -> Context: ... def literal(self: Self, keyword: str) -> Self: ... - def value(self, value: object, converter: _TConvFunc | None = ..., add_param: bool = ...) -> Context: ... + def value(self, value: object, converter: _ConvFunc | None = ..., add_param: bool = ...) -> Context: ... def __sql__(self, ctx: Context) -> Context: ... def parse(self, node: Node) -> tuple[str, tuple | None]: ... def query(self) -> tuple[str, tuple | None]: ... @@ -391,8 +392,8 @@ class CTE(_HashableSource, Source): def __sql__(self, ctx: Context) -> Context: ... class ColumnBase(Node): - _converter: _TConvFunc | None - def converter(self: Self, converter: _TConvFunc | None = ...) -> Self: ... + _converter: _ConvFunc | None + def converter(self: Self, converter: _ConvFunc | None = ...) -> Self: ... @overload def alias(self: Self, alias: None) -> Self: ... @overload @@ -456,13 +457,13 @@ class Column(ColumnBase): def __hash__(self) -> int: ... def __sql__(self, ctx: Context) -> Context: ... -class WrappedNode(ColumnBase, Generic[_TNode]): - node: _TNode +class WrappedNode(ColumnBase, Generic[_NodeT]): + node: _NodeT _coerce: bool - _converter: _TConvFunc | None - def __init__(self, node: _TNode): ... + _converter: _ConvFunc | None + def __init__(self, node: _NodeT): ... def is_alias(self) -> bool: ... - def unwrap(self) -> _TNode: ... + def unwrap(self) -> _NodeT: ... class EntityFactory: node: Node @@ -503,9 +504,9 @@ class BitwiseNegated(BitwiseMixin, WrappedNode): class Value(ColumnBase): value: object - converter: _TConvFunc | None + converter: _ConvFunc | None multi: bool - def __init__(self, value: object, converter: _TConvFunc | None = ..., unpack: bool = ...): ... + def __init__(self, value: object, converter: _ConvFunc | None = ..., unpack: bool = ...): ... def __sql__(self, ctx: Context) -> Context: ... def AsIs(value: object) -> Value: ... @@ -558,12 +559,12 @@ def Check(constraint: str) -> SQL: ... class Function(ColumnBase): name: str arguments: tuple - def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: _TConvFunc | None = ...): ... + def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: _ConvFunc | None = ...): ... def __getattr__(self, attr: str) -> Callable[..., Function]: ... # TODO (dargueta): `where` is an educated guess def filter(self: Self, where: Expression | None = ...) -> Self: ... def order_by(self: Self, *ordering: Field | Expression) -> Self: ... - def python_value(self: Self, func: _TConvFunc | None = ...) -> Self: ... + def python_value(self: Self, func: _ConvFunc | None = ...) -> Self: ... def over( self, partition_by: Sequence[Field] | Window | None = ..., @@ -729,7 +730,7 @@ class BaseQuery(Node): def dicts(self: Self, as_dict: bool = ...) -> Self: ... def tuples(self: Self, as_tuple: bool = ...) -> Self: ... def namedtuples(self: Self, as_namedtuple: bool = ...) -> Self: ... - def objects(self: Self, constructor: _TConvFunc | None = ...) -> Self: ... + def objects(self: Self, constructor: _ConvFunc | None = ...) -> Self: ... def __sql__(self, ctx: Context) -> Context: ... def sql(self) -> tuple[str, tuple | None]: ... def execute(self, database: Database | None = ...) -> CursorWrapper: ... @@ -909,7 +910,7 @@ class Index(Node): class ModelIndex(Index): def __init__( self, - model: type[_TModel], + model: type[_ModelT], fields: Iterable[Field | Node | str], unique: bool = ..., safe: bool = ..., @@ -983,7 +984,7 @@ class ConnectionContext(_callable_context_manager): def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... class Database(_callable_context_manager): - context_class: ClassVar[type[_TContextClass]] + context_class: ClassVar[type[_ContextClassT]] field_types: ClassVar[Mapping[str, str]] operations: ClassVar[Mapping[str, Any]] # TODO (dargueta) Verify k/v types param: ClassVar[str] @@ -1032,7 +1033,7 @@ class Database(_callable_context_manager): def execute_sql(self, sql: str, params: tuple | None = ..., commit: bool | _TSentinel = ...) -> __ICursor: ... def execute(self, query: Query, commit: bool | _TSentinel = ..., **context_options: object) -> __ICursor: ... def get_context_options(self) -> Mapping[str, object]: ... - def get_sql_context(self, **context_options: object) -> _TContextClass: ... + def get_sql_context(self, **context_options: object) -> _ContextClassT: ... def conflict_statement(self, on_conflict: OnConflict, query: Query) -> SQL | None: ... def conflict_update(self, oc: OnConflict, query: Query) -> NodeList: ... def last_insert_id(self, cursor: __ICursor, query_type: int | None = ...) -> int: ... @@ -1130,13 +1131,13 @@ class SqliteDatabase(Database): @wal_autocheckpoint.setter def wal_autocheckpoint(self, value: object) -> Any: ... def register_aggregate(self, klass: type[__IAggregate], name: str | None = ..., num_params: int = ...): ... - def aggregate(self, name: str | None = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... + def aggregate(self, name: str | None = ..., num_params: int = ...) -> Callable[[_ClassT], _ClassT]: ... def register_collation(self, fn: Callable, name: str | None = ...) -> None: ... - def collation(self, name: str | None = ...) -> Callable[[_TFunc], _TFunc]: ... + def collation(self, name: str | None = ...) -> Callable[[_FuncT], _FuncT]: ... def register_function(self, fn: Callable, name: str | None = ..., num_params: int = ...) -> int: ... - def func(self, name: str | None = ..., num_params: int = ...) -> Callable[[_TFunc], _TFunc]: ... + def func(self, name: str | None = ..., num_params: int = ...) -> Callable[[_FuncT], _FuncT]: ... def register_window_function(self, klass: type, name: str | None = ..., num_params: int = ...) -> None: ... - def window_function(self, name: str | None = ..., num_params: int = ...) -> Callable[[_TClass], _TClass]: ... + def window_function(self, name: str | None = ..., num_params: int = ...) -> Callable[[_ClassT], _ClassT]: ... def register_table_function(self, klass: type[__ITableFunction], name: str | None = ...) -> None: ... def table_function(self, name: str | None = ...) -> Callable[[type[__ITableFunction]], type[__ITableFunction]]: ... def unregister_aggregate(self, name: str) -> None: ... @@ -1298,8 +1299,8 @@ class ForeignKeyAccessor(FieldAccessor): @overload def __get__(self, instance: None, instance_type: type) -> Any: ... @overload - def __get__(self, instance: _TModel, instance_type: type[_TModel]) -> ForeignKeyField: ... - def __set__(self, instance: _TModel, obj: object) -> None: ... + def __get__(self, instance: _ModelT, instance_type: type[_ModelT]) -> ForeignKeyField: ... + def __set__(self, instance: _ModelT, obj: object) -> None: ... class NoQueryForeignKeyAccessor(ForeignKeyAccessor): def get_rel_instance(self, instance: Model) -> Any: ... @@ -1321,7 +1322,7 @@ class ObjectIdAccessor: @overload def __get__(self, instance: None, instance_type: type[Model]) -> ForeignKeyField: ... @overload - def __get__(self, instance: _TModel, instance_type: type[_TModel] = ...) -> Any: ... + def __get__(self, instance: _ModelT, instance_type: type[_ModelT] = ...) -> Any: ... def __set__(self, instance: Model, value: object) -> None: ... class Field(ColumnBase): @@ -1466,9 +1467,9 @@ class BigBitFieldData: class BigBitFieldAccessor(FieldAccessor): @overload - def __get__(self, instance: None, instance_type: type[_TModel]) -> Field: ... + def __get__(self, instance: None, instance_type: type[_ModelT]) -> Field: ... @overload - def __get__(self, instance: _TModel, instance_type: type[_TModel]) -> BigBitFieldData: ... + def __get__(self, instance: _ModelT, instance_type: type[_ModelT]) -> BigBitFieldData: ... def __set__(self, instance: Any, value: memoryview | bytearray | BigBitFieldData | str | bytes) -> None: ... class BigBitField(BlobField): @@ -1499,7 +1500,7 @@ class BinaryUUIDField(BlobField): @overload def python_value(self, value: bytearray | bytes | memoryview | uuid.UUID) -> uuid.UUID: ... -def format_date_time(value: str, formats: Iterable[str], post_process: _TConvFunc | None = ...) -> str: ... +def format_date_time(value: str, formats: Iterable[str], post_process: _ConvFunc | None = ...) -> str: ... @overload def simple_date_time(value: _T) -> _T: ... @@ -1598,8 +1599,8 @@ class BooleanField(Field): class BareField(Field): # If `adapt` was omitted from the constructor or None, this attribute won't exist. - adapt: _TConvFunc | None - def __init__(self, adapt: _TConvFunc | None = ..., *args: object, **kwargs: object): ... + adapt: _ConvFunc | None + def __init__(self, adapt: _ConvFunc | None = ..., *args: object, **kwargs: object): ... def ddl_datatype(self, ctx: Context) -> None: ... class ForeignKeyField(Field): @@ -1699,10 +1700,10 @@ class ManyToManyField(MetaField): def get_models(self) -> list[type[Model]]: ... def get_through_model(self) -> type[Model] | DeferredThroughModel: ... -class VirtualField(MetaField, Generic[_TField]): - field_class: type[_TField] - field_instance: _TField | None - def __init__(self, field_class: type[_TField] | None = ..., *args: object, **kwargs: object): ... +class VirtualField(MetaField, Generic[_FieldT]): + field_class: type[_FieldT] + field_instance: _FieldT | None + def __init__(self, field_class: type[_FieldT] | None = ..., *args: object, **kwargs: object): ... def db_value(self, value: object) -> Any: ... def python_value(self, value: object) -> Any: ... def bind(self, model: type[Model], name: str, set_attribute: bool = ...) -> None: ... @@ -1938,15 +1939,15 @@ class Model(Node, metaclass=ModelBase): def add_index(cls, *fields: str | SQL | Index, **kwargs: object) -> None: ... # "Provide a separate reference to a model in a query." -class ModelAlias(Node, Generic[_TModel]): - model: type[_TModel] +class ModelAlias(Node, Generic[_ModelT]): + model: type[_ModelT] alias: str | None - def __init__(self, model: type[_TModel], alias: str | None = ...): ... + def __init__(self, model: type[_ModelT], alias: str | None = ...): ... def __getattr__(self, attr: str) -> Any: ... def __setattr__(self, attr: str, value: object) -> NoReturn: ... def get_field_aliases(self) -> list[Field]: ... def select(self, *selection: Field) -> ModelSelect: ... - def __call__(self, **kwargs) -> _TModel: ... + def __call__(self, **kwargs) -> _ModelT: ... def __sql__(self, ctx: Context) -> Context: ... _TModelOrTable: TypeAlias = type[Model] | ModelAlias | Table @@ -1977,10 +1978,10 @@ class _ModelQueryHelper: default_row_type: ClassVar[int] def objects(self: Self, constructor: Callable[..., Any] | None = ...) -> Self: ... -class ModelRaw(_ModelQueryHelper, RawQuery, Generic[_TModel]): - model: type[_TModel] - def __init__(self, model: type[_TModel], sql: str, params: tuple, **kwargs: object): ... - def get(self) -> _TModel: ... +class ModelRaw(_ModelQueryHelper, RawQuery, Generic[_ModelT]): + model: type[_ModelT] + def __init__(self, model: type[_ModelT], sql: str, params: tuple, **kwargs: object): ... + def get(self) -> _ModelT: ... class BaseModelSelect(_ModelQueryHelper): def union_all(self, rhs: object) -> ModelCompoundSelectQuery: ... @@ -1996,13 +1997,13 @@ class BaseModelSelect(_ModelQueryHelper): def get(self, database: Database | None = ...) -> Any: ... def group_by(self: Self, *columns: type[Model] | Table | Field) -> Self: ... -class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery, Generic[_TModel]): - model: type[_TModel] - def __init__(self, model: type[_TModel], *args: object, **kwargs: object): ... +class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery, Generic[_ModelT]): + model: type[_ModelT] + def __init__(self, model: type[_ModelT], *args: object, **kwargs: object): ... -class ModelSelect(BaseModelSelect, Select, Generic[_TModel]): - model: type[_TModel] - def __init__(self, model: type[_TModel], fields_or_models: Iterable[_TFieldOrModel], is_default: bool = ...): ... +class ModelSelect(BaseModelSelect, Select, Generic[_ModelT]): + model: type[_ModelT] + def __init__(self, model: type[_ModelT], fields_or_models: Iterable[_TFieldOrModel], is_default: bool = ...): ... def clone(self: Self) -> Self: ... def select(self: Self, *fields_or_models: _TFieldOrModel) -> Self: ... # type: ignore def switch(self: Self, ctx: type[Model] | None = ...) -> Self: ... @@ -2057,57 +2058,57 @@ class ManyToManyQuery(ModelSelect): def remove(self, value: SelectQuery | type[Model] | Iterable[str]) -> int | None: ... def clear(self) -> int: ... -class BaseModelCursorWrapper(DictCursorWrapper, Generic[_TModel]): +class BaseModelCursorWrapper(DictCursorWrapper, Generic[_ModelT]): ncols: int columns: list[str] - converters: list[_TConvFunc] + converters: list[_ConvFunc] fields: list[Field] - model: type[_TModel] + model: type[_ModelT] select: Sequence[str] - def __init__(self, cursor: __ICursor, model: type[_TModel], columns: Sequence[str] | None): ... + def __init__(self, cursor: __ICursor, model: type[_ModelT], columns: Sequence[str] | None): ... def process_row(self, row: tuple) -> Mapping[str, object]: ... # type: ignore -class ModelDictCursorWrapper(BaseModelCursorWrapper[_TModel]): +class ModelDictCursorWrapper(BaseModelCursorWrapper[_ModelT]): def process_row(self, row: tuple) -> dict[str, Any]: ... -class ModelTupleCursorWrapper(ModelDictCursorWrapper[_TModel]): +class ModelTupleCursorWrapper(ModelDictCursorWrapper[_ModelT]): constructor: ClassVar[Callable[[Sequence[Any]], tuple]] def process_row(self, row: tuple) -> tuple: ... # type: ignore -class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper[_TModel]): ... +class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper[_ModelT]): ... -class ModelObjectCursorWrapper(ModelDictCursorWrapper[_TModel]): - constructor: type[_TModel] | Callable[[Any], _TModel] +class ModelObjectCursorWrapper(ModelDictCursorWrapper[_ModelT]): + constructor: type[_ModelT] | Callable[[Any], _ModelT] is_model: bool # TODO (dargueta): `select` is some kind of Sequence def __init__( self, cursor: __ICursor, - model: _TModel, + model: _ModelT, select: Sequence[Incomplete], - constructor: type[_TModel] | Callable[[Any], _TModel], + constructor: type[_ModelT] | Callable[[Any], _ModelT], ): ... - def process_row(self, row: tuple) -> _TModel: ... # type: ignore + def process_row(self, row: tuple) -> _ModelT: ... # type: ignore -class ModelCursorWrapper(BaseModelCursorWrapper[_TModel]): +class ModelCursorWrapper(BaseModelCursorWrapper[_ModelT]): # TODO (dargueta) -- Iterable[Union[Join, ...]] from_list: Iterable[Incomplete] - # TODO (dargueta) -- Mapping[, tuple[?, ?, Callable[..., _TModel], int?]] - joins: Mapping[Hashable, tuple[object, object, Callable[..., _TModel], int]] - key_to_constructor: dict[type[_TModel], Callable[..., _TModel]] + # TODO (dargueta) -- Mapping[, tuple[?, ?, Callable[..., _ModelT], int?]] + joins: Mapping[Hashable, tuple[object, object, Callable[..., _ModelT], int]] + key_to_constructor: dict[type[_ModelT], Callable[..., _ModelT]] src_is_dest: dict[type[Model], bool] src_to_dest: list[tuple] # TODO -- tuple[, join_type[1], join_type[0], bool, join_type[3]] column_keys: list[Incomplete] def __init__( self, cursor: __ICursor, - model: type[_TModel], + model: type[_ModelT], select, from_list: Iterable[object], - joins: Mapping[Hashable, tuple[object, object, Callable[..., _TModel], int]], + joins: Mapping[Hashable, tuple[object, object, Callable[..., _ModelT], int]], ): ... def initialize(self) -> None: ... - def process_row(self, row: tuple) -> _TModel: ... # type: ignore + def process_row(self, row: tuple) -> _ModelT: ... # type: ignore class __PrefetchQuery(NamedTuple): query: Query # TODO (dargueta): Verify From fda102c4413e93e0eca73188b46068aa884f794f Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 15 Jun 2022 18:32:15 +0100 Subject: [PATCH 17/22] Add some missing type parameters, use some more `Incomplete` where I'm not sure --- stubs/peewee/peewee.pyi | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 192495995fcf..fd7de4bc54be 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -102,9 +102,9 @@ __all__ = [ class __ICursor(Protocol): description: tuple[str, Any, Any, Any, Any, Any, Any] rowcount: int - def fetchone(self) -> tuple | None: ... - def fetchmany(self, size: int = ...) -> Iterable[tuple]: ... - def fetchall(self) -> Iterable[tuple]: ... + def fetchone(self) -> tuple[Incomplete, ...] | None: ... + def fetchmany(self, size: int = ...) -> Iterable[tuple[Incomplete, ...]]: ... + def fetchall(self) -> Iterable[tuple[Incomplete, ...]]: ... class __IConnection(Protocol): def cursor(self) -> __ICursor: ... @@ -122,14 +122,14 @@ class __ITableFunction(Protocol): name: str print_tracebacks: bool def initialize(self, **parameters: object) -> None: ... - def iterate(self, idx: int) -> tuple: ... + def iterate(self, idx: int) -> tuple[Incomplete, ...]: ... @classmethod def register(cls, conn: __IConnection) -> None: ... def _sqlite_date_part(lookup_type: str, datetime_string: str) -> str | None: ... def _sqlite_date_trunc(lookup_type: str, datetime_string: str) -> str | None: ... -class attrdict(dict): +class attrdict(dict[str, object]): def __getattr__(self, attr: str) -> Any: ... def __setattr__(self, attr: str, value: object) -> None: ... def __iadd__(self: Self, rhs: Mapping[str, object]) -> Self: ... @@ -161,8 +161,8 @@ CSQ_PARENTHESES_NEVER: int CSQ_PARENTHESES_ALWAYS: int CSQ_PARENTHESES_UNNESTED: int -SNAKE_CASE_STEP1: Pattern -SNAKE_CASE_STEP2: Pattern +SNAKE_CASE_STEP1: Pattern[str] +SNAKE_CASE_STEP2: Pattern[str] MODEL_BASE: str @@ -242,8 +242,8 @@ class Context: def literal(self: Self, keyword: str) -> Self: ... def value(self, value: object, converter: _ConvFunc | None = ..., add_param: bool = ...) -> Context: ... def __sql__(self, ctx: Context) -> Context: ... - def parse(self, node: Node) -> tuple[str, tuple | None]: ... - def query(self) -> tuple[str, tuple | None]: ... + def parse(self, node: Node) -> tuple[str, tuple[Incomplete, ...] | None]: ... + def query(self) -> tuple[str, tuple[Incomplete, ...] | None]: ... def query_to_string(query: Node) -> str: ... From 45fb0c2f4f2eda3a1eee40b27c9d29f68a4908d6 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 15 Jun 2022 19:21:18 +0100 Subject: [PATCH 18/22] Fix a few mypy complaints, falling back to `Incomplete` when I'm unsure --- stubs/peewee/peewee.pyi | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index fd7de4bc54be..f69f4beb8ae2 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -524,7 +524,7 @@ class Ordering(WrappedNode): def __sql__(self, ctx: Context) -> Context: ... class _SupportsSQLOrdering(Protocol): - def __call__(node: Node, collation: str | None = ..., nulls: str | None = ...) -> Ordering: ... + def __call__(self, node: Node, collation: str | None = ..., nulls: str | None = ...) -> Ordering: ... def Asc(node: Node, collation: str | None = ..., nulls: str | None = ...) -> Ordering: ... def Desc(node: Node, collation: str | None = ..., nulls: str | None = ...) -> Ordering: ... @@ -984,7 +984,7 @@ class ConnectionContext(_callable_context_manager): def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: ... class Database(_callable_context_manager): - context_class: ClassVar[type[_ContextClassT]] + context_class: ClassVar[type[Incomplete]] field_types: ClassVar[Mapping[str, str]] operations: ClassVar[Mapping[str, Any]] # TODO (dargueta) Verify k/v types param: ClassVar[str] @@ -1502,6 +1502,8 @@ class BinaryUUIDField(BlobField): def format_date_time(value: str, formats: Iterable[str], post_process: _ConvFunc | None = ...) -> str: ... @overload +def simple_date_time(value: str) -> datetime.datetime: ... +@overload def simple_date_time(value: _T) -> _T: ... class _BaseFormattedField(Field): @@ -1723,8 +1725,8 @@ class CompositeKey(MetaField): @overload def __get__(self, instance: _T, instance_type: type[_T]) -> tuple: ... def __set__(self, instance: Model, value: list | tuple) -> None: ... - def __eq__(self, other: Expression) -> Expression: ... - def __ne__(self, other: Expression) -> Expression: ... + def __eq__(self, other): ... + def __ne__(self, other): ... def __hash__(self) -> int: ... def __sql__(self, ctx: Context) -> Context: ... def bind(self, model: type[Model], name: str, set_attribute: bool = ...) -> None: ... @@ -1790,7 +1792,7 @@ class Metadata: table_name: str | None = ..., indexes: Iterable[str | Sequence[str]] | None = ..., primary_key: Literal[False] | Field | CompositeKey | None = ..., - constraints: Iterable[Check | SQL] | None = ..., + constraints: Iterable[Callable[[str], SQL] | SQL] | None = ..., schema: str | None = ..., only_save_dirty: bool = ..., depends_on: Sequence[type[Model]] | None = ..., @@ -1950,7 +1952,7 @@ class ModelAlias(Node, Generic[_ModelT]): def __call__(self, **kwargs) -> _ModelT: ... def __sql__(self, ctx: Context) -> Context: ... -_TModelOrTable: TypeAlias = type[Model] | ModelAlias | Table +_TModelOrTable: TypeAlias = Union[type[Model], ModelAlias, Table] _TSubquery: TypeAlias = Union[tuple[Query, type[Model]], type[Model], ModelAlias] _TFieldOrModel: TypeAlias = _TModelOrTable | Field @@ -1966,10 +1968,7 @@ class FieldAlias(Field): def adapt(self, value: object) -> Any: ... def python_value(self, value: object) -> Any: ... def db_value(self, value: object) -> Any: ... - @overload - def __getattr__(self, attr: Literal["model"]) -> Node: ... - @overload - def __getattr__(self, attr: str) -> Any: ... + def __getattr__(self, attr: str) -> Incomplete: ... def __sql__(self, ctx: Context) -> Context: ... def sort_models(models: Iterable[type[Model]]) -> list[type[Model]]: ... From 3e3165349b516d13576c0159cb1e5952d1b6f765 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Sun, 19 Jun 2022 20:35:47 +0100 Subject: [PATCH 19/22] Mark more things as incomplete --- stubs/peewee/peewee.pyi | 135 ++++++++++++++++++++-------------------- 1 file changed, 69 insertions(+), 66 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index f69f4beb8ae2..6636c2a8f2a6 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -99,6 +99,10 @@ __all__ = [ "Window", ] +AnyField: Incomplete +chunked: Incomplete +Tuple: Incomplete + class __ICursor(Protocol): description: tuple[str, Any, Any, Any, Any, Any, Any] rowcount: int @@ -405,7 +409,7 @@ class ColumnBase(Node): def desc(self, collation: str | None = ..., nulls: str | None = ...) -> _SupportsSQLOrdering: ... __neg__ = desc # TODO (dargueta): This always returns Negated but subclasses can return something else - def __invert__(self) -> WrappedNode: ... + def __invert__(self) -> WrappedNode[Incomplete]: ... def __and__(self, other: object) -> Expression: ... def __or__(self, other: object) -> Expression: ... def __add__(self, other: object) -> Expression: ... @@ -476,7 +480,7 @@ class _DynamicEntity: @overload def __get__(self, instance: _T, instance_type: type[_T]) -> EntityFactory: ... -class Alias(WrappedNode): +class Alias(WrappedNode[Incomplete]): c: ClassVar[_DynamicEntity] def __init__(self, node: Node, alias: str): ... def __hash__(self) -> int: ... @@ -488,7 +492,7 @@ class Alias(WrappedNode): def is_alias(self) -> Literal[True]: ... def __sql__(self, ctx: Context) -> Context: ... -class Negated(WrappedNode): +class Negated(WrappedNode[Incomplete]): def __invert__(self) -> Node: ... def __sql__(self, ctx: Context) -> Context: ... @@ -498,7 +502,7 @@ class BitwiseMixin: def __sub__(self, other: object) -> Expression: ... def __invert__(self) -> BitwiseNegated: ... -class BitwiseNegated(BitwiseMixin, WrappedNode): +class BitwiseNegated(BitwiseMixin, WrappedNode[Incomplete]): def __invert__(self) -> Node: ... def __sql__(self, ctx: Context) -> Context: ... @@ -511,11 +515,11 @@ class Value(ColumnBase): def AsIs(value: object) -> Value: ... -class Cast(WrappedNode): +class Cast(WrappedNode[Incomplete]): def __init__(self, node: Node, cast: str): ... def __sql__(self, ctx: Context) -> Context: ... -class Ordering(WrappedNode): +class Ordering(WrappedNode[Incomplete]): direction: str collation: str | None nulls: str | None @@ -558,8 +562,8 @@ def Check(constraint: str) -> SQL: ... class Function(ColumnBase): name: str - arguments: tuple - def __init__(self, name: str, arguments: tuple, coerce: bool = ..., python_value: _ConvFunc | None = ...): ... + arguments: tuple[Incomplete, ...] + def __init__(self, name: str, arguments: tuple[Incomplete, ...], coerce: bool = ..., python_value: _ConvFunc | None = ...): ... def __getattr__(self, attr: str) -> Callable[..., Function]: ... # TODO (dargueta): `where` is an educated guess def filter(self: Self, where: Expression | None = ...) -> Self: ... @@ -678,7 +682,7 @@ class DQ(ColumnBase): def __invert__(self: Self) -> Self: ... def clone(self: Self) -> Self: ... -class QualifiedNames(WrappedNode): +class QualifiedNames(WrappedNode[Incomplete]): def __sql__(self, ctx: Context) -> Context: ... @overload @@ -732,10 +736,9 @@ class BaseQuery(Node): def namedtuples(self: Self, as_namedtuple: bool = ...) -> Self: ... def objects(self: Self, constructor: _ConvFunc | None = ...) -> Self: ... def __sql__(self, ctx: Context) -> Context: ... - def sql(self) -> tuple[str, tuple | None]: ... - def execute(self, database: Database | None = ...) -> CursorWrapper: ... - # TODO (dargueta): `Any` is too loose; list types of the cursor wrappers - def iterator(self, database: Database | None = ...) -> Iterator[Any]: ... + def sql(self) -> tuple[str, tuple[Incomplete, ...] | None]: ... + def execute(self, database: Database | None = ...) -> CursorWrapper[Incomplete]: ... + def iterator(self, database: Database | None = ...) -> Iterator[Incomplete]: ... def __iter__(self) -> Iterator[Any]: ... @overload def __getitem__(self, value: int) -> Any: ... @@ -745,7 +748,7 @@ class BaseQuery(Node): class RawQuery(BaseQuery): # TODO (dargueta): `tuple` may not be 100% accurate, maybe Sequence[object]? - def __init__(self, sql: str | None = ..., params: tuple | None = ..., **kwargs: object): ... + def __init__(self, sql: str | None = ..., params: tuple[Incomplete, ...] | None = ..., **kwargs: object): ... def __sql__(self, ctx: Context) -> Context: ... class Query(BaseQuery): @@ -798,7 +801,7 @@ class SelectBase(_HashableSource, Source, SelectQuery): @overload def scalar(self, database: Database | None = ..., as_tuple: Literal[False] = ...) -> object: ... @overload - def scalar(self, database: Database | None = ..., as_tuple: Literal[True] = ...) -> tuple: ... + def scalar(self, database: Database | None = ..., as_tuple: Literal[True] = ...) -> tuple[Incomplete, ...]: ... def count(self, database: Database | None = ..., clear_limit: bool = ...) -> int: ... def exists(self, database: Database | None = ...) -> bool: ... def get(self, database: Database | None = ...) -> object: ... @@ -856,7 +859,7 @@ class _WriteQuery(Query): def __init__(self, table: Table, returning: Iterable[type[Model] | Field] | None = ..., **kwargs: object): ... def returning(self: Self, *returning: type[Model] | Field) -> Self: ... def apply_returning(self, ctx: Context) -> Context: ... - def execute_returning(self, database: Database) -> CursorWrapper: ... + def execute_returning(self, database: Database) -> CursorWrapper[Incomplete]: ... def handle_result(self, database: Database, cursor: __ICursor) -> int | __ICursor: ... def __sql__(self, ctx: Context) -> Context: ... @@ -1030,10 +1033,10 @@ class Database(_callable_context_manager): def is_connection_usable(self) -> bool: ... def connection(self) -> __IConnection: ... def cursor(self, commit: bool | None = ...) -> __ICursor: ... - def execute_sql(self, sql: str, params: tuple | None = ..., commit: bool | _TSentinel = ...) -> __ICursor: ... + def execute_sql(self, sql: str, params: tuple[Incomplete, ...] | None = ..., commit: bool | _TSentinel = ...) -> __ICursor: ... def execute(self, query: Query, commit: bool | _TSentinel = ..., **context_options: object) -> __ICursor: ... def get_context_options(self) -> Mapping[str, object]: ... - def get_sql_context(self, **context_options: object) -> _ContextClassT: ... + def get_sql_context(self, **context_options: object): ... def conflict_statement(self, on_conflict: OnConflict, query: Query) -> SQL | None: ... def conflict_update(self, oc: OnConflict, query: Query) -> NodeList: ... def last_insert_id(self, cursor: __ICursor, query_type: int | None = ...) -> int: ... @@ -1132,9 +1135,9 @@ class SqliteDatabase(Database): def wal_autocheckpoint(self, value: object) -> Any: ... def register_aggregate(self, klass: type[__IAggregate], name: str | None = ..., num_params: int = ...): ... def aggregate(self, name: str | None = ..., num_params: int = ...) -> Callable[[_ClassT], _ClassT]: ... - def register_collation(self, fn: Callable, name: str | None = ...) -> None: ... + def register_collation(self, fn: Callable[..., Incomplete], name: str | None = ...) -> None: ... def collation(self, name: str | None = ...) -> Callable[[_FuncT], _FuncT]: ... - def register_function(self, fn: Callable, name: str | None = ..., num_params: int = ...) -> int: ... + def register_function(self, fn: Callable[Incomplete, ...], name: str | None = ..., num_params: int = ...) -> int: ... def func(self, name: str | None = ..., num_params: int = ...) -> Callable[[_FuncT], _FuncT]: ... def register_window_function(self, klass: type, name: str | None = ..., num_params: int = ...) -> None: ... def window_function(self, name: str | None = ..., num_params: int = ...) -> Callable[[_ClassT], _ClassT]: ... @@ -1256,19 +1259,19 @@ class CursorWrapper(Generic[_T]): def __len__(self) -> int: ... def initialize(self) -> None: ... def iterate(self, cache: bool = ...) -> _T: ... - def process_row(self, row: tuple) -> _T: ... + def process_row(self, row: tuple[Incomplete, ...]) -> _T: ... def iterator(self) -> Iterator[_T]: ... def fill_cache(self, n: int = ...) -> None: ... class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ... # FIXME (dargueta): Somehow figure out how to make this a NamedTuple sorta deal -class NamedTupleCursorWrapper(CursorWrapper[tuple]): - tuple_class: type[tuple] +class NamedTupleCursorWrapper(CursorWrapper[tuple[Incomplete, ...]]): + tuple_class: type[tuple[Incomplete, ...]] -class ObjectCursorWrapper(DictCursorWrapper[_T]): +class ObjectCursorWrapper(DictCursorWrapper, Generic[_T]): constructor: Callable[..., _T] - def __init__(self, cursor: __ICursor, constructor: Callable[..., _T]): ... + def __init__(self, cursor: __ICursor, constructor: Callable[..., _T]) -> None: ... def process_row(self, row: tuple) -> _T: ... # type: ignore class ResultIterator(Generic[_T]): @@ -1300,7 +1303,7 @@ class ForeignKeyAccessor(FieldAccessor): def __get__(self, instance: None, instance_type: type) -> Any: ... @overload def __get__(self, instance: _ModelT, instance_type: type[_ModelT]) -> ForeignKeyField: ... - def __set__(self, instance: _ModelT, obj: object) -> None: ... + def __set__(self, instance: Model, obj: object) -> None: ... class NoQueryForeignKeyAccessor(ForeignKeyAccessor): def get_rel_instance(self, instance: Model) -> Any: ... @@ -1478,7 +1481,7 @@ class BigBitField(BlobField): @overload def db_value(self, value: None) -> None: ... @overload - def db_value(self, value: _T) -> bytes: ... + def db_value(self, value) -> bytes: ... class UUIDField(Field): @overload @@ -1678,10 +1681,10 @@ class ManyToManyFieldAccessor(FieldAccessor): dest_fk: ForeignKeyField def __init__(self, model: type[Model], field: ForeignKeyField, name: str): ... @overload - def __get__(self, instance: None, instance_type: type[_T] = ..., force_query: bool = ...) -> Field: ... + def __get__(self, instance: None, instance_type: type[Incomplete] = ..., force_query: bool = ...) -> Field: ... @overload - def __get__(self, instance: _T, instance_type: type[_T] = ..., force_query: bool = ...) -> list[str] | ManyToManyQuery: ... - def __set__(self, instance: _T, value) -> None: ... + def __get__(self, instance: _T, instance_type: type[Incomplete] = ..., force_query: bool = ...) -> list[str] | ManyToManyQuery: ... + def __set__(self, instance, value) -> None: ... class ManyToManyField(MetaField): accessor_class: ClassVar[type[ManyToManyFieldAccessor]] @@ -1705,7 +1708,7 @@ class ManyToManyField(MetaField): class VirtualField(MetaField, Generic[_FieldT]): field_class: type[_FieldT] field_instance: _FieldT | None - def __init__(self, field_class: type[_FieldT] | None = ..., *args: object, **kwargs: object): ... + def __init__(self, field_class: type[Incomplete] | None = ..., *args: object, **kwargs: object): ... def db_value(self, value: object) -> Any: ... def python_value(self, value: object) -> Any: ... def bind(self, model: type[Model], name: str, set_attribute: bool = ...) -> None: ... @@ -1723,8 +1726,8 @@ class CompositeKey(MetaField): @overload def __get__(self, instance: None, instance_type: type) -> CompositeKey: ... @overload - def __get__(self, instance: _T, instance_type: type[_T]) -> tuple: ... - def __set__(self, instance: Model, value: list | tuple) -> None: ... + def __get__(self, instance: _T, instance_type: type[_T]) -> tuple[Incomplete, ...]: ... + def __set__(self, instance: Model, value: list[Incomplete] | tuple[Incomplete, ...]) -> None: ... def __eq__(self, other): ... def __ne__(self, other): ... def __hash__(self) -> int: ... @@ -1813,7 +1816,7 @@ class Metadata: def remove_ref(self, field: ForeignKeyField) -> None: ... def add_manytomany(self, field: ManyToManyField) -> None: ... def remove_manytomany(self, field: ManyToManyField) -> None: ... - def get_rel_for_model(self, model: type[Model] | ModelAlias) -> tuple[list[ForeignKeyField], list[type[Model]]]: ... + def get_rel_for_model(self, model: type[Model] | ModelAlias[Incomplete]) -> tuple[list[ForeignKeyField], list[type[Model]]]: ... def add_field(self, field_name: str, field: Field, set_attribute: bool = ...) -> None: ... def remove_field(self, field_name: str) -> None: ... def set_primary_key(self, name: str, field: Field | CompositeKey) -> None: ... @@ -1861,9 +1864,9 @@ class Model(Node, metaclass=ModelBase): @classmethod def validate_model(cls) -> None: ... @classmethod - def alias(cls, alias: str | None = ...) -> ModelAlias: ... + def alias(cls, alias: str | None = ...) -> ModelAlias[Incomplete]: ... @classmethod - def select(cls, *fields: Field) -> ModelSelect: ... + def select(cls, *fields: Field) -> ModelSelect[Incomplete]: ... @classmethod def update(cls, __data: Iterable[str | Field] | None = ..., **update: Any) -> ModelUpdate: ... @classmethod @@ -1873,15 +1876,15 @@ class Model(Node, metaclass=ModelBase): def insert_many(cls, rows: Iterable[Mapping[str, object]], fields: None) -> ModelInsert: ... @overload @classmethod - def insert_many(cls, rows: Iterable[tuple], fields: Sequence[Field]) -> ModelInsert: ... + def insert_many(cls, rows: Iterable[tuple[Incomplete, ...]], fields: Sequence[Field]) -> ModelInsert: ... @classmethod def insert_from(cls, query: SelectQuery, fields: Iterable[Field | str]) -> ModelInsert: ... @classmethod def replace(cls, __data: Iterable[str | Field] | None = ..., **insert: object) -> OnConflict: ... @classmethod - def replace_many(cls, rows: Iterable[tuple], fields: Sequence[Field] | None = ...) -> OnConflict: ... + def replace_many(cls, rows: Iterable[tuple[Incomplete, ...]], fields: Sequence[Field] | None = ...) -> OnConflict: ... @classmethod - def raw(cls, sql: str, *params: object) -> ModelRaw: ... + def raw(cls, sql: str, *params: object) -> ModelRaw[Incomplete]: ... @classmethod def delete(cls) -> ModelDelete: ... @classmethod @@ -1895,17 +1898,17 @@ class Model(Node, metaclass=ModelBase): @classmethod def noop(cls) -> NoopModelSelect: ... @classmethod - def get(cls, *query: object, **filters: object) -> ModelSelect: ... + def get(cls, *query: object, **filters: object) -> ModelSelect[Incomplete]: ... @classmethod - def get_or_none(cls, *query: object, **filters: object) -> ModelSelect | None: ... + def get_or_none(cls, *query: object, **filters: object) -> ModelSelect[Incomplete] | None: ... @classmethod - def get_by_id(cls, pk: object) -> ModelSelect: ... + def get_by_id(cls, pk: object) -> ModelSelect[Incomplete]: ... # TODO (dargueta) I'm 99% sure of return value for this one @classmethod - def set_by_id(cls, key, value) -> CursorWrapper: ... + def set_by_id(cls, key, value) -> CursorWrapper[Incomplete]: ... # TODO (dargueta) I'm also not 100% about this one's return value. @classmethod - def delete_by_id(cls, pk: object) -> CursorWrapper: ... + def delete_by_id(cls, pk: object) -> CursorWrapper[Incomplete]: ... @classmethod def get_or_create(cls, *, defaults: Mapping[str, object] = ..., **kwargs: object) -> tuple[Any, bool]: ... @classmethod @@ -1948,12 +1951,12 @@ class ModelAlias(Node, Generic[_ModelT]): def __getattr__(self, attr: str) -> Any: ... def __setattr__(self, attr: str, value: object) -> NoReturn: ... def get_field_aliases(self) -> list[Field]: ... - def select(self, *selection: Field) -> ModelSelect: ... + def select(self, *selection: Field) -> ModelSelect[Incomplete]: ... def __call__(self, **kwargs) -> _ModelT: ... def __sql__(self, ctx: Context) -> Context: ... -_TModelOrTable: TypeAlias = Union[type[Model], ModelAlias, Table] -_TSubquery: TypeAlias = Union[tuple[Query, type[Model]], type[Model], ModelAlias] +_TModelOrTable: TypeAlias = Union[type[Model], ModelAlias[Incomplete], Table] +_TSubquery: TypeAlias = Union[tuple[Query, type[Model]], type[Model], ModelAlias[Incomplete]] _TFieldOrModel: TypeAlias = _TModelOrTable | Field class FieldAlias(Field): @@ -1963,7 +1966,7 @@ class FieldAlias(Field): # TODO (dargueta): Making an educated guess about `source`; might be `Node` def __init__(self, source: MetaField, field: Field): ... @classmethod - def create(cls, source: ModelAlias, field: str) -> FieldAlias: ... + def create(cls, source: ModelAlias[Incomplete], field: str) -> FieldAlias: ... def clone(self) -> FieldAlias: ... def adapt(self, value: object) -> Any: ... def python_value(self, value: object) -> Any: ... @@ -1979,17 +1982,17 @@ class _ModelQueryHelper: class ModelRaw(_ModelQueryHelper, RawQuery, Generic[_ModelT]): model: type[_ModelT] - def __init__(self, model: type[_ModelT], sql: str, params: tuple, **kwargs: object): ... + def __init__(self, model: type[_ModelT], sql: str, params: tuple[Incomplete, ...], **kwargs: object): ... def get(self) -> _ModelT: ... class BaseModelSelect(_ModelQueryHelper): - def union_all(self, rhs: object) -> ModelCompoundSelectQuery: ... + def union_all(self, rhs: object) -> ModelCompoundSelectQuery[Incomplete]: ... __add__ = union_all - def union(self, rhs: object) -> ModelCompoundSelectQuery: ... + def union(self, rhs: object) -> ModelCompoundSelectQuery[Incomplete]: ... __or__ = union - def intersect(self, rhs: object) -> ModelCompoundSelectQuery: ... + def intersect(self, rhs: object) -> ModelCompoundSelectQuery[Incomplete]: ... __and__ = intersect - def except_(self, rhs: object) -> ModelCompoundSelectQuery: ... + def except_(self, rhs: object) -> ModelCompoundSelectQuery[Incomplete]: ... __sub__ = except_ def __iter__(self) -> Iterator[Any]: ... def prefetch(self, *subqueries: _TSubquery) -> list[Any]: ... @@ -2008,30 +2011,30 @@ class ModelSelect(BaseModelSelect, Select, Generic[_ModelT]): def switch(self: Self, ctx: type[Model] | None = ...) -> Self: ... def join( # type: ignore self: Self, - dest: type[Model] | Table | ModelAlias | ModelSelect, + dest: type[Model] | Table | ModelAlias[Incomplete] | ModelSelect[Incomplete], join_type: int = ..., on: Column | Expression | Field | None = ..., - src: type[Model] | Table | ModelAlias | ModelSelect | None = ..., + src: type[Model] | Table | ModelAlias[Incomplete] | ModelSelect[Incomplete] | None = ..., attr: str | None = ..., ) -> Self: ... def join_from( self, - src: type[Model] | Table | ModelAlias | ModelSelect, - dest: type[Model] | Table | ModelAlias | ModelSelect, + src: type[Model] | Table | ModelAlias[Incomplete] | ModelSelect[Incomplete], + dest: type[Model] | Table | ModelAlias[Incomplete] | ModelSelect[Incomplete], join_type: int = ..., on: Column | Expression | Field | None = ..., attr: str | None = ..., - ) -> ModelSelect: ... + ) -> ModelSelect[Incomplete]: ... def ensure_join( self, lm: type[Model], rm: type[Model], on: Column | Expression | Field | None = ..., **join_kwargs: Any - ) -> ModelSelect: ... + ) -> ModelSelect[Incomplete]: ... # TODO (dargueta): 85% sure about the return value def convert_dict_to_node(self, qdict: Mapping[str, object]) -> tuple[list[Expression], list[Field]]: ... - def filter(self, *args: Node, **kwargs: object) -> ModelSelect: ... + def filter(self, *args: Node, **kwargs: object) -> ModelSelect[Incomplete]: ... def create_table(self, name: str, safe: bool = ..., **meta: object) -> None: ... def __sql_selection__(self, ctx: Context, is_subquery: bool = ...) -> Context: ... -class NoopModelSelect(ModelSelect): +class NoopModelSelect(ModelSelect[Incomplete]): def __sql__(self, ctx: Context) -> Context: ... class _ModelWriteQueryHelper(_ModelQueryHelper): @@ -2049,7 +2052,7 @@ class ModelInsert(_ModelWriteQueryHelper, Insert): class ModelDelete(_ModelWriteQueryHelper, Delete): ... -class ManyToManyQuery(ModelSelect): +class ManyToManyQuery(ModelSelect[Incomplete]): def __init__( self, instance: Model, accessor: ManyToManyFieldAccessor, rel: _TFieldOrModel, *args: object, **kwargs: object ): ... @@ -2068,11 +2071,11 @@ class BaseModelCursorWrapper(DictCursorWrapper, Generic[_ModelT]): def process_row(self, row: tuple) -> Mapping[str, object]: ... # type: ignore class ModelDictCursorWrapper(BaseModelCursorWrapper[_ModelT]): - def process_row(self, row: tuple) -> dict[str, Any]: ... + def process_row(self, row: tuple[Incomplete, ...]) -> dict[str, Any]: ... class ModelTupleCursorWrapper(ModelDictCursorWrapper[_ModelT]): - constructor: ClassVar[Callable[[Sequence[Any]], tuple]] - def process_row(self, row: tuple) -> tuple: ... # type: ignore + constructor: ClassVar[Callable[[Sequence[Any]], tuple[Incomplete, ...]]] + def process_row(self, row: tuple[Incomplete, ...]) -> tuple[Incomplete, ...]: ... # type: ignore class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper[_ModelT]): ... @@ -2096,7 +2099,7 @@ class ModelCursorWrapper(BaseModelCursorWrapper[_ModelT]): joins: Mapping[Hashable, tuple[object, object, Callable[..., _ModelT], int]] key_to_constructor: dict[type[_ModelT], Callable[..., _ModelT]] src_is_dest: dict[type[Model], bool] - src_to_dest: list[tuple] # TODO -- tuple[, join_type[1], join_type[0], bool, join_type[3]] + src_to_dest: list[tuple[Incomplete, Incomplete, Incomplete, bool, Incomplete]] # TODO -- tuple[, join_type[1], join_type[0], bool, join_type[3]] column_keys: list[Incomplete] def __init__( self, From 589ad8dc8f786e44c7bebe24180e5d55fd6a1d7c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 19 Jun 2022 19:37:12 +0000 Subject: [PATCH 20/22] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stubs/peewee/peewee.pyi | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 6636c2a8f2a6..f4a165bcc17a 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -563,7 +563,9 @@ def Check(constraint: str) -> SQL: ... class Function(ColumnBase): name: str arguments: tuple[Incomplete, ...] - def __init__(self, name: str, arguments: tuple[Incomplete, ...], coerce: bool = ..., python_value: _ConvFunc | None = ...): ... + def __init__( + self, name: str, arguments: tuple[Incomplete, ...], coerce: bool = ..., python_value: _ConvFunc | None = ... + ): ... def __getattr__(self, attr: str) -> Callable[..., Function]: ... # TODO (dargueta): `where` is an educated guess def filter(self: Self, where: Expression | None = ...) -> Self: ... @@ -1033,7 +1035,9 @@ class Database(_callable_context_manager): def is_connection_usable(self) -> bool: ... def connection(self) -> __IConnection: ... def cursor(self, commit: bool | None = ...) -> __ICursor: ... - def execute_sql(self, sql: str, params: tuple[Incomplete, ...] | None = ..., commit: bool | _TSentinel = ...) -> __ICursor: ... + def execute_sql( + self, sql: str, params: tuple[Incomplete, ...] | None = ..., commit: bool | _TSentinel = ... + ) -> __ICursor: ... def execute(self, query: Query, commit: bool | _TSentinel = ..., **context_options: object) -> __ICursor: ... def get_context_options(self) -> Mapping[str, object]: ... def get_sql_context(self, **context_options: object): ... @@ -1683,7 +1687,9 @@ class ManyToManyFieldAccessor(FieldAccessor): @overload def __get__(self, instance: None, instance_type: type[Incomplete] = ..., force_query: bool = ...) -> Field: ... @overload - def __get__(self, instance: _T, instance_type: type[Incomplete] = ..., force_query: bool = ...) -> list[str] | ManyToManyQuery: ... + def __get__( + self, instance: _T, instance_type: type[Incomplete] = ..., force_query: bool = ... + ) -> list[str] | ManyToManyQuery: ... def __set__(self, instance, value) -> None: ... class ManyToManyField(MetaField): @@ -1816,7 +1822,9 @@ class Metadata: def remove_ref(self, field: ForeignKeyField) -> None: ... def add_manytomany(self, field: ManyToManyField) -> None: ... def remove_manytomany(self, field: ManyToManyField) -> None: ... - def get_rel_for_model(self, model: type[Model] | ModelAlias[Incomplete]) -> tuple[list[ForeignKeyField], list[type[Model]]]: ... + def get_rel_for_model( + self, model: type[Model] | ModelAlias[Incomplete] + ) -> tuple[list[ForeignKeyField], list[type[Model]]]: ... def add_field(self, field_name: str, field: Field, set_attribute: bool = ...) -> None: ... def remove_field(self, field_name: str) -> None: ... def set_primary_key(self, name: str, field: Field | CompositeKey) -> None: ... @@ -2099,7 +2107,9 @@ class ModelCursorWrapper(BaseModelCursorWrapper[_ModelT]): joins: Mapping[Hashable, tuple[object, object, Callable[..., _ModelT], int]] key_to_constructor: dict[type[_ModelT], Callable[..., _ModelT]] src_is_dest: dict[type[Model], bool] - src_to_dest: list[tuple[Incomplete, Incomplete, Incomplete, bool, Incomplete]] # TODO -- tuple[, join_type[1], join_type[0], bool, join_type[3]] + src_to_dest: list[ + tuple[Incomplete, Incomplete, Incomplete, bool, Incomplete] + ] # TODO -- tuple[, join_type[1], join_type[0], bool, join_type[3]] column_keys: list[Incomplete] def __init__( self, From a63e648420a2c697ab2fe0c8f3c26ee96d016eb2 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Sun, 19 Jun 2022 20:43:48 +0100 Subject: [PATCH 21/22] More fixes --- stubs/peewee/peewee.pyi | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 6636c2a8f2a6..20ab8e5bf31d 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -13,7 +13,6 @@ _T = TypeVar("_T") _ModelT = TypeVar("_ModelT", bound=Model) _FuncT = TypeVar("_FuncT", bound=Callable[..., Any]) _ClassT = TypeVar("_ClassT", bound=type) -_ContextClassT = TypeVar("_ContextClassT", bound=Context) _FieldT = TypeVar("_FieldT", bound=Field) _NodeT = TypeVar("_NodeT", bound=Node) _P = ParamSpec("_P") @@ -836,8 +835,8 @@ class Select(SelectBase): def clone(self: Self) -> Self: ... # TODO (dargueta) `Field` might be wrong in this union def columns(self: Self, *columns: Column | Field, **kwargs: object) -> Self: ... - select = columns - def select_extend(self: Self, *columns) -> Self: ... + select = columns # type: ignore + def select_extend(self: Self, *columns: Incomplete) -> Self: ... # TODO (dargueta): Is `sources` right? def from_(self: Self, *sources: Source | type[Model]) -> Self: ... def join(self: Self, dest: type[Model], join_type: int = ..., on: Expression | None = ...) -> Self: ... @@ -1137,7 +1136,7 @@ class SqliteDatabase(Database): def aggregate(self, name: str | None = ..., num_params: int = ...) -> Callable[[_ClassT], _ClassT]: ... def register_collation(self, fn: Callable[..., Incomplete], name: str | None = ...) -> None: ... def collation(self, name: str | None = ...) -> Callable[[_FuncT], _FuncT]: ... - def register_function(self, fn: Callable[Incomplete, ...], name: str | None = ..., num_params: int = ...) -> int: ... + def register_function(self, fn: Callable[..., Incomplete], name: str | None = ..., num_params: int = ...) -> int: ... def func(self, name: str | None = ..., num_params: int = ...) -> Callable[[_FuncT], _FuncT]: ... def register_window_function(self, klass: type, name: str | None = ..., num_params: int = ...) -> None: ... def window_function(self, name: str | None = ..., num_params: int = ...) -> Callable[[_ClassT], _ClassT]: ... @@ -1683,7 +1682,7 @@ class ManyToManyFieldAccessor(FieldAccessor): @overload def __get__(self, instance: None, instance_type: type[Incomplete] = ..., force_query: bool = ...) -> Field: ... @overload - def __get__(self, instance: _T, instance_type: type[Incomplete] = ..., force_query: bool = ...) -> list[str] | ManyToManyQuery: ... + def __get__(self, instance: _T, instance_type: type[_T] = ..., force_query: bool = ...) -> list[str] | ManyToManyQuery: ... def __set__(self, instance, value) -> None: ... class ManyToManyField(MetaField): From 0e80802f1c3e14c4a8280b5b9c679f6b66812148 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 19 Jun 2022 19:45:49 +0000 Subject: [PATCH 22/22] [pre-commit.ci] auto fixes from pre-commit.com hooks --- stubs/peewee/peewee.pyi | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 6dde00a8c950..7425bfe3bf85 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -1686,9 +1686,7 @@ class ManyToManyFieldAccessor(FieldAccessor): @overload def __get__(self, instance: None, instance_type: type[Incomplete] = ..., force_query: bool = ...) -> Field: ... @overload - def __get__( - self, instance: _T, instance_type: type[_T] = ..., force_query: bool = ... - ) -> list[str] | ManyToManyQuery: ... + def __get__(self, instance: _T, instance_type: type[_T] = ..., force_query: bool = ...) -> list[str] | ManyToManyQuery: ... def __set__(self, instance, value) -> None: ... class ManyToManyField(MetaField):